# Week 2: RL Training Methods for Language Models

**MSc Course -- Generative Models in Finance**

This notebook walks through the key training paradigms used to align and improve language models after pretraining:

| Section | Method | Key Idea |
|---------|--------|----------|
| 1 | **SFT** | Supervised fine-tuning on instruction-response pairs |
| 2 | **REINFORCE** | Classic policy-gradient with a simple reward heuristic |
| 3 | **PPO** | Proximal Policy Optimisation with GAE and clipped objective |
| 4 | **GRPO** | Group Relative Policy Optimisation (critic-free) |
| 5 | **Bradley-Terry Reward Model** | Learn a reward model from preference data |

All models are tiny (~1M parameters) and train on CPU in under 2 minutes.

**Dependencies:** `torch`, `numpy`, `matplotlib`

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import copy
import math
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(42)
np.random.seed(42)

DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")

Using device: cpu
PyTorch version: 2.6.0


## Shared Infrastructure: Tiny GPT-2 Model and Tokeniser

We build a minimal GPT-2-style transformer from scratch. It has:
- A small vocabulary (character-level + special tokens)
- 4 transformer layers, 4 attention heads, embedding dim 128
- ~1M parameters

This is shared across all experiments.

In [None]:
# ---------------------------------------------------------------------------
# Simple character-level tokeniser
# ---------------------------------------------------------------------------
class CharTokeniser:
    """Character-level tokeniser with special tokens."""

    def __init__(self):
        # Printable ASCII 32-126 plus special tokens
        chars = [chr(i) for i in range(32, 127)]  # 95 chars
        self.special = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3}
        self.char2id = {**self.special}
        for i, c in enumerate(chars):
            self.char2id[c] = len(self.special) + i
        self.id2char = {v: k for k, v in self.char2id.items()}
        self.vocab_size = len(self.char2id)
        self.pad_id = self.special["<pad>"]
        self.bos_id = self.special["<bos>"]
        self.eos_id = self.special["<eos>"]

    def encode(self, text: str, add_bos=True, add_eos=True) -> list[int]:
        ids = []
        if add_bos:
            ids.append(self.bos_id)
        for c in text:
            ids.append(self.char2id.get(c, self.special["<unk>"]))
        if add_eos:
            ids.append(self.eos_id)
        return ids

    def decode(self, ids: list[int]) -> str:
        chars = []
        for i in ids:
            tok = self.id2char.get(i, "")
            if tok in ("<pad>", "<bos>", "<eos>", "<unk>"):
                continue
            chars.append(tok)
        return "".join(chars)


tokeniser = CharTokeniser()
VOCAB_SIZE = tokeniser.vocab_size
print(f"Vocab size: {VOCAB_SIZE}")
print(f"Example encode: {tokeniser.encode('hello')}")
print(f"Example decode: {tokeniser.decode(tokeniser.encode('hello'))}")

In [None]:
# ---------------------------------------------------------------------------
# Tiny GPT-2 Model
# ---------------------------------------------------------------------------
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_len: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        # Causal mask
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x)  # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)
        # Reshape to (B, n_heads, T, head_dim)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        # Attention scores
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        out = att @ v  # (B, n_heads, T, head_dim)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_len: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, max_len, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class TinyGPT2(nn.Module):
    """Minimal GPT-2 style autoregressive language model."""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        n_heads: int = 4,
        n_layers: int = 4,
        max_len: int = 128,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.max_len = max_len
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList(
            [TransformerBlock(d_model, n_heads, max_len, dropout) for _ in range(n_layers)]
        )
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        # Weight tying
        self.head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        """Returns logits of shape (B, T, vocab_size)."""
        B, T = idx.shape
        assert T <= self.max_len, f"Sequence length {T} > max_len {self.max_len}"
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)  # (1, T)
        x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(
        self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0
    ) -> torch.Tensor:
        """Autoregressive generation via sampling."""
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_len :]
            logits = self(idx_cond)[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_tok = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_tok], dim=1)
        return idx


def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Instantiate and inspect
_test_model = TinyGPT2(VOCAB_SIZE)
print(f"TinyGPT2 parameters: {count_parameters(_test_model):,}")
del _test_model

---
## 1. Supervised Fine-Tuning (SFT)

SFT is the first step after pretraining. We take a set of **(instruction, response)** pairs and
fine-tune the model with **teacher-forced cross-entropy loss**: at each position $t$ the model
predicts the next token given the ground-truth prefix.

$$\mathcal{L}_{\text{SFT}} = -\frac{1}{T}\sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t})$$

We train on a small synthetic dataset of instruction-response pairs relevant to finance.

In [None]:
# ---------------------------------------------------------------------------
# 1a. SFT Dataset
# ---------------------------------------------------------------------------
SFT_DATA = [
    ("What is a bond?", "A bond is a fixed-income instrument representing a loan made by an investor to a borrower."),
    ("Define stock.", "A stock is a security that represents ownership of a fraction of a corporation."),
    ("What is volatility?", "Volatility measures the degree of variation in a trading price over time."),
    ("Explain diversification.", "Diversification is a risk management strategy mixing a variety of investments."),
    ("What is a derivative?", "A derivative is a financial contract whose value depends on an underlying asset."),
    ("Define yield.", "Yield is the income return on an investment, typically expressed as a percentage."),
    ("What is liquidity?", "Liquidity refers to how quickly an asset can be converted to cash without loss."),
    ("Explain hedging.", "Hedging is an investment to reduce the risk of adverse price movements in an asset."),
    ("What is a put option?", "A put option gives the holder the right to sell an asset at a specified price."),
    ("Define alpha.", "Alpha is the excess return of an investment relative to its benchmark index."),
    ("What is a call option?", "A call option gives the holder the right to buy an asset at a specified price."),
    ("Explain leverage.", "Leverage is the use of borrowed capital to increase the potential return on investment."),
    ("What is beta?", "Beta measures the sensitivity of an asset's returns to the overall market returns."),
    ("Define portfolio.", "A portfolio is a collection of financial investments like stocks, bonds, and cash."),
    ("What is arbitrage?", "Arbitrage is the simultaneous purchase and sale of an asset to profit from price differences."),
    ("Explain short selling.", "Short selling involves borrowing shares and selling them, hoping to buy back at lower price."),
]


class SFTDataset(Dataset):
    def __init__(self, data, tokeniser, max_len=128):
        self.samples = []
        for instruction, response in data:
            text = f"Q: {instruction} A: {response}"
            ids = tokeniser.encode(text, add_bos=True, add_eos=True)
            ids = ids[:max_len]  # truncate
            self.samples.append(torch.tensor(ids, dtype=torch.long))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


def collate_pad(batch, pad_id=0):
    max_len = max(len(s) for s in batch)
    padded = torch.full((len(batch), max_len), pad_id, dtype=torch.long)
    for i, s in enumerate(batch):
        padded[i, : len(s)] = s
    return padded


sft_dataset = SFTDataset(SFT_DATA, tokeniser)
sft_loader = DataLoader(sft_dataset, batch_size=8, shuffle=True, collate_fn=collate_pad)
print(f"SFT dataset size: {len(sft_dataset)} samples")
print(f"Example (decoded): {tokeniser.decode(sft_dataset[0].tolist())}")

In [None]:
# ---------------------------------------------------------------------------
# 1b. SFT Training Loop
# ---------------------------------------------------------------------------
sft_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
sft_optimiser = torch.optim.AdamW(sft_model.parameters(), lr=3e-4, weight_decay=0.01)

NUM_SFT_EPOCHS = 60
sft_losses = []

sft_model.train()
for epoch in range(NUM_SFT_EPOCHS):
    epoch_loss = 0.0
    n_batches = 0
    for batch in sft_loader:
        batch = batch.to(DEVICE)
        # Teacher forcing: input = tokens[:-1], target = tokens[1:]
        inputs = batch[:, :-1]
        targets = batch[:, 1:]
        logits = sft_model(inputs)  # (B, T-1, V)
        # Flatten for cross-entropy; ignore padding (id=0)
        loss = F.cross_entropy(
            logits.reshape(-1, VOCAB_SIZE),
            targets.reshape(-1),
            ignore_index=tokeniser.pad_id,
        )
        sft_optimiser.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(sft_model.parameters(), 1.0)
        sft_optimiser.step()
        epoch_loss += loss.item()
        n_batches += 1
    avg_loss = epoch_loss / n_batches
    sft_losses.append(avg_loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d}/{NUM_SFT_EPOCHS}  loss={avg_loss:.4f}")

print("\nSFT training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 1c. SFT Training Loss Curve and Sample Generation
# ---------------------------------------------------------------------------
fig, ax = plt.subplots(1, 1, figsize=(7, 3.5))
ax.plot(sft_losses, linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Cross-Entropy Loss")
ax.set_title("SFT Training Loss")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Generate a sample
sft_model.eval()
prompt = "Q: What is a bond? A:"
prompt_ids = torch.tensor([tokeniser.encode(prompt, add_bos=True, add_eos=False)], device=DEVICE)
gen_ids = sft_model.generate(prompt_ids, max_new_tokens=60, temperature=0.7)
print(f"Prompt:    {prompt}")
print(f"Generated: {tokeniser.decode(gen_ids[0].tolist())}")

---
## 2. REINFORCE (Policy Gradient)

REINFORCE treats the language model as a **policy** $\pi_\theta$ that generates token sequences
(actions). A scalar **reward** $R$ is assigned to each complete sequence.

The policy gradient estimator is:

$$\nabla_\theta J = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t \mid s_t) \cdot (R(\tau) - b)\right]$$

where $b$ is an optional **baseline** (e.g. the running mean reward) to reduce variance.

**Reward heuristic:** We use a simple reward that encourages the model to generate sequences
that (a) contain a target word like "profit" and (b) have moderate length (penalising too short
or too long outputs).

In [None]:
# ---------------------------------------------------------------------------
# 2a. Reward function and rollout generation
# ---------------------------------------------------------------------------
TARGET_WORD = "profit"


def compute_reward(text: str) -> float:
    """Simple heuristic reward for a generated text.
    +2.0 if text contains the target word 'profit'
    Length penalty: -0.02 * |len - 40|  (prefer ~40 chars)
    Small bonus for ending with a period.
    """
    r = 0.0
    if TARGET_WORD in text.lower():
        r += 2.0
    # Length penalty: prefer around 40 characters
    r -= 0.02 * abs(len(text) - 40)
    # Punctuation bonus
    if text.strip().endswith("."):
        r += 0.5
    return r


def generate_with_logprobs(
    model: TinyGPT2,
    prompt_ids: torch.Tensor,
    max_new_tokens: int,
    temperature: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Generate tokens and collect log-probabilities for each sampled token.
    Returns (generated_ids [B, T+max_new], log_probs [B, max_new]).
    """
    model.eval()
    all_log_probs = []
    idx = prompt_ids.clone()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.max_len :]
        with torch.no_grad():
            logits = model(idx_cond)[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        next_tok = dist.sample()
        all_log_probs.append(dist.log_prob(next_tok))
        idx = torch.cat([idx, next_tok.unsqueeze(1)], dim=1)
    log_probs = torch.stack(all_log_probs, dim=1)  # (B, max_new_tokens)
    return idx, log_probs


# Quick test
rl_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
# Copy SFT weights as starting point
rl_model.load_state_dict(sft_model.state_dict())
_p = torch.tensor([tokeniser.encode("Q: Explain profit. A:", add_bos=True, add_eos=False)], device=DEVICE)
_gen, _lp = generate_with_logprobs(rl_model, _p, max_new_tokens=20)
print(f"Generated: {tokeniser.decode(_gen[0].tolist())}")
print(f"Log-prob shape: {_lp.shape}, sum: {_lp.sum().item():.3f}")

In [None]:
# ---------------------------------------------------------------------------
# 2b. REINFORCE training (no baseline)
# ---------------------------------------------------------------------------
def reinforce_step(
    model: TinyGPT2,
    prompt_ids: torch.Tensor,
    max_new_tokens: int,
    batch_size: int,
    temperature: float = 1.0,
    baseline: float = 0.0,
) -> tuple[float, float, list[float]]:
    """One REINFORCE gradient step.
    Returns (mean_reward, policy_loss, list_of_per_sample_grads_norms).
    """
    # Expand prompt for batch
    prompts = prompt_ids.expand(batch_size, -1)

    # Generate with gradients through log-probs
    model.eval()  # keep dropout off for generation
    all_log_probs = []
    idx = prompts.clone()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.max_len :]
        logits = model(idx_cond)[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        next_tok = dist.sample()
        all_log_probs.append(dist.log_prob(next_tok))
        idx = torch.cat([idx, next_tok.unsqueeze(1).detach()], dim=1)

    log_probs = torch.stack(all_log_probs, dim=1)  # (B, max_new_tokens)

    # Compute rewards
    rewards = []
    for i in range(batch_size):
        gen_text = tokeniser.decode(idx[i].tolist())
        rewards.append(compute_reward(gen_text))
    rewards_t = torch.tensor(rewards, dtype=torch.float32, device=DEVICE)

    # REINFORCE loss: -E[log_prob * (R - baseline)]
    advantages = rewards_t - baseline
    per_sample_loss = -(log_probs.sum(dim=1) * advantages)
    loss = per_sample_loss.mean()

    return loss, rewards_t.mean().item(), rewards


print("REINFORCE step function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 2c. Train with REINFORCE -- compare with/without baseline
# ---------------------------------------------------------------------------
def train_reinforce(use_baseline: bool, n_steps: int = 80, batch_size: int = 16, lr: float = 1e-4):
    model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
    model.load_state_dict(sft_model.state_dict())
    optimiser = torch.optim.Adam(model.parameters(), lr=lr)

    prompt = "Q: Explain profit. A:"
    prompt_ids = torch.tensor(
        [tokeniser.encode(prompt, add_bos=True, add_eos=False)], device=DEVICE
    )

    reward_history = []
    grad_norm_history = []
    running_baseline = 0.0

    for step in range(n_steps):
        baseline_val = running_baseline if use_baseline else 0.0
        loss, mean_r, _ = reinforce_step(
            model, prompt_ids, max_new_tokens=30, batch_size=batch_size,
            temperature=1.0, baseline=baseline_val,
        )
        optimiser.zero_grad()
        loss.backward()
        # Record gradient norm
        total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        grad_norm_history.append(total_norm.item())
        optimiser.step()

        reward_history.append(mean_r)
        running_baseline = 0.95 * running_baseline + 0.05 * mean_r

        if (step + 1) % 20 == 0:
            tag = "baseline" if use_baseline else "no baseline"
            print(f"  [{tag}] step {step+1:3d}  reward={mean_r:.3f}  grad_norm={total_norm:.3f}")

    return reward_history, grad_norm_history, model


print("Training REINFORCE without baseline...")
rewards_no_bl, grads_no_bl, _ = train_reinforce(use_baseline=False)
print("\nTraining REINFORCE with baseline...")
rewards_bl, grads_bl, reinforce_model = train_reinforce(use_baseline=True)
print("\nDone.")

In [None]:
# ---------------------------------------------------------------------------
# 2d. Visualise REINFORCE results
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Reward curves
axes[0].plot(rewards_no_bl, alpha=0.7, label="No baseline")
axes[0].plot(rewards_bl, alpha=0.7, label="With baseline")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Mean Reward")
axes[0].set_title("REINFORCE: Reward over Training")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Gradient norm comparison
axes[1].plot(grads_no_bl, alpha=0.7, label="No baseline")
axes[1].plot(grads_bl, alpha=0.7, label="With baseline")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Gradient Norm")
axes[1].set_title("REINFORCE: Gradient Variance")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

std_no_bl = np.std(grads_no_bl)
std_bl = np.std(grads_bl)
print(f"\nGrad norm std (no baseline):   {std_no_bl:.3f}")
print(f"Grad norm std (with baseline): {std_bl:.3f}")
reduction = (1 - std_bl / std_no_bl) * 100
print(f"-> The baseline reduces gradient norm std by {reduction:.0f}%.")

In [None]:
# ---------------------------------------------------------------------------
# 2e. Show sample generations from the REINFORCE-trained model
# ---------------------------------------------------------------------------
reinforce_model.eval()
prompt = "Q: Explain profit. A:"
prompt_ids = torch.tensor(
    [tokeniser.encode(prompt, add_bos=True, add_eos=False)], device=DEVICE
)
print("Sample generations from REINFORCE model:\n")
for i in range(5):
    gen = reinforce_model.generate(prompt_ids, max_new_tokens=40, temperature=0.8)
    text = tokeniser.decode(gen[0].tolist())
    r = compute_reward(text)
    print(f"  [{i+1}] (R={r:.2f}) {text}")

---
## 3. PPO from Scratch

Proximal Policy Optimisation (PPO) improves on vanilla REINFORCE with:

1. **Clipped surrogate objective** -- prevents the policy from changing too much:
   $$L^{\text{CLIP}} = \mathbb{E}_t\left[\min\left(r_t(\theta) \hat{A}_t,\; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t\right)\right]$$
   where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$

2. **Generalised Advantage Estimation (GAE)** -- a biased but lower-variance advantage:
   $$\hat{A}_t^{\text{GAE}} = \sum_{l=0}^{T-t-1} (\gamma\lambda)^l \delta_{t+l}, \quad \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$$

3. **Value function** -- a critic that estimates the expected return from each state.

4. **Multiple minibatch epochs** over the same rollout data.

In [None]:
# ---------------------------------------------------------------------------
# 3a. Value Network (Critic)
# ---------------------------------------------------------------------------
class ValueHead(nn.Module):
    """Scalar value head on top of the transformer backbone.
    Shares the backbone with the policy but has a separate linear head."""

    def __init__(self, d_model: int = 128):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, 1),
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """hidden_states: (B, T, d_model) -> values: (B, T)"""
        return self.head(hidden_states).squeeze(-1)


class ActorCritic(nn.Module):
    """Wraps TinyGPT2 (actor) and a ValueHead (critic)."""

    def __init__(self, vocab_size: int, d_model: int = 128, **kwargs):
        super().__init__()
        self.backbone = TinyGPT2(vocab_size, d_model=d_model, **kwargs)
        self.value_head = ValueHead(d_model)

    def forward(self, idx: torch.Tensor):
        """Returns (logits, values)."""
        B, T = idx.shape
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.backbone.drop(self.backbone.tok_emb(idx) + self.backbone.pos_emb(pos))
        for block in self.backbone.blocks:
            x = block(x)
        x = self.backbone.ln_f(x)  # (B, T, d_model)
        logits = self.backbone.head(x)  # (B, T, V)
        values = self.value_head(x)  # (B, T)
        return logits, values

    def get_logits(self, idx: torch.Tensor) -> torch.Tensor:
        logits, _ = self(idx)
        return logits


# Quick check
_ac = ActorCritic(VOCAB_SIZE)
print(f"ActorCritic parameters: {count_parameters(_ac):,}")
del _ac

In [None]:
# ---------------------------------------------------------------------------
# 3b. Rollout collection
# ---------------------------------------------------------------------------
@torch.no_grad()
def collect_rollouts(
    actor_critic: ActorCritic,
    prompt_ids: torch.Tensor,
    batch_size: int,
    max_new_tokens: int,
    temperature: float = 1.0,
) -> dict:
    """Generate sequences and collect data needed for PPO."""
    prompts = prompt_ids.expand(batch_size, -1)
    prompt_len = prompt_ids.shape[1]

    idx = prompts.clone()
    all_log_probs = []
    all_values = []
    all_actions = []

    for t in range(max_new_tokens):
        idx_cond = idx[:, -actor_critic.backbone.max_len :]
        logits, values = actor_critic(idx_cond)
        # Take last position logits and value
        last_logits = logits[:, -1, :] / temperature
        last_value = values[:, -1]  # (B,)
        probs = F.softmax(last_logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        all_actions.append(action)
        all_log_probs.append(log_prob)
        all_values.append(last_value)

        idx = torch.cat([idx, action.unsqueeze(1)], dim=1)

    # Compute terminal values (bootstrap = 0 for complete episodes)
    # Compute per-sequence rewards
    rewards_list = []
    for i in range(batch_size):
        gen_text = tokeniser.decode(idx[i].tolist())
        rewards_list.append(compute_reward(gen_text))

    # We treat the whole generation as one "step" with the reward at the end
    # For token-level PPO, we assign reward only at the last token
    token_rewards = torch.zeros(batch_size, max_new_tokens, device=DEVICE)
    for i in range(batch_size):
        token_rewards[i, -1] = rewards_list[i]

    return {
        "sequences": idx,  # (B, prompt_len + max_new_tokens)
        "actions": torch.stack(all_actions, dim=1),  # (B, max_new_tokens)
        "log_probs": torch.stack(all_log_probs, dim=1),  # (B, max_new_tokens)
        "values": torch.stack(all_values, dim=1),  # (B, max_new_tokens)
        "token_rewards": token_rewards,  # (B, max_new_tokens)
        "total_rewards": torch.tensor(rewards_list, device=DEVICE),  # (B,)
        "prompt_len": prompt_len,
    }


print("Rollout collection function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 3c. GAE Computation
# ---------------------------------------------------------------------------
def compute_gae(
    token_rewards: torch.Tensor,
    values: torch.Tensor,
    gamma: float = 1.0,
    lam: float = 0.95,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Generalised Advantage Estimation.

    Args:
        token_rewards: (B, T) rewards at each token step
        values: (B, T) value estimates from critic
        gamma: discount factor
        lam: GAE lambda parameter

    Returns:
        advantages: (B, T)
        returns: (B, T) = advantages + values
    """
    B, T = token_rewards.shape
    advantages = torch.zeros_like(token_rewards)
    last_gae = torch.zeros(B, device=token_rewards.device)

    for t in reversed(range(T)):
        if t == T - 1:
            next_value = torch.zeros(B, device=token_rewards.device)  # terminal
        else:
            next_value = values[:, t + 1]
        # TD error: delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
        delta = token_rewards[:, t] + gamma * next_value - values[:, t]
        # GAE: A_t = delta_t + gamma * lambda * A_{t+1}
        last_gae = delta + gamma * lam * last_gae
        advantages[:, t] = last_gae

    returns = advantages + values
    return advantages, returns


# Quick test
_r = torch.tensor([[0.0, 0.0, 1.0]])
_v = torch.tensor([[0.1, 0.2, 0.3]])
_adv, _ret = compute_gae(_r, _v, gamma=0.99, lam=0.95)
print(f"Test rewards: {_r}")
print(f"Test values:  {_v}")
print(f"GAE advantages: {_adv}")
print(f"Returns:        {_ret}")

In [None]:
# ---------------------------------------------------------------------------
# 3d. PPO Clipped Objective
# ---------------------------------------------------------------------------
def ppo_loss(
    new_log_probs: torch.Tensor,
    old_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    clip_eps: float = 0.2,
) -> torch.Tensor:
    """PPO clipped surrogate objective.

    L = min(r_t * A_t, clip(r_t, 1-eps, 1+eps) * A_t)

    We negate because we want to maximise the objective (gradient ascent).
    """
    # Importance sampling ratio
    ratio = torch.exp(new_log_probs - old_log_probs)  # r_t(theta)
    # Clipped ratio
    clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
    # Surrogate losses
    surr1 = ratio * advantages
    surr2 = clipped_ratio * advantages
    # Take the minimum (pessimistic bound) and negate for minimisation
    policy_loss = -torch.min(surr1, surr2).mean()
    return policy_loss


def value_loss(predicted_values: torch.Tensor, returns: torch.Tensor) -> torch.Tensor:
    """Simple MSE loss for the value function."""
    return F.mse_loss(predicted_values, returns)


print("PPO loss functions defined.")

In [None]:
# ---------------------------------------------------------------------------
# 3e. PPO update step with minibatch epochs
# ---------------------------------------------------------------------------
def ppo_update(
    actor_critic: ActorCritic,
    optimiser: torch.optim.Optimizer,
    rollout: dict,
    n_epochs: int = 4,
    minibatch_size: int = 8,
    clip_eps: float = 0.2,
    vf_coef: float = 0.5,
    entropy_coef: float = 0.01,
    gamma: float = 1.0,
    lam: float = 0.95,
) -> dict:
    """Perform PPO update over multiple epochs of minibatches."""
    sequences = rollout["sequences"]
    actions = rollout["actions"]
    old_log_probs = rollout["log_probs"]
    old_values = rollout["values"]
    token_rewards = rollout["token_rewards"]
    prompt_len = rollout["prompt_len"]

    # Compute GAE
    advantages, returns = compute_gae(token_rewards, old_values, gamma=gamma, lam=lam)
    # Normalise advantages
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    B = sequences.shape[0]
    T_gen = actions.shape[1]
    total_policy_loss = 0.0
    total_value_loss = 0.0
    n_updates = 0

    for epoch in range(n_epochs):
        # Shuffle indices
        perm = torch.randperm(B)
        for start in range(0, B, minibatch_size):
            mb_idx = perm[start : start + minibatch_size]
            mb_seq = sequences[mb_idx]  # (mb, total_len)
            mb_actions = actions[mb_idx]  # (mb, T_gen)
            mb_old_lp = old_log_probs[mb_idx]
            mb_advantages = advantages[mb_idx]
            mb_returns = returns[mb_idx]

            # Forward pass through actor-critic
            # We feed the full sequence and extract logits/values for the generated part
            full_logits, full_values = actor_critic(mb_seq)
            # Logits at positions [prompt_len-1 ... prompt_len+T_gen-2] predict tokens
            # at positions [prompt_len ... prompt_len+T_gen-1]
            gen_logits = full_logits[:, prompt_len - 1 : prompt_len - 1 + T_gen, :]
            gen_values = full_values[:, prompt_len - 1 : prompt_len - 1 + T_gen]

            # Compute new log probs
            gen_log_probs_all = F.log_softmax(gen_logits, dim=-1)
            new_log_probs = gen_log_probs_all.gather(
                2, mb_actions.unsqueeze(-1)
            ).squeeze(-1)

            # Entropy bonus
            entropy = -(F.softmax(gen_logits, dim=-1) * gen_log_probs_all).sum(-1).mean()

            # Losses
            p_loss = ppo_loss(new_log_probs, mb_old_lp, mb_advantages, clip_eps=clip_eps)
            v_loss = value_loss(gen_values, mb_returns)
            total_loss = p_loss + vf_coef * v_loss - entropy_coef * entropy

            optimiser.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), 1.0)
            optimiser.step()

            total_policy_loss += p_loss.item()
            total_value_loss += v_loss.item()
            n_updates += 1

    return {
        "policy_loss": total_policy_loss / max(n_updates, 1),
        "value_loss": total_value_loss / max(n_updates, 1),
    }


print("PPO update function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 3f. PPO Training Loop
# ---------------------------------------------------------------------------
ppo_ac = ActorCritic(VOCAB_SIZE).to(DEVICE)
# Initialise from SFT weights for the backbone
ppo_ac.backbone.load_state_dict(sft_model.state_dict())

ppo_optimiser = torch.optim.Adam(ppo_ac.parameters(), lr=5e-5)

prompt_text = "Q: Explain profit. A:"
prompt_ids = torch.tensor(
    [tokeniser.encode(prompt_text, add_bos=True, add_eos=False)], device=DEVICE
)

PPO_STEPS = 80
ROLLOUT_BATCH = 32
GEN_LEN = 30

ppo_reward_history = []
ppo_policy_loss_history = []
ppo_value_loss_history = []

print("Starting PPO training...\n")
for step in range(PPO_STEPS):
    # 1. Collect rollouts
    rollout = collect_rollouts(
        ppo_ac, prompt_ids, batch_size=ROLLOUT_BATCH,
        max_new_tokens=GEN_LEN, temperature=1.0,
    )
    mean_reward = rollout["total_rewards"].mean().item()
    ppo_reward_history.append(mean_reward)

    # 2. PPO update
    losses = ppo_update(
        ppo_ac, ppo_optimiser, rollout,
        n_epochs=4, minibatch_size=8, clip_eps=0.2,
        vf_coef=0.5, entropy_coef=0.01,
    )
    ppo_policy_loss_history.append(losses["policy_loss"])
    ppo_value_loss_history.append(losses["value_loss"])

    if (step + 1) % 10 == 0:
        print(
            f"Step {step+1:3d}/{PPO_STEPS}  "
            f"reward={mean_reward:.3f}  "
            f"pi_loss={losses['policy_loss']:.4f}  "
            f"v_loss={losses['value_loss']:.4f}"
        )

print("\nPPO training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 3g. PPO Training Plots
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(ppo_reward_history, linewidth=1.5)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Mean Reward")
axes[0].set_title("PPO: Reward over Training")
axes[0].grid(True, alpha=0.3)

axes[1].plot(ppo_policy_loss_history, linewidth=1.5, color="tab:orange")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Policy Loss")
axes[1].set_title("PPO: Clipped Policy Loss")
axes[1].grid(True, alpha=0.3)

axes[2].plot(ppo_value_loss_history, linewidth=1.5, color="tab:green")
axes[2].set_xlabel("Step")
axes[2].set_ylabel("Value Loss")
axes[2].set_title("PPO: Value Function Loss")
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 3h. PPO Sample Generations
# ---------------------------------------------------------------------------
ppo_ac.eval()
print("Sample generations from PPO-trained model:\n")
for i in range(5):
    with torch.no_grad():
        gen = ppo_ac.backbone.generate(prompt_ids, max_new_tokens=40, temperature=0.8)
    text = tokeniser.decode(gen[0].tolist())
    r = compute_reward(text)
    print(f"  [{i+1}] (R={r:.2f}) {text}")

---
## 4. Group Relative Policy Optimisation (GRPO)

GRPO (from DeepSeek-R1) is a **critic-free** alternative to PPO. For each prompt:

1. Generate a **group** of $G$ responses from the current policy.
2. Score each response with a reward function.
3. Compute **group-normalised advantages**: $A_i = \frac{R_i - \mu_G}{\sigma_G}$
4. Update the policy with a clipped objective (like PPO) but using these advantages.

Key benefit: no value network needed, reducing complexity and training instability.

Additionally, GRPO includes a KL penalty to keep the policy close to a reference:
$$\mathcal{L}_{\text{GRPO}} = -\frac{1}{G}\sum_{i=1}^{G}\frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min(r_t A_i,\; \text{clip}(r_t, 1-\epsilon, 1+\epsilon) A_i) - \beta\, D_{KL}\right]$$

In [None]:
# ---------------------------------------------------------------------------
# 4a. GRPO: Generate group and compute group-normalised advantages
# ---------------------------------------------------------------------------
def grpo_generate_group(
    model: TinyGPT2,
    prompt_ids: torch.Tensor,
    group_size: int,
    max_new_tokens: int,
    temperature: float = 1.0,
) -> dict:
    """Generate a group of responses and compute group-normalised advantages."""
    prompts = prompt_ids.expand(group_size, -1)
    prompt_len = prompt_ids.shape[1]

    # Generate with log-probs (need gradients for update)
    model.eval()
    idx = prompts.clone()
    all_log_probs = []
    all_actions = []

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.max_len :]
        with torch.no_grad():
            logits = model(idx_cond)[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        all_actions.append(action)
        all_log_probs.append(log_prob)
        idx = torch.cat([idx, action.unsqueeze(1)], dim=1)

    # Compute rewards
    rewards = []
    for i in range(group_size):
        text = tokeniser.decode(idx[i].tolist())
        rewards.append(compute_reward(text))
    rewards_t = torch.tensor(rewards, dtype=torch.float32, device=DEVICE)

    # Group-normalised advantages: A_i = (R_i - mean) / std
    mean_r = rewards_t.mean()
    std_r = rewards_t.std() + 1e-8
    advantages = (rewards_t - mean_r) / std_r  # (G,)

    return {
        "sequences": idx,
        "actions": torch.stack(all_actions, dim=1),
        "old_log_probs": torch.stack(all_log_probs, dim=1),
        "advantages": advantages,
        "rewards": rewards_t,
        "prompt_len": prompt_len,
    }


print("GRPO group generation function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 4b. GRPO Update Step
# ---------------------------------------------------------------------------
def grpo_update(
    model: TinyGPT2,
    ref_model: TinyGPT2,
    optimiser: torch.optim.Optimizer,
    group_data: dict,
    clip_eps: float = 0.2,
    beta_kl: float = 0.04,
    n_epochs: int = 2,
) -> float:
    """GRPO policy update with clipped objective and KL penalty."""
    sequences = group_data["sequences"]
    actions = group_data["actions"]
    old_log_probs = group_data["old_log_probs"]
    advantages = group_data["advantages"]  # (G,)
    prompt_len = group_data["prompt_len"]
    G, T_gen = actions.shape

    total_loss_val = 0.0
    for epoch in range(n_epochs):
        # Forward pass
        logits = model(sequences)  # (G, total_len, V)
        gen_logits = logits[:, prompt_len - 1 : prompt_len - 1 + T_gen, :]
        gen_log_probs = F.log_softmax(gen_logits, dim=-1)
        new_log_probs = gen_log_probs.gather(2, actions.unsqueeze(-1)).squeeze(-1)

        # Reference model log probs (for KL)
        with torch.no_grad():
            ref_logits = ref_model(sequences)
            ref_gen_logits = ref_logits[:, prompt_len - 1 : prompt_len - 1 + T_gen, :]
            ref_log_probs = F.log_softmax(ref_gen_logits, dim=-1)
            ref_lp = ref_log_probs.gather(2, actions.unsqueeze(-1)).squeeze(-1)

        # Importance sampling ratio
        ratio = torch.exp(new_log_probs - old_log_probs)  # (G, T_gen)
        clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)

        # Expand advantages to token level: (G,) -> (G, T_gen)
        adv_expanded = advantages.unsqueeze(1).expand_as(ratio)

        surr1 = ratio * adv_expanded
        surr2 = clipped_ratio * adv_expanded
        policy_loss = -torch.min(surr1, surr2).mean()

        # KL divergence penalty (approx): D_KL = exp(ref_lp - new_lp) - (ref_lp - new_lp) - 1
        log_ratio_kl = ref_lp - new_log_probs
        kl = (torch.exp(log_ratio_kl) - log_ratio_kl - 1.0).mean()

        loss = policy_loss + beta_kl * kl

        optimiser.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimiser.step()
        total_loss_val += loss.item()

    return total_loss_val / n_epochs


print("GRPO update function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 4c. GRPO Training Loop
# ---------------------------------------------------------------------------
grpo_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
grpo_model.load_state_dict(sft_model.state_dict())
grpo_ref = copy.deepcopy(grpo_model)  # frozen reference
grpo_ref.eval()
for p in grpo_ref.parameters():
    p.requires_grad = False

grpo_optimiser = torch.optim.Adam(grpo_model.parameters(), lr=5e-5)

GRPO_STEPS = 80
GROUP_SIZE = 32

grpo_reward_history = []

print("Starting GRPO training...\n")
for step in range(GRPO_STEPS):
    group_data = grpo_generate_group(
        grpo_model, prompt_ids, group_size=GROUP_SIZE,
        max_new_tokens=GEN_LEN, temperature=1.0,
    )
    mean_r = group_data["rewards"].mean().item()
    grpo_reward_history.append(mean_r)

    loss = grpo_update(
        grpo_model, grpo_ref, grpo_optimiser, group_data,
        clip_eps=0.2, beta_kl=0.04, n_epochs=2,
    )

    if (step + 1) % 10 == 0:
        print(f"Step {step+1:3d}/{GRPO_STEPS}  reward={mean_r:.3f}  loss={loss:.4f}")

print("\nGRPO training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 4d. Compare PPO vs GRPO
# ---------------------------------------------------------------------------
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(ppo_reward_history, label="PPO", linewidth=1.5)
ax.plot(grpo_reward_history, label="GRPO", linewidth=1.5)
ax.set_xlabel("Step")
ax.set_ylabel("Mean Reward")
ax.set_title("PPO vs GRPO: Reward over Training")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"PPO final reward (last 10 avg):  {np.mean(ppo_reward_history[-10:]):.3f}")
print(f"GRPO final reward (last 10 avg): {np.mean(grpo_reward_history[-10:]):.3f}")
print("\nGRPO achieves comparable rewards without needing a critic network.")

---
## 5. Bradley-Terry Reward Model

In Sections 2--4 the reward was a simple heuristic. In practice, the reward signal comes
from a **learned reward model** trained on human preference data.

The **Bradley-Terry** model (Bradley & Terry, 1952) defines the probability that response $y_w$
is preferred over $y_l$ as:

$$P(y_w \succ y_l) = \sigma\!\big(r_\phi(x, y_w) - r_\phi(x, y_l)\big)$$

The loss is the negative log-likelihood of observed preferences:
$$\mathcal{L}_{\text{BT}} = -\mathbb{E}\!\left[\log\sigma\!\big(r_\phi(x, y_w) - r_\phi(x, y_l)\big)\right]$$

Below we train a small reward model on synthetic finance QA preference pairs.

In [None]:
# ---------------------------------------------------------------------------
# 5a. Preference Data and Reward Model Architecture
# ---------------------------------------------------------------------------
# Each tuple: (prompt, preferred_response, dispreferred_response)
BT_PREFS = [
    ("What is a bond?",
     "A bond is a fixed-income instrument representing a loan from investor to borrower.",
     "A bond is some kind of financial thing."),
    ("Define stock.",
     "A stock represents ownership of a fraction of a corporation.",
     "Stock is stuff you buy."),
    ("What is volatility?",
     "Volatility measures the degree of variation in trading price over time.",
     "Volatility is when prices go up and down."),
    ("Explain diversification.",
     "Diversification is a risk management strategy mixing various investments.",
     "Diversification means buying different things."),
    ("What is a derivative?",
     "A derivative is a financial contract whose value depends on an underlying asset.",
     "A derivative is complicated finance stuff."),
    ("Define yield.",
     "Yield is the income return on an investment expressed as a percentage.",
     "Yield is the money you get."),
    ("What is liquidity?",
     "Liquidity refers to how quickly an asset converts to cash without loss.",
     "Liquidity is about cash."),
    ("Explain hedging.",
     "Hedging is an investment to reduce the risk of adverse price movements.",
     "Hedging is like insurance or something."),
    ("What is leverage?",
     "Leverage is using borrowed capital to increase potential return on investment.",
     "Leverage means borrowing money."),
    ("Define alpha.",
     "Alpha is the excess return of an investment relative to its benchmark index.",
     "Alpha is a greek letter used in finance."),
    ("What is arbitrage?",
     "Arbitrage is the simultaneous purchase and sale of an asset to profit from price differences.",
     "Arbitrage is free money in the market."),
    ("Explain short selling.",
     "Short selling involves borrowing shares and selling them, hoping to buy back at a lower price.",
     "Short selling is betting stocks go down."),
]


class RewardModel(nn.Module):
    """Small reward model: transformer backbone + scalar head."""

    def __init__(self, vocab_size: int, d_model: int = 96, n_heads: int = 4,
                 n_layers: int = 3, max_len: int = 128):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList(
            [TransformerBlock(d_model, n_heads, max_len, dropout=0.1)
             for _ in range(n_layers)]
        )
        self.ln_f = nn.LayerNorm(d_model)
        self.reward_head = nn.Linear(d_model, 1)
        self.max_len = max_len
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        """Returns scalar reward for each sequence in the batch. Shape: (B, 1)."""
        B, T = idx.shape
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        last_hidden = x[:, -1, :]
        reward = self.reward_head(last_hidden)
        return reward


rm = RewardModel(VOCAB_SIZE).to(DEVICE)
print(f"Reward Model parameters: {count_parameters(rm):,}")
print(f"Preference pairs: {len(BT_PREFS)}")

In [None]:
# ---------------------------------------------------------------------------
# 5b. Bradley-Terry Training and Evaluation
# ---------------------------------------------------------------------------
# Tokenise preference pairs
all_bt_data = []
for prompt, y_w, y_l in BT_PREFS:
    full_w = f"Q: {prompt} A: {y_w}"
    full_l = f"Q: {prompt} A: {y_l}"
    ids_w = tokeniser.encode(full_w, add_bos=True, add_eos=True)
    ids_l = tokeniser.encode(full_l, add_bos=True, add_eos=True)
    all_bt_data.append((ids_w, ids_l))

train_bt = all_bt_data[:9]
test_bt = all_bt_data[9:]


def pad_pair(ids_w, ids_l, pad_id=0):
    max_len = max(len(ids_w), len(ids_l))
    w = ids_w + [pad_id] * (max_len - len(ids_w))
    l = ids_l + [pad_id] * (max_len - len(ids_l))
    return (
        torch.tensor([w], dtype=torch.long, device=DEVICE),
        torch.tensor([l], dtype=torch.long, device=DEVICE),
    )


rm_optimiser = torch.optim.Adam(rm.parameters(), lr=1e-3)
BT_EPOCHS = 60
bt_losses = []

print("Training Bradley-Terry Reward Model...\n")
for epoch in range(BT_EPOCHS):
    epoch_loss = 0.0
    perm = np.random.permutation(len(train_bt))
    for i in perm:
        ids_w, ids_l = train_bt[i]
        tw, tl = pad_pair(ids_w, ids_l)
        r_w = rm(tw)
        r_l = rm(tl)
        loss = -F.logsigmoid(r_w - r_l).mean()
        rm_optimiser.zero_grad()
        loss.backward()
        rm_optimiser.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_bt)
    bt_losses.append(avg_loss)
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d}/{BT_EPOCHS}  loss={avg_loss:.4f}")

# --- Evaluation ---
rm.eval()

fig, ax = plt.subplots(1, 1, figsize=(7, 3.5))
ax.plot(bt_losses, linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Bradley-Terry Loss")
ax.set_title("Reward Model: Training Loss")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Accuracy on train and test
for label, data in [("Train", train_bt), ("Test", test_bt)]:
    correct = sum(
        1 for ids_w, ids_l in data
        for tw, tl in [pad_pair(ids_w, ids_l)]
        if rm(tw).item() > rm(tl).item()
    )
    print(f"{label} accuracy: {correct}/{len(data)} = {correct/len(data):.0%}")

---
## Summary

| Method | Requires Reward Model? | Requires Critic? | Key Mechanism |
|--------|----------------------|------------------|---------------|
| **SFT** | No | No | Teacher-forced cross-entropy on demonstrations |
| **REINFORCE** | Yes (or heuristic) | No | $\nabla J = \mathbb{E}[\nabla\log\pi \cdot R]$, high variance |
| **PPO** | Yes (or heuristic) | Yes (value fn) | Clipped ratio + GAE advantages |
| **GRPO** | Yes (or heuristic) | **No** | Group-normalised advantages, no critic |
| **Bradley-Terry RM** | N/A (this *is* the RM) | No | Pairwise preference loss |

- SFT provides the foundation; RL methods refine behaviour towards a reward signal.
- REINFORCE is simple but has high variance; baselines help.
- PPO stabilises training via clipping and GAE, but requires a critic.
- GRPO achieves similar results without a critic by using group-relative advantages.
- The Bradley-Terry model provides a principled way to learn rewards from human preferences.