# Week 2: RL Training Methods for Language Models

**MSc Course -- Generative Models in Finance**

This notebook walks through the key training paradigms used to align and improve language models after pretraining:

| Section | Method | Key Idea |
|---------|--------|----------|
| 1 | **SFT** | Supervised fine-tuning on instruction-response pairs |
| 2 | **REINFORCE** | Classic policy-gradient with a simple reward heuristic |
| 3 | **PPO** | Proximal Policy Optimisation with GAE and clipped objective |
| 4 | **GRPO** | Group Relative Policy Optimisation (critic-free) |
| 5 | **DPO** | Direct Preference Optimisation from pairwise preferences |
| 6 | **Bradley-Terry Reward Model** | Learn a reward model from preference data |
| 7 | **GRPO + Lean Verification** | RL for math with a formal theorem prover as reward |

All models are tiny (~1M parameters) and train on CPU in under 2 minutes.

**Dependencies:** `torch`, `numpy`, `matplotlib`, `requests`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import copy
import math
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(42)
np.random.seed(42)

DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")

## Shared Infrastructure: Tiny GPT-2 Model and Tokeniser

We build a minimal GPT-2-style transformer from scratch. It has:
- A small vocabulary (character-level + special tokens)
- 4 transformer layers, 4 attention heads, embedding dim 128
- ~1M parameters

This is shared across all experiments.

In [None]:
# ---------------------------------------------------------------------------
# Simple character-level tokeniser
# ---------------------------------------------------------------------------
class CharTokeniser:
    """Character-level tokeniser with special tokens."""

    def __init__(self):
        # Printable ASCII 32-126 plus special tokens
        chars = [chr(i) for i in range(32, 127)]  # 95 chars
        self.special = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3}
        self.char2id = {**self.special}
        for i, c in enumerate(chars):
            self.char2id[c] = len(self.special) + i
        self.id2char = {v: k for k, v in self.char2id.items()}
        self.vocab_size = len(self.char2id)
        self.pad_id = self.special["<pad>"]
        self.bos_id = self.special["<bos>"]
        self.eos_id = self.special["<eos>"]

    def encode(self, text: str, add_bos=True, add_eos=True) -> list[int]:
        ids = []
        if add_bos:
            ids.append(self.bos_id)
        for c in text:
            ids.append(self.char2id.get(c, self.special["<unk>"]))
        if add_eos:
            ids.append(self.eos_id)
        return ids

    def decode(self, ids: list[int]) -> str:
        chars = []
        for i in ids:
            tok = self.id2char.get(i, "")
            if tok in ("<pad>", "<bos>", "<eos>", "<unk>"):
                continue
            chars.append(tok)
        return "".join(chars)


tokeniser = CharTokeniser()
VOCAB_SIZE = tokeniser.vocab_size
print(f"Vocab size: {VOCAB_SIZE}")
print(f"Example encode: {tokeniser.encode('hello')}")
print(f"Example decode: {tokeniser.decode(tokeniser.encode('hello'))}")

In [None]:
# ---------------------------------------------------------------------------
# Tiny GPT-2 Model
# ---------------------------------------------------------------------------
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_len: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        # Causal mask
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x)  # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)
        # Reshape to (B, n_heads, T, head_dim)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        # Attention scores
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        out = att @ v  # (B, n_heads, T, head_dim)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_len: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, max_len, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class TinyGPT2(nn.Module):
    """Minimal GPT-2 style autoregressive language model."""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        n_heads: int = 4,
        n_layers: int = 4,
        max_len: int = 128,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.max_len = max_len
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList(
            [TransformerBlock(d_model, n_heads, max_len, dropout) for _ in range(n_layers)]
        )
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        # Weight tying
        self.head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        """Returns logits of shape (B, T, vocab_size)."""
        B, T = idx.shape
        assert T <= self.max_len, f"Sequence length {T} > max_len {self.max_len}"
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)  # (1, T)
        x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(
        self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0
    ) -> torch.Tensor:
        """Autoregressive generation via sampling."""
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_len :]
            logits = self(idx_cond)[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_tok = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_tok], dim=1)
        return idx


def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Instantiate and inspect
_test_model = TinyGPT2(VOCAB_SIZE)
print(f"TinyGPT2 parameters: {count_parameters(_test_model):,}")
del _test_model

---
## 1. Supervised Fine-Tuning (SFT)

SFT is the first step after pretraining. We take a set of **(instruction, response)** pairs and
fine-tune the model with **teacher-forced cross-entropy loss**: at each position $t$ the model
predicts the next token given the ground-truth prefix.

$$\mathcal{L}_{\text{SFT}} = -\frac{1}{T}\sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t})$$

We train on a small synthetic dataset of instruction-response pairs relevant to finance.

In [None]:
# ---------------------------------------------------------------------------
# 1a. SFT Dataset
# ---------------------------------------------------------------------------
SFT_DATA = [
    ("What is a bond?", "A bond is a fixed-income instrument representing a loan made by an investor to a borrower."),
    ("Define stock.", "A stock is a security that represents ownership of a fraction of a corporation."),
    ("What is volatility?", "Volatility measures the degree of variation in a trading price over time."),
    ("Explain diversification.", "Diversification is a risk management strategy mixing a variety of investments."),
    ("What is a derivative?", "A derivative is a financial contract whose value depends on an underlying asset."),
    ("Define yield.", "Yield is the income return on an investment, typically expressed as a percentage."),
    ("What is liquidity?", "Liquidity refers to how quickly an asset can be converted to cash without loss."),
    ("Explain hedging.", "Hedging is an investment to reduce the risk of adverse price movements in an asset."),
    ("What is a put option?", "A put option gives the holder the right to sell an asset at a specified price."),
    ("Define alpha.", "Alpha is the excess return of an investment relative to its benchmark index."),
    ("What is a call option?", "A call option gives the holder the right to buy an asset at a specified price."),
    ("Explain leverage.", "Leverage is the use of borrowed capital to increase the potential return on investment."),
    ("What is beta?", "Beta measures the sensitivity of an asset's returns to the overall market returns."),
    ("Define portfolio.", "A portfolio is a collection of financial investments like stocks, bonds, and cash."),
    ("What is arbitrage?", "Arbitrage is the simultaneous purchase and sale of an asset to profit from price differences."),
    ("Explain short selling.", "Short selling involves borrowing shares and selling them, hoping to buy back at lower price."),
]


class SFTDataset(Dataset):
    def __init__(self, data, tokeniser, max_len=128):
        self.samples = []
        for instruction, response in data:
            text = f"Q: {instruction} A: {response}"
            ids = tokeniser.encode(text, add_bos=True, add_eos=True)
            ids = ids[:max_len]  # truncate
            self.samples.append(torch.tensor(ids, dtype=torch.long))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


def collate_pad(batch, pad_id=0):
    max_len = max(len(s) for s in batch)
    padded = torch.full((len(batch), max_len), pad_id, dtype=torch.long)
    for i, s in enumerate(batch):
        padded[i, : len(s)] = s
    return padded


sft_dataset = SFTDataset(SFT_DATA, tokeniser)
sft_loader = DataLoader(sft_dataset, batch_size=8, shuffle=True, collate_fn=collate_pad)
print(f"SFT dataset size: {len(sft_dataset)} samples")
print(f"Example (decoded): {tokeniser.decode(sft_dataset[0].tolist())}")

In [None]:
# ---------------------------------------------------------------------------
# 1b. SFT Training Loop
# ---------------------------------------------------------------------------
sft_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
sft_optimiser = torch.optim.AdamW(sft_model.parameters(), lr=3e-4, weight_decay=0.01)

NUM_SFT_EPOCHS = 60
sft_losses = []

sft_model.train()
for epoch in range(NUM_SFT_EPOCHS):
    epoch_loss = 0.0
    n_batches = 0
    for batch in sft_loader:
        batch = batch.to(DEVICE)
        # Teacher forcing: input = tokens[:-1], target = tokens[1:]
        inputs = batch[:, :-1]
        targets = batch[:, 1:]
        logits = sft_model(inputs)  # (B, T-1, V)
        # Flatten for cross-entropy; ignore padding (id=0)
        loss = F.cross_entropy(
            logits.reshape(-1, VOCAB_SIZE),
            targets.reshape(-1),
            ignore_index=tokeniser.pad_id,
        )
        sft_optimiser.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(sft_model.parameters(), 1.0)
        sft_optimiser.step()
        epoch_loss += loss.item()
        n_batches += 1
    avg_loss = epoch_loss / n_batches
    sft_losses.append(avg_loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d}/{NUM_SFT_EPOCHS}  loss={avg_loss:.4f}")

print("\nSFT training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 1c. SFT Training Loss Curve and Sample Generation
# ---------------------------------------------------------------------------
fig, ax = plt.subplots(1, 1, figsize=(7, 3.5))
ax.plot(sft_losses, linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Cross-Entropy Loss")
ax.set_title("SFT Training Loss")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Generate a sample
sft_model.eval()
prompt = "Q: What is a bond? A:"
prompt_ids = torch.tensor([tokeniser.encode(prompt, add_bos=True, add_eos=False)], device=DEVICE)
gen_ids = sft_model.generate(prompt_ids, max_new_tokens=60, temperature=0.7)
print(f"Prompt:    {prompt}")
print(f"Generated: {tokeniser.decode(gen_ids[0].tolist())}")

---
## 2. REINFORCE (Policy Gradient)

REINFORCE treats the language model as a **policy** $\pi_\theta$ that generates token sequences
(actions). A scalar **reward** $R$ is assigned to each complete sequence.

The policy gradient estimator is:

$$\nabla_\theta J = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t \mid s_t) \cdot (R(\tau) - b)\right]$$

where $b$ is an optional **baseline** (e.g. the running mean reward) to reduce variance.

**Reward heuristic:** We use a simple reward that encourages the model to generate sequences
that (a) contain a target word like "profit" and (b) have moderate length (penalising too short
or too long outputs).

In [None]:
# ---------------------------------------------------------------------------
# 2a. Reward function and rollout generation
# ---------------------------------------------------------------------------
TARGET_WORD = "profit"


def compute_reward(text: str) -> float:
    """Simple heuristic reward for a generated text.
    +2.0 if text contains the target word 'profit'
    Length penalty: -0.02 * |len - 40|  (prefer ~40 chars)
    Small bonus for ending with a period.
    """
    r = 0.0
    if TARGET_WORD in text.lower():
        r += 2.0
    # Length penalty: prefer around 40 characters
    r -= 0.02 * abs(len(text) - 40)
    # Punctuation bonus
    if text.strip().endswith("."):
        r += 0.5
    return r


def generate_with_logprobs(
    model: TinyGPT2,
    prompt_ids: torch.Tensor,
    max_new_tokens: int,
    temperature: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Generate tokens and collect log-probabilities for each sampled token.
    Returns (generated_ids [B, T+max_new], log_probs [B, max_new]).
    """
    model.eval()
    all_log_probs = []
    idx = prompt_ids.clone()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.max_len :]
        with torch.no_grad():
            logits = model(idx_cond)[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        next_tok = dist.sample()
        all_log_probs.append(dist.log_prob(next_tok))
        idx = torch.cat([idx, next_tok.unsqueeze(1)], dim=1)
    log_probs = torch.stack(all_log_probs, dim=1)  # (B, max_new_tokens)
    return idx, log_probs


# Quick test
rl_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
# Copy SFT weights as starting point
rl_model.load_state_dict(sft_model.state_dict())
_p = torch.tensor([tokeniser.encode("Q: Explain profit. A:", add_bos=True, add_eos=False)], device=DEVICE)
_gen, _lp = generate_with_logprobs(rl_model, _p, max_new_tokens=20)
print(f"Generated: {tokeniser.decode(_gen[0].tolist())}")
print(f"Log-prob shape: {_lp.shape}, sum: {_lp.sum().item():.3f}")

In [None]:
# ---------------------------------------------------------------------------
# 2b. REINFORCE training (no baseline)
# ---------------------------------------------------------------------------
def reinforce_step(
    model: TinyGPT2,
    prompt_ids: torch.Tensor,
    max_new_tokens: int,
    batch_size: int,
    temperature: float = 1.0,
    baseline: float = 0.0,
) -> tuple[float, float, list[float]]:
    """One REINFORCE gradient step.
    Returns (mean_reward, policy_loss, list_of_per_sample_grads_norms).
    """
    # Expand prompt for batch
    prompts = prompt_ids.expand(batch_size, -1)

    # Generate with gradients through log-probs
    model.eval()  # keep dropout off for generation
    all_log_probs = []
    idx = prompts.clone()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.max_len :]
        logits = model(idx_cond)[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        next_tok = dist.sample()
        all_log_probs.append(dist.log_prob(next_tok))
        idx = torch.cat([idx, next_tok.unsqueeze(1).detach()], dim=1)

    log_probs = torch.stack(all_log_probs, dim=1)  # (B, max_new_tokens)

    # Compute rewards
    rewards = []
    for i in range(batch_size):
        gen_text = tokeniser.decode(idx[i].tolist())
        rewards.append(compute_reward(gen_text))
    rewards_t = torch.tensor(rewards, dtype=torch.float32, device=DEVICE)

    # REINFORCE loss: -E[log_prob * (R - baseline)]
    advantages = rewards_t - baseline
    per_sample_loss = -(log_probs.sum(dim=1) * advantages)
    loss = per_sample_loss.mean()

    return loss, rewards_t.mean().item(), rewards


print("REINFORCE step function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 2c. Train with REINFORCE -- compare with/without baseline
# ---------------------------------------------------------------------------
def train_reinforce(use_baseline: bool, n_steps: int = 80, batch_size: int = 16, lr: float = 1e-4):
    model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
    model.load_state_dict(sft_model.state_dict())
    optimiser = torch.optim.Adam(model.parameters(), lr=lr)

    prompt = "Q: Explain profit. A:"
    prompt_ids = torch.tensor(
        [tokeniser.encode(prompt, add_bos=True, add_eos=False)], device=DEVICE
    )

    reward_history = []
    grad_norm_history = []
    running_baseline = 0.0

    for step in range(n_steps):
        baseline_val = running_baseline if use_baseline else 0.0
        loss, mean_r, _ = reinforce_step(
            model, prompt_ids, max_new_tokens=30, batch_size=batch_size,
            temperature=1.0, baseline=baseline_val,
        )
        optimiser.zero_grad()
        loss.backward()
        # Record gradient norm
        total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        grad_norm_history.append(total_norm.item())
        optimiser.step()

        reward_history.append(mean_r)
        running_baseline = 0.95 * running_baseline + 0.05 * mean_r

        if (step + 1) % 20 == 0:
            tag = "baseline" if use_baseline else "no baseline"
            print(f"  [{tag}] step {step+1:3d}  reward={mean_r:.3f}  grad_norm={total_norm:.3f}")

    return reward_history, grad_norm_history, model


print("Training REINFORCE without baseline...")
rewards_no_bl, grads_no_bl, _ = train_reinforce(use_baseline=False)
print("\nTraining REINFORCE with baseline...")
rewards_bl, grads_bl, reinforce_model = train_reinforce(use_baseline=True)
print("\nDone.")

In [None]:
# ---------------------------------------------------------------------------
# 2d. Visualise REINFORCE results
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Reward curves
axes[0].plot(rewards_no_bl, alpha=0.7, label="No baseline")
axes[0].plot(rewards_bl, alpha=0.7, label="With baseline")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Mean Reward")
axes[0].set_title("REINFORCE: Reward over Training")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Gradient norm comparison
axes[1].plot(grads_no_bl, alpha=0.7, label="No baseline")
axes[1].plot(grads_bl, alpha=0.7, label="With baseline")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Gradient Norm")
axes[1].set_title("REINFORCE: Gradient Variance")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

std_no_bl = np.std(grads_no_bl)
std_bl = np.std(grads_bl)
print(f"\nGrad norm std (no baseline):   {std_no_bl:.3f}")
print(f"Grad norm std (with baseline): {std_bl:.3f}")
reduction = (1 - std_bl / std_no_bl) * 100
print(f"-> The baseline reduces gradient norm std by {reduction:.0f}%.")

In [None]:
# ---------------------------------------------------------------------------
# 2e. Show sample generations from the REINFORCE-trained model
# ---------------------------------------------------------------------------
reinforce_model.eval()
prompt = "Q: Explain profit. A:"
prompt_ids = torch.tensor(
    [tokeniser.encode(prompt, add_bos=True, add_eos=False)], device=DEVICE
)
print("Sample generations from REINFORCE model:\n")
for i in range(5):
    gen = reinforce_model.generate(prompt_ids, max_new_tokens=40, temperature=0.8)
    text = tokeniser.decode(gen[0].tolist())
    r = compute_reward(text)
    print(f"  [{i+1}] (R={r:.2f}) {text}")

---
## 3. PPO from Scratch

Proximal Policy Optimisation (PPO) improves on vanilla REINFORCE with:

1. **Clipped surrogate objective** -- prevents the policy from changing too much:
   $$L^{\text{CLIP}} = \mathbb{E}_t\left[\min\left(r_t(\theta) \hat{A}_t,\; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t\right)\right]$$
   where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$

2. **Generalised Advantage Estimation (GAE)** -- a biased but lower-variance advantage:
   $$\hat{A}_t^{\text{GAE}} = \sum_{l=0}^{T-t-1} (\gamma\lambda)^l \delta_{t+l}, \quad \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$$

3. **Value function** -- a critic that estimates the expected return from each state.

4. **Multiple minibatch epochs** over the same rollout data.

In [None]:
# ---------------------------------------------------------------------------
# 3a. Value Network (Critic)
# ---------------------------------------------------------------------------
class ValueHead(nn.Module):
    """Scalar value head on top of the transformer backbone.
    Shares the backbone with the policy but has a separate linear head."""

    def __init__(self, d_model: int = 128):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, 1),
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """hidden_states: (B, T, d_model) -> values: (B, T)"""
        return self.head(hidden_states).squeeze(-1)


class ActorCritic(nn.Module):
    """Wraps TinyGPT2 (actor) and a ValueHead (critic)."""

    def __init__(self, vocab_size: int, d_model: int = 128, **kwargs):
        super().__init__()
        self.backbone = TinyGPT2(vocab_size, d_model=d_model, **kwargs)
        self.value_head = ValueHead(d_model)

    def forward(self, idx: torch.Tensor):
        """Returns (logits, values)."""
        B, T = idx.shape
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.backbone.drop(self.backbone.tok_emb(idx) + self.backbone.pos_emb(pos))
        for block in self.backbone.blocks:
            x = block(x)
        x = self.backbone.ln_f(x)  # (B, T, d_model)
        logits = self.backbone.head(x)  # (B, T, V)
        values = self.value_head(x)  # (B, T)
        return logits, values

    def get_logits(self, idx: torch.Tensor) -> torch.Tensor:
        logits, _ = self(idx)
        return logits


# Quick check
_ac = ActorCritic(VOCAB_SIZE)
print(f"ActorCritic parameters: {count_parameters(_ac):,}")
del _ac

In [None]:
# ---------------------------------------------------------------------------
# 3b. Rollout collection
# ---------------------------------------------------------------------------
@torch.no_grad()
def collect_rollouts(
    actor_critic: ActorCritic,
    prompt_ids: torch.Tensor,
    batch_size: int,
    max_new_tokens: int,
    temperature: float = 1.0,
) -> dict:
    """Generate sequences and collect data needed for PPO."""
    prompts = prompt_ids.expand(batch_size, -1)
    prompt_len = prompt_ids.shape[1]

    idx = prompts.clone()
    all_log_probs = []
    all_values = []
    all_actions = []

    for t in range(max_new_tokens):
        idx_cond = idx[:, -actor_critic.backbone.max_len :]
        logits, values = actor_critic(idx_cond)
        # Take last position logits and value
        last_logits = logits[:, -1, :] / temperature
        last_value = values[:, -1]  # (B,)
        probs = F.softmax(last_logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        all_actions.append(action)
        all_log_probs.append(log_prob)
        all_values.append(last_value)

        idx = torch.cat([idx, action.unsqueeze(1)], dim=1)

    # Compute terminal values (bootstrap = 0 for complete episodes)
    # Compute per-sequence rewards
    rewards_list = []
    for i in range(batch_size):
        gen_text = tokeniser.decode(idx[i].tolist())
        rewards_list.append(compute_reward(gen_text))

    # We treat the whole generation as one "step" with the reward at the end
    # For token-level PPO, we assign reward only at the last token
    token_rewards = torch.zeros(batch_size, max_new_tokens, device=DEVICE)
    for i in range(batch_size):
        token_rewards[i, -1] = rewards_list[i]

    return {
        "sequences": idx,  # (B, prompt_len + max_new_tokens)
        "actions": torch.stack(all_actions, dim=1),  # (B, max_new_tokens)
        "log_probs": torch.stack(all_log_probs, dim=1),  # (B, max_new_tokens)
        "values": torch.stack(all_values, dim=1),  # (B, max_new_tokens)
        "token_rewards": token_rewards,  # (B, max_new_tokens)
        "total_rewards": torch.tensor(rewards_list, device=DEVICE),  # (B,)
        "prompt_len": prompt_len,
    }


print("Rollout collection function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 3c. GAE Computation
# ---------------------------------------------------------------------------
def compute_gae(
    token_rewards: torch.Tensor,
    values: torch.Tensor,
    gamma: float = 1.0,
    lam: float = 0.95,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Generalised Advantage Estimation.

    Args:
        token_rewards: (B, T) rewards at each token step
        values: (B, T) value estimates from critic
        gamma: discount factor
        lam: GAE lambda parameter

    Returns:
        advantages: (B, T)
        returns: (B, T) = advantages + values
    """
    B, T = token_rewards.shape
    advantages = torch.zeros_like(token_rewards)
    last_gae = torch.zeros(B, device=token_rewards.device)

    for t in reversed(range(T)):
        if t == T - 1:
            next_value = torch.zeros(B, device=token_rewards.device)  # terminal
        else:
            next_value = values[:, t + 1]
        # TD error: delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
        delta = token_rewards[:, t] + gamma * next_value - values[:, t]
        # GAE: A_t = delta_t + gamma * lambda * A_{t+1}
        last_gae = delta + gamma * lam * last_gae
        advantages[:, t] = last_gae

    returns = advantages + values
    return advantages, returns


# Quick test
_r = torch.tensor([[0.0, 0.0, 1.0]])
_v = torch.tensor([[0.1, 0.2, 0.3]])
_adv, _ret = compute_gae(_r, _v, gamma=0.99, lam=0.95)
print(f"Test rewards: {_r}")
print(f"Test values:  {_v}")
print(f"GAE advantages: {_adv}")
print(f"Returns:        {_ret}")

In [None]:
# ---------------------------------------------------------------------------
# 3d. PPO Clipped Objective
# ---------------------------------------------------------------------------
def ppo_loss(
    new_log_probs: torch.Tensor,
    old_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    clip_eps: float = 0.2,
) -> torch.Tensor:
    """PPO clipped surrogate objective.

    L = min(r_t * A_t, clip(r_t, 1-eps, 1+eps) * A_t)

    We negate because we want to maximise the objective (gradient ascent).
    """
    # Importance sampling ratio
    ratio = torch.exp(new_log_probs - old_log_probs)  # r_t(theta)
    # Clipped ratio
    clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
    # Surrogate losses
    surr1 = ratio * advantages
    surr2 = clipped_ratio * advantages
    # Take the minimum (pessimistic bound) and negate for minimisation
    policy_loss = -torch.min(surr1, surr2).mean()
    return policy_loss


def value_loss(predicted_values: torch.Tensor, returns: torch.Tensor) -> torch.Tensor:
    """Simple MSE loss for the value function."""
    return F.mse_loss(predicted_values, returns)


print("PPO loss functions defined.")

In [None]:
# ---------------------------------------------------------------------------
# 3e. PPO update step with minibatch epochs
# ---------------------------------------------------------------------------
def ppo_update(
    actor_critic: ActorCritic,
    optimiser: torch.optim.Optimizer,
    rollout: dict,
    n_epochs: int = 4,
    minibatch_size: int = 8,
    clip_eps: float = 0.2,
    vf_coef: float = 0.5,
    entropy_coef: float = 0.01,
    gamma: float = 1.0,
    lam: float = 0.95,
) -> dict:
    """Perform PPO update over multiple epochs of minibatches."""
    sequences = rollout["sequences"]
    actions = rollout["actions"]
    old_log_probs = rollout["log_probs"]
    old_values = rollout["values"]
    token_rewards = rollout["token_rewards"]
    prompt_len = rollout["prompt_len"]

    # Compute GAE
    advantages, returns = compute_gae(token_rewards, old_values, gamma=gamma, lam=lam)
    # Normalise advantages
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    B = sequences.shape[0]
    T_gen = actions.shape[1]
    total_policy_loss = 0.0
    total_value_loss = 0.0
    n_updates = 0

    for epoch in range(n_epochs):
        # Shuffle indices
        perm = torch.randperm(B)
        for start in range(0, B, minibatch_size):
            mb_idx = perm[start : start + minibatch_size]
            mb_seq = sequences[mb_idx]  # (mb, total_len)
            mb_actions = actions[mb_idx]  # (mb, T_gen)
            mb_old_lp = old_log_probs[mb_idx]
            mb_advantages = advantages[mb_idx]
            mb_returns = returns[mb_idx]

            # Forward pass through actor-critic
            # We feed the full sequence and extract logits/values for the generated part
            full_logits, full_values = actor_critic(mb_seq)
            # Logits at positions [prompt_len-1 ... prompt_len+T_gen-2] predict tokens
            # at positions [prompt_len ... prompt_len+T_gen-1]
            gen_logits = full_logits[:, prompt_len - 1 : prompt_len - 1 + T_gen, :]
            gen_values = full_values[:, prompt_len - 1 : prompt_len - 1 + T_gen]

            # Compute new log probs
            gen_log_probs_all = F.log_softmax(gen_logits, dim=-1)
            new_log_probs = gen_log_probs_all.gather(
                2, mb_actions.unsqueeze(-1)
            ).squeeze(-1)

            # Entropy bonus
            entropy = -(F.softmax(gen_logits, dim=-1) * gen_log_probs_all).sum(-1).mean()

            # Losses
            p_loss = ppo_loss(new_log_probs, mb_old_lp, mb_advantages, clip_eps=clip_eps)
            v_loss = value_loss(gen_values, mb_returns)
            total_loss = p_loss + vf_coef * v_loss - entropy_coef * entropy

            optimiser.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor_critic.parameters(), 1.0)
            optimiser.step()

            total_policy_loss += p_loss.item()
            total_value_loss += v_loss.item()
            n_updates += 1

    return {
        "policy_loss": total_policy_loss / max(n_updates, 1),
        "value_loss": total_value_loss / max(n_updates, 1),
    }


print("PPO update function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 3f. PPO Training Loop
# ---------------------------------------------------------------------------
ppo_ac = ActorCritic(VOCAB_SIZE).to(DEVICE)
# Initialise from SFT weights for the backbone
ppo_ac.backbone.load_state_dict(sft_model.state_dict())

ppo_optimiser = torch.optim.Adam(ppo_ac.parameters(), lr=5e-5)

prompt_text = "Q: Explain profit. A:"
prompt_ids = torch.tensor(
    [tokeniser.encode(prompt_text, add_bos=True, add_eos=False)], device=DEVICE
)

PPO_STEPS = 80
ROLLOUT_BATCH = 32
GEN_LEN = 30

ppo_reward_history = []
ppo_policy_loss_history = []
ppo_value_loss_history = []

print("Starting PPO training...\n")
for step in range(PPO_STEPS):
    # 1. Collect rollouts
    rollout = collect_rollouts(
        ppo_ac, prompt_ids, batch_size=ROLLOUT_BATCH,
        max_new_tokens=GEN_LEN, temperature=1.0,
    )
    mean_reward = rollout["total_rewards"].mean().item()
    ppo_reward_history.append(mean_reward)

    # 2. PPO update
    losses = ppo_update(
        ppo_ac, ppo_optimiser, rollout,
        n_epochs=4, minibatch_size=8, clip_eps=0.2,
        vf_coef=0.5, entropy_coef=0.01,
    )
    ppo_policy_loss_history.append(losses["policy_loss"])
    ppo_value_loss_history.append(losses["value_loss"])

    if (step + 1) % 10 == 0:
        print(
            f"Step {step+1:3d}/{PPO_STEPS}  "
            f"reward={mean_reward:.3f}  "
            f"pi_loss={losses['policy_loss']:.4f}  "
            f"v_loss={losses['value_loss']:.4f}"
        )

print("\nPPO training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 3g. PPO Training Plots
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(ppo_reward_history, linewidth=1.5)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Mean Reward")
axes[0].set_title("PPO: Reward over Training")
axes[0].grid(True, alpha=0.3)

axes[1].plot(ppo_policy_loss_history, linewidth=1.5, color="tab:orange")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Policy Loss")
axes[1].set_title("PPO: Clipped Policy Loss")
axes[1].grid(True, alpha=0.3)

axes[2].plot(ppo_value_loss_history, linewidth=1.5, color="tab:green")
axes[2].set_xlabel("Step")
axes[2].set_ylabel("Value Loss")
axes[2].set_title("PPO: Value Function Loss")
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ---------------------------------------------------------------------------
# 3h. PPO Sample Generations
# ---------------------------------------------------------------------------
ppo_ac.eval()
print("Sample generations from PPO-trained model:\n")
for i in range(5):
    with torch.no_grad():
        gen = ppo_ac.backbone.generate(prompt_ids, max_new_tokens=40, temperature=0.8)
    text = tokeniser.decode(gen[0].tolist())
    r = compute_reward(text)
    print(f"  [{i+1}] (R={r:.2f}) {text}")

---
## 4. Group Relative Policy Optimisation (GRPO)

GRPO (from DeepSeek-R1) is a **critic-free** alternative to PPO. For each prompt:

1. Generate a **group** of $G$ responses from the current policy.
2. Score each response with a reward function.
3. Compute **group-normalised advantages**: $A_i = \frac{R_i - \mu_G}{\sigma_G}$
4. Update the policy with a clipped objective (like PPO) but using these advantages.

Key benefit: no value network needed, reducing complexity and training instability.

Additionally, GRPO includes a KL penalty to keep the policy close to a reference:
$$\mathcal{L}_{\text{GRPO}} = -\frac{1}{G}\sum_{i=1}^{G}\frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min(r_t A_i,\; \text{clip}(r_t, 1-\epsilon, 1+\epsilon) A_i) - \beta\, D_{KL}\right]$$

In [None]:
# ---------------------------------------------------------------------------
# 4a. GRPO: Generate group and compute group-normalised advantages
# ---------------------------------------------------------------------------
def grpo_generate_group(
    model: TinyGPT2,
    prompt_ids: torch.Tensor,
    group_size: int,
    max_new_tokens: int,
    temperature: float = 1.0,
) -> dict:
    """Generate a group of responses and compute group-normalised advantages."""
    prompts = prompt_ids.expand(group_size, -1)
    prompt_len = prompt_ids.shape[1]

    # Generate with log-probs (need gradients for update)
    model.eval()
    idx = prompts.clone()
    all_log_probs = []
    all_actions = []

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.max_len :]
        with torch.no_grad():
            logits = model(idx_cond)[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        all_actions.append(action)
        all_log_probs.append(log_prob)
        idx = torch.cat([idx, action.unsqueeze(1)], dim=1)

    # Compute rewards
    rewards = []
    for i in range(group_size):
        text = tokeniser.decode(idx[i].tolist())
        rewards.append(compute_reward(text))
    rewards_t = torch.tensor(rewards, dtype=torch.float32, device=DEVICE)

    # Group-normalised advantages: A_i = (R_i - mean) / std
    mean_r = rewards_t.mean()
    std_r = rewards_t.std() + 1e-8
    advantages = (rewards_t - mean_r) / std_r  # (G,)

    return {
        "sequences": idx,
        "actions": torch.stack(all_actions, dim=1),
        "old_log_probs": torch.stack(all_log_probs, dim=1),
        "advantages": advantages,
        "rewards": rewards_t,
        "prompt_len": prompt_len,
    }


print("GRPO group generation function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 4b. GRPO Update Step
# ---------------------------------------------------------------------------
def grpo_update(
    model: TinyGPT2,
    ref_model: TinyGPT2,
    optimiser: torch.optim.Optimizer,
    group_data: dict,
    clip_eps: float = 0.2,
    beta_kl: float = 0.04,
    n_epochs: int = 2,
) -> float:
    """GRPO policy update with clipped objective and KL penalty."""
    sequences = group_data["sequences"]
    actions = group_data["actions"]
    old_log_probs = group_data["old_log_probs"]
    advantages = group_data["advantages"]  # (G,)
    prompt_len = group_data["prompt_len"]
    G, T_gen = actions.shape

    total_loss_val = 0.0
    for epoch in range(n_epochs):
        # Forward pass
        logits = model(sequences)  # (G, total_len, V)
        gen_logits = logits[:, prompt_len - 1 : prompt_len - 1 + T_gen, :]
        gen_log_probs = F.log_softmax(gen_logits, dim=-1)
        new_log_probs = gen_log_probs.gather(2, actions.unsqueeze(-1)).squeeze(-1)

        # Reference model log probs (for KL)
        with torch.no_grad():
            ref_logits = ref_model(sequences)
            ref_gen_logits = ref_logits[:, prompt_len - 1 : prompt_len - 1 + T_gen, :]
            ref_log_probs = F.log_softmax(ref_gen_logits, dim=-1)
            ref_lp = ref_log_probs.gather(2, actions.unsqueeze(-1)).squeeze(-1)

        # Importance sampling ratio
        ratio = torch.exp(new_log_probs - old_log_probs)  # (G, T_gen)
        clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)

        # Expand advantages to token level: (G,) -> (G, T_gen)
        adv_expanded = advantages.unsqueeze(1).expand_as(ratio)

        surr1 = ratio * adv_expanded
        surr2 = clipped_ratio * adv_expanded
        policy_loss = -torch.min(surr1, surr2).mean()

        # KL divergence penalty (approx): D_KL = exp(ref_lp - new_lp) - (ref_lp - new_lp) - 1
        log_ratio_kl = ref_lp - new_log_probs
        kl = (torch.exp(log_ratio_kl) - log_ratio_kl - 1.0).mean()

        loss = policy_loss + beta_kl * kl

        optimiser.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimiser.step()
        total_loss_val += loss.item()

    return total_loss_val / n_epochs


print("GRPO update function defined.")

In [None]:
# ---------------------------------------------------------------------------
# 4c. GRPO Training Loop
# ---------------------------------------------------------------------------
grpo_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
grpo_model.load_state_dict(sft_model.state_dict())
grpo_ref = copy.deepcopy(grpo_model)  # frozen reference
grpo_ref.eval()
for p in grpo_ref.parameters():
    p.requires_grad = False

grpo_optimiser = torch.optim.Adam(grpo_model.parameters(), lr=5e-5)

GRPO_STEPS = 80
GROUP_SIZE = 32

grpo_reward_history = []

print("Starting GRPO training...\n")
for step in range(GRPO_STEPS):
    group_data = grpo_generate_group(
        grpo_model, prompt_ids, group_size=GROUP_SIZE,
        max_new_tokens=GEN_LEN, temperature=1.0,
    )
    mean_r = group_data["rewards"].mean().item()
    grpo_reward_history.append(mean_r)

    loss = grpo_update(
        grpo_model, grpo_ref, grpo_optimiser, group_data,
        clip_eps=0.2, beta_kl=0.04, n_epochs=2,
    )

    if (step + 1) % 10 == 0:
        print(f"Step {step+1:3d}/{GRPO_STEPS}  reward={mean_r:.3f}  loss={loss:.4f}")

print("\nGRPO training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 4d. Compare PPO vs GRPO
# ---------------------------------------------------------------------------
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(ppo_reward_history, label="PPO", linewidth=1.5)
ax.plot(grpo_reward_history, label="GRPO", linewidth=1.5)
ax.set_xlabel("Step")
ax.set_ylabel("Mean Reward")
ax.set_title("PPO vs GRPO: Reward over Training")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"PPO final reward (last 10 avg):  {np.mean(ppo_reward_history[-10:]):.3f}")
print(f"GRPO final reward (last 10 avg): {np.mean(grpo_reward_history[-10:]):.3f}")
print("\nGRPO achieves comparable rewards without needing a critic network.")

---
## 5. Direct Preference Optimisation (DPO)

DPO eliminates the need for a separate reward model entirely. Given **preference pairs**
$(y_w, y_l)$ where $y_w$ is preferred over $y_l$, DPO directly optimises the policy:

$$\mathcal{L}_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\left(\log\frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \log\frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right)\right)\right]$$

The implicit reward under DPO is:
$$r(x, y) = \beta \log\frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)} + \beta\log Z(x)$$

We construct synthetic preference pairs from our finance QA data.

In [None]:
# ---------------------------------------------------------------------------
# 5a. Synthetic Preference Dataset
# ---------------------------------------------------------------------------
# Each tuple: (prompt, preferred_response, dispreferred_response)
DPO_PREFS = [
    ("What is a bond?",
     "A bond is a fixed-income instrument representing a loan from investor to borrower.",
     "A bond is some kind of financial thing."),
    ("Define stock.",
     "A stock represents ownership of a fraction of a corporation.",
     "Stock is stuff you buy."),
    ("What is volatility?",
     "Volatility measures the degree of variation in trading price over time.",
     "Volatility is when prices go up and down."),
    ("Explain diversification.",
     "Diversification is a risk management strategy mixing various investments.",
     "Diversification means buying different things."),
    ("What is a derivative?",
     "A derivative is a financial contract whose value depends on an underlying asset.",
     "A derivative is complicated finance stuff."),
    ("Define yield.",
     "Yield is the income return on an investment expressed as a percentage.",
     "Yield is the money you get."),
    ("What is liquidity?",
     "Liquidity refers to how quickly an asset converts to cash without loss.",
     "Liquidity is about cash."),
    ("Explain hedging.",
     "Hedging is an investment to reduce the risk of adverse price movements.",
     "Hedging is like insurance or something."),
    ("What is leverage?",
     "Leverage is using borrowed capital to increase potential return on investment.",
     "Leverage means borrowing money."),
    ("Define alpha.",
     "Alpha is the excess return of an investment relative to its benchmark index.",
     "Alpha is a greek letter used in finance."),
    ("What is arbitrage?",
     "Arbitrage is the simultaneous purchase and sale of an asset to profit from price differences.",
     "Arbitrage is free money in the market."),
    ("Explain short selling.",
     "Short selling involves borrowing shares and selling them, hoping to buy back at a lower price.",
     "Short selling is betting stocks go down."),
]


def compute_sequence_log_prob(
    model: TinyGPT2, full_ids: torch.Tensor, prompt_len: int
) -> torch.Tensor:
    """Compute sum of log-probs for the response portion of a sequence.
    full_ids: (1, total_len), prompt_len: length of the prompt prefix.
    Returns scalar log-prob.
    """
    logits = model(full_ids)  # (1, total_len, V)
    # Shift: logits[t] predicts token[t+1]
    # Response tokens are at positions [prompt_len ... total_len-1]
    # Corresponding logits are at positions [prompt_len-1 ... total_len-2]
    response_logits = logits[:, prompt_len - 1 : -1, :]  # (1, resp_len, V)
    response_targets = full_ids[:, prompt_len:]  # (1, resp_len)
    log_probs = F.log_softmax(response_logits, dim=-1)
    token_lp = log_probs.gather(2, response_targets.unsqueeze(-1)).squeeze(-1)
    return token_lp.sum(dim=1)  # (1,)


print(f"DPO preference pairs: {len(DPO_PREFS)}")

In [None]:
# ---------------------------------------------------------------------------
# 5b. DPO Training
# ---------------------------------------------------------------------------
dpo_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
dpo_model.load_state_dict(sft_model.state_dict())
dpo_ref = copy.deepcopy(dpo_model)
dpo_ref.eval()
for p in dpo_ref.parameters():
    p.requires_grad = False

dpo_optimiser = torch.optim.Adam(dpo_model.parameters(), lr=1e-4)

DPO_BETA = 0.1
DPO_EPOCHS = 40

dpo_losses = []
implicit_reward_margins = []  # track reward_w - reward_l over training

print("Starting DPO training...\n")
for epoch in range(DPO_EPOCHS):
    epoch_loss = 0.0
    epoch_margins = []

    # Shuffle preference pairs
    perm = np.random.permutation(len(DPO_PREFS))

    for idx in perm:
        prompt, y_w, y_l = DPO_PREFS[idx]
        prompt_text = f"Q: {prompt} A: "

        # Encode full sequences
        prompt_ids_enc = tokeniser.encode(prompt_text, add_bos=True, add_eos=False)
        prompt_len = len(prompt_ids_enc)

        yw_ids = prompt_ids_enc + tokeniser.encode(y_w, add_bos=False, add_eos=True)
        yl_ids = prompt_ids_enc + tokeniser.encode(y_l, add_bos=False, add_eos=True)

        yw_t = torch.tensor([yw_ids], dtype=torch.long, device=DEVICE)
        yl_t = torch.tensor([yl_ids], dtype=torch.long, device=DEVICE)

        # Current policy log probs
        lp_w = compute_sequence_log_prob(dpo_model, yw_t, prompt_len)
        lp_l = compute_sequence_log_prob(dpo_model, yl_t, prompt_len)

        # Reference policy log probs
        with torch.no_grad():
            ref_lp_w = compute_sequence_log_prob(dpo_ref, yw_t, prompt_len)
            ref_lp_l = compute_sequence_log_prob(dpo_ref, yl_t, prompt_len)

        # DPO loss: -log sigma(beta * (log pi(yw)/pi_ref(yw) - log pi(yl)/pi_ref(yl)))
        log_ratio_w = lp_w - ref_lp_w  # log(pi/pi_ref) for preferred
        log_ratio_l = lp_l - ref_lp_l  # log(pi/pi_ref) for dispreferred
        logit = DPO_BETA * (log_ratio_w - log_ratio_l)
        loss = -F.logsigmoid(logit).mean()

        dpo_optimiser.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(dpo_model.parameters(), 1.0)
        dpo_optimiser.step()

        epoch_loss += loss.item()

        # Track implicit reward margin
        with torch.no_grad():
            implicit_rw = DPO_BETA * (compute_sequence_log_prob(dpo_model, yw_t, prompt_len) -
                                       compute_sequence_log_prob(dpo_ref, yw_t, prompt_len))
            implicit_rl = DPO_BETA * (compute_sequence_log_prob(dpo_model, yl_t, prompt_len) -
                                       compute_sequence_log_prob(dpo_ref, yl_t, prompt_len))
            epoch_margins.append((implicit_rw - implicit_rl).item())

    avg_loss = epoch_loss / len(DPO_PREFS)
    avg_margin = np.mean(epoch_margins)
    dpo_losses.append(avg_loss)
    implicit_reward_margins.append(avg_margin)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d}/{DPO_EPOCHS}  loss={avg_loss:.4f}  reward_margin={avg_margin:.4f}")

print("\nDPO training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 5c. DPO Visualisation: Loss and Implicit Reward Margin
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(dpo_losses, linewidth=1.5)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("DPO Loss")
axes[0].set_title("DPO: Training Loss")
axes[0].grid(True, alpha=0.3)

axes[1].plot(implicit_reward_margins, linewidth=1.5, color="tab:red")
axes[1].axhline(y=0, color="grey", linestyle="--", alpha=0.5)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Implicit Reward Margin (w - l)")
axes[1].set_title("DPO: Implicit Reward Margin")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("The implicit reward margin increases: the model learns to assign higher")
print("implicit reward to preferred responses and lower to dispreferred ones.")

In [None]:
# ---------------------------------------------------------------------------
# 5d. DPO: Inspect Implicit Rewards for Specific Pairs
# ---------------------------------------------------------------------------
dpo_model.eval()
print("Implicit reward for each preference pair:\n")
print(f"{'Prompt':<30} {'R(preferred)':>14} {'R(dispreferred)':>16} {'Margin':>8}")
print("-" * 72)

for prompt, y_w, y_l in DPO_PREFS[:6]:
    prompt_text = f"Q: {prompt} A: "
    prompt_ids_enc = tokeniser.encode(prompt_text, add_bos=True, add_eos=False)
    p_len = len(prompt_ids_enc)

    yw_ids = prompt_ids_enc + tokeniser.encode(y_w, add_bos=False, add_eos=True)
    yl_ids = prompt_ids_enc + tokeniser.encode(y_l, add_bos=False, add_eos=True)

    yw_t = torch.tensor([yw_ids], dtype=torch.long, device=DEVICE)
    yl_t = torch.tensor([yl_ids], dtype=torch.long, device=DEVICE)

    with torch.no_grad():
        r_w = DPO_BETA * (
            compute_sequence_log_prob(dpo_model, yw_t, p_len) -
            compute_sequence_log_prob(dpo_ref, yw_t, p_len)
        ).item()
        r_l = DPO_BETA * (
            compute_sequence_log_prob(dpo_model, yl_t, p_len) -
            compute_sequence_log_prob(dpo_ref, yl_t, p_len)
        ).item()

    print(f"{prompt:<30} {r_w:>14.4f} {r_l:>16.4f} {r_w - r_l:>8.4f}")

---
## 6. Bradley-Terry Reward Model

The **Bradley-Terry** model is the standard approach for training a reward model from
pairwise preference data. Given a pair $(y_w, y_l)$ where $y_w$ is preferred:

$$P(y_w \succ y_l) = \sigma(r_\phi(x, y_w) - r_\phi(x, y_l))$$

The loss is:
$$\mathcal{L}_{\text{BT}} = -\mathbb{E}\left[\log\sigma(r_\phi(x, y_w) - r_\phi(x, y_l))\right]$$

We train a small reward model that takes a (prompt, response) pair and outputs a scalar reward.

In [None]:
# ---------------------------------------------------------------------------
# 6a. Reward Model Architecture
# ---------------------------------------------------------------------------
class RewardModel(nn.Module):
    """Small reward model: transformer backbone + scalar head.
    Uses a separate (smaller) backbone for efficiency."""

    def __init__(self, vocab_size: int, d_model: int = 96, n_heads: int = 4,
                 n_layers: int = 3, max_len: int = 128):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList(
            [TransformerBlock(d_model, n_heads, max_len, dropout=0.1)
             for _ in range(n_layers)]
        )
        self.ln_f = nn.LayerNorm(d_model)
        self.reward_head = nn.Linear(d_model, 1)
        self.max_len = max_len
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        """Returns scalar reward for each sequence in the batch. Shape: (B, 1)."""
        B, T = idx.shape
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        # Pool: take the last non-padding token representation
        # For simplicity, use the last position
        last_hidden = x[:, -1, :]  # (B, d_model)
        reward = self.reward_head(last_hidden)  # (B, 1)
        return reward


rm = RewardModel(VOCAB_SIZE).to(DEVICE)
print(f"Reward Model parameters: {count_parameters(rm):,}")

In [None]:
# ---------------------------------------------------------------------------
# 6b. Bradley-Terry Training
# ---------------------------------------------------------------------------
# Prepare data: split into train/test
all_bt_data = []
for prompt, y_w, y_l in DPO_PREFS:
    full_w = f"Q: {prompt} A: {y_w}"
    full_l = f"Q: {prompt} A: {y_l}"
    ids_w = tokeniser.encode(full_w, add_bos=True, add_eos=True)
    ids_l = tokeniser.encode(full_l, add_bos=True, add_eos=True)
    all_bt_data.append((ids_w, ids_l))

# 9 train, 3 test
train_bt = all_bt_data[:9]
test_bt = all_bt_data[9:]


def pad_pair(ids_w, ids_l, pad_id=0):
    max_len = max(len(ids_w), len(ids_l))
    w = ids_w + [pad_id] * (max_len - len(ids_w))
    l = ids_l + [pad_id] * (max_len - len(ids_l))
    return (
        torch.tensor([w], dtype=torch.long, device=DEVICE),
        torch.tensor([l], dtype=torch.long, device=DEVICE),
    )


rm_optimiser = torch.optim.Adam(rm.parameters(), lr=1e-3)
BT_EPOCHS = 60
bt_losses = []

print("Training Bradley-Terry Reward Model...\n")
for epoch in range(BT_EPOCHS):
    epoch_loss = 0.0
    perm = np.random.permutation(len(train_bt))
    for i in perm:
        ids_w, ids_l = train_bt[i]
        tw, tl = pad_pair(ids_w, ids_l)

        r_w = rm(tw)  # (1, 1)
        r_l = rm(tl)  # (1, 1)

        # BT loss: -log sigma(r_w - r_l)
        loss = -F.logsigmoid(r_w - r_l).mean()

        rm_optimiser.zero_grad()
        loss.backward()
        rm_optimiser.step()
        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_bt)
    bt_losses.append(avg_loss)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d}/{BT_EPOCHS}  loss={avg_loss:.4f}")

print("\nBradley-Terry training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 6c. Evaluate Reward Model
# ---------------------------------------------------------------------------
# Training loss plot
fig, ax = plt.subplots(1, 1, figsize=(7, 3.5))
ax.plot(bt_losses, linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Bradley-Terry Loss")
ax.set_title("Reward Model: Training Loss")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Accuracy on train and test sets
rm.eval()


def eval_accuracy(data, label):
    correct = 0
    total = len(data)
    for ids_w, ids_l in data:
        tw, tl = pad_pair(ids_w, ids_l)
        with torch.no_grad():
            r_w = rm(tw).item()
            r_l = rm(tl).item()
        if r_w > r_l:
            correct += 1
    acc = correct / total
    print(f"{label} accuracy: {correct}/{total} = {acc:.1%}")
    return acc


train_acc = eval_accuracy(train_bt, "Train")
test_acc = eval_accuracy(test_bt, "Test")

# Show reward scores for all pairs
print(f"\n{'Prompt':<30} {'R(preferred)':>13} {'R(dispreferred)':>16} {'Correct?':>9}")
print("-" * 72)
for i, (prompt, y_w, y_l) in enumerate(DPO_PREFS):
    ids_w, ids_l = all_bt_data[i]
    tw, tl = pad_pair(ids_w, ids_l)
    with torch.no_grad():
        r_w = rm(tw).item()
        r_l = rm(tl).item()
    correct = "yes" if r_w > r_l else "NO"
    split = "train" if i < 9 else "test"
    print(f"[{split}] {prompt:<25} {r_w:>12.4f} {r_l:>16.4f} {correct:>9}")

---
## 7. RL for Mathematics with Lean Verification

In Sections 2–4 we used a **heuristic reward** (does the response contain "profit"?). In practice, the reward model is a learned neural network whose output is subjective and hackable.

**Mathematics is different.** A formal theorem prover such as [Lean 4](https://lean-lang.org/) can verify proofs with absolute certainty. This gives us a **perfect, non-hackable reward signal**:

$$R(x, y) = \begin{cases} +1 & \text{if Lean verifies the proof } y \text{ for theorem } x \\ \phantom{+}0 & \text{otherwise} \end{cases}$$

This section demonstrates the idea on a toy scale:
1. Define a small set of simple Lean theorems.
2. Build a reward function that sends generated proofs to a **public Lean server** for verification.
3. Train a tiny model with **GRPO** (Section 4) using verification as the reward.

This is the approach used by **DeepSeek-Prover** and **AlphaProof** at scale.

> **Note:** This section requires internet access to reach the Lean server. If the server is unavailable, cached results are used automatically.

In [None]:
# ---------------------------------------------------------------------------
# 7a. Lean Verifier Client
# ---------------------------------------------------------------------------
import requests, json as _json, hashlib

# --- Configuration ---
# Point this to any Lean 4 server that accepts POST requests with a "cmd"
# field and returns JSON with diagnostics.  The default is the public
# lean4web instance; change it if you run your own (e.g. via lean-lsp-mcp).
LEAN_SERVER_URL = "https://live.lean-lang.org/api/check"

# Cache verified results so we don't spam the server during RL training
_lean_cache: dict[str, bool] = {}


def verify_lean_code(code: str, timeout: float = 30.0) -> dict:
    """Send Lean 4 code to a public server and return verification result.

    Returns
    -------
    dict with keys:
        verified : bool   – True iff no errors
        messages : list   – diagnostic messages from Lean
        cached   : bool   – True if result came from cache
    """
    cache_key = hashlib.sha256(code.encode()).hexdigest()
    if cache_key in _lean_cache:
        return {"verified": _lean_cache[cache_key], "messages": [], "cached": True}

    try:
        resp = requests.post(
            LEAN_SERVER_URL,
            json={"cmd": code, "env": 0},
            timeout=timeout,
        )
        resp.raise_for_status()
        data = resp.json()

        # lean4web returns {"env": int, "messages": [...]}
        messages = data.get("messages", [])
        has_error = any(
            m.get("severity", "") == "error" for m in messages
        )
        verified = not has_error
    except Exception as e:
        # Server unreachable – return unknown
        return {"verified": False, "messages": [str(e)], "cached": False}

    _lean_cache[cache_key] = verified
    return {"verified": verified, "messages": messages, "cached": False}


# Quick smoke test
test_result = verify_lean_code("theorem test : 2 + 2 = 4 := by norm_num")
print(f"Smoke test — 'theorem test : 2 + 2 = 4 := by norm_num'")
print(f"  verified: {test_result['verified']}")
print(f"  cached:   {test_result['cached']}")

bad_result = verify_lean_code("theorem bad : 2 + 2 = 5 := by norm_num")
print(f"\nSmoke test — 'theorem bad : 2 + 2 = 5 := by norm_num'")
print(f"  verified: {bad_result['verified']}")

if test_result["verified"] and not bad_result["verified"]:
    print("\n✓ Lean server is reachable and working correctly.")
    LEAN_AVAILABLE = True
else:
    print("\n⚠ Lean server may be unavailable. Using cached fallback.")
    LEAN_AVAILABLE = False

In [None]:
# ---------------------------------------------------------------------------
# 7b. Training Data: Simple Lean Theorems
# ---------------------------------------------------------------------------
# Each entry: (statement, correct_tactic, description)
# The model will learn to generate the tactic given the statement.

LEAN_THEOREMS = [
    {
        "statement": "theorem t1 : 2 + 3 = 5",
        "tactic": "norm_num",
        "desc": "simple arithmetic",
    },
    {
        "statement": "theorem t2 : 10 - 3 = 7",
        "tactic": "norm_num",
        "desc": "subtraction",
    },
    {
        "statement": "theorem t3 : 3 * 4 = 12",
        "tactic": "norm_num",
        "desc": "multiplication",
    },
    {
        "statement": "theorem t4 : (1 + 1) * 3 = 6",
        "tactic": "norm_num",
        "desc": "nested arithmetic",
    },
    {
        "statement": "theorem t5 (a : Nat) : a + 0 = a",
        "tactic": "simp",
        "desc": "additive identity",
    },
    {
        "statement": "theorem t6 (a : Nat) : 0 + a = a",
        "tactic": "simp",
        "desc": "additive identity (comm)",
    },
    {
        "statement": "theorem t7 (a b : Nat) : a + b = b + a",
        "tactic": "omega",
        "desc": "commutativity of addition",
    },
    {
        "statement": "theorem t8 (a b c : Nat) : a + b + c = a + (b + c)",
        "tactic": "omega",
        "desc": "associativity of addition",
    },
    {
        "statement": "theorem t9 (n : Nat) : n * 1 = n",
        "tactic": "simp",
        "desc": "multiplicative identity",
    },
    {
        "statement": "theorem t10 (n : Nat) : n * 0 = 0",
        "tactic": "simp",
        "desc": "multiplication by zero",
    },
]

# Build a cache of known correct results (fallback if server is down)
KNOWN_CORRECT: dict[str, str] = {}
for thm in LEAN_THEOREMS:
    full_code = f"{thm['statement']} := by {thm['tactic']}"
    _lean_cache[hashlib.sha256(full_code.encode()).hexdigest()] = True
    KNOWN_CORRECT[thm["statement"]] = thm["tactic"]

print(f"Defined {len(LEAN_THEOREMS)} training theorems.")
print("\nExamples:")
for t in LEAN_THEOREMS[:3]:
    print(f"  {t['statement']} := by {t['tactic']}  ({t['desc']})")

In [None]:
# ---------------------------------------------------------------------------
# 7c. Lean-based Reward Function
# ---------------------------------------------------------------------------

# Simple tokeniser for Lean tactic strings: we reuse CharTokeniser but
# restrict the vocabulary to characters that appear in our tactics.
LEAN_TACTICS = ["norm_num", "simp", "omega", "ring", "decide", "rfl"]


def compute_lean_reward(theorem_statement: str, generated_tactic: str) -> float:
    """Verify a generated tactic proof via Lean and return a scalar reward.

    Reward scheme:
        +1.0  if Lean verifies the proof successfully
        -0.2  if verification fails (encourages exploration without harsh penalty)
         0.0  if the generated text is empty or unparseable
    """
    tactic = generated_tactic.strip()
    if not tactic:
        return 0.0

    # Construct the full Lean file
    lean_code = f"{theorem_statement} := by {tactic}"

    result = verify_lean_code(lean_code)
    if result["verified"]:
        return 1.0
    else:
        return -0.2


# Test the reward function
for thm in LEAN_THEOREMS[:3]:
    r_correct = compute_lean_reward(thm["statement"], thm["tactic"])
    r_wrong = compute_lean_reward(thm["statement"], "sorry")
    print(f"{thm['desc']:30s}  correct={r_correct:+.1f}  wrong={r_wrong:+.1f}")

In [None]:
# ---------------------------------------------------------------------------
# 7d. GRPO Training with Lean Verification Reward
# ---------------------------------------------------------------------------
# We use a simplified GRPO loop that generates candidate tactics (as short
# character strings) for each theorem and updates based on Lean verification.
#
# Because our TinyGPT2 is character-level and tactics are short strings,
# we encode the prompt as: "<theorem statement> := by "
# and let the model generate the tactic completion.

# Prepare prompt tokens for each theorem
lean_prompts = []
for thm in LEAN_THEOREMS:
    prompt_text = f"{thm['statement']} := by "
    ids = tokeniser.encode(prompt_text)
    lean_prompts.append(torch.tensor([ids], dtype=torch.long, device=DEVICE))

# Fresh model from SFT checkpoint
lean_model = TinyGPT2(VOCAB_SIZE).to(DEVICE)
lean_model.load_state_dict(sft_model.state_dict())
lean_ref = copy.deepcopy(lean_model)
lean_ref.eval()
for p in lean_ref.parameters():
    p.requires_grad = False

lean_optimiser = torch.optim.Adam(lean_model.parameters(), lr=5e-5)

LEAN_STEPS = 40
LEAN_GROUP_SIZE = 8
LEAN_MAX_TACTIC_LEN = 12  # tactics are short: "norm_num", "simp", "omega"

lean_reward_history = []
lean_success_rate_history = []

print("Starting GRPO training with Lean verification rewards...\n")
for step in range(LEAN_STEPS):
    step_rewards = []
    step_verified = 0
    step_total = 0

    # Cycle through theorems
    thm_idx = step % len(LEAN_THEOREMS)
    thm = LEAN_THEOREMS[thm_idx]
    prompt_ids = lean_prompts[thm_idx]

    # Generate a group of tactic candidates
    group_data = grpo_generate_group(
        lean_model, prompt_ids, group_size=LEAN_GROUP_SIZE,
        max_new_tokens=LEAN_MAX_TACTIC_LEN, temperature=1.0,
    )

    # Replace heuristic rewards with Lean verification rewards
    lean_rewards = []
    for i in range(LEAN_GROUP_SIZE):
        full_text = tokeniser.decode(group_data["sequences"][i].tolist())
        # Extract tactic: everything after " := by "
        marker = " := by "
        if marker in full_text:
            tactic = full_text.split(marker, 1)[1].strip()
        else:
            tactic = full_text.strip()
        reward = compute_lean_reward(thm["statement"], tactic)
        lean_rewards.append(reward)
        if reward > 0:
            step_verified += 1
        step_total += 1

    rewards_t = torch.tensor(lean_rewards, dtype=torch.float32, device=DEVICE)

    # Recompute group-normalised advantages with Lean rewards
    mean_r = rewards_t.mean()
    std_r = rewards_t.std() + 1e-8
    group_data["advantages"] = (rewards_t - mean_r) / std_r
    group_data["rewards"] = rewards_t

    # GRPO update
    loss = grpo_update(
        lean_model, lean_ref, lean_optimiser, group_data,
        clip_eps=0.2, beta_kl=0.04, n_epochs=2,
    )

    avg_reward = mean_r.item()
    success_rate = step_verified / step_total
    lean_reward_history.append(avg_reward)
    lean_success_rate_history.append(success_rate)

    if (step + 1) % 10 == 0:
        print(
            f"Step {step+1:3d}/{LEAN_STEPS}  "
            f"thm={thm['desc'][:20]:20s}  "
            f"reward={avg_reward:.3f}  "
            f"verified={step_verified}/{step_total}  "
            f"loss={loss:.4f}"
        )

print("\nGRPO+Lean training complete.")

In [None]:
# ---------------------------------------------------------------------------
# 7e. Visualisation: Reward & Proof Success Rate
# ---------------------------------------------------------------------------
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Reward curve
ax1.plot(lean_reward_history, linewidth=1.5, color="tab:blue")
ax1.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
ax1.set_xlabel("Step")
ax1.set_ylabel("Mean Reward")
ax1.set_title("GRPO + Lean: Mean Reward over Training")
ax1.grid(True, alpha=0.3)

# Success rate (smoothed)
window = max(1, len(lean_success_rate_history) // 10)
if len(lean_success_rate_history) >= window:
    smoothed = np.convolve(
        lean_success_rate_history, np.ones(window) / window, mode="valid"
    )
    ax2.plot(smoothed, linewidth=1.5, color="tab:green", label=f"smoothed (w={window})")
ax2.plot(lean_success_rate_history, alpha=0.3, color="tab:green", label="raw")
ax2.set_xlabel("Step")
ax2.set_ylabel("Proof Success Rate")
ax2.set_title("Fraction of Generated Proofs Verified by Lean")
ax2.set_ylim(-0.05, 1.05)
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show some example generations
print("\n--- Example generated tactics (last theorem seen) ---")
lean_model.eval()
for thm in LEAN_THEOREMS[:5]:
    prompt_text = f"{thm['statement']} := by "
    ids = tokeniser.encode(prompt_text)
    input_ids = torch.tensor([ids], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        gen = lean_model.generate(input_ids, max_new_tokens=LEAN_MAX_TACTIC_LEN, temperature=0.5)
    full_text = tokeniser.decode(gen[0].tolist())
    marker = " := by "
    tactic = full_text.split(marker, 1)[1].strip() if marker in full_text else "???"
    result = verify_lean_code(f"{thm['statement']} := by {tactic}")
    status = "✓" if result["verified"] else "✗"
    print(f"  {thm['statement']}  =>  {tactic:15s}  [{status}]  (expected: {thm['tactic']})")

### Discussion

**What we demonstrated:** Even a tiny (~1M parameter) model can learn to associate theorem statements with correct Lean tactics when given binary verification feedback through GRPO. The reward signal is **perfect** — Lean either accepts or rejects a proof, with no ambiguity.

**Connection to real systems:**
- **DeepSeek-Prover-V2** (2025) uses exactly this pipeline at scale: generate groups of proof candidates, verify with Lean, update with GRPO.
- **AlphaProof** (DeepMind, 2024) combines this with Monte Carlo Tree Search (MCTS) over tactic sequences, achieving silver-medal performance at IMO 2024.

**Why this matters for finance:** The same principle — *RL with a verifiable reward* — applies whenever we have an automated checker:
- **Code generation:** run unit tests as reward
- **Quantitative finance:** backtest trading strategies as reward
- **Regulatory compliance:** rule-based checkers as reward

**Limitations of our toy example:**
- The model is far too small to learn generalised proof strategies
- We only use single-tactic proofs; real proofs require multi-step reasoning
- Training steps are too few for meaningful convergence
- The character-level tokeniser is not ideal for Lean syntax
- If the public Lean server is unreachable, the training loop falls back to cached results — only the known-correct ground-truth tactics receive reward $+1$, so the model receives $-0.2$ on all generated attempts and the reward curve stays flat. When the server is live, the model can discover correct tactics on its own and the curve rises.

---
## Summary

This notebook covered the key post-training methods for language models:

| Method | Requires Reward Model? | Requires Critic? | Key Mechanism |
|--------|----------------------|------------------|---------------|
| **SFT** | No | No | Teacher-forced cross-entropy on demonstrations |
| **REINFORCE** | Yes (or heuristic) | No | $\nabla J = \mathbb{E}[\nabla\log\pi \cdot R]$, high variance |
| **PPO** | Yes (or heuristic) | Yes (value fn) | Clipped ratio + GAE advantages |
| **GRPO** | Yes (or heuristic) | **No** | Group-normalised advantages, no critic |
| **DPO** | **No** | No | Direct optimisation from preferences |
| **Bradley-Terry RM** | N/A (this *is* the RM) | No | Pairwise preference loss |
| **GRPO + Lean** | **No** (verifier) | **No** | Formal verification as perfect reward |

**Key takeaways:**
- SFT provides the foundation; RL methods refine behaviour towards a reward signal.
- REINFORCE is simple but suffers from high variance; baselines help.
- PPO stabilises training via clipping and GAE, but requires a critic.
- GRPO achieves similar results without a critic by using group-relative advantages.
- DPO bypasses reward modelling entirely, optimising directly from preferences.
- The Bradley-Terry model provides a principled way to learn rewards from preferences.
- **Formal verification** (e.g. Lean 4) provides a perfect, non-hackable reward signal for mathematical reasoning — the same GRPO algorithm applies, but with provably correct feedback.