# Finetuning for Classification

**Module 4.4, Lesson 1** | CourseAI

In this notebook you will:

1. **Load pretrained GPT-2** via HuggingFace and examine its architecture — identify what to keep and what to replace
2. **Freeze the backbone and add a classification head** — replace `lm_head` with `nn.Linear(768, num_classes)`
3. **Prepare a text classification dataset** — tokenize SST-2 examples and create a DataLoader
4. **Implement the finetuning training loop** — frozen backbone, trainable head, cross-entropy loss
5. **Evaluate, unfreeze layers, and compare** — measure accuracy, try partial unfreezing with differential learning rates

For each exercise, **PREDICT the output before running the cell.**

---

## Setup

Run this cell to install dependencies and configure the environment. Use a **GPU runtime** in Colab — even frozen-backbone training benefits from GPU acceleration for the transformer forward pass.

In [None]:
!pip install -q transformers datasets tiktoken

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np

from transformers import GPT2LMHeadModel
import tiktoken
from datasets import load_dataset

# Reproducible results
torch.manual_seed(42)
np.random.seed(42)

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

---

## Exercise 1: Load GPT-2 and Examine the Architecture [Guided]

The first step of transfer learning is always the same: load the pretrained model and understand what you have. In the CNN transfer learning lesson, you loaded a pretrained ResNet and identified `model.fc` as the part to replace. Here you will do the same with GPT-2.

**Before running, predict:**
- GPT-2 has a `lm_head` that projects to vocabulary size. What shape is `lm_head.weight`?
- How many total parameters does GPT-2 have?
- What is the hidden dimension (the size of each token's representation)?

In [None]:
# Load pretrained GPT-2 from HuggingFace
hf_model = GPT2LMHeadModel.from_pretrained('gpt2')
hf_model.eval()

print("=" * 60)
print("GPT-2 Architecture Overview")
print("=" * 60)

# The model has two main parts:
# 1. transformer — the backbone (embeddings + 12 transformer blocks + layer norm)
# 2. lm_head — the output projection to vocabulary
print("\nTop-level modules:")
for name, _ in hf_model.named_children():
    print(f"  {name}")

# lm_head shape: projects from hidden_dim to vocab_size
print(f"\nlm_head weight shape: {hf_model.lm_head.weight.shape}")
print(f"  -> Projects from {hf_model.lm_head.weight.shape[1]} (hidden dim) "
      f"to {hf_model.lm_head.weight.shape[0]} (vocab size)")

# Count parameters
total_params = sum(p.numel() for p in hf_model.parameters())
backbone_params = sum(p.numel() for p in hf_model.transformer.parameters())
lm_head_params = sum(p.numel() for p in hf_model.lm_head.parameters())

print(f"\nParameter counts:")
print(f"  Total:     {total_params:>12,}")
print(f"  Backbone:  {backbone_params:>12,}")
print(f"  lm_head:   {lm_head_params:>12,}")

# The hidden dimension — this is what our classification head will take as input
n_embd = hf_model.config.n_embd
n_layer = hf_model.config.n_layer
print(f"\nHidden dimension (n_embd): {n_embd}")
print(f"Number of transformer blocks: {n_layer}")

# Verify weight tying: lm_head shares weights with token embeddings
print(f"\nWeight tying check:")
print(f"  lm_head.weight data_ptr:  {hf_model.lm_head.weight.data_ptr()}")
print(f"  wte.weight data_ptr:      {hf_model.transformer.wte.weight.data_ptr()}")
print(f"  Same tensor? {hf_model.lm_head.weight.data_ptr() == hf_model.transformer.wte.weight.data_ptr()}")

### What happened

GPT-2 has two main parts: `transformer` (the backbone) and `lm_head` (the output projection). The backbone extracts language features from tokens — 12 transformer blocks that turn input tokens into 768-dimensional hidden states. The `lm_head` projects those 768-dimensional hidden states to 50,257 vocabulary logits for next-token prediction.

For classification, we will **keep the backbone** (it is our text feature extractor) and **replace `lm_head`** with a new linear layer that projects to the number of classes. This is exactly the same strategy as replacing `model.fc` in ResNet.

Notice the weight tying: `lm_head` shares its weight tensor with the token embedding (`wte`). When we replace `lm_head`, we break that tie — the embedding stays (we still need it to embed input tokens), but the output projection now maps to class labels.

---

## Exercise 2: Freeze the Backbone and Add a Classification Head [Guided]

Now we build the classification model. The pattern is identical to CNN transfer learning:
1. Keep the pretrained backbone
2. Freeze it (`requires_grad = False`)
3. Add a new classification head

The one genuinely new question: **which hidden state represents the whole sequence?** In a CNN, global average pooling collapses the spatial feature map into one vector. In a causal transformer, we take the **last token's** hidden state — it is the only position that has attended to all previous tokens (because of causal masking).

**Before running, predict:**
- How many parameters will the classification head have for binary classification (2 classes)?
- What fraction of total model parameters will be trainable?
- What shape will `last_hidden` have for a batch of 4 sequences, each 20 tokens long?

In [None]:
class GPT2ForClassification(nn.Module):
    def __init__(self, hf_model, num_classes):
        super().__init__()
        # Keep the pretrained transformer backbone
        self.transformer = hf_model.transformer

        # Replace lm_head with a classification head
        # 768 = GPT-2's hidden dimension (n_embd)
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask=None):
        # 1. Run through the transformer backbone
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        # hidden_states shape: (batch, seq_len, 768)

        # 2. Take the LAST token's hidden state
        # Causal masking means the last token has attended to ALL previous
        # tokens — it is the only position with full sequence context.
        if attention_mask is not None:
            # Find the actual last token (not padding) for each sequence
            # attention_mask is 1 for real tokens, 0 for padding
            seq_lengths = attention_mask.sum(dim=1) - 1  # 0-indexed
            last_hidden = hidden_states[
                torch.arange(hidden_states.size(0), device=hidden_states.device),
                seq_lengths
            ]
        else:
            last_hidden = hidden_states[:, -1, :]
        # last_hidden shape: (batch, 768)

        # 3. Classify
        logits = self.classifier(last_hidden)
        # logits shape: (batch, num_classes)

        return logits


# Create the classification model
num_classes = 2  # Sentiment: positive / negative
model = GPT2ForClassification(hf_model, num_classes).to(device)

# Freeze the entire transformer backbone
for param in model.transformer.parameters():
    param.requires_grad = False

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print("=" * 60)
print("GPT2ForClassification")
print("=" * 60)
print(f"Total parameters:     {total_params:>12,}")
print(f"Frozen (backbone):    {frozen_params:>12,}")
print(f"Trainable (head):     {trainable_params:>12,}")
print(f"Trainable fraction:   {trainable_params / total_params:.4%}")

# Verify the classification head dimensions
print(f"\nClassifier weight shape: {model.classifier.weight.shape}")
print(f"Classifier bias shape:   {model.classifier.bias.shape}")
print(f"Head params: {768 * num_classes} (weight) + {num_classes} (bias) = {768 * num_classes + num_classes}")

# Quick forward pass to verify shapes
test_ids = torch.randint(0, 50257, (4, 20)).to(device)
with torch.no_grad():
    test_logits = model(test_ids)

print(f"\nForward pass test:")
print(f"  Input shape:  {test_ids.shape}  (batch=4, seq_len=20)")
print(f"  Output shape: {test_logits.shape}  (batch=4, num_classes=2)")

### What happened

The classification head is tiny: 768 x 2 + 2 = 1,538 trainable parameters out of ~124 million total. You are training about 0.001% of the model. The backbone is a **general text feature extractor** — the pretrained transformer blocks convert token sequences into rich 768-dimensional representations. The classification head learns to map those representations to class predictions.

The key architectural decision: we take `hidden_states[:, -1, :]` — the **last token's** hidden state. Because of causal masking, the last position is the only one that has attended to all previous tokens. Using any earlier position would throw away information. This is the direct consequence of the causal attention pattern you studied earlier.

When sequences have padding, we use the attention mask to find the actual last real token for each sequence, rather than blindly taking position -1 (which would be a padding token).

---

## Exercise 3: Prepare the SST-2 Dataset [Supported]

We need a text classification dataset. SST-2 (Stanford Sentiment Treebank) is a standard benchmark: movie review sentences labeled as positive (1) or negative (0).

Your job: tokenize the sentences with tiktoken (the same BPE tokenizer GPT-2 uses), pad/truncate to a fixed length, and create DataLoaders.

The setup is the same input pipeline as generation — same tokenizer, same token IDs. The only difference is that each example also has a label.

<details>
<summary>Hint: tokenization</summary>

Use `enc.encode(text)` to get token IDs. Truncate to `max_length` if longer. Pad with a pad token (we will use 50256, the `<|endoftext|>` token) if shorter. Build an attention mask: 1 for real tokens, 0 for padding.

</details>

In [None]:
# Load SST-2 dataset
sst2 = load_dataset('glue', 'sst2')
print(f"Train examples: {len(sst2['train'])}")
print(f"Validation examples: {len(sst2['validation'])}")
print(f"\nSample: {sst2['train'][0]}")

# We will use a subset for speed — 2000 train, 500 validation
train_data = sst2['train'].shuffle(seed=42).select(range(2000))
val_data = sst2['validation']

# Tokenizer — same tiktoken encoder GPT-2 uses
enc = tiktoken.get_encoding('gpt2')
PAD_TOKEN = enc.encode('<|endoftext|>')[0]  # 50256
MAX_LENGTH = 64

print(f"\nPad token ID: {PAD_TOKEN}")
print(f"Max sequence length: {MAX_LENGTH}")

In [None]:
class SST2Dataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length):
        self.input_ids = []
        self.attention_masks = []
        self.labels = []

        for example in hf_dataset:
            text = example['sentence']
            label = example['label']

            # TODO: Tokenize the text
            # 1. Encode with tiktoken: token_ids = tokenizer.encode(text)
            # 2. Truncate to max_length if longer: token_ids = token_ids[:max_length]
            # 3. Create attention_mask: [1] * len(token_ids) + [0] * padding_needed
            # 4. Pad token_ids to max_length with PAD_TOKEN


            self.input_ids.append(torch.tensor(token_ids, dtype=torch.long))
            self.attention_masks.append(torch.tensor(attention_mask, dtype=torch.long))
            self.labels.append(label)

        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attention_masks[idx], self.labels[idx]


# TODO: Create train_dataset and val_dataset
# train_dataset = SST2Dataset(train_data, enc, MAX_LENGTH)
# val_dataset = SST2Dataset(val_data, enc, MAX_LENGTH)


# TODO: Create DataLoaders
# train_loader: batch_size=32, shuffle=True
# val_loader: batch_size=32, shuffle=False


# Verify
sample_ids, sample_mask, sample_label = train_dataset[0]
print(f"Sample input_ids shape: {sample_ids.shape}")
print(f"Sample attention_mask:  {sample_mask.shape}")
print(f"Sample label: {sample_label}")
print(f"Real tokens: {sample_mask.sum().item()}, padding: {(sample_mask == 0).sum().item()}")
print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")

<details>
<summary>Solution</summary>

The key insight: this is the **same tokenization pipeline** you used for text generation. Same tiktoken encoder, same BPE token IDs. The only addition is padding (so batches have uniform length) and an attention mask (so the model knows which tokens are real).

```python
            # Tokenize
            token_ids = tokenizer.encode(text)
            # Truncate
            token_ids = token_ids[:max_length]
            # Create attention mask before padding
            padding_needed = max_length - len(token_ids)
            attention_mask = [1] * len(token_ids) + [0] * padding_needed
            # Pad
            token_ids = token_ids + [PAD_TOKEN] * padding_needed
```

Then create datasets and loaders:

```python
train_dataset = SST2Dataset(train_data, enc, MAX_LENGTH)
val_dataset = SST2Dataset(val_data, enc, MAX_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
```

Common mistake: forgetting to create the attention mask. Without it, the model would treat padding tokens as real content, and `hidden_states[:, -1, :]` would always be a padding token's representation — useless for classification.

</details>

---

## Exercise 4: Implement the Finetuning Training Loop [Supported]

The training loop follows the same heartbeat as every loop you have written: forward, loss, backward, step. The differences from pretraining are surface-level:
- Loss compares against **class labels** (not next-token targets)
- Optimizer updates only **classifier parameters** (not the full model)
- `nn.CrossEntropyLoss` with 2 classes instead of 50,257

Fill in the TODOs to complete the training loop and evaluation function.

<details>
<summary>Hint: what to pass to the optimizer</summary>

Only the classification head parameters need gradients: `model.classifier.parameters()`. The backbone is frozen, so passing `model.parameters()` to the optimizer would work but waste memory tracking frozen params. Be explicit.

</details>

In [None]:
def evaluate(model, data_loader, device):
    """Evaluate classification accuracy on a dataset."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for input_ids, attention_mask, labels in data_loader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            # TODO: Forward pass and count correct predictions
            # 1. Get logits from model (pass both input_ids and attention_mask)
            # 2. Get predicted class: _, predicted = torch.max(logits, dim=1)
            # 3. Update correct and total counts


    return correct / total


# TODO: Set up optimizer — only train the classifier head parameters
# optimizer = torch.optim.AdamW(???, lr=1e-3)


criterion = nn.CrossEntropyLoss()
num_epochs = 5

# Training history
history = {'train_loss': [], 'train_acc': [], 'val_acc': []}

print(f"Training for {num_epochs} epochs (frozen backbone, trainable head)")
print(f"Trainable parameters: {sum(p.numel() for p in model.classifier.parameters()):,}")
print("=" * 65)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for input_ids, attention_mask, labels in train_loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        # TODO: Complete the training step
        # 1. Forward pass: logits = model(input_ids, attention_mask)
        # 2. Compute loss: loss = criterion(logits, labels)
        # 3. Backward pass: optimizer.zero_grad(), loss.backward(), optimizer.step()


        # Track metrics
        running_loss += loss.item() * input_ids.size(0)
        _, predicted = torch.max(logits, dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    val_acc = evaluate(model, val_loader, device)

    history['train_loss'].append(epoch_loss)
    history['train_acc'].append(epoch_acc)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch+1}/{num_epochs}  "
          f"Train Loss: {epoch_loss:.4f}  "
          f"Train Acc: {epoch_acc:.1%}  "
          f"Val Acc: {val_acc:.1%}")

print("=" * 65)
print(f"\nFinal validation accuracy (frozen backbone): {history['val_acc'][-1]:.1%}")

<details>
<summary>Solution</summary>

The training loop is the same heartbeat — forward, loss, backward, step. The only difference from pretraining: the optimizer only receives the classifier head parameters, and loss is against class labels.

Evaluate function:
```python
            logits = model(input_ids, attention_mask)
            _, predicted = torch.max(logits, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
```

Optimizer:
```python
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=1e-3)
```

Training step:
```python
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
```

Notice we pass `model.classifier.parameters()` to the optimizer, not `model.parameters()`. The backbone is frozen, so its parameters would be ignored anyway, but being explicit makes the intent clear and is slightly more memory efficient.

</details>

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

epochs_range = range(1, num_epochs + 1)

ax1.plot(epochs_range, history['train_loss'], 'o-', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss (Frozen Backbone)')
ax1.grid(alpha=0.3)

ax2.plot(epochs_range, history['train_acc'], 'o-', linewidth=2, label='Train')
ax2.plot(epochs_range, history['val_acc'], 's-', linewidth=2, label='Val')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Accuracy (Frozen Backbone)')
ax2.legend()
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

---

## Exercise 5: Unfreeze Layers, Compare, and Explore [Independent]

The frozen backbone gave you a baseline accuracy. Now explore what happens when you partially unfreeze the model.

Your tasks:

1. **Record the frozen-backbone accuracy** from Exercise 4.
2. **Create a fresh model** (reload weights — important so the comparison is fair).
3. **Unfreeze the last 2 transformer blocks** while keeping the rest frozen.
4. **Use differential learning rates**: lower LR for the unfrozen backbone layers (1e-5), higher LR for the classification head (1e-3). This is the same strategy you used with ResNet.
5. **Train for 3 epochs** and compare validation accuracy.
6. **Generate text** with both models (frozen and partially-unfrozen) to check for catastrophic forgetting. Use a prompt like `"The movie was"` and generate 20 tokens.

Think about:
- Does unfreezing improve accuracy? By how much?
- Does the partially-unfrozen model still generate coherent text?
- What is the tradeoff between frozen and unfrozen training?

In [None]:
# Your code here.
#
# Suggested structure:
#
# 1. Store the frozen accuracy:
#    frozen_acc = history['val_acc'][-1]
#
# 2. Create a fresh model:
#    hf_model_fresh = GPT2LMHeadModel.from_pretrained('gpt2')
#    model_unfrozen = GPT2ForClassification(hf_model_fresh, num_classes=2).to(device)
#
# 3. Freeze everything first, then unfreeze last 2 blocks:
#    for param in model_unfrozen.transformer.parameters():
#        param.requires_grad = False
#    for block_idx in [10, 11]:  # last 2 of 12 blocks (0-indexed)
#        for param in model_unfrozen.transformer.h[block_idx].parameters():
#            param.requires_grad = True
#
# 4. Set up optimizer with differential LR (parameter groups):
#    optimizer = torch.optim.AdamW([
#        {'params': model_unfrozen.transformer.h[10].parameters(), 'lr': 1e-5},
#        {'params': model_unfrozen.transformer.h[11].parameters(), 'lr': 1e-5},
#        {'params': model_unfrozen.classifier.parameters(), 'lr': 1e-3},
#    ])
#
# 5. Train for 3 epochs (same loop structure as Exercise 4)
#
# 6. Compare accuracies
#
# 7. Generate text with both models to check for catastrophic forgetting:
#    def generate_text(transformer_backbone, prompt, max_new_tokens=20):
#        """Generate text using just the transformer backbone."""
#        input_ids = torch.tensor([enc.encode(prompt)]).to(device)
#        with torch.no_grad():
#            for _ in range(max_new_tokens):
#                outputs = transformer_backbone(input_ids)
#                # Use the last hidden state projected through wte transpose
#                # (since we removed lm_head, we approximate by using the
#                # embedding matrix transposed as a projection)
#                hidden = outputs.last_hidden_state[:, -1, :]
#                logits = hidden @ transformer_backbone.wte.weight.T
#                next_token = torch.argmax(logits, dim=-1, keepdim=True)
#                input_ids = torch.cat([input_ids, next_token], dim=1)
#        return enc.decode(input_ids[0].tolist())



<details>
<summary>Solution</summary>

The key insight is that partial unfreezing trades safety for potential accuracy. With a frozen backbone, you cannot overfit the backbone — there is zero risk of catastrophic forgetting because the backbone weights never change. With partial unfreezing, you allow the last few layers to adapt their representations to your task, which can help, but you need a lower learning rate (differential LR) to avoid destroying the pretrained features.

```python
# 1. Record frozen accuracy
frozen_acc = history['val_acc'][-1]

# 2. Fresh model
hf_model_fresh = GPT2LMHeadModel.from_pretrained('gpt2')
model_unfrozen = GPT2ForClassification(hf_model_fresh, num_classes=2).to(device)

# 3. Freeze all, then unfreeze last 2 blocks
for param in model_unfrozen.transformer.parameters():
    param.requires_grad = False
for block_idx in [10, 11]:
    for param in model_unfrozen.transformer.h[block_idx].parameters():
        param.requires_grad = True

unfrozen_trainable = sum(p.numel() for p in model_unfrozen.parameters() if p.requires_grad)
print(f"Trainable parameters (unfrozen): {unfrozen_trainable:,}")

# 4. Differential learning rates
optimizer_unfrozen = torch.optim.AdamW([
    {'params': model_unfrozen.transformer.h[10].parameters(), 'lr': 1e-5},
    {'params': model_unfrozen.transformer.h[11].parameters(), 'lr': 1e-5},
    {'params': model_unfrozen.classifier.parameters(), 'lr': 1e-3},
])

criterion = nn.CrossEntropyLoss()
unfrozen_history = {'train_loss': [], 'val_acc': []}

# 5. Train
for epoch in range(3):
    model_unfrozen.train()
    running_loss = 0.0
    total = 0

    for input_ids, attention_mask, labels in train_loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        logits = model_unfrozen(input_ids, attention_mask)
        loss = criterion(logits, labels)

        optimizer_unfrozen.zero_grad()
        loss.backward()
        optimizer_unfrozen.step()

        running_loss += loss.item() * input_ids.size(0)
        total += input_ids.size(0)

    val_acc = evaluate(model_unfrozen, val_loader, device)
    unfrozen_history['train_loss'].append(running_loss / total)
    unfrozen_history['val_acc'].append(val_acc)
    print(f"Epoch {epoch+1}/3  Loss: {running_loss/total:.4f}  Val Acc: {val_acc:.1%}")

# 6. Compare
unfrozen_acc = unfrozen_history['val_acc'][-1]
print(f"\n{'='*40}")
print(f"Frozen backbone:  {frozen_acc:.1%}")
print(f"Unfrozen (last 2): {unfrozen_acc:.1%}")
print(f"Difference: {unfrozen_acc - frozen_acc:+.1%}")

# 7. Generate text to check catastrophic forgetting
def generate_text(transformer_backbone, prompt, max_new_tokens=20):
    transformer_backbone.eval()
    input_ids = torch.tensor([enc.encode(prompt)]).to(device)
    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = transformer_backbone(input_ids)
            hidden = outputs.last_hidden_state[:, -1, :]
            logits = hidden @ transformer_backbone.wte.weight.T
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            input_ids = torch.cat([input_ids, next_token], dim=1)
    return enc.decode(input_ids[0].tolist())

prompt = "The movie was"
print(f"\nGeneration check (prompt: '{prompt}')")
print(f"  Frozen model:   {generate_text(model.transformer, prompt)}")
print(f"  Unfrozen model: {generate_text(model_unfrozen.transformer, prompt)}")
```

You should observe that the frozen model generates text identically to vanilla GPT-2 (no forgetting at all), while the partially-unfrozen model may show slight differences but should still be largely coherent. With only 2 blocks unfrozen and a low LR, catastrophic forgetting is minimal. If you unfroze all blocks with a high LR, the text generation would degrade noticeably.

</details>

---

## Key Takeaways

1. **A pretrained transformer is a text feature extractor.** Add a classification head, freeze the backbone, train the head. The same transfer learning pattern as CNNs — only the feature extractor changed.

2. **Use the last token's hidden state as the sequence representation.** Causal masking means the last position has attended to all previous tokens — it is the only position with full sequence context. The architecture dictates this choice.

3. **The classification head is tiny.** For binary classification: 768 x 2 + 2 = 1,538 trainable parameters out of ~124 million total. You are training about 0.001% of the model.

4. **The training loop is the same heartbeat.** Forward, loss, backward, step — same structure as every loop since Series 1. The only differences: loss is against class labels, optimizer updates only head parameters.

5. **Start frozen, unfreeze if needed.** Frozen backbone is safe (no forgetting, no overfitting), fast, and often good enough. Partial unfreezing with differential learning rates is the middle ground — same strategy as CNN transfer learning.