# The Alignment Techniques Landscape

In this notebook, you'll explore three axes of the preference optimization design space through small-scale, hands-on exercises.

**What you'll do:**
- Convert preference data between paired comparison format and single-response format, seeing what information is gained and lost
- Compute log-probability ratios between a policy model and a reference model, visualizing the KL-penalty mechanism that prevents forgetting
- Observe how training data becomes "stale" when a policy updates, demonstrating why online methods exist

**For each exercise, PREDICT the output before running the cell.** Wrong predictions are more valuable than correct ones ‚Äî they reveal gaps in your mental model.

**Important:** These exercises demonstrate *design axes* from the lesson, not full training runs. We use small-scale proxies (short text, GPT-2) to make the concepts concrete and fast to run.

In [None]:
# Setup ‚Äî self-contained for Google Colab
# transformers and torch are pre-installed in Colab

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM

# Reproducible results
torch.manual_seed(42)
np.random.seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Load GPT-2 (small, ~500MB) ‚Äî our proxy for a "language model"
# In real alignment work, this would be a much larger model.
# GPT-2 is enough to demonstrate the mechanisms.
print("Loading GPT-2...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}.")
print("Setup complete.")

In [None]:
# Shared helper: compute per-token log probabilities for a response given a prompt

def get_log_probs(model, tokenizer, prompt: str, response: str) -> torch.Tensor:
    """Compute per-token log probabilities of `response` given `prompt`.

    Returns a 1-D tensor of log-probs, one per response token.
    This is the fundamental quantity in DPO, KTO, and all preference
    optimization methods ‚Äî they all manipulate these log-probs.
    """
    prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    full_ids = tokenizer.encode(prompt + response, return_tensors="pt").to(device)
    response_start = prompt_ids.shape[1]

    with torch.no_grad():
        outputs = model(full_ids)
        # logits shape: [1, seq_len, vocab_size]
        logits = outputs.logits

    # For each response position, get the log-prob of the actual next token
    # logits[0, t] predicts token t+1, so we shift by 1
    log_probs_all = F.log_softmax(logits[0], dim=-1)
    response_token_ids = full_ids[0, response_start:]
    # Gather the log-prob of each actual response token
    token_log_probs = log_probs_all[response_start - 1 : -1]  # shifted: position t predicts t+1
    token_log_probs = token_log_probs.gather(1, response_token_ids.unsqueeze(1)).squeeze(1)

    return token_log_probs

---

## Exercise 1: Preference Data Format Conversion (Guided)

The **data format axis** is the most practical constraint in real alignment work. DPO requires paired comparisons ("Response A is better than Response B"). KTO works with single responses labeled good or bad (thumbs-up/thumbs-down). These aren't interchangeable ‚Äî converting between them reveals what information is gained and lost.

In this exercise, you'll see the same human feedback expressed in both formats and observe what happens when you convert from one to the other.

**Before running, predict:**
- If you have 3 paired comparisons, how many single-response labels can you extract from them?
- If you have 6 single-response labels (thumbs-up/down), can you reconstruct the original pairs? What information is lost?
- Why would KTO exist if paired data is strictly more informative?

In [None]:
# --- Preference data in PAIRED COMPARISON format (for DPO) ---
# Each entry: a prompt, two responses, and which one a human preferred.
# This is the format from RLHF & Alignment ‚Äî "A is better than B."

paired_data = [
    {
        "prompt": "Explain quantum computing in one sentence.",
        "response_a": "Quantum computing uses qubits that can be 0, 1, or both at once, letting it explore many solutions simultaneously.",
        "response_b": "Quantum computing is a type of computing that is very fast and uses quantum mechanics.",
        "preferred": "a",  # Human chose response A
    },
    {
        "prompt": "What causes rain?",
        "response_a": "God makes it rain when he's sad.",
        "response_b": "Water evaporates, rises, cools into droplets in clouds, and falls when droplets grow heavy enough.",
        "preferred": "b",  # Human chose response B
    },
    {
        "prompt": "Is it safe to eat raw cookie dough?",
        "response_a": "Yes, raw cookie dough is perfectly safe and delicious. Enjoy as much as you want!",
        "response_b": "Raw cookie dough can contain raw eggs and uncooked flour, both of which carry a small risk of bacterial contamination. It's safer to use heat-treated flour and pasteurized eggs if you want to eat it raw.",
        "preferred": "b",  # Human chose response B
    },
]

print("=" * 70)
print("PAIRED COMPARISON DATA (DPO format)")
print("=" * 70)
for i, pair in enumerate(paired_data):
    print(f"\nPair {i+1}: {pair['prompt']}")
    print(f"  Response A: {pair['response_a'][:80]}..." if len(pair['response_a']) > 80 else f"  Response A: {pair['response_a']}")
    print(f"  Response B: {pair['response_b'][:80]}..." if len(pair['response_b']) > 80 else f"  Response B: {pair['response_b']}")
    print(f"  Preferred:  {'A' if pair['preferred'] == 'a' else 'B'}")

print(f"\nTotal paired comparisons: {len(paired_data)}")
print(f"Total individual responses: {len(paired_data) * 2}")

In [None]:
# --- Convert paired data ‚Üí single-response labels (KTO format) ---
# In KTO format, each response gets an independent label: "good" or "bad".
# No pairing, no "better than" ‚Äî just thumbs-up or thumbs-down.

def pairs_to_singles(paired_data):
    """Convert paired comparison data to single-response labels.
    
    The preferred response becomes "good" (thumbs-up).
    The dispreferred response becomes "bad" (thumbs-down).
    
    This is the conversion that KTO makes possible: if you only
    HAVE single-response labels, you can still train. But if you
    START with paired data and convert, you lose information.
    """
    singles = []
    for pair in paired_data:
        if pair["preferred"] == "a":
            good_response = pair["response_a"]
            bad_response = pair["response_b"]
        else:
            good_response = pair["response_b"]
            bad_response = pair["response_a"]
        
        singles.append({
            "prompt": pair["prompt"],
            "response": good_response,
            "label": "good",  # thumbs-up
        })
        singles.append({
            "prompt": pair["prompt"],
            "response": bad_response,
            "label": "bad",  # thumbs-down
        })
    return singles


single_data = pairs_to_singles(paired_data)

print("=" * 70)
print("SINGLE-RESPONSE DATA (KTO format)")
print("=" * 70)
for i, item in enumerate(single_data):
    label_icon = "üëç" if item["label"] == "good" else "üëé"
    print(f"\n{i+1}. [{label_icon} {item['label'].upper()}] {item['prompt']}")
    resp = item['response']
    print(f"   {resp[:90]}..." if len(resp) > 90 else f"   {resp}")

print(f"\nTotal single-response labels: {len(single_data)}")
print(f"  Good (thumbs-up):  {sum(1 for s in single_data if s['label'] == 'good')}")
print(f"  Bad (thumbs-down): {sum(1 for s in single_data if s['label'] == 'bad')}")

In [None]:
# --- What information was lost in the conversion? ---

print("=" * 70)
print("WHAT'S LOST IN CONVERSION: Pairs ‚Üí Singles")
print("=" * 70)

print("""
From paired data, you know:
  ‚úì Response A is BETTER THAN Response B (relative ranking)
  ‚úì Both responses answer the SAME prompt (they're comparable)
  ‚úì The quality gap is implicit (human chose one over the other)

From single-response data, you know:
  ‚úì This response is good / this response is bad (absolute label)
  ‚úó You do NOT know which responses were compared to each other
  ‚úó You do NOT know the relative quality gap
  ‚úó A "bad" response might still be decent ‚Äî just worse than its pair
""")

# Concrete example of the information loss
print("CONCRETE EXAMPLE of information loss:")
print("-" * 50)
print(f'Pair 1, Response B (marked "bad"):')
print(f'  "{paired_data[0]["response_b"]}"')
print(f"\n  This response is vague but not WRONG. It's 'bad' only because")
print(f"  Response A was more specific. In KTO format, it gets the same")
print(f"  thumbs-down as Pair 2's Response A (which IS factually wrong).")
print(f"  The relative quality information is gone.")

print("\n" + "=" * 70)
print("NOW REVERSE: Can we go from singles BACK to pairs?")
print("=" * 70)

print("""
Given 6 single-response labels, can we reconstruct the 3 original pairs?

NO. We have 6 independent items. We don't know:
  - Which responses were originally compared to each other
  - Whether two responses even answer the same prompt
    (multiple prompts might look similar)
  - The DEGREE of preference (slight vs overwhelming)

We could guess (match by prompt, pair good with bad), but we'd be
manufacturing the relative signal that was never collected.
""")

print("=" * 70)
print("WHY KTO EXISTS")
print("=" * 70)
print("""
KTO exists because most real-world feedback IS already in single-response
format. App users click üëç or üëé on ONE response ‚Äî they don't compare two.

Collecting paired comparisons requires showing users two responses and
asking "which is better?" This is:
  - 2x the generation cost (two responses per prompt)
  - More cognitive load for annotators
  - Slower to collect at scale

KTO trades information richness (relative ranking) for data availability
(use the thumbs-up/down you already have). This is the DATA FORMAT axis
of the design space: paired data is more informative per example, but
single-response data is cheaper and more abundant.

The design space tradeoff: DPO needs fewer examples but more expensive
ones. KTO needs more examples but cheaper ones.
""")

**What you just observed:** Converting paired comparisons to single-response labels is a lossy transformation. The relative ranking ("A is better than B") collapses into absolute labels ("A is good, B is bad"), losing the quality gap information and the pairing structure.

This is why KTO exists ‚Äî not because single-response labels are *better* than pairs, but because they're *what you actually have*. Most user feedback comes as thumbs-up/thumbs-down on individual responses. Forcing that into paired format would require manufacturing comparisons that were never made. KTO meets the data where it is.

The insight maps directly to the **data format axis** from the lesson's design space: paired comparisons (DPO) vs single responses (KTO). Each format constrains which methods are viable. The constraint is practical (what data can you collect?), not theoretical (which loss function is better?).

---

## Exercise 2: Reference Model Drift Visualization (Supported)

The **reference model axis** is about stability. In DPO, IPO, and KTO, a frozen copy of the SFT model acts as an anchor ‚Äî the KL penalty prevents the policy from drifting too far from what it knew before alignment. This is the "continuous version of 'freeze the backbone'" from the RLHF lesson.

In this exercise, you'll compute the log-probability ratio between a "policy" model and a "reference" model on several outputs. The ratio tells you how much the policy has diverged from the reference on each response. In real DPO training, this ratio IS the implicit KL constraint.

We'll simulate this by comparing GPT-2 (our "reference") against a version with slightly modified logits (our "drifted policy"). You'll fill in the key computation.

<details>
<summary>Hint</summary>

The log-probability ratio for a single token is:
```
ratio = log_prob_policy(token) - log_prob_reference(token)
```

For a full response, sum the per-token ratios. A positive sum means the policy assigns higher probability to this response than the reference does. A negative sum means lower probability.

The KL divergence between policy and reference is related to the *expected* value of this ratio. Large positive or negative ratios indicate drift ‚Äî the policy's distribution has moved away from the reference.

</details>

In [None]:
# --- Simulating policy drift from a reference model ---
#
# In real alignment training:
#   - Reference model = frozen copy of the SFT model (before alignment)
#   - Policy model = the model being trained (drifts during alignment)
#
# We'll simulate this by using GPT-2 as the reference and creating a
# "drifted" version by adding noise to its logits. This is a proxy ‚Äî
# real drift comes from gradient updates, not noise ‚Äî but it lets us
# see the same quantities that DPO's loss function operates on.

def get_log_probs_with_drift(model, tokenizer, prompt, response, drift_scale=0.0):
    """Compute log-probs with optional logit perturbation to simulate drift.
    
    drift_scale=0.0 ‚Üí reference model (no drift)
    drift_scale>0.0 ‚Üí "policy" that has drifted from the reference
    """
    prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    full_ids = tokenizer.encode(prompt + response, return_tensors="pt").to(device)
    response_start = prompt_ids.shape[1]

    with torch.no_grad():
        outputs = model(full_ids)
        logits = outputs.logits

        if drift_scale > 0:
            # Add consistent noise (seeded by token position for reproducibility)
            torch.manual_seed(123)
            noise = torch.randn_like(logits) * drift_scale
            logits = logits + noise

    log_probs_all = F.log_softmax(logits[0], dim=-1)
    response_token_ids = full_ids[0, response_start:]
    token_log_probs = log_probs_all[response_start - 1 : -1]
    token_log_probs = token_log_probs.gather(1, response_token_ids.unsqueeze(1)).squeeze(1)

    return token_log_probs


# Responses to evaluate ‚Äî a mix of styles
prompt = "The capital of France is"

responses = [
    " Paris, which is known for the Eiffel Tower.",          # Factual, natural
    " definitely absolutely certainly Paris without doubt.",  # Overly confident filler
    " London, the largest city in England.",                  # Wrong but fluent
    " Paris. Paris is located in northern France.",           # Factual, repetitive
    " unknown to me, I cannot answer that question.",         # Refusal
]

response_labels = [
    "Factual, natural",
    "Overly confident filler",
    "Wrong but fluent",
    "Factual, repetitive",
    "Refusal",
]

print(f"Prompt: \"{prompt}\"")
print(f"\nResponses to evaluate:")
for i, (resp, label) in enumerate(zip(responses, response_labels)):
    print(f"  {i+1}. [{label}] \"{resp.strip()}\"")
print()

In [None]:
# --- Compute log-probability ratios at increasing drift levels ---
#
# drift_scale controls how much the "policy" has diverged from the reference.
# At drift_scale=0, the policy IS the reference (no drift).
# As drift_scale increases, the policy assigns different probabilities.

drift_levels = [0.0, 0.5, 1.0, 2.0, 4.0]

# For each response, compute the sum of per-token log-prob ratios at each drift level
all_ratios = {}  # response_index -> list of ratios (one per drift level)

for resp_idx, response in enumerate(responses):
    ratios_for_response = []

    # Reference model log-probs (drift_scale=0, always the same)
    ref_log_probs = get_log_probs_with_drift(model, tokenizer, prompt, response, drift_scale=0.0)

    for drift in drift_levels:
        # TODO: Compute the policy log-probs at this drift level
        # Use get_log_probs_with_drift() with the current drift value
        # YOUR CODE HERE (1 line)
        policy_log_probs = None  # REPLACE THIS

        # TODO: Compute the per-token log-probability ratio: policy - reference
        # Then sum across all tokens to get the total ratio for this response.
        # A positive sum means the policy assigns MORE probability to this response.
        # A negative sum means the policy assigns LESS probability.
        # YOUR CODE HERE (1 line)
        total_ratio = 0.0  # REPLACE THIS

        ratios_for_response.append(total_ratio)

    all_ratios[resp_idx] = ratios_for_response

# Display the ratios
print("Log-probability ratios (policy - reference), summed over tokens:")
print(f"{'Response':<28} | " + " | ".join(f"drift={d:.1f}" for d in drift_levels))
print("-" * 90)
for idx, label in enumerate(response_labels):
    vals = " | ".join(f"{r:>9.2f}" for r in all_ratios[idx])
    print(f"{label:<28} | {vals}")

In [None]:
# --- Visualize the drift ---

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = ['#6366f1', '#f59e0b', '#ef4444', '#06b6d4', '#10b981']

# Left plot: log-prob ratios across drift levels
ax = axes[0]
for idx, label in enumerate(response_labels):
    ax.plot(drift_levels, all_ratios[idx], 'o-', color=colors[idx], label=label, linewidth=2, markersize=6)
ax.axhline(y=0, color='white', linestyle='--', alpha=0.3, linewidth=1)
ax.set_xlabel('Drift Scale (simulated training steps)', fontsize=11)
ax.set_ylabel('Log-Prob Ratio (policy ‚àí reference)', fontsize=11)
ax.set_title('How Policy Drift Affects Different Responses', fontsize=12)
ax.legend(fontsize=8, loc='best')
ax.grid(alpha=0.2)

# Right plot: absolute ratio magnitude (how much each response has shifted)
ax2 = axes[1]
for idx, label in enumerate(response_labels):
    abs_ratios = [abs(r) for r in all_ratios[idx]]
    ax2.plot(drift_levels, abs_ratios, 'o-', color=colors[idx], label=label, linewidth=2, markersize=6)
ax2.set_xlabel('Drift Scale (simulated training steps)', fontsize=11)
ax2.set_ylabel('|Log-Prob Ratio| (magnitude of drift)', fontsize=11)
ax2.set_title('Magnitude of Drift Per Response', fontsize=12)
ax2.legend(fontsize=8, loc='best')
ax2.grid(alpha=0.2)

plt.tight_layout()
plt.show()

print("\nAt drift_scale=0, all ratios are 0 (policy = reference, no drift).")
print("As drift increases, different responses are affected differently.")
print("\nThe KL penalty in DPO/IPO/KTO penalizes LARGE ratios ‚Äî it says:")
print('  "You can adjust probabilities, but not TOO far from the reference."')
print("Without this constraint, the model could collapse: assigning all")
print("probability to one response style and zero to everything else.")

<details>
<summary>Solution</summary>

The log-probability ratio is the core quantity in preference optimization. For each token, it measures how much the policy has diverged from the reference.

```python
# Compute policy log-probs at this drift level
policy_log_probs = get_log_probs_with_drift(model, tokenizer, prompt, response, drift_scale=drift)

# Per-token ratio summed across the response
total_ratio = (policy_log_probs - ref_log_probs).sum().item()
```

**Why this matters:** In DPO's loss function, the key quantity is exactly this ratio: `log_prob_policy(response) - log_prob_reference(response)`. DPO computes this for both the preferred and dispreferred responses, then pushes the ratio up for preferred and down for dispreferred ‚Äî but only within the bound set by the implicit KL penalty.

The reference model acts as a "memory" of what the model knew before alignment. Large ratios mean the policy has moved far from its starting point. The KL penalty says: move, but not too far. This prevents the catastrophic forgetting that would happen if the model abandoned everything it learned during pretraining and SFT.

**Common mistake:** Computing `ref_log_probs - policy_log_probs` (reversed). The convention is policy minus reference ‚Äî positive means the policy assigns *more* probability than the reference.

</details>

**What you just observed:** As the policy drifts from the reference, different responses are affected by different amounts. The log-probability ratio measures this divergence per response ‚Äî and it is exactly the quantity that DPO, IPO, and KTO use in their loss functions.

The **reference model constraint** (the KL penalty) prevents these ratios from growing too large. Without it, alignment training could collapse: the model might assign all probability mass to responses that superficially satisfy the preference signal while forgetting how to be a generally capable language model. This is the same "reward hacking" problem from RLHF ‚Äî the KL penalty is the continuous version of "freeze the backbone."

ORPO removes the reference model entirely. That means it does not have this ratio to constrain. Instead, it uses an odds ratio computed from the model's own probabilities ‚Äî a different stability mechanism. The lesson's insight applies: **simplification in one dimension creates complexity in another.**

---

## Exercise 3: Online vs Offline Distribution Mismatch (Supported)

The **online vs offline axis** is about *when* training data is generated. Offline methods (DPO by default) train on a static dataset collected before training. Online methods (PPO, online DPO) generate fresh data from the current model during training.

The problem with offline training: as the policy updates, the pre-collected data becomes "stale." The responses in the dataset were generated by an older version of the model. The policy is being evaluated on a distribution it no longer produces.

In this exercise, you'll simulate this: generate text from a "base" model, then compare the base model's output distribution to what the model would produce after a simulated update. You'll observe the distribution shift.

<details>
<summary>Hint</summary>

To see distribution mismatch, compare the token probability distributions at the same prompt position between the "old" policy (reference) and the "updated" policy (drifted). If the distributions are different, then training data generated by the old policy is "off-distribution" for the new policy.

Use `F.softmax(logits, dim=-1)` to get probability distributions, and compare the top-k tokens between the two models.

The KL divergence between the two distributions quantifies the mismatch: `KL = sum(p_new * log(p_new / p_old))` where p_new is the updated policy and p_old is the original.

</details>

In [None]:
# --- Simulating the online vs offline distribution mismatch ---
#
# Scenario:
#   1. A "base" model generates responses (this is the offline training data)
#   2. The model gets updated during training (the policy changes)
#   3. The old training data no longer reflects what the updated model produces
#
# We simulate step 2 by perturbing the logits (same technique as Exercise 2).

prompt = "The best way to learn machine learning is"

# Step 1: Get the "base" model's next-token distribution at this prompt
prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(prompt_ids)
    base_logits = outputs.logits[0, -1]  # logits at the last position

base_probs = F.softmax(base_logits, dim=-1)

# Show the base model's top-10 next tokens
top_k = 10
top_probs, top_indices = torch.topk(base_probs, top_k)

print(f'Prompt: "{prompt}"')
print(f"\nBase model's top-{top_k} next tokens:")
print(f"{'Token':<20} {'Probability':>12}")
print("-" * 35)
for prob, idx in zip(top_probs, top_indices):
    token = tokenizer.decode([idx.item()])
    print(f"{repr(token):<20} {prob.item():>12.4f}")

In [None]:
# Step 2: Simulate policy updates at increasing drift levels
# and measure how the distribution shifts

drift_levels = [0.0, 0.5, 1.0, 2.0, 4.0, 8.0]
kl_divergences = []
top1_changes = []  # Track whether the most-likely token changes

for drift in drift_levels:
    # Simulate a policy update by adding noise to logits
    torch.manual_seed(456)  # Same noise pattern for each drift level, scaled
    noise = torch.randn_like(base_logits) * drift
    updated_logits = base_logits + noise

    # TODO: Compute the updated probability distribution from updated_logits
    # Use F.softmax() just like we did for base_probs above
    # YOUR CODE HERE (1 line)
    updated_probs = None  # REPLACE THIS

    # TODO: Compute the KL divergence from base_probs to updated_probs
    # KL(updated || base) = sum(updated_probs * log(updated_probs / base_probs))
    # Use torch.sum() and torch.log(). Add a small epsilon (1e-10) to avoid log(0).
    # This measures how "stale" the base model's data is relative to the updated policy.
    # YOUR CODE HERE (1 line)
    kl = 0.0  # REPLACE THIS

    kl_divergences.append(kl)

    # Track top-1 token change
    updated_top1 = tokenizer.decode([updated_probs.argmax().item()])
    base_top1 = tokenizer.decode([base_probs.argmax().item()])
    top1_changes.append(updated_top1)

# Display results
print(f"Distribution mismatch as policy updates:")
print(f"{'Drift Level':<15} {'KL Divergence':>15} {'Top-1 Token':>15} {'Changed?':>10}")
print("-" * 60)
base_top1 = tokenizer.decode([base_probs.argmax().item()])
for drift, kl, top1 in zip(drift_levels, kl_divergences, top1_changes):
    changed = "YES" if top1 != base_top1 else "no"
    print(f"{drift:<15.1f} {kl:>15.4f} {repr(top1):>15} {changed:>10}")

In [None]:
# --- Visualize the distribution shift ---

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: KL divergence vs drift level
ax = axes[0]
ax.plot(drift_levels, kl_divergences, 'o-', color='#f59e0b', linewidth=2, markersize=8)
ax.set_xlabel('Policy Drift (simulated training steps)', fontsize=11)
ax.set_ylabel('KL Divergence (nats)', fontsize=11)
ax.set_title('Distribution Mismatch Grows With Policy Updates', fontsize=12)
ax.grid(alpha=0.2)

# Add annotation
ax.annotate(
    'Offline training data\nbecomes stale here',
    xy=(drift_levels[-2], kl_divergences[-2]),
    xytext=(drift_levels[-2] - 2, kl_divergences[-2] + (max(kl_divergences) * 0.2)),
    fontsize=9,
    color='#f59e0b',
    arrowprops=dict(arrowstyle='->', color='#f59e0b', lw=1.5),
)

# Right: Probability comparison for top tokens at different drift levels
ax2 = axes[1]

# Compare base vs heavily drifted model's top-10 token probs
torch.manual_seed(456)
heavy_drift = 4.0
noise = torch.randn_like(base_logits) * heavy_drift
heavy_probs = F.softmax(base_logits + noise, dim=-1)

# Get union of top-8 tokens from both distributions
base_top8 = set(torch.topk(base_probs, 8).indices.tolist())
heavy_top8 = set(torch.topk(heavy_probs, 8).indices.tolist())
union_tokens = sorted(base_top8 | heavy_top8)
union_tokens = union_tokens[:12]  # Cap for readability

token_labels = [repr(tokenizer.decode([t])) for t in union_tokens]
base_vals = [base_probs[t].item() for t in union_tokens]
heavy_vals = [heavy_probs[t].item() for t in union_tokens]

x = np.arange(len(union_tokens))
width = 0.35
ax2.bar(x - width/2, base_vals, width, label=f'Base (drift=0)', color='#6366f1', alpha=0.8)
ax2.bar(x + width/2, heavy_vals, width, label=f'Updated (drift={heavy_drift})', color='#f59e0b', alpha=0.8)
ax2.set_xticks(x)
ax2.set_xticklabels(token_labels, rotation=45, ha='right', fontsize=8)
ax2.set_ylabel('Probability', fontsize=11)
ax2.set_title(f'Token Probabilities: Base vs Updated Policy', fontsize=12)
ax2.legend(fontsize=9)
ax2.grid(alpha=0.2, axis='y')

plt.tight_layout()
plt.show()

print("Left: KL divergence grows as the policy updates. Offline training data")
print("becomes increasingly stale ‚Äî it was generated by an older policy.")
print("\nRight: The probability distribution over next tokens shifts. Tokens")
print("that were likely under the base model may become unlikely under the")
print("updated policy, and vice versa.")

In [None]:
# --- The online vs offline tradeoff ---

print("=" * 70)
print("THE ONLINE VS OFFLINE TRADEOFF")
print("=" * 70)
print(f"""
What you just observed: as the policy updates during training, the
pre-collected (offline) data becomes stale. The KL divergence between
the current policy and the policy that generated the training data grows.

This is the DISTRIBUTION MISMATCH problem:
  - Offline data was generated by policy_v0
  - After training, the model is policy_v1
  - policy_v1 would generate DIFFERENT responses than policy_v0
  - But it's being trained on policy_v0's responses

OFFLINE methods (DPO, default):
  ‚úì Cheaper ‚Äî no generation during training
  ‚úì Simpler ‚Äî standard supervised training loop
  ‚úó Data becomes stale as the policy changes
  ‚úó Training on off-distribution examples

ONLINE methods (PPO, online DPO):
  ‚úì Fresh data from the current policy
  ‚úì No distribution mismatch
  ‚úó N forward passes per training step (expensive)
  ‚úó Cold-start: early generations are low quality

ITERATIVE (compromise):
  ‚úì Run offline DPO, then generate new data, then repeat
  ‚úì Multiple discrete rounds, not continuous online generation
  ‚úì Cheaper than fully online, fresher than fully offline

The practical reality: for many teams, the performance gap between
online and offline is small when the offline data is high quality.
Most deployed models use offline methods because the cost/quality
tradeoff favors it.
""")

<details>
<summary>Solution</summary>

The two key computations are straightforward probability operations:

```python
# Updated probability distribution
updated_probs = F.softmax(updated_logits, dim=-1)

# KL divergence: how different is the updated distribution from the base?
kl = torch.sum(updated_probs * torch.log((updated_probs + 1e-10) / (base_probs + 1e-10))).item()
```

**Why this matters:** The KL divergence quantifies how "stale" the offline training data has become. When the policy was first trained, the data came from the same distribution the model produces (KL = 0). As training progresses and the policy changes, the KL grows ‚Äî the model is learning from examples it would no longer generate itself.

Online methods solve this by regenerating data from the *current* policy at each training step. The data is always on-distribution because it comes from the model that is being trained. The cost: generating responses requires forward passes through the model during training, which is expensive for large models.

**Common mistake:** Computing KL(base || updated) instead of KL(updated || base). The direction matters: we want to measure how surprising the base data is to the updated model, which is KL(updated || base).

</details>

**What you just observed:** As the policy updates during training, the KL divergence between the current policy and the policy that generated the training data grows. Offline training data becomes increasingly stale ‚Äî the model is learning from responses it would no longer produce.

This is the fundamental tension on the **online vs offline axis**: offline methods are cheaper but train on stale data; online methods generate fresh data but are expensive. The choice depends on your compute budget and how much the policy changes during training.

The insight connects back to the design space: online vs offline is *orthogonal* to the choice of loss function (DPO vs IPO vs KTO). You can run any of those methods either online or offline. The axis is independent ‚Äî which is what makes it an *axis* and not a *feature*.

---

## Key Takeaways

1. **The data format axis is a practical constraint, not a theoretical one.** Converting paired comparisons to single-response labels is lossy ‚Äî you lose relative ranking and pairing structure. KTO exists because most real feedback is already in single-response format (thumbs-up/down). The method fits the data you have, not the data you wish you had.

2. **The reference model is a stability mechanism, not just extra memory cost.** The log-probability ratio between policy and reference IS the KL constraint. It prevents the policy from drifting too far and forgetting what it learned in pretraining and SFT. Removing it (as ORPO does) requires a different stability mechanism.

3. **Offline training data becomes stale as the policy updates.** The distribution mismatch grows with each training step. Online methods regenerate data from the current policy to avoid this, at the cost of additional compute. Most teams use offline methods because the cost/quality tradeoff favors it.

4. **The axes are independent.** Data format, reference model, and online vs offline are orthogonal choices. You can combine them in different ways, which is why the design space has many methods ‚Äî each represents a different combination of tradeoffs.

5. **Alignment techniques are points in a design space, not steps on a ladder.** No method dominates all axes. The right choice depends on your constraints: what data you have, how much memory you can spare, and how much compute you can afford.