## Task Description

Apply Proximal Policy Optimization to fine-tune the pretrained GPT-2 Medium model so that it excels at simple addition problems with operands from 0 to 100 (or 0 to 50). Define a reward that reflects arithmetic correctness, train the model with the PPO algorithm, and track the key loss components throughout training. Finally, visualize how the policy loss, value loss, and entropy loss evolve over the course of optimization.

## Code

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers trl==0.11.3 wandb



In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
import random
from torch.utils.data import Dataset
from tqdm import tqdm
import re
import math

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [5]:
model_name = "gpt2-medium"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
MIN_OPERAND = 0
MAX_OPERAND = 50
DATASET_SIZE = 20000

class AdditionDataset(Dataset):
    def __init__(self, tokenizer, min_operand=0, max_operand=99, num_samples=1000):
        self.tokenizer = tokenizer
        self.min_operand = min_operand
        self.max_operand = max_operand
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate on the fly
        num1 = random.randint(self.min_operand, self.max_operand)
        num2 = random.randint(self.min_operand, self.max_operand)
        correct_sum = num1 + num2

        prompt = f"{num1} + {num2} = "

        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").squeeze(0)

        return {
            "query": prompt,
            "input_ids": input_ids,
            "answer_str": str(correct_sum)
        }

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [8]:
dataset = AdditionDataset(tokenizer, min_operand=MIN_OPERAND, max_operand=MAX_OPERAND, num_samples=DATASET_SIZE)

In [9]:
config = PPOConfig(
    model_name=model_name,
    learning_rate=1.41e-5,
    log_with="wandb",
    batch_size=32,
    mini_batch_size=16
)

ppo_trainer = PPOTrainer(
    config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator
)

[34m[1mwandb[0m: Currently logged in as: [33mdadra102[0m ([33mdadra102-heinrich-heine-university-d-sseldorf[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
def extract_single_number_from_str(input: str) -> int | None:
  """
  Extracts a number from a string.
  If the string contains multiple or no numbers, return None.
  """
  numbers_found = re.findall(r"\b\d+\b", input)
  if len(numbers_found) == 1:
        return int(numbers_found[0])
  else:
        return None

def calculate_reward(correct_num, generated_str):
    # Allow text in response like: 'The answer is 10'
    answer_num = extract_single_number_from_str(generated_str)

    if answer_num is None:
        return -0.5 # Small penalty for no or multiple numbers
    else:
        # Use a smooth reward function
        error = float(abs(answer_num - correct_num))
        raw_reward = 10.0 * math.exp(-0.3 * error)
        return max(0.0, raw_reward)

In [11]:
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 10,
}

mean_rewards = []
for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len = 5
        generation_kwargs["max_new_tokens"] = gen_len

        query_response = ppo_trainer.generate(query, **generation_kwargs).squeeze()

        response_len = len(query_response) - len(query)
        response_tensors.append(query_response[-response_len:])

    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    # Calculate rewards
    rewards = []
    for i in range(len(batch["query"])): # Iterate through each sample in the batch
        correct_num = int(batch["answer_str"][i])
        generated_output_str = batch["response"][i]
        current_reward_val = calculate_reward(correct_num, generated_output_str)
        rewards.append(torch.tensor(current_reward_val, dtype=torch.float, device=ppo_trainer.accelerator.device))

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    #### Logging
    ppo_trainer.log_stats(stats, batch, rewards)
    mean_rewards.append(torch.mean(torch.stack(rewards)).item())

  0%|          | 0/625 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 625/625 [59:53<00:00,  5.75s/it]


## Visualization

It was not quite there yet, but I ran out of compute on Google Colab. Here is an image from wandb:

<img src="https://drive.google.com/uc?id=1Ewoirnf0n2VWylbqsHVkny0CFzVSy-kc">

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
plt.plot(mean_rewards, label="Mean Reward per Batch")
plt.xlabel("Training Steps (Batches)")
plt.ylabel("Mean Reward")
plt.title("Mean Reward During PPO Training")
plt.legend()
plt.grid(True)
plt.show()

In [12]:
model.eval()

test_prompts = [
    "10 + 5 = ",
    "25 + 30 = ",
    "7 + 88 = ",
    "50 + 49 = ",
    "1 + 1 = "
]

inference_generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 10,
}

with torch.no_grad():
    for prompt_text in test_prompts:
        print(f"Prompt: {prompt_text}")

        input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)

        output_sequences = model.generate(
            input_ids=input_ids,
            **inference_generation_kwargs
        )

        generated_text_full = tokenizer.decode(output_sequences[0], skip_special_tokens=True)

        generated_text_answer_only = generated_text_full[len(prompt_text):]

        print(f"Generated: {generated_text_answer_only.strip()}")
        print("-" * 20)

Prompt: 10 + 5 = 
Generated: 16 of the way to the previous exhaust port
--------------------
Prompt: 25 + 30 = 
Generated: 56
Evaluation sample, it should
--------------------
Prompt: 7 + 88 = 
Generated: 15
If you have done your calculations,
--------------------
Prompt: 50 + 49 = 
Generated: 94% of the world's Shar校
--------------------
Prompt: 1 + 1 = 
Generated: 2. In the N +1, it
--------------------
