# Integration with PyTorchRL

In [1]:
import gymnasium as gym
from lm_human_preferences.env.rlhf_env import RLHFEnv
from lm_human_preferences.data.base import QueryData
from lm_human_preferences.lm.reward import RewardModel
from transformers import AutoTokenizer
import torch
import gymnasium as gym
from torchrl.envs import GymEnv

DEVICE = 'cpu'

def global_init_env():
    gym.envs.register(
        id='RLHFEnv-v0',
        entry_point='lm_human_preferences.env.rlhf_env:RLHFEnv',
        kwargs={
            'ref_model_name': 'openai-community/gpt2',
            'reward_model': RewardModel.from_pretrained('../models/reward_model'),
            'dataset': QueryData.from_openai_format(
                AutoTokenizer.from_pretrained('openai-community/gpt2'),
                '../data/descriptiveness_offline_5k'
            ),
            'kl_coef': 0.01,
            'max_generation': 64,
            'device': DEVICE,
            'seed': 42
        }
    )

global_init_env()
base_env = GymEnv("RLHFEnv-v0", categorical_action_encoding=True, device=DEVICE)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tensordict = base_env.reset()
print(tensordict)
tensordict_with_action = base_env.rand_action(tensordict)
print(tensordict_with_action)
step_tensordict = base_env.step(tensordict_with_action)
print(step_tensordict)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([54, 50257]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([54, 50257]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False

In [3]:
print(tensordict_with_action['action']) # YOU HAVE TO KNOW THAT THIS OUTPUT SPACE ITSELF RATHER THAN ID
print(tensordict_with_action['action'].argmax())

tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')
tensor(30212, device='cuda:0')


In [4]:
from torchrl.envs import step_mdp

# Move and replace all next: tensordict items into root keys
data = step_mdp(step_tensordict)
print(data)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        observation: Tensor(shape=torch.Size([55, 50257]), device=cuda:0, dtype=torch.int64, is_shared=True),
        terminated: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        truncated: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True)


# Rollout Test

In [5]:
base_env.rollout(max_steps=10)

LazyStackedTensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 50257]), device=cuda:0, dtype=torch.int64, is_shared=True),
        done: Tensor(shape=torch.Size([10, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        next: LazyStackedTensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                observation: Tensor(shape=torch.Size([10, -1, 50257]), device=cuda:0, dtype=torch.int64, is_shared=True),
                reward: Tensor(shape=torch.Size([10, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                terminated: Tensor(shape=torch.Size([10, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                truncated: Tensor(shape=torch.Size([10, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
            exclusive_fields={
            },
            batch_size=torch.Size([10]),
            device=cuda:0,
            is_shared=True

In [16]:
rollout = base_env.rollout(max_steps=10)
tokenizer = base_env.env.tokenizer
pad_token = tokenizer.pad_token_id
obs_rollout = rollout.get('observation', as_padded_tensor=True, padding_side='left', padding_value=pad_token)
obs_rollout.shape # steps, seq_len, vocab_size
for inx, step in enumerate(rollout):
    obs = step.get('observation', as_padded_tensor=True, padding_side='left', padding_value=pad_token)
    ids = obs.argmax(dim=1)
    print(f"STEP : {inx}")
    print(tokenizer.decode(ids))

STEP : 0
 Hun and Sweetie, the gray-haired ancient waitresses that owed this place during daylight hours shuffled around, bringing everyone sodas and water. And Sal, Pete's business partner and the daytime cook, stepped out with plates and plates of food. Everybody was celebrating.
STEP : 1
 Hun and Sweetie, the gray-haired ancient waitresses that owed this place during daylight hours shuffled around, bringing everyone sodas and water. And Sal, Pete's business partner and the daytime cook, stepped out with plates and plates of food. Everybody was celebrating.olic
STEP : 2
 Hun and Sweetie, the gray-haired ancient waitresses that owed this place during daylight hours shuffled around, bringing everyone sodas and water. And Sal, Pete's business partner and the daytime cook, stepped out with plates and plates of food. Everybody was celebrating.olicrf
STEP : 3
 Hun and Sweetie, the gray-haired ancient waitresses that owed this place during daylight hours shuffled around, bringing everyone s

In [23]:
from torchrl.envs.utils import check_env_specs

check_env_specs(base_env)

2025-05-01 15:49:01,082 [torchrl][INFO] check_env_specs succeeded!


In [22]:
print("observation_spec:", base_env.observation_spec)
print("reward_spec:", base_env.reward_spec)
print("input_spec:", base_env.input_spec)
print("action_spec (as defined by input_spec):", base_env.action_spec)

observation_spec: Composite(
    observation: OneHot(
        shape=torch.Size([-1, 50257]),
        space=CategoricalBox(n=50257),
        device=cuda:0,
        dtype=torch.int64,
        domain=discrete),
    device=cuda:0,
    shape=torch.Size([]))
reward_spec: UnboundedContinuous(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float32, contiguous=True)),
    device=cuda:0,
    dtype=torch.float32,
    domain=continuous)
input_spec: Composite(
    full_state_spec: Composite(
    ,
        device=cuda:0,
        shape=torch.Size([])),
    full_action_spec: Composite(
        action: OneHot(
            shape=torch.Size([50257]),
            space=CategoricalBox(n=50257),
            device=cuda:0,
            dtype=torch.int64,
            domain=discrete),
        device=cuda:0,
        shape=torch.Size([])),
    

# Transformer Policy and Value Function

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from lm_human_preferences.lm.reward import RewardModel
from torchrl.envs import GymEnv
import torch
import gymnasium as gym
from lm_human_preferences.env.rlhf_env import RLHFEnv
from lm_human_preferences.data.base import QueryData
from lm_human_preferences.lm.reward import RewardModel
from transformers import AutoTokenizer

MODEL_NAME = 'openai-community/gpt2'
DEVICE = 'cuda'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
value_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=1).to(DEVICE)
value_model.transformer.requires_grad_(False)
value_model.score.requires_grad_(True)
value_model.config.pad_token_id = value_model.config.eos_token_id
reward_model = RewardModel.from_pretrained("../models/reward_model")

def global_init_env():
    gym.envs.register(
        id='RLHFEnv-v0',
        entry_point='lm_human_preferences.env.rlhf_env:RLHFEnv',
        kwargs={
            'ref_model_name': 'openai-community/gpt2',
            'reward_model': reward_model,
            'dataset': QueryData.from_openai_format(
                AutoTokenizer.from_pretrained('openai-community/gpt2'),
                '../data/descriptiveness_offline_5k'
            ),
            'kl_coef': 0.01,
            'max_generation': 64,
            'device': DEVICE,
            'seed': 42
        }
    )

global_init_env()
base_env = GymEnv("RLHFEnv-v0", device=DEVICE)

torchdict = base_env.reset()

  from .autonotebook import tqdm as notebook_tqdm
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at openai-community/gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
class TransformerPolicy(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, input_ids, attention_mask=None):
        # input_ids [batch_size, seq_len, vocab_size]
        if input_ids.ndim == 2:
            input_ids = input_ids.unsqueeze(dim=0)
            
        input_ids = input_ids.argmax(dim=-1) # [batch_size, seq_len]
        attention_mask = input_ids != tokenizer.pad_token_id
        logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits # [batch_size, seq_len, vocab_size]
        return logits[:, -1, :] # [batch_size, vocab_size]

from tensordict.nn import TensorDictModule
from torchrl.modules import ProbabilisticActor
from torchrl.modules.distributions.discrete import OneHotCategorical

policy_module = TensorDictModule(
    module=TransformerPolicy(model),
    in_keys=["observation"],   # what the env will supply in the tensordict
    out_keys=["logits"],      # what we'll hand to Categorical
)

# dist_module = CategoricalParamWrapper(logits_key="logits", dist_key="dist")
actor = ProbabilisticActor(
    module=policy_module,
    in_keys=["logits"],
    out_keys=["action"],
    distribution_class=OneHotCategorical,     # ← use Categorical
    spec=base_env.action_spec,          # your Discrete(vocab_size)
    return_log_prob=True
)

torchdict = actor(torchdict)
# print(torchdict['logits'])
# torchdict['action'].argmax(dim=1)

In [None]:
model.transformers

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [4]:
from torchrl.modules import ValueOperator

class TransformerValue(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, input_ids, attention_mask=None):
        # input_ids [batch_size, seq_len, vocab_size]
        if input_ids.ndim == 2:
            input_ids = input_ids.unsqueeze(dim=0)
    
        input_ids = input_ids.argmax(dim=-1)
        attention_mask = input_ids != tokenizer.pad_token_id
        logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits # [batch_size, 1]
        
        return logits[:, -1]

value_module = ValueOperator(
    module=TransformerValue(value_model),
    in_keys=["observation"]
)
# torchdict = value_module(torchdict)
# torchdict['state_value']

In [None]:
from typing import Any
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement

# This is a sequential collector
collector = SyncDataCollector(
    base_env,
    actor,
    frames_per_batch=5,
    total_frames=15,
    split_trajs=False,
    device=DEVICE,
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=30),
    sampler=SamplerWithoutReplacement(),
)

from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

advantage_module = GAE(
    gamma=.99, lmbda=1, value_network=value_module, average_gae=True
)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=0.2,
    entropy_bonus=True,
    entropy_coef=0.01,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, 3000 // 30, 0.0
)

In [8]:
import torch.nn.functional as F

def pad_observation_and_next_observation(obs_data, next_obs_data):
    token_ids = obs_data.argmax(dim=-1) # (step, seq_len)
    token_ids_mask = obs_data.sum(dim=-1) == 0
    token_ids[token_ids_mask] = tokenizer.pad_token_id

    next_token_ids = next_obs_data.argmax(dim=-1)
    next_token_ids_mask = next_obs_data.sum(dim=-1) == 0
    next_token_ids[next_token_ids_mask] = tokenizer.pad_token_id

    max_seq_len = max(token_ids.size(1), next_token_ids.size(1))

    # Pad token_ids and next_token_ids to the same sequence length
    if token_ids.size(1) < max_seq_len:
        token_ids = F.pad(token_ids, (0, max_seq_len - token_ids.size(1)), value=tokenizer.pad_token_id)
    if next_token_ids.size(1) < max_seq_len:
        next_token_ids = F.pad(next_token_ids, (0, max_seq_len - next_token_ids.size(1)), value=tokenizer.pad_token_id)

    one_hot = F.one_hot(token_ids, num_classes=tokenizer.vocab_size)
    next_one_hot = F.one_hot(next_token_ids, num_classes=tokenizer.vocab_size)

    return one_hot, next_one_hot

In [9]:
tensor_data = next(iter(collector))

tensor_data['observation'], tensor_data['next']['observation'] = pad_observation_and_next_observation(
    tensor_data.get(
        'observation',
        as_padded_tensor=True,
        padding_value=0,
        padding_side="left")
    ,
    tensor_data['next'].get(
        'observation',
        as_padded_tensor=True,
        padding_value=0,
        padding_side="left"
    )
)

  action = int(action)


In [10]:
print(tokenizer.decode(tensor_data['observation'].argmax(dim=-1)[2]))
print(tokenizer.decode(tensor_data['next']['observation'].argmax(dim=-1)[2]))
print(tensor_data['observation'].shape)
print(tensor_data['next']['observation'].shape)

<|endoftext|><|endoftext|>  Chris, you too."
Mrs. Browley threw her husband an impressed look.
"That sounds like a wonderful idea, Priscilla."  Mr. Browley stated.  "Mason, Hanna, Chris?"
"That sounds good to me."  Hanna replied. He was<|endoftext|>
<|endoftext|><|endoftext|>  Chris, you too."
Mrs. Browley threw her husband an impressed look.
"That sounds like a wonderful idea, Priscilla."  Mr. Browley stated.  "Mason, Hanna, Chris?"
"That sounds good to me."  Hanna replied. He was no
torch.Size([5, 64, 50257])
torch.Size([5, 64, 50257])


In [218]:
import torch  
import torch.nn as nn  
  
class CustomGAE(nn.Module):  
    """Custom Generalized Advantage Estimation module that doesn't use vmap.  
      
    This implementation follows the algorithm described in the paper   
    "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"  
    https://arxiv.org/pdf/1506.02438.pdf  
    """  
      
    def __init__(  
        self,  
        gamma: float,  
        lmbda: float,  
        normalize_advantages: bool = False  
    ):  
        """Initialize the GAE module.  
          
        Args:  
            gamma (float): Discount factor for future rewards  
            lmbda (float): GAE lambda parameter for controlling bias-variance tradeoff  
            normalize_advantages (bool): Whether to normalize advantages. Default: False  
        """  
        super().__init__()  
        self.register_buffer("gamma", torch.tensor(gamma))  
        self.register_buffer("lmbda", torch.tensor(lmbda))  
        self.normalize_advantages = normalize_advantages  
      
    def forward(  
        self,  
        state_value: torch.Tensor,  
        next_state_value: torch.Tensor,  
        reward: torch.Tensor,  
        done: torch.Tensor,  
        terminated: torch.Tensor = None,  
    ):  
        """Compute GAE advantages and value targets.  
          
        Args:  
            state_value: Value estimates of states at time t [batch_size, time_steps, 1]  
            next_state_value: Value estimates of states at time t+1 [batch_size, time_steps, 1]  
            reward: Rewards received at time t [batch_size, time_steps, 1]  
            done: Boolean tensor indicating if episode is done [batch_size, time_steps, 1]  
            terminated: Boolean tensor indicating if episode is terminated [batch_size, time_steps, 1]  
          
        Returns:  
            tuple: (advantages, value_targets)  
        """  
        # Handle defaults  
        if terminated is None:  
            terminated = done.clone()  
          
        # Check shapes  
        if not (next_state_value.shape == state_value.shape == reward.shape == done.shape == terminated.shape):  
            raise RuntimeError("All input tensors (value, reward and done states) must share a unique shape.")  
          
        # Move constants to the right device  
        device = state_value.device  
        if self.gamma.device != device:  
            self.gamma = self.gamma.to(device)  
        if self.lmbda.device != device:  
            self.lmbda = self.lmbda.to(device)  
          
        gamma = self.gamma  
        lmbda = self.lmbda  
          
        # Setup  
        not_done = (~done).int()  
        not_terminated = (~terminated).int()  
        *batch_size, time_steps, lastdim = not_done.shape  
          
        # Preallocate advantage tensor  
        advantage = torch.zeros_like(state_value)  
          
        # Calculate TD error: δ_t = r_t + γV(s_{t+1}) - V(s_t)  
        g_not_terminated = gamma * not_terminated  
        delta = reward + (g_not_terminated * next_state_value) - state_value  
          
        # Calculate GAE by backward recursion  
        discount = lmbda * gamma * not_done  
          
        # Initialize to zero for the last timestep  
        last_gae = torch.zeros_like(delta[..., 0, :])  
          
        # Use a single reverse for loop - unavoidable for GAE calculation  
        for t in reversed(range(time_steps)):  
            # Compute current GAE value  
            last_gae = delta[..., t, :] + discount[..., t, :] * last_gae  
            advantage[..., t, :] = last_gae  
          
        # Calculate value targets: V_target = A + V  
        value_target = advantage + state_value  
          
        # Optionally normalize advantages  
        if self.normalize_advantages:  
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)  
          
        return advantage, value_target

gae = CustomGAE(gamma=0.99, lmbda=0.95, normalize_advantages=True)  
gae(
    tensor_data['state_value'],
    tensor_data['next']['state_value'],
    tensor_data['next']['reward'],
    tensor_data['done'],
    tensor_data['terminated']
)

RuntimeError: All input tensors (value, reward and done states) must share a unique shape.

In [None]:
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our ``env`` horizon).
        # The ``rollout`` method of the ``env`` can take a policy as argument:
        # it will then execute this policy at each step.
        with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()

LazyStackedTensorDict(
    fields={
        action: Tensor(shape=torch.Size([30, 50257]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: LazyStackedTensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([30]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([30]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        logits: Tensor(shape=torch.Size([30, 50257]), device=cpu, dtype=torch.float32, is_shared=False),
        next: LazyStackedTensorDict(
            fields={
                done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([30, -1, 50257]), device=cpu, dtype=torch.int64, is_shared=False),
                reward: Tensor(shape=torch

In [None]:
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()


# Custom Pytorch Only Implementation