In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.abspath(".."))

In [None]:
from minrl.algorithms import (
    rollout
)
from minrl.tasks.connections import ConnectionsDataset, connections_reward_func
from minrl.constants import GEMMA_3_1B
from transformers import AutoTokenizer
from vllm import LLM

tokenizer = AutoTokenizer.from_pretrained(GEMMA_3_1B)


dataset = ConnectionsDataset(split="eval", host="notebook")

In [None]:
vllm_model = LLM(
    model=GEMMA_3_1B,
    gpu_memory_utilization=0.5,
    max_model_len=1024,
    max_seq_len_to_capture=1024,
    enforce_eager=True,
    enable_prefix_caching=True,
)


In [None]:
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(GEMMA_3_1B, device_map="auto", dtype=torch.bfloat16)

In [None]:
tokenizer.eos_token_id

In [None]:
batch = dataset[:4]

conversations = [
    dataset.initial_conversation(sample, i)
    for i, sample in enumerate(batch)
]

episodes = rollout(
    1,
    512,
    tokenizer,
    4,
    1,
    conversations,
    batch,
    reward_function=connections_reward_func,
    vllm_model=vllm_model,
)

In [None]:
from minrl.algorithms import process_batch
from minrl.trainer import get_available_device

# Process the batch
logprobs, target_masks, batch_rewards_t, batch_entropy, n_target_tokens = (
    process_batch(
        model=model,
        episodes=episodes,
        tokenizer=tokenizer,
        pad_token_id=tokenizer.eos_token_id,
        device=torch.device("cuda"),
    )
)
