# Notebook 03: DPO — Direct Preference Optimization

## Learning Objectives
- Derive DPO loss mathematically step by step
- Implement DPO loss from scratch in PyTorch
- Train a model on preference pairs
- Compare with RLHF

## Mathematical Derivation

RLHF optimises:
$$\max_\pi \mathbb{E}[R(x,y)] - \beta \cdot \text{KL}(\pi \| \pi_{\text{ref}})$$

The optimal policy satisfies:
$$\pi^*(y|x) = \frac{\pi_{\text{ref}}(y|x) \exp(R(x,y)/\beta)}{Z(x)}$$

Solving for R:
$$R^*(x,y) = \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} + \beta \log Z(x)$$

Substituting into the Bradley-Terry preference model:
$$p(y_w \succ y_l | x) = \sigma\!\left(\beta \log \frac{\pi^*(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi^*(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right)$$

**DPO Loss** (maximize log-likelihood of preferred responses):
$$\mathcal{L}_{\text{DPO}} = -\log \sigma\!\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right)$$

In [None]:
# !pip install torch transformers tqdm matplotlib

In [None]:
import sys
sys.path.insert(0, '..')
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
print('Ready!')

## Step 1: Implement DPO Loss from Scratch

In [None]:
from src.training.dpo_trainer import compute_log_probs, dpo_loss
print('DPO loss imported successfully')
print()
print('dpo_loss signature: (pol_chosen, pol_rejected, ref_chosen, ref_rejected, beta)')
print()
# Demo with synthetic log-probs
pol_chosen   = torch.tensor([-2.0, -1.5, -1.8])
pol_rejected = torch.tensor([-3.5, -4.0, -3.2])
ref_chosen   = torch.tensor([-2.5, -2.0, -2.2])
ref_rejected = torch.tensor([-3.0, -3.5, -2.8])
loss, acc, margin = dpo_loss(pol_chosen, pol_rejected, ref_chosen, ref_rejected, beta=0.1)
print(f'DPO Loss:       {loss.item():.4f}')
print(f'Reward Acc:     {acc.item():.4f}  (fraction where chosen > rejected)')
print(f'Reward Margin:  {margin.item():.4f} (how much chosen > rejected on average)')

## Step 2: Load Preference Data

In [None]:
from src.data.preference_data import create_synthetic_preferences
pairs = create_synthetic_preferences()
print(f'Loaded {len(pairs)} preference pairs')
for i, p in enumerate(pairs[:2]):
    print(f'Pair {i+1}:')
    print(f'  Chosen:   {p["chosen"][:70]}')
    print(f'  Rejected: {p["rejected"][:70]}')
    print()

## Step 3: Visualize DPO Training Dynamics

Key metrics to track:
1. **Loss** (should decrease)
2. **Reward accuracy** (chosen reward > rejected, should approach 1.0)
3. **Reward margin** (gap between chosen and rejected, should increase)

In [None]:
import numpy as np
np.random.seed(42)
steps = list(range(1, 51))
loss_curve   = [0.7 - 0.35*(1-np.exp(-s/15)) + np.random.randn()*0.03 for s in steps]
acc_curve    = [0.5 + 0.45*(1-np.exp(-s/15)) + np.random.randn()*0.02 for s in steps]
margin_curve = [0.0 + 0.5*(1-np.exp(-s/20)) + np.random.randn()*0.02 for s in steps]

fig, axes = plt.subplots(1, 3, figsize=(14, 4))
for ax, y, title, color in zip(axes,
        [loss_curve, acc_curve, margin_curve],
        ['DPO Loss', 'Reward Accuracy', 'Reward Margin'],
        ['steelblue', 'green', 'darkorange']):
    ax.plot(steps, y, color=color, linewidth=2)
    ax.set_title(title); ax.set_xlabel('Step')
    ax.grid(alpha=0.3)
plt.suptitle('DPO Training Dynamics', fontsize=14)
plt.tight_layout()
plt.show()

## DPO Variants

| Variant | Change | When to use |
|---------|--------|-------------|
| **IPO** | Identity instead of sigmoid | Avoid overfitting to extreme pairs |
| **KTO** | No paired data needed | Unpaired binary feedback |
| **SimPO** | Reference-free, length-normalized | Simpler, no ref model |
| **ORPO** | Combines SFT + DPO | More efficient, one-stage |

---

## Exercises

1. **Vary beta:** Try beta=0.01, 0.1, 1.0. How does it affect reward margin?
2. **IPO implementation:** Replace logsigmoid with identity loss. Compare stability.
3. **More data:** Generate 20 preference pairs from GSM8K. Does accuracy improve?
4. **Reference-free DPO:** What happens if you set ref_log_probs = 0?
5. **Credit connection:** How could DPO signal be used for agent credit assignment?