# Notebook 01: Supervised Fine-Tuning (SFT)

## Learning Objectives
- Understand *why* SFT is needed and what it does
- Implement SFT training loop from scratch (manual PyTorch)
- Learn about label masking (why we mask instruction tokens)
- Compare before/after SFT generation quality

---

## What is SFT?

A pre-trained language model (GPT-2, LLaMA, etc.) is trained to **predict the next token** on a massive text corpus. It learns *how language works*, but it doesn't know how to *answer questions* or *follow instructions*.

**SFT (Supervised Fine-Tuning)** teaches the model to follow instructions by showing it examples of (instruction, good response) pairs.

### The Loss Function

$$\mathcal{L}_{\text{SFT}} = -\frac{1}{T} \sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t}, x)$$

where:
- $x$ = the instruction (input prompt)
- $y_t$ = the $t$-th response token
- $T$ = total number of **response** tokens (NOT instruction tokens)

### Label Masking

```
Input:  [Instruction tokens...] [Response tokens...]
Labels: [   -100   -100 ...  ] [  actual token IDs ]
          ← masked (ignored) →  ← compute loss here →
```

We mask instruction tokens with `-100` (PyTorch's ignore index) so the loss is only computed on the response.

In [None]:
# Install dependencies (run this cell first on Google Colab)
# !pip install torch transformers tqdm matplotlib

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

print(f'PyTorch version: {torch.__version__}')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Step 1: Create a Synthetic Math Instruction Dataset

For demos we create a small synthetic dataset. In real training you'd use datasets like:
- OpenOrca, Alpaca, ShareGPT (instruction following)
- MetaMath, MathInstruct (math-specific SFT)

In [None]:
# Synthetic math instruction dataset
MATH_INSTRUCTIONS = [
    {"instruction": "A store has 45 apples. It sells 18. How many remain?",
     "response": "Step 1: Start with 45 apples.\nStep 2: Subtract 18 sold: 45 - 18 = 27.\nThe answer is: 27"},
    {"instruction": "A train goes 60 km/h for 3 hours. Distance?",
     "response": "Distance = speed × time = 60 × 3 = 180 km.\nThe answer is: 180"},
    {"instruction": "6 boxes with 12 apples each. Total apples?",
     "response": "Total = 6 × 12 = 72 apples.\nThe answer is: 72"},
    {"instruction": "Rectangle 8m wide and 5m tall. Area?",
     "response": "Area = width × height = 8 × 5 = 40 m².\nThe answer is: 40"},
    {"instruction": "Sarah earns $15/hour and works 8 hours. Earnings?",
     "response": "Earnings = 15 × 8 = $120.\nThe answer is: 120"},
]

print(f'Dataset size: {len(MATH_INSTRUCTIONS)} examples')
print('\nExample:')
ex = MATH_INSTRUCTIONS[0]
print(f'  Instruction: {ex["instruction"]}')
print(f'  Response:    {ex["response"]}')

## Step 2: Load Model and Tokenizer

We use `distilgpt2` — a small (~82M param) model that runs on free Colab T4.

In [None]:
MODEL_NAME = 'distilgpt2'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 has no pad token by default

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = model.to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f'Model: {MODEL_NAME} ({n_params/1e6:.1f}M parameters)')

## Step 3: Test Pre-SFT Generation

Before fine-tuning, the base model has no idea how to solve math problems in our format.

In [None]:
def generate(model, tokenizer, prompt, max_new_tokens=80):
    """Simple generation function."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id,
        )
    gen_ids = output[0][input_ids.shape[1]:]
    return tokenizer.decode(gen_ids, skip_special_tokens=True)

test_prompt = '### Instruction:\nA store has 45 apples. It sells 18. How many remain?\n\n### Response:\n'
print('=== PRE-SFT generation ===')
print(f'Prompt: {test_prompt}')
print(f'Output: {generate(model, tokenizer, test_prompt)}')

## Step 4: Implement the SFT Training Loop

In [None]:
import sys
sys.path.insert(0, '..')
from src.training.sft_trainer import SFTConfig, SFTTrainer

config = SFTConfig(
    learning_rate=2e-4,
    num_epochs=3,
    batch_size=2,
    max_length=128,
    logging_steps=5,
)

trainer = SFTTrainer(model, tokenizer, config)
print('SFTTrainer ready!')
print(f'Config: {config}')

In [None]:
# Train!
losses = trainer.train(MATH_INSTRUCTIONS)
print(f'Training complete! Final loss: {losses[-1]:.4f}')

## Step 5: Visualize Training Loss

In [None]:
from src.evaluation.visualization import plot_training_curves

fig = plot_training_curves(losses, title='SFT Training Loss on Synthetic Math Dataset')
plt.show()
print(f'Starting loss: {losses[0]:.4f} → Final loss: {losses[-1]:.4f}')

## Step 6: Compare Before vs After SFT

In [None]:
print('=== POST-SFT generation ===')
print(f'Prompt: {test_prompt}')
print(f'Output: {generate(model, tokenizer, test_prompt)}')

print('\n=== Another test ===')
prompt2 = '### Instruction:\n6 boxes with 12 apples each. How many total?\n\n### Response:\n'
print(f'Output: {generate(model, tokenizer, prompt2)}')

## Summary

**What we learned:**
1. SFT fine-tunes a pre-trained LM on instruction-response pairs
2. Loss is computed only on response tokens (label masking with -100)
3. Training loss decreasing shows the model is learning the format
4. SFT is the **first stage** in RLHF and DPO pipelines

---

## Exercises

1. **Add more data:** Expand `MATH_INSTRUCTIONS` with 10 more problems. Does accuracy improve?
2. **Change the instruction template:** Use a different format (e.g., `Q:` / `A:`). What changes?
3. **Ablate masking:** What happens if you compute loss on instruction tokens too? (Remove masking)
4. **Learning rate sweep:** Try lr=1e-5, 1e-4, 1e-3. Which gives the best results?
5. **Extension:** Modify the trainer to support gradient checkpointing for longer sequences.