# 06: RLHF PPO Training Loop
This notebook simulates Reinforcement Learning from Human Feedback (RLHF) using PPO with a trained reward model.

In [None]:
!pip install torch transformers

In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [2]:
import torch
from transformers import AutoTokenizer
from models.decoder_only import GPTStyleDecoder
from rlhf.reward_model import RewardModel
from rlhf.ppo_trainer import PPOTrainer

## Load tokenizer and models

In [4]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # make sure it's added if it was before

device = "cuda" if torch.cuda.is_available() else "cpu"

policy_model = GPTStyleDecoder(vocab_size=tokenizer.vocab_size, embed_dim=768, depth=6, heads=12, ff_dim=2048).to(device)
ref_model = GPTStyleDecoder(vocab_size=tokenizer.vocab_size, embed_dim=768, depth=6, heads=12, ff_dim=2048).to(device)
#reward_model = RewardModel(vocab_size=tokenizer.vocab_size).to(device)
reward_model = RewardModel(vocab_size=len(tokenizer)).to(device)

policy_model.load_state_dict(torch.load("gpt_decoder_trained.pt", map_location=device))
ref_model.load_state_dict(torch.load("gpt_decoder_trained.pt", map_location=device))
reward_model.load_state_dict(torch.load("reward_model.pt", map_location=device))

<All keys matched successfully>

## Initialize PPO Trainer

In [5]:
ppo = PPOTrainer(policy_model, ref_model, reward_model, lr=1e-5)

## Simulate batch of prompts and responses

In [6]:
prompts = ["The meaning of life is", "A good day starts with"]
input_ids = [tokenizer(p, return_tensors="pt")["input_ids"][0] for p in prompts]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)

# Simulate responses by repeating prompt or randomly extending
responses = torch.cat([input_ids, torch.randint(0, tokenizer.vocab_size, (2, 20)).to(device)], dim=1)

## Perform one PPO training step

In [7]:
ppo_metrics = ppo.train_step(input_ids, responses)
print("PPO Training Metrics:")
for k, v in ppo_metrics.items():
    print(f"{k}: {v:.4f}")

PPO Training Metrics:
loss: 0.0000
policy_loss: -0.0000
value_loss: 0.0000
rewards_mean: -2.3325
advantages_mean: 0.0000
