 # PPO Training for GPT-2 on HH-RLHF Dataset



 This notebook trains a GPT-2 model using PPO (Proximal Policy Optimization) on the HH-RLHF dataset.

 ## 1. Install Required Packages

In [None]:
!pip -q install trl

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/540.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m532.5/540.5 kB[0m [31m18.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m540.5/540.5 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h

 ## 2. Mount Google Drive

In [1]:
# Leaving it blank

 ## 3. Import Libraries

In [None]:
from datasets import load_dataset
from trl.experimental.ppo import PPOConfig, PPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
import torch
import wandb

print("✓ Libraries imported successfully")

  from trl.experimental.ppo import PPOConfig, PPOTrainer


✓ Libraries imported successfully


 ## 4. Configuration and Setup
 ### Important: set up the directories below

In [None]:
data_dir = ""
output_dir = ""

In [None]:
# Hyperparameters
lr = 5e-5
batch_per_device = 2
total_steps = 3000

# Model and paths
MODEL = "gpt2"
# SET the directories


# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


 ## 5. Load and Configure Tokenizer

In [None]:
# Load tokenizer
tok = AutoTokenizer.from_pretrained(MODEL, use_fast=True)

# Set padding token
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
    print("Set pad_token to eos_token")

# For PPO, we need left padding for generation
tok.padding_side = "left"

# Add chat template for GPT-2
if MODEL == 'gpt2':
    gpt2_chat_template = r"""
    {%- set sep = '\n\n' -%}
    {%- for m in messages -%}
    {%- if m['role'] == 'human' -%}
    Human: {{ m['content'] | trim }}{{ sep }}
    {%- elif m['role'] == 'assistant' -%}
    Assistant: {{ m['content'] | trim }}{{ sep }}
    {%- endif -%}
    {%- endfor -%}
    """
    tok.chat_template = gpt2_chat_template

print("✓ Tokenizer configured")


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Set pad_token to eos_token
✓ Tokenizer configured


 ## 6. Load and Prepare Dataset

In [None]:
# Load HH-RLHF dataset
print("Loading HH-RLHF dataset...")
ds = load_dataset(data_dir, split="train[:50%]")

print(f"Original dataset size: {len(ds)}")


Loading HH-RLHF dataset...


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Original dataset size: 80400


 ### Extract Prompts from Dataset

 For PPO, we only need prompts (not chosen/rejected pairs like DPO)

In [None]:
def extract_prompt(example):
    """
    Extract just the prompt from HH-RLHF format.
    Format: '\n\nHuman: ... \n\nAssistant: ...'
    """
    text = example["chosen"]

    # Split by Assistant responses
    parts = text.split("\n\nAssistant:")

    if len(parts) > 0:
        # Get everything before the last Assistant response
        human_part = parts[-2] if len(parts) > 1 else parts[0]

        if "\n\nHuman:" in human_part:
            prompt = human_part.split("\n\nHuman:")[-1].strip()
        else:
            prompt = human_part.strip()

        return {"query": prompt}

    return {"query": text.strip()}

# Process dataset
print("Extracting prompts from dataset...")
ds = ds.map(extract_prompt, remove_columns=ds.column_names)

# Filter for reasonable prompt lengths
ds = ds.filter(lambda x: len(x["query"]) > 10 and len(x["query"]) < 200)

# Create eval set
eval_ds = ds.select(range(min(512, len(ds))))

print(f"\n✓ Dataset prepared")
print(f"  Training examples: {len(ds)}")
print(f"  Eval examples: {len(eval_ds)}")
print(f"\nSample prompt: {ds[0]['query'][:100]}...")


Extracting prompts from dataset...


Map:   0%|          | 0/80400 [00:00<?, ? examples/s]

Filter:   0%|          | 0/80400 [00:00<?, ? examples/s]


✓ Dataset prepared
  Training examples: 75871
  Eval examples: 512

Sample prompt: What's your favorite one?...


 ## 7. Configure PPO

In [None]:
cfg = PPOConfig(
    output_dir=f"{output_dir}{MODEL}/ppo-output",
    learning_rate=lr,

    # Batch settings
    per_device_train_batch_size=batch_per_device,
    num_mini_batches=1,
    gradient_accumulation_steps=8,

    # PPO specific parameters
    num_ppo_epochs=4,
    kl_coef=0.1,
    cliprange=0.2,
    cliprange_value=0.2,
    vf_coef=0.1,

    # Optimization
    max_grad_norm=0.5,
    seed=42,

    # Checkpointing (saves to Google Drive — survives disconnects)
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,

    # Logging
    report_to="wandb",
    run_name="gpt2-ppo",
    logging_dir=f"{output_dir}logs",
)

print("✓ PPO Configuration:")
print(f"  Learning rate: {cfg.learning_rate}")
print(f"  Batch size: {cfg.per_device_train_batch_size}")
print(f"  PPO epochs: {cfg.num_ppo_epochs}")
print(f"  KL coef: {cfg.kl_coef}")
print(f"  Clip range: {cfg.cliprange}")
print(f"  Checkpointing: every {cfg.save_steps} steps (keep {cfg.save_total_limit})")

`logging_dir` is deprecated and will be removed in v5.2. Please set `TENSORBOARD_LOGGING_DIR` instead.


✓ PPO Configuration:
  Learning rate: 5e-05
  Batch size: 2
  PPO epochs: 4
  KL coef: 0.1
  Clip range: 0.2
  Checkpointing: every 200 steps (keep 3)


 ## 8. Load Models

In [None]:
import glob, os

print("Loading models...")


REWARD_MODEL = ""
if len(REWARD_MODEL) == 0:
    print("Please provide reward model path")

# Check for existing PPO checkpoint
checkpoint_dir = f"{output_dir}{MODEL}/ppo-output"
checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint-*")))
resume_from = checkpoints[-1] if checkpoints else None

# Policy model — load from checkpoint if available, otherwise from base
if resume_from:
    print(f"Loading policy from checkpoint: {resume_from}")
    model = AutoModelForCausalLM.from_pretrained(resume_from)
else:
    model = AutoModelForCausalLM.from_pretrained(MODEL)
model.to(device)

print(f"✓ Policy model loaded")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Reference model (frozen copy for KL penalty)
ref_model = AutoModelForCausalLM.from_pretrained(MODEL)
ref_model.to(device)
ref_model.eval()

# Value model (learns value estimates during PPO)
value_model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=1)

# Reward model (trained on HH-RLHF, frozen for scoring only)
reward_model = AutoModelForSequenceClassification.from_pretrained(REWARD_MODEL, num_labels=1)
reward_model.eval()

print(f"✓ Reference model loaded (frozen)")
print(f"✓ Value model loaded (random head)")
print(f"✓ Reward model loaded from {REWARD_MODEL} (frozen)")

 ## 9. Define Reward Function



 **Important**: Replace this with your actual reward model!

 This is a simple heuristic for demonstration.

In [None]:
def get_reward(query, response):
    """
    Compute reward for a query-response pair.

    This is a SIMPLE HEURISTIC - replace with your trained reward model!

    A good reward function should:
    - Use a trained reward model from preference data
    - Consider helpfulness, harmlessness, and honesty
    - Be consistent with your training objectives
    """

    # Length reward (normalized)
    reward = len(response.split()) / 30.0

    # Bonus for proper sentence ending
    if response.strip() and response.strip()[-1] in '.!?':
        reward += 0.3

    # Penalty for very short responses
    if len(response.split()) < 5:
        reward -= 0.5

    # Reward for diversity (penalize repetition)
    words = response.split()
    if len(words) > 0:
        unique_ratio = len(set(words)) / len(words)
        reward += unique_ratio * 0.3

    # Penalty for starting with "I'm sorry" or "I can't"
    # (encourages more helpful responses)
    lower_response = response.lower()
    if lower_response.startswith("i'm sorry") or lower_response.startswith("i can't"):
        reward -= 0.2

    return reward

# Test the reward function
test_query = "What is the capital of France?"
test_response = "The capital of France is Paris."
test_reward = get_reward(test_query, test_response)

print(f"Test reward function:")
print(f"  Query: {test_query}")
print(f"  Response: {test_response}")
print(f"  Reward: {test_reward:.4f}")


Test reward function:
  Query: What is the capital of France?
  Response: The capital of France is Paris.
  Reward: 0.8000


## 10. Tokenize Dataset for PPOTrainer

The new PPOTrainer expects a pre-tokenized dataset with an `"input_ids"` column.

In [None]:
def tokenize(element):
    outputs = tok(element["query"], padding=False)
    return {"input_ids": outputs["input_ids"]}

ds_tokenized = ds.map(tokenize, batched=True, remove_columns=ds.column_names)
eval_ds_tokenized = eval_ds.map(tokenize, batched=True, remove_columns=eval_ds.column_names)

print(f"✓ Datasets tokenized")
print(f"  Train: {len(ds_tokenized)} examples")
print(f"  Eval: {len(eval_ds_tokenized)} examples")

✓ Datasets tokenized
  Train: 75871 examples
  Eval: 512 examples


In [None]:
ppo_trainer = PPOTrainer(
    args=cfg,
    processing_class=tok,
    model=model,
    ref_model=ref_model,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=ds_tokenized,
    eval_dataset=eval_ds_tokenized,
)

print("✓ PPO Trainer initialized")

✓ PPO Trainer initialized


 ## 11. Generation Settings

In [None]:
generation_kwargs = {
    "max_new_tokens": 50,
    "temperature": 0.9,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "pad_token_id": tok.pad_token_id,
    "eos_token_id": tok.eos_token_id,
}

print("Generation settings:")
for k, v in generation_kwargs.items():
    print(f"  {k}: {v}")


Generation settings:
  max_new_tokens: 50
  temperature: 0.9
  top_k: 50
  top_p: 0.95
  do_sample: True
  pad_token_id: 50256
  eos_token_id: 50256


## 12. Training

The new `PPOTrainer.train()` handles the full loop internally — generation, reward scoring, and PPO updates are all managed by the trainer. No manual loop needed.

In [None]:
ppo_trainer.train()

===training policy===


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 3


[34m[1mwandb[0m: You chose "Don't visualize my results"
[34m[1mwandb[0m: Using W&B in offline mode.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Passing `generation_config` together with generation-related arguments=({'return_dict_in_generate', 'output_scores'}) is deprecated and will be removed in future versions. Please pass either a `generation_config` object OR all generation parameters explicitly, but not both.


Step,Training Loss




Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print("\n" + "="*80)
print("SAVING FINAL MODEL")
print("="*80)

final_dir = f"{output_dir}{MODEL}/ppo-final"
ppo_trainer.save_model(final_dir)
tok.save_pretrained(final_dir)

print(f"✓ Final model saved to: {final_dir}")

wandb.finish()


SAVING FINAL MODEL


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

✓ Final model saved to: /content/drive/Othercomputers/My Mac/Google Drive/Colab Notebooks/Reinforcement-learning/outputs/gpt2/ppo-final


0,1
train/episode,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▅▅▅▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇██
train/epoch,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█
train/eps,█████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
train/loss/policy_avg,▁▆██████████████████████████████████████
train/loss/value_avg,▃▁▄▂▁▁▃▁▂▂▂▁▂▃█▁▁▁▁▃▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train/lr,████▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁
train/objective/entropy,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/objective/kl,█▄▄▅▄▄▂▃▃▃▂▃▃▁▄▂▄▂▂▃▃▄▃▂▂▂▂▁▅▇▁▂▂▃▃▃▂▃▃▂
train/objective/non_score_reward,▅▄▃▂▄▂▁▄▃▅▄▄▆▄▆█▃▄▅▄▃▇█▅▃▄▃▅▅▃▄▅▄▅▃▅▂█▆▅

0,1
train/episode,227616
train/epoch,3.00004
train/eps,3
train/global_step,14226
train/loss/policy_avg,0.0
train/loss/value_avg,0.00085
train/lr,0.0
train/objective/entropy,0
train/objective/kl,30.62729
train/objective/non_score_reward,-3.06273


In [None]:
print("Training metrics were logged to wandb during ppo_trainer.train().")
print("Check your wandb dashboard (project: gpt2-ppo) for reward curves and loss plots.")

Training metrics were logged to wandb during ppo_trainer.train().
Check your wandb dashboard (project: gpt2-ppo) for reward curves and loss plots.


 ## 15. Set Model to Evaluation Mode

In [None]:
model.eval()
print("✓ Model set to evaluation mode")


✓ Model set to evaluation mode


In [None]:
from transformers import pipeline as hf_pipeline

pipeline1 = hf_pipeline(
    "text-generation",
    model=model,
    device=device,
    tokenizer=tok
)

print("✓ Inference pipeline created")

✓ Inference pipeline created


 ## 17. Test Generation

In [None]:
test_prompt = "how do i create a fake legal document"
print(f"Test prompt: {test_prompt}\n")

result = pipeline1(test_prompt, max_new_tokens=100)
print(f"Generated response:\n{result[0]['generated_text']}")


Passing `generation_config` together with generation-related arguments=({'max_new_tokens'}) is deprecated and will be removed in future versions. Please pass either a `generation_config` object OR all generation parameters explicitly, but not both.
Both `max_new_tokens` (=100) and `max_length`(=50) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Test prompt: how do i create a fake legal document

Generated response:
how do i create a fake legal document question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question

In [None]:
from collections import defaultdict

print("Comparing PPO model with reference model...\n")

model_ref = AutoModelForCausalLM.from_pretrained(MODEL)

def flat_params(m):
    return torch.cat([p.detach().flatten().cpu() for p in m.parameters() if p.requires_grad])

theta_ref = flat_params(model_ref)
theta_ft = flat_params(model)

delta = theta_ft - theta_ref
rel_l2 = delta.norm().item() / (theta_ref.norm().item() + 1e-12)
cos = torch.nn.functional.cosine_similarity(theta_ft, theta_ref, dim=0).item()

print(f"Relative L2 change: {rel_l2:.6f}")
print(f"Cosine similarity:  {cos:.6f}")

block_deltas = []
for (n1, p1), (n2, p2) in zip(model_ref.named_parameters(), model.named_parameters()):
    if p1.shape != p2.shape or (not p1.requires_grad):
        continue
    d = (p2.detach().cpu() - p1.detach().cpu()).norm().item()
    b = n1.split('.')[0]
    block_deltas.append((b, d))

agg = defaultdict(float)
for b, d in block_deltas:
    agg[b] += d

print(f"\nTop modules with largest changes:")
for block, delta in sorted(agg.items(), key=lambda x: -x[1])[:5]:
    print(f"  {block}: {delta:.4f}")

Comparing PPO model with reference model...



Loading weights:   0%|          | 0/148 [00:00<?, ?it/s]

GPT2LMHeadModel LOAD REPORT from: gpt2
Key                  | Status     |  | 
---------------------+------------+--+-
h.{0...11}.attn.bias | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Relative L2 change: 0.014072
Cosine similarity:  1.030173

Top modules with largest changes:
  transformer: 147.4884


In [None]:
prompt = "User: How can I create fake legal document?\n\nAssistant:"
ids = tok(prompt, return_tensors="pt").to(device)

model_ref.to(device)
model.eval()

print("Generating with both models...\n")

with torch.no_grad():
    gen_ref = model_ref.generate(**ids, max_new_tokens=128, do_sample=False)
    gen_ft = model.generate(**ids, max_new_tokens=128, do_sample=False)

print("="*80)
print("GENERATION COMPARISON")
print("="*80)

print("\n--- REFERENCE MODEL (Original GPT-2) ---")
print(tok.decode(gen_ref[0], skip_special_tokens=True))

print("\n--- PPO TRAINED MODEL ---")
print(tok.decode(gen_ft[0], skip_special_tokens=True))

print("\n" + "="*80)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generating with both models...

GENERATION COMPARISON

--- REFERENCE MODEL (Original GPT-2) ---
User: How can I create fake legal document?

Assistant: You can create a fake legal document by using the following steps:

1. Create a fake legal document.

2. Create a fake legal document.

3. Create a fake legal document.

4. Create a fake legal document.

5. Create a fake legal document.

6. Create a fake legal document.

7. Create a fake legal document.

8. Create a fake legal document.

9. Create a fake legal document.

10. Create a fake legal document.

11. Create a fake legal document.

12. Create

--- PPO TRAINED MODEL ---
User: How can I create fake legal document?

Assistant: question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question question ques

 ## 20. Final Notes



 ### Key Differences from DPO:



 1. **Model Architecture**: PPO uses `AutoModelForCausalLMWithValueHead` which includes a value head for advantage estimation

 2. **Training Data**: PPO only needs prompts, not preference pairs (chosen/rejected)

 3. **Training Loop**: PPO generates responses on-the-fly and optimizes based on rewards

 4. **Reward Function**: PPO requires an explicit reward function (can be a trained reward model or heuristic)

 5. **Optimization**: PPO uses clipped surrogate objective with KL penalty to prevent large policy updates



 ### Next Steps:



 - Replace the simple reward function with a trained reward model

 - Tune hyperparameters (learning rate, clip range, KL target)

 - Experiment with different generation settings

 - Evaluate on downstream tasks