# BERT Head Replacement: Adapting Pre-trained Models for Custom Tasks

In this notebook you will learn about head replacement - a technique to adapt pre-trained BERT models for specific tasks by replacing the final classification layer while preserving the pre-trained representations.
You'll learn how BERT works, why head replacement is effective, and how to implement it with practical examples.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hyperskill-content/hyperskill-ml-notebooks/blob/main/Attention_developments/bert_head_replacement.ipynb)

## 🚀 Prerequisites

Make sure you're comfortable with the topics below before starting this notebook:

| # | Topic (clickable links)|
|---|-------|
| 1 | **[Transformers in NLP](https://hyperskill.org/learn/step/30103)**  |
| 2 | **[Tokenization](https://hyperskill.org/learn/step/51949)** |
| 3 | **[BERT architecture](https://lena-voita.github.io/nlp_course/transfer_learning.html#bert)** |


*If you're new to any item above, review it quickly, then dive back in here.*

For this notebook, enable the hosted GPU runtime in Colab.

# 📑 Table of Contents  
  
1. [What's the Big Idea Behind Head Replacement?](#sec-idea)  
2. [BERT Architecture Primer](#sec-bert-primer)  
3. [Base Model vs Head: The Conceptual Split](#sec-base-vs-head)
4. [Head Replacement vs Fine-tuning: When to Use What?](#sec-comparison)  
5. [What You Need for Head Replacement](#sec-requirements)
6. [Hands-on: Classification Head Replacement](#sec-classification)  
7. [Hands-on: Question Answering Head](#sec-qa)
8. [Hands-on: Token Classification Head](#sec-token-class)  
9. [Performance Comparison & Best Practices](#sec-best-practices)
10. [Practice Exercises](#sec-exercises)

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import pandas as pd
from datetime import datetime


# Check if we have GPU available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
print(f"Current date: {datetime.now().strftime('%Y-%m-%d')}")
print("Transformers library loaded!")
print("Ready to explore BERT head replacement!")

## 1. What's the Big Idea Behind Head Replacement? <a id="sec-idea"></a>

**Head replacement** is a technique where you keep BERT's pre-trained layers frozen and only train a new head for your specific task.    
**Traditional fine-tuning**: Updates all model parameters using task-specific data, which can alter the pre-trained representations.    
**Head replacement**: Keeps the pre-trained model layers unchanged and only trains the final classification layer.

### Advantages of Head Replacement

| **Advantage** | **Explanation** |
|---------------|-----------------|
| **Speed** | Only training the head reduces computation time |
| **Data Efficiency** | Requires fewer task-specific examples |
| **Stability** |  Preserves pre-trained language representations |
| **Multiple Tasks** | Same base model can support multiple task-specific heads |

## 2. BERT Architecture Primer <a id="sec-bert-primer"></a>
**BERT = Bidirectional Encoder Representations from Transformers**

**Bidirectional**: Processes text in both directions simultaneously, considering context from both left and right         
**Encoder**: Uses the encoder component of the Transformer architecture to create text representations    
**Representations**: Generates numerical embeddings that capture semantic meaning          
**Transformers**: Built on the attention-based Transformer architecture   
### BERT's Architecture in Simple Terms

```
Input Text: "The cat sat on the mat"
     ↓
[Tokenization] → ["[CLS]", "The", "cat", "sat", "on", "the", "mat", "[SEP]"]
     ↓
[Embedding Layer] → Convert tokens to vectors
     ↓
[12 Transformer Layers] → Deep understanding through attention
     ↓
[Base Model Output] → Rich representations for each token
     ↓
[TASK HEAD] ← This is what we replace!
     ↓
[Final Prediction]
```

### Key Components:

1. **[CLS] Token**: Special token at the beginning - represents the entire sequence
2. **[SEP] Token**: Separates different sentences or marks the end
3. **Transformer Layers**: 12 layers (in BERT-base) that build understanding
4. **Hidden Size**: 768 dimensions per token representation
5. **Attention Heads**: 12 heads per layer that focus on different aspects

![BERT Architecture](https://towardsdatascience.com/wp-content/uploads/2024/05/1Qww2aaIdqrWVeNmo3AS0ZQ.png)


In [None]:
# Let's load a BERT model and explore its structure
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)

print("BERT Model Structure:")
print(f"Model name: {model_name}")
print(f"Hidden size: {bert_model.config.hidden_size}")
print(f"Number of layers: {bert_model.config.num_hidden_layers}")
print(f"Number of attention heads: {bert_model.config.num_attention_heads}")
print(f"Vocabulary size: {bert_model.config.vocab_size}")

# Let's see the actual model architecture
print("\nBERT Architecture Overview:")
print(bert_model)

In [None]:
# Let's see how BERT processes text
sample_text = "The cat sat on the mat"
print(f"Original text: '{sample_text}'")

# Tokenize the text
tokens = tokenizer.tokenize(sample_text)
print(f"Tokens: {tokens}")

# Add special tokens and convert to IDs
encoded = tokenizer.encode(sample_text, add_special_tokens=True)
print(f"Token IDs: {encoded}")

# Convert back to tokens to see special tokens
decoded_tokens = tokenizer.convert_ids_to_tokens(encoded)
print(f"Tokens with special tokens: {decoded_tokens}")

# Get the actual text representations
input_ids = torch.tensor([encoded])
with torch.no_grad():
    outputs = bert_model(input_ids)

# The output contains representations for each token
last_hidden_states = outputs.last_hidden_state
print(f"\nOutput shape: {last_hidden_states.shape}")
print(f"Shape explanation: [batch_size=1, sequence_length={len(encoded)}, hidden_size=768]")

## 3. Base Model vs Head: The Conceptual Split <a id="sec-base-vs-head"></a>

## 3. Base Model vs Head: The Conceptual Split <a id="sec-base-vs-head"></a>

BERT models can be conceptually divided into two components: the base model and the task head.

### The Base Model

The **base model** contains the core language understanding components:

- All Transformer layers (12 in BERT-base)
- Attention mechanisms that capture relationships between words
- Pre-trained language representations learned from large text corpora

During head replacement, the base model parameters are typically frozen to preserve the learned representations.

### The Task Head

The **task head** is a small neural network (usually 1-2 layers) that processes the base model's output for specific tasks:

| **Task** | **Head Type** | **Function** |
|----------|---------------|--------------|
| **Text Classification** | Linear layer | Maps [CLS] token to class probabilities |
| **Question Answering** | Two linear layers | Predicts start and end positions of answers |
| **Token Classification** | Linear layer per token | Labels individual tokens (NER, POS tagging) |
| **Similarity** | Cosine similarity | Compares sentence embeddings |

### Why This Architecture Works

The separation allows the base model to provide universal language understanding while the head handles task-specific predictions. This design enables efficient adaptation to new tasks without modifying the pre-trained representations.


In [None]:
# Let's demonstrate the base model vs head concept
class CustomClassificationHead(nn.Module):
    """
    A simple classification head that we can attach to BERT
    """
    def __init__(self, hidden_size, num_classes, dropout_prob=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, bert_output):
        # Use the [CLS] token representation (first token)
        cls_representation = bert_output.last_hidden_state[:, 0, :]  # [batch_size, hidden_size]
        cls_representation = self.dropout(cls_representation)
        logits = self.classifier(cls_representation)
        return logits

class BERTWithCustomHead(nn.Module):
    """
    BERT base model + our custom head
    """
    def __init__(self, model_name, num_classes):
        super().__init__()
        # Base model (we'll freeze this)
        self.bert = BertModel.from_pretrained(model_name)

        # Custom head (we'll train this)
        self.head = CustomClassificationHead(
            hidden_size=self.bert.config.hidden_size,
            num_classes=num_classes
        )

        # Freeze the base model
        for param in self.bert.parameters():
            param.requires_grad = False

        print(f"🔒 Frozen BERT parameters: {sum(p.numel() for p in self.bert.parameters())}")
        print(f"🔓 Trainable head parameters: {sum(p.numel() for p in self.head.parameters())}")

    def forward(self, input_ids, attention_mask=None):
        # Get representations from base model
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Apply our custom head
        logits = self.head(bert_output)
        return logits

# Let's create an example model for 3-class classification
model_with_head = BERTWithCustomHead(model_name="bert-base-uncased", num_classes=3)
print("\nCreated BERT with custom classification head!")
print(f"Input size to head: {model_with_head.bert.config.hidden_size}")
print(f"Output classes: 3")

## 4. Head Replacement vs Fine-tuning: When to Use What? <a id="sec-comparison"></a>

Head replacement and fine-tuning are two approaches for adapting BERT to specific tasks, each with distinct advantages.

### Head Replacement

**Process**: Train only the new task head while keeping the base model parameters frozen.

**Best suited for**:
- Small datasets (< 10k samples)
- Rapid prototyping
- Multiple tasks using the same base model
- Preserving general language representations

### Fine-tuning

**Process**: Train the entire model including both base layers and task head.

**Best suited for**:
- Large datasets (> 50k samples)
- Domain-specific tasks (medical, legal texts)
- Maximum performance requirements
- Tasks significantly different from BERT's pre-training

### Comparison

| **Aspect** | **Head Replacement** | **Full Fine-tuning** |
|------------|---------------------|---------------------|
| **Training Speed** | Fast (minutes) | Slower (hours) |
| **Data Required** | Small (1k-10k) | Large (10k+) |
| **Memory Usage** | Low | High |
| **Risk of Overfitting** | Low | Higher |
| **Performance** | Good | Best (with sufficient data) |
| **Multiple Tasks** | Easy to manage | Requires separate models |

### Selection Guidelines

**Use Head Replacement when**:
- Training data is limited
- Quick experimentation is needed
- Working with standard NLP tasks
- Managing multiple task-specific models
- Computational resources are constrained

**Use Fine-tuning when**:
- Abundant training data is available
- Domain is highly specialized
- Maximum performance is required
- Sufficient computational resources are available

In [None]:
# Let's create a practical comparison showing the difference
import torch.optim as optim

def count_trainable_parameters(model):
    """Count how many parameters will be updated during training"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def create_head_replacement_model(num_classes):
    """Model with frozen BERT + trainable head"""
    model = BERTWithCustomHead("bert-base-uncased", num_classes)
    return model

def create_full_finetuning_model(num_classes):
    """Model where everything is trainable"""
    model = BERTWithCustomHead("bert-base-uncased", num_classes)

    # Unfreeze all BERT parameters
    for param in model.bert.parameters():
        param.requires_grad = True

    return model

# Compare the two approaches
print("🔍 Comparing Head Replacement vs Fine-tuning")
print("=" * 50)

head_model = create_head_replacement_model(num_classes=3)
finetune_model = create_full_finetuning_model(num_classes=3)

head_params = count_trainable_parameters(head_model)
finetune_params = count_trainable_parameters(finetune_model)

print(f"📊 Head Replacement - Trainable parameters: {head_params:,}")
print(f"📊 Full Fine-tuning - Trainable parameters: {finetune_params:,}")
print(f"⚡ Speed difference: {finetune_params / head_params:.1f}x more parameters to train")

# Memory usage estimation
print(f"\n💾 Approximate memory comparison:")
print(f"Head Replacement: ~{head_params * 4 / 1e6:.1f} MB")
print(f"Full Fine-tuning: ~{finetune_params * 4 / 1e6:.1f} MB")

## 5. Requirements for Head Replacement <a id="sec-requirements"></a>

### Data Requirements

Data should be structured according to the target task:

| **Task Type** | **Input Format** | **Output Format** | **Example** |
|---------------|------------------|-------------------|-------------|
| **Text Classification** | Text | Class label | "I love this movie" → "positive" |
| **Question Answering** | Context + Question | Start & End positions | "Where is Paris?" → position 15-20 |
| **Token Classification** | Text | Label per token | "John lives in NYC" → ["PERSON", "O", "O", "CITY"] |

### Technical Requirements

1. **Base Model**: Select appropriate BERT variant (bert-base-uncased, bert-large, etc.)
2. **Tokenizer**: Must match the chosen base model
3. **Head Architecture**: Design suited to the specific task
4. **Dataset**: Minimum 1k samples, recommended 5k+ per class
5. **Evaluation Metrics**: Task-appropriate performance measures

### Data Preparation

1. **Data Cleaning**: Remove duplicates and handle missing values
2. **Tokenization**: Apply tokenizer consistent with base model
3. **Data Splitting**: Partition into train/validation/test sets (70%/15%/15%)
4. **Class Balance**: Verify reasonable distribution across classes
5. **PyTorch Formatting**: Create appropriate DataLoader objects

In [None]:
# Let's create a comprehensive data preparation example
from torch.utils.data import Dataset, DataLoader

class TextClassificationDataset(Dataset):
    """
    Dataset class for text classification with BERT
    """
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        # Tokenize the text
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Let's create some sample data to demonstrate
sample_data = {
    'texts': [
        "I absolutely love this product! It's amazing!",
        "This is the worst thing I've ever bought.",
        "It's okay, nothing special but does the job.",
        "Fantastic quality and great customer service!",
        "Terrible experience, would not recommend.",
        "Pretty good, meets my expectations.",
        "Outstanding! Exceeded all my expectations!",
        "Not worth the money, very disappointed.",
        "Average product, nothing to complain about."
    ],
    'labels': [2, 0, 1, 2, 0, 1, 2, 0, 1]  # 0=negative, 1=neutral, 2=positive
}

# Create label mapping for better understanding
label_map = {0: "negative", 1: "neutral", 2: "positive"}
print("Sample Data Overview:")
print(f"Total samples: {len(sample_data['texts'])}")

# Show class distribution
import collections
label_counts = collections.Counter(sample_data['labels'])
for label, count in label_counts.items():
    print(f"  {label_map[label]}: {count} samples")

# Create dataset and dataloader
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = TextClassificationDataset(
    texts=sample_data['texts'],
    labels=sample_data['labels'],
    tokenizer=tokenizer,
    max_length=64
)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

print("\nSample batch from DataLoader:")
for batch in dataloader:
    print(f"Input IDs shape: {batch['input_ids'].shape}")
    print(f"Attention mask shape: {batch['attention_mask'].shape}")
    print(f"Labels shape: {batch['label'].shape}")
    print(f"Sample input IDs: {batch['input_ids'][0][:10]}...")  # First 10 tokens
    break

## 6. Hands-on: Classification Head Replacement <a id="sec-classification"></a>

Now let's implement our first head replacement! We'll create a sentiment classification model using BERT with a custom head.

### The Plan

1. **Load pre-trained BERT** as our base model
2. **Create a classification head** (simple linear layer)
3. **Freeze BERT weights** (only train the head)
4. **Train on our sentiment data**
5. **Evaluate performance**

This is the most common type of head replacement - perfect for getting started.

In [None]:
# Complete implementation of BERT with classification head
import torch.nn.functional as F
from tqdm import tqdm

class SentimentClassifier(nn.Module):
    """
    BERT-based sentiment classifier with replaceable head
    """
    def __init__(self, model_name='bert-base-uncased', num_classes=3, dropout=0.3):
        super().__init__()

        # Load pre-trained BERT
        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)

        # Classification head - this is what we'll train!
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        # Initialize the classifier layer
        nn.init.normal_(self.classifier.weight, std=0.02)
        nn.init.zeros_(self.classifier.bias)

        print(f"Created sentiment classifier:")
        print(f"Frozen BERT parameters: {sum(p.numel() for p in self.bert.parameters()):,}")
        print(f"Trainable head parameters: {sum(p.numel() for p in self.classifier.parameters()):,}")

    def forward(self, input_ids, attention_mask=None):
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0]  # [batch_size, hidden_size]
        pooled_output = self.dropout(pooled_output)

        # Apply classification head
        logits = self.classifier(pooled_output)
        return logits

# Create our model
model = SentimentClassifier(num_classes=3)
model = model.to(device)

# Create optimizer - only for the head parameters!
optimizer = optim.Adam([
    {'params': model.classifier.parameters(), 'lr': 2e-3}
])

loss_fn = nn.CrossEntropyLoss()

print("\nModel created and ready for training!")
print(f"Running on: {device}")

In [None]:
# Training function for our head replacement model
def train_epoch(model, dataloader, optimizer, loss_fn, device):
    """Train the model for one epoch"""
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0

    for batch in tqdm(dataloader, desc="Training"):
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)

        # Calculate loss
        loss = loss_fn(logits, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Track metrics
        total_loss += loss.item()
        predictions = torch.argmax(logits, dim=1)
        correct_predictions += (predictions == labels).sum().item()
        total_predictions += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions

    return avg_loss, accuracy

def evaluate_model(model, dataloader, loss_fn, device):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids, attention_mask)
            loss = loss_fn(logits, labels)

            total_loss += loss.item()
            predictions = torch.argmax(logits, dim=1)
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions

    return avg_loss, accuracy

# Let's do a quick training demo
print("Starting training demonstration...")

# Since we have limited data, let's do just a few epochs
num_epochs = 3
training_history = []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")

    # Train
    train_loss, train_acc = train_epoch(model, dataloader, optimizer, loss_fn, device)

    # Evaluate (using same data for demo - normally you'd use separate validation set)
    val_loss, val_acc = evaluate_model(model, dataloader, loss_fn, device)

    training_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc
    })

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

print("\nTraining completed!")

In [None]:
# Let's test our trained model with some examples
def predict_sentiment(model, tokenizer, text, device, label_map):
    """Make prediction on a single text"""
    model.eval()

    # Tokenize
    encoding = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=64,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probabilities = F.softmax(logits, dim=1)
        prediction = torch.argmax(logits, dim=1).item()

    return prediction, probabilities[0].cpu().numpy()

# Test with some new examples
test_texts = [
    "This movie is absolutely fantastic! I loved every minute of it!",
    "Boring movie, fell asleep halfway through.",
    "It was an okay film, not bad but not great either.",
    "Worst movie ever made, complete waste of time.",
    "Amazing cinematography and brilliant acting!"
]

label_map = {0: "Negative", 1: "Neutral", 2: "Positive"}

print("🧪 Testing our trained sentiment classifier:")
print("=" * 60)

for i, text in enumerate(test_texts):
    prediction, probabilities = predict_sentiment(model, tokenizer, text, device, label_map)

    print(f"\nText {i+1}: '{text[:50]}{'...' if len(text) > 50 else ''}'")
    print(f"Prediction: {label_map[prediction]}")
    print(f"Confidence: {probabilities[prediction]:.3f}")
    print(f"All probabilities: {dict(zip(label_map.values(), probabilities))}")

print("\nHead replacement working successfully!")

## 7. Question Answering Head Implementation <a id="sec-qa"></a>

Question answering requires predicting the start and end positions of answers within context text, rather than single class labels.

### Question Answering Process

1. **Input**: Context text and question
2. **Processing**: BERT processes concatenated context and question
3. **Output**: Start and end token positions of the answer span

### QA Head Architecture

Question answering heads require two components:
- **Start position classifier**: Predicts answer start position for each token
- **End position classifier**: Predicts answer end position for each token

In [None]:
class QuestionAnsweringHead(nn.Module):
    """
    Question Answering head that predicts start and end positions
    """
    def __init__(self, hidden_size, dropout=0.1):
        super().__init__()
        self.qa_outputs = nn.Linear(hidden_size, 2)  # 2 outputs: start_logits, end_logits
        self.dropout = nn.Dropout(dropout)

    def forward(self, sequence_output):
        """
        Args:
            sequence_output: [batch_size, seq_len, hidden_size]
        Returns:
            start_logits, end_logits: [batch_size, seq_len]
        """
        sequence_output = self.dropout(sequence_output)
        logits = self.qa_outputs(sequence_output)  # [batch_size, seq_len, 2]

        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)  # [batch_size, seq_len]
        end_logits = end_logits.squeeze(-1)      # [batch_size, seq_len]

        return start_logits, end_logits

class BERTQuestionAnswering(nn.Module):
    """
    BERT model with Question Answering head
    """
    def __init__(self, model_name='bert-base-uncased'):
        super().__init__()

        # Load pre-trained BERT
        self.bert = BertModel.from_pretrained(model_name)

        # QA head
        self.qa_head = QuestionAnsweringHead(self.bert.config.hidden_size)

        # Freeze BERT (only train the head)
        for param in self.bert.parameters():
            param.requires_grad = False

        print(f"Created QA model:")
        print(f"   Frozen BERT parameters: {sum(p.numel() for p in self.bert.parameters()):,}")
        print(f"   Trainable QA head parameters: {sum(p.numel() for p in self.qa_head.parameters()):,}")

    def forward(self, input_ids, attention_mask=None):
        # Get BERT outputs for all tokens
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]

        # Apply QA head
        start_logits, end_logits = self.qa_head(sequence_output)

        return start_logits, end_logits

# Create QA model
qa_model = BERTQuestionAnswering()
qa_model = qa_model.to(device)

print("\nQuestion Answering model ready!")

In [None]:
# Let's create a simple QA dataset and demonstrate the concept
class QADataset(Dataset):
    """
    Dataset for Question Answering
    """
    def __init__(self, contexts, questions, answers, tokenizer, max_length=256):
        self.contexts = contexts
        self.questions = questions
        self.answers = answers
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, idx):
        context = self.contexts[idx]
        question = self.questions[idx]
        answer = self.answers[idx]

        # Tokenize context and question together
        encoding = self.tokenizer(
            question,
            context,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Find answer positions in tokenized text
        # For simplicity, we'll use a basic approach
        start_position = 0  # Simplified - normally you'd find actual positions
        end_position = 1    # Simplified - normally you'd find actual positions

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'start_position': torch.tensor(start_position, dtype=torch.long),
            'end_position': torch.tensor(end_position, dtype=torch.long)
        }

# Sample QA data
qa_sample_data = {
    'contexts': [
        "Paris is the capital of France. It is located in northern France on the river Seine.",
        "The Eiffel Tower was built in 1889 for the World's Fair. It stands 324 meters tall.",
        "Python is a programming language created by Guido van Rossum in 1991."
    ],
    'questions': [
        "What is the capital of France?",
        "When was the Eiffel Tower built?",
        "Who created Python?"
    ],
    'answers': [
        "Paris",
        "1889",
        "Guido van Rossum"
    ]
}

# Demonstrate how QA tokenization works
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

print("Question Answering Tokenization Example:")
print("=" * 50)

context = qa_sample_data['contexts'][0]
question = qa_sample_data['questions'][0]
answer = qa_sample_data['answers'][0]

print(f"Context: {context}")
print(f"Question: {question}")
print(f"Answer: {answer}")

# Tokenize question + context
encoding = tokenizer(
    question,
    context,
    truncation=True,
    padding='max_length',
    max_length=128,
    return_tensors='pt'
)

# Show the tokenized result
tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'][0])
print(f"\nTokenized (first 20 tokens): {tokens[:20]}")
print(f"Total tokens: {len(tokens)}")

# Demonstrate model output
with torch.no_grad():
    start_logits, end_logits = qa_model(
        encoding['input_ids'].to(device),
        encoding['attention_mask'].to(device)
    )

print(f"\nModel outputs:")
print(f"Start logits shape: {start_logits.shape}")
print(f"End logits shape: {end_logits.shape}")
print(f"Predicted start position: {torch.argmax(start_logits, dim=1).item()}")
print(f"Predicted end position: {torch.argmax(end_logits, dim=1).item()}")

## 8. Token Classification Head Implementation <a id="sec-token-class"></a>

Token classification predicts labels for individual tokens, making it suitable for tasks like Named Entity Recognition (NER) and Part-of-Speech (POS) tagging.

### Token Classification Tasks

| **Task** | **Function** | **Example** |
|----------|--------------|-------------|
| **NER** | Identify entities in text | "John lives in **New York**" → PERSON, CITY |
| **POS Tagging** | Classify word types | "The cat runs" → DET, NOUN, VERB |
| **Chunk Detection** | Identify phrases | "The big red car" → [NP: The big red car] |

### Architecture Distinction
- **Sequence Classification**: One prediction per sequence (entire text)
- **Token Classification**: One prediction per token (individual words)

In [None]:
class TokenClassificationHead(nn.Module):
    """
    Token classification head that predicts a label for each token
    """
    def __init__(self, hidden_size, num_labels, dropout=0.1):
        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, sequence_output):
        """
        Args:
            sequence_output: [batch_size, seq_len, hidden_size]
        Returns:
            logits: [batch_size, seq_len, num_labels]
        """
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return logits

class BERTTokenClassification(nn.Module):
    """
    BERT model with Token Classification head for NER
    """
    def __init__(self, model_name='bert-base-uncased', num_labels=3):
        super().__init__()

        # Load pre-trained BERT
        self.bert = BertModel.from_pretrained(model_name)

        # Token classification head
        self.token_head = TokenClassificationHead(
            hidden_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )

        # Freeze BERT (only train the head)
        for param in self.bert.parameters():
            param.requires_grad = False

        print(f"Created Token Classification model:")
        print(f"Frozen BERT parameters: {sum(p.numel() for p in self.bert.parameters()):,}")
        print(f"Trainable head parameters: {sum(p.numel() for p in self.token_head.parameters()):,}")
        print(f"Number of labels: {num_labels}")

    def forward(self, input_ids, attention_mask=None):
        # Get BERT outputs for all tokens
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]

        # Apply token classification head
        logits = self.token_head(sequence_output)

        return logits

# Create token classification model
# For this example, let's do simple NER: O (Other), PERSON, LOCATION
label_map = {0: "O", 1: "PERSON", 2: "LOCATION"}
num_labels = len(label_map)

token_model = BERTTokenClassification(num_labels=num_labels)
token_model = token_model.to(device)

print(f"\nToken Classification model ready!")
print(f"Labels: {label_map}")

In [None]:
# Let's demonstrate token classification with a practical example
def demonstrate_token_classification():
    """
    Demonstrate how token classification works
    """
    # Sample text with entities
    text = "John Smith lives in New York and works at Google"

    print("Token Classification Demonstration:")
    print("=" * 50)
    print(f"Text: {text}")

    # Tokenize
    encoding = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=64,
        return_tensors='pt'
    )

    # Get tokens for display
    tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'][0])

    # Get model predictions
    token_model.eval()
    with torch.no_grad():
        logits = token_model(
            encoding['input_ids'].to(device),
            encoding['attention_mask'].to(device)
        )

        # Get predictions for each token
        predictions = torch.argmax(logits, dim=-1)[0]  # [seq_len]

    print(f"\nToken-by-token predictions:")
    print("=" * 50)

    # Show first 15 meaningful tokens (skip padding)
    for i, (token, pred) in enumerate(zip(tokens[:15], predictions[:15])):
        if token not in ['[PAD]', '[CLS]', '[SEP]']:
            label = label_map[pred.item()]
            print(f"Token: '{token:12}' → Label: {label}")

    # Create a visualization
    meaningful_tokens = []
    meaningful_labels = []

    for token, pred in zip(tokens, predictions):
        if token not in ['[PAD]', '[CLS]', '[SEP]']:
            meaningful_tokens.append(token)
            meaningful_labels.append(label_map[pred.item()])

    print(f"\nVisualization:")
    print("=" * 50)
    for token, label in zip(meaningful_tokens, meaningful_labels):
        if label != "O":
            print(f"'{token}' → {label}")
        else:
            print(f"'{token}' → _")

demonstrate_token_classification()

# Show the difference in output shapes between different tasks
print(f"\nComparing Output Shapes:")
print("=" * 30)
print(f"Text Classification: [batch_size, num_classes] = [1, 3]")
print(f"Question Answering: [batch_size, seq_len] (start & end) = [1, 64] each")
print(f"Token Classification: [batch_size, seq_len, num_labels] = [1, 64, 3]")

## 9. Performance Comparison & Best Practices <a id="sec-best-practices"></a>

### Performance Considerations

| **Metric** | **Head Replacement** | **Full Fine-tuning** | **Feature Extraction** |
|------------|---------------------|---------------------|------------------------|
| **Training Time** | Fast (minutes) | Slow (hours) | Very Fast (seconds) |
| **Data Efficiency** | Good (1k+ samples) | Requires more (10k+) | Works with small data |
| **Performance** | Good-Very Good | Best | Lower |
| **Memory Usage** | Moderate | High | Low |
| **Multiple Tasks** | Easy | Requires separate models | Very Easy |

### Best Practices for Head Replacement

#### Data Preparation
- **Quality over quantity**: Prioritize high-quality samples over large volumes
- **Balanced datasets**: Maintain reasonable class distribution
- **Proper validation**: Use separate validation sets
- **Text preprocessing**: Clean data without over-processing

#### Model Architecture
- **Start simple**: Begin with single linear layer
- **Add complexity gradually**: Introduce dropout, multiple layers, skip connections as needed
- **Monitor overfitting**: Smaller heads reduce overfitting risk

#### Training Strategy
- **Learning rates**: Use higher rates (1e-3 to 1e-2) for head training
- **Training duration**: 3-10 epochs typically sufficient
- **Early stopping**: Monitor validation loss
- **Warm-up**: Optional for training stability

## 10. Practice Exercise: Movie Review Sentiment Analysis <a id="sec-exercises"></a>

### Exercise Overview

**Task**: Build a BERT-based classifier for movie review sentiment analysis

**Objective**: Classify reviews as "Positive", "Negative", or "Neutral"

**Implementation**: Complete pipeline from data preparation to evaluation

### Learning Objectives

- Create custom datasets for BERT
- Build and train classification heads
- Handle real-world text data
- Evaluate model performance

### Step 1: Prepare the Movie Review Dataset

First, let's create a realistic movie review dataset. Your task is to complete the missing parts!


In [None]:
# Step 1: Create Movie Review Dataset
# TODO: Complete the missing parts marked with "# YOUR CODE HERE"

import torch
from sklearn.model_selection import train_test_split

# Sample movie review data
movie_reviews = {
    'reviews': [
        "This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.",
        "Terrible film. Boring story, bad acting, waste of time. Would not recommend to anyone.",
        "It was okay, nothing special. Some good moments but overall just average entertainment.",
        "One of the best movies I've ever seen! Brilliant cinematography and outstanding performances.",
        "Not great, not terrible. Watchable but forgettable. Could have been much better.",
        "Awful movie with terrible dialogue. The worst film of the year without any doubt.",
        "Decent film with good character development. Worth watching but not a masterpiece.",
        "Absolutely loved it! Amazing story, great acting, and beautiful visuals. Highly recommended!",
        "Mediocre at best. Some interesting ideas but poor execution. Left me disappointed.",
        "Outstanding masterpiece! Every scene was perfect. This will be remembered as a classic.",
        "Boring and predictable. Nothing new or exciting. Felt like a waste of money.",
        "Pretty good movie overall. Well-made with solid performances from the entire cast."
    ],
    'labels': [2, 0, 1, 2, 1, 0, 1, 2, 1, 2, 0, 1]  # 0=Negative, 1=Neutral, 2=Positive
}

# Create label mapping
label_to_text = {0: "Negative", 1: "Neutral", 2: "Positive"}
print("Movie Review Dataset Created!")
print(f"Total reviews: {len(movie_reviews['reviews'])}")

# TODO: Calculate and print the class distribution
# HINT: Use collections.Counter on movie_reviews['labels']
import collections
class_distribution = # YOUR CODE HERE
print("\nClass Distribution:")
for label_id, count in class_distribution.items():
    label_name = # YOUR CODE HERE  # Get label name from label_to_text
    print(f"  {label_name}: {count} reviews")

# TODO: Split the data into train and validation sets
# HINT: Use train_test_split with test_size=0.3 and random_state=42
X_train, X_val, y_train, y_val = # YOUR CODE HERE

print(f"\nData Split:")
print(f"  Training samples: {len(X_train)}")
print(f"  Validation samples: {len(X_val)}")

### Step 2: Create the BERT Movie Classifier

Now build your BERT-based movie review classifier. Complete the missing methods!


In [None]:
# Step 2: Build BERT Movie Review Classifier
# TODO: Complete the missing parts in the class definition

class MovieReviewClassifier(nn.Module):
    """
    BERT-based movie review sentiment classifier
    """
    def __init__(self, model_name='bert-base-uncased', num_classes=3, dropout=0.3):
        super().__init__()

        # Load pre-trained BERT
        self.bert = BertModel.from_pretrained(model_name)

        # TODO: Add dropout layer
        self.dropout = # YOUR CODE HERE

        # TODO: Create classification head (Linear layer)
        # HINT: Input size should be self.bert.config.hidden_size, output size is num_classes
        self.classifier = # YOUR CODE HERE

        # TODO: Freeze BERT parameters
        # HINT: Loop through self.bert.parameters() and set requires_grad = False
        for param in self.bert.parameters():
            # YOUR CODE HERE

        print(f"Created Movie Review Classifier:")
        print(f"   Frozen BERT parameters: {sum(p.numel() for p in self.bert.parameters()):,}")
        print(f"   Trainable parameters: {sum(p.numel() for p in self.classifier.parameters()):,}")

    def forward(self, input_ids, attention_mask=None):
        # TODO: Get BERT outputs
        outputs = # YOUR CODE HERE

        # TODO: Get [CLS] token representation (first token)
        cls_output = # YOUR CODE HERE  # Shape: [batch_size, hidden_size]

        # TODO: Apply dropout
        cls_output = # YOUR CODE HERE

        # TODO: Apply classification head
        logits = # YOUR CODE HERE

        return logits

# Create the model
print("🎬 Creating Movie Review Classifier...")
movie_classifier = MovieReviewClassifier(num_classes=3)
movie_classifier = movie_classifier.to(device)
print("Model created successfully!")

### Step 3: Prepare Data for Training

Create PyTorch datasets and data loaders. Fill in the missing dataset implementation!


In [None]:
# Step 3: Create Dataset and DataLoaders
# TODO: Complete the MovieReviewDataset class

class MovieReviewDataset(Dataset):
    """
    Dataset for movie review classification
    """
    def __init__(self, reviews, labels, tokenizer, max_length=128):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        label = self.labels[idx]

        # TODO: Tokenize the review
        # HINT: Use self.tokenizer with truncation=True, padding='max_length', return_tensors='pt'
        encoding = # YOUR CODE HERE

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Create datasets
print("Creating datasets...")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# TODO: Create training and validation datasets
train_dataset = # YOUR CODE HERE
val_dataset = # YOUR CODE HERE

# TODO: Create data loaders with batch_size=4
train_loader = # YOUR CODE HERE
val_loader = # YOUR CODE HERE

print(f"Datasets created!")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

# Test the data loader
print("\n Sample batch:")
for batch in train_loader:
    print(f"  Input IDs shape: {batch['input_ids'].shape}")
    print(f"  Attention mask shape: {batch['attention_mask'].shape}")
    print(f"  Labels shape: {batch['label'].shape}")
    break

### Step 4: Train Your Model

Implement the training loop. Complete the missing training logic!


In [None]:
# Step 4: Training Setup and Loop
# TODO: Complete the training implementation

# TODO: Create optimizer for only the classifier parameters
# HINT: Use Adam with learning rate 2e-3
optimizer = # YOUR CODE HERE

# TODO: Create loss function
# HINT: Use CrossEntropyLoss
loss_fn = # YOUR CODE HERE

def train_movie_classifier(model, train_loader, val_loader, optimizer, loss_fn, num_epochs=3):
    """
    Train the movie review classifier
    """
    training_history = []

    for epoch in range(num_epochs):
        print(f"\n🚂 Epoch {epoch + 1}/{num_epochs}")

        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for batch in train_loader:
            # Move to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # TODO: Forward pass
            optimizer.zero_grad()
            logits = # YOUR CODE HERE

            # TODO: Calculate loss
            loss = # YOUR CODE HERE

            # TODO: Backward pass
            # YOUR CODE HERE  # loss.backward()
            # YOUR CODE HERE  # optimizer.step()

            # Track metrics
            train_loss += loss.item()
            predictions = torch.argmax(logits, dim=1)
            train_correct += (predictions == labels).sum().item()
            train_total += labels.size(0)

        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                # TODO: Forward pass (no gradients needed)
                logits = # YOUR CODE HERE
                loss = loss_fn(logits, labels)

                val_loss += loss.item()
                predictions = torch.argmax(logits, dim=1)
                val_correct += (predictions == labels).sum().item()
                val_total += labels.size(0)

        # Calculate metrics
        train_acc = train_correct / train_total
        val_acc = val_correct / val_total

        epoch_stats = {
            'epoch': epoch + 1,
            'train_loss': train_loss / len(train_loader),
            'train_acc': train_acc,
            'val_loss': val_loss / len(val_loader),
            'val_acc': val_acc
        }

        training_history.append(epoch_stats)

        print(f"Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}")
        print(f"Train Loss: {epoch_stats['train_loss']:.3f}, Val Loss: {epoch_stats['val_loss']:.3f}")

    return training_history

# TODO: Start training!
print("🚀 Starting training...")
# YOUR CODE HERE  # Call train_movie_classifier function

### Step 5: Test Your Trained Model

Finally, test your model on new movie reviews! Complete the prediction function.


In [None]:
# Step 5: Test the Trained Model
# TODO: Complete the prediction function

def predict_movie_sentiment(model, tokenizer, review_text, device):
    """
    Predict sentiment for a single movie review
    """
    model.eval()

    # TODO: Tokenize the review
    encoding = # YOUR CODE HERE  # Use tokenizer with appropriate parameters

    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        # TODO: Get model predictions
        logits = # YOUR CODE HERE

        # TODO: Convert to probabilities using softmax
        probabilities = # YOUR CODE HERE

        # TODO: Get predicted class
        prediction = # YOUR CODE HERE  # Use torch.argmax

    return prediction.item(), probabilities.cpu().numpy()

# Test on new movie reviews
test_reviews = [
    "This movie was a complete masterpiece! Incredible acting and stunning visuals.",
    "Boring and poorly written. Complete waste of time and money.",
    "It was okay, not bad but nothing special either.",
    "Absolutely terrible! Worst movie I've ever watched in my entire life.",
    "Pretty good film with some great moments and solid performances."
]

print("🧪 Testing Model on New Reviews:")
print("=" * 60)

for i, review in enumerate(test_reviews):
    # TODO: Get prediction for each review
    pred_label, probabilities = # YOUR CODE HERE

    predicted_sentiment = label_to_text[pred_label]
    confidence = probabilities[pred_label]

    print(f"\n📝 Review {i+1}: '{review[:50]}{'...' if len(review) > 50 else ''}'")
    print(f"🎯 Prediction: {predicted_sentiment}")
    print(f"📊 Confidence: {confidence:.3f}")

    # Show all probabilities
    all_probs = {label_to_text[j]: prob for j, prob in enumerate(probabilities)}
    print(f"📈 All probabilities: {all_probs}")

print("\n🎉 Congratulations! You've successfully completed the BERT head replacement exercise!")
print("💡 Try experimenting with:")
print("   • Different dropout values")
print("   • More training epochs")
print("   • Different learning rates")
print("   • Adding more layers to the head")

## Key Takeaways

### Core Concepts
- **BERT Architecture**: Understanding how BERT creates text representations
- **Head Replacement**: Replacing task-specific layers while preserving base knowledge
- **Base vs Head**: Conceptual separation enabling effective transfer learning

### Practical Skills
- **Text Classification**: Sentiment analysis with frozen BERT and trainable head
- **Question Answering**: Locating answer spans in context text
- **Token Classification**: Individual word labeling for NER tasks
- **Advanced Techniques**: Attention pooling, weighted loss, multi-layer heads

### When to Use Head Replacement
- Small to medium datasets (1k-10k samples)
- Quick prototyping and experimentation
- Multiple related tasks with shared base model
- Resource-constrained environments
- Preserving general language knowledge

### Performance Guidelines
- **Start simple**: Single linear layer often sufficient
- **Handle imbalanced data**: Apply weighted loss functions  
- **Experiment with pooling**: CLS, mean, max, or attention-based approaches
- **Monitor training**: Use proper validation and early stopping
- **Scale gradually**: Add complexity only when necessary

### Next Steps
- Experiment with different BERT variants (RoBERTa, DeBERTa)
- Explore domain-specific pre-trained models
- Investigate multi-task learning with shared base models
- Compare head replacement vs full fine-tuning performance
- Develop production-ready inference pipelines

## Additional Resources

### Essential Reading
- [BERT Paper (Original)](https://arxiv.org/abs/1810.04805) - The foundational paper
- [Hugging Face Transformers Documentation](https://huggingface.co/docs/transformers) - Comprehensive guide
- [The Illustrated BERT](http://jalammar.github.io/illustrated-bert/) - Visual explanations

### Advanced Topics
- [RoBERTa Improvements](https://arxiv.org/abs/1907.11692) - Better training strategies
- [DeBERTa Enhancements](https://arxiv.org/abs/2006.03654) - Disentangled attention
- [Domain Adaptation with BERT](https://arxiv.org/abs/2004.02288) - Specialized domains