# Week 1 Practical: Language Models & Transformers

## Foundations: Tokenization, Embeddings, Attention

### Learning Objectives

By the end of this practical session, you will:

1. **Understand tokenization algorithms** (Exercise 1)
   - Compare BPE, WordPiece, and SentencePiece
   - Train tokenizers on custom corpora
   - Understand how tokenization affects model input

2. **Analyze contextualized embeddings** (Exercise 2)
   - Explore how token representations change across transformer layers
   - Visualize word sense disambiguation (polysemy)
   - Investigate coreference resolution in embeddings

3. **Visualize attention mechanisms** (Exercise 3)
   - Use BertViz to explore attention patterns
   - Identify different attention head specializations
   - Understand how attention evolves through layers

4. **Implement self-attention from scratch** (Exercise 4)
   - Build a multi-head self-attention model
   - Train it on a toy task requiring positional reasoning
   - Analyze what attention patterns emerge during training

## Setup

The next cell imports the necessary modules – just run it, no need to look at
the details!

In [None]:
import os
import random
import tempfile
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.optim as optim
from bertviz import head_view, model_view
from IPython.display import HTML, display
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from tokenizers import Tokenizer
from tokenizers.models import BPE, WordPiece
from tokenizers.pre_tokenizers import ByteLevel, Whitespace
from tokenizers.trainers import BpeTrainer, WordPieceTrainer
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm.autonotebook import tqdm
from transformers import AutoModel, AutoTokenizer

warnings.filterwarnings('ignore')

# Device: CUDA > MPS > CPU
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

---

# Exercise 1: Tokenization

In this exercise, we will first look at how BPE tokenizes sentences. We will
then compare BPE, WordPiece and SentencePiece.


## 1. Byte Pair Encoding (BPE)

**Paper:** [Neural Machine Translation of Rare Words with Subword
Units](https://arxiv.org/abs/1508.07909) (Sennrich et al., 2016)

### How It Works

1. Start with a **character-level vocabulary** (or bytes for byte-level BPE)
2. Find the **most frequent adjacent pair** of tokens in the corpus
3. Merge that pair into a single new token
4. Repeat until vocabulary size is achieved

## 2. WordPiece

**Papers:**
- [Japanese and Korean Voice
  Search](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf)
  (Schuster & Nakajima, 2012)
- Adopted and popularized in [BERT](https://arxiv.org/abs/1810.04805) (Devlin
  et al., 2019)

### How It Works

1. Start with a **character-level vocabulary**
2. For each possible pair, compute a **likelihood score**: `likelihood =
   count(pair) / (count(first) × count(second))`
3. Merge the pair with the **highest likelihood** (not frequency)
4. Repeat N times


## 3. SentencePiece (Unigram)

**Paper:** [Subword Regularization: Improving Neural Network Translation
Models with Multiple Subword Candidates](https://arxiv.org/abs/1804.10959)
(Kudo 2018)

### How It Works

1. Start with a **large initial vocabulary** (all chars + common substrings,
   e.g using BPE)
2. Iteratively **remove the token** that decrease the least the likelihood
3. Continue until target vocabulary size is reached

## Basic Tokenization

We first load a tokenizer (from HuggingFace) – here, a *WordPiece* tokenizer.

In [None]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

We can now tokenize any text using the `tokenize` method:

In [None]:
text = "The quick brown fox jumps."
tokens = tokenizer.tokenize(text)
print(f"Text: {text}")
print(f"Tokens: {tokens}")
print(f"Num tokens: {len(tokens)}")

Notice that the above tokenization segments the text into word units. This is
not always the case as shown in the example below:

In [None]:
test_words = ["hello", "internationalization", "COVID-19"]
for word in test_words:
    tokens = tokenizer.tokenize(word)
    print(f"{word:30s} -> {tokens}")

## Tokenizing for Transformers

When using the tokenizer for a Transformer, we use token ids rather than
tokens. The call `tokenizer(text).input_ids` returns these ids, which can be
decoded with `tokenizer.convert_ids_to_tokens`. Notice how extra tokens
(`[CLS]` at the beginning, `[SEP]` at the end) are added by default.

In [None]:
ids = tokenizer(text).input_ids
tokens = tokenizer.convert_ids_to_tokens(ids)
print(f"{ids} -> {tokens}")

Finally, when using a Transformer, as with most models, we use batches of text. Here, we use two new options, `return_tensors` that ensure that we get PyTorch tensors as outputs, and `padding` that adds padding tokens when the length of the tokenized sentences are not the same.

In [None]:
tokenized = tokenizer(["the quick brown fox jumps.", "the squirrel jumps."], return_tensors="pt", padding=True)

print(tokenized.input_ids)
print(tokenized.attention_mask)

**Question:** What is the purpose of `attention_mask`?

## Step 1.4: Training new tokenizers

In the next cell, we define the `TokenizerTrainer` class that allows to *train* tokenizers.

### Understanding the `TokenizerTrainer` Class

This class trains three different tokenizers (BPE, WordPiece, and SentencePiece) on a given corpus and allows you to compare their behavior.

**Key methods:**
- `__init__(corpus, vocab_size)`: Trains all three tokenizers on the given corpus
- `encode_bpe(text)`, `encode_wp(text)`, `encode_sp(text)`: Tokenize text with each algorithm
- `compare(text)`: Returns tokenization results from all three tokenizers
- `get_vocab_*()`: Retrieve the learned vocabulary for each tokenizer

**Helper function:**
- `print_comparison(name, corpus, examples, vocab_size)`: Trains tokenizers and compares their outputs on example texts

The implementation is hidden below, but you can expand it if you want to see the details.

In [None]:
class TokenizerTrainer:
    """Trains both BPE and SentencePiece on the same corpus in the constructor."""
    
    def __init__(self, corpus: str, vocab_size: int = 50, suppress_output: bool = True):
        """
        Initialize and train both tokenizers.
        
        Args:
            corpus: Text corpus to train on
            vocab_size: Target vocabulary size
            suppress_output: Whether to suppress SentencePiece training logs
        """
        # Create temporary file for corpus
        with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
            f.write(corpus)
            corpus_file = f.name
        
        # Train both tokenizers immediately
        self.tokenizer_bpe = self._train_bpe(corpus_file, vocab_size)
        self.tokenizer_sp = self._train_sentencepiece(corpus_file, vocab_size, suppress_output)
        self.tokenizer_wp = self._train_wordpiece(corpus_file, vocab_size)
    
    @staticmethod
    def _train_bpe(corpus_file: str, vocab_size: int) -> Tokenizer:
        """Train and return BPE tokenizer."""
        tokenizer = Tokenizer(BPE())
        tokenizer.pre_tokenizer = ByteLevel()
        trainer = BpeTrainer(
            vocab_size=vocab_size,
            min_frequency=1,
            special_tokens=['[UNK]']
        )
        tokenizer.train([corpus_file], trainer)
        tokenizer.model.unk_token = '[UNK]'
        return tokenizer

    @staticmethod
    def _train_wordpiece(corpus_file: str, vocab_size: int) -> Tokenizer:
        """Train and return WordPiece tokenizer."""
        tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordPieceTrainer(
            vocab_size=vocab_size,
            min_frequency=1,
            special_tokens=['[UNK]']
        )
        tokenizer.train([corpus_file], trainer)
        return tokenizer
    
    @staticmethod
    def _train_sentencepiece(corpus_file: str, vocab_size: int, suppress_output: bool):
        """Train and return SentencePiece tokenizer (with optional output suppression)."""
        sp_prefix = os.path.join(tempfile.gettempdir(), f"sp_model_{os.getpid()}")
        
        if suppress_output:
            # Redirect file descriptors 1 (stdout) and 2 (stderr) to /dev/null
            # This works for C++ subprocesses, unlike contextlib.redirect_*
            devnull_fd = os.open(os.devnull, os.O_WRONLY)
            old_stdout = os.dup(1)
            old_stderr = os.dup(2)
            
            try:
                os.dup2(devnull_fd, 1)
                os.dup2(devnull_fd, 2)
                
                spm.SentencePieceTrainer.train(
                    input=corpus_file,
                    model_prefix=sp_prefix,
                    vocab_size=vocab_size,
                    model_type='bpe',
                    character_coverage=1.0,
                )
            finally:
                os.dup2(old_stdout, 1)
                os.dup2(old_stderr, 2)
                os.close(old_stdout)
                os.close(old_stderr)
                os.close(devnull_fd)
        else:
            spm.SentencePieceTrainer.train(
                input=corpus_file,
                model_prefix=sp_prefix,
                vocab_size=vocab_size,
                model_type='bpe',
                character_coverage=1.0,
            )
        
        return spm.SentencePieceProcessor(model_file=f"{sp_prefix}.model")
    
    def encode_bpe(self, text: str) -> list[str]:
        """Encode text with BPE."""
        return self.tokenizer_bpe.encode(text).tokens

    def encode_wp(self, text: str) -> list[str]:
        """Encode text with WordPiece."""
        return self.tokenizer_wp.encode(text).tokens
    
    def encode_sp(self, text: str) -> list[str]:
        """Encode text with SentencePiece."""
        ids = self.tokenizer_sp.encode(text, out_type=int)
        return [self.tokenizer_sp.id_to_piece(i) for i in ids]

    def compare(self, text: str) -> tuple[list[str], list[str]]:
        """Encode text with both tokenizers and return results."""
        return self.encode_bpe(text), self.encode_sp(text), self.encode_wp(text)
    
    def get_vocab_sp(self) -> dict[int, str]:
        """Get SentencePiece vocabulary as {id: piece} dict."""
        vocab_size = self.tokenizer_sp.vocab_size()
        return {self.tokenizer_sp.id_to_piece(i): i for i in range(vocab_size)}
    
    def get_vocab_bpe(self) -> dict[int, str]:
        """Get BPE vocabulary as {id: token} dict."""
        return self.tokenizer_bpe.get_vocab()
        
    def get_vocab_wp(self) -> dict[int, str]:
        """Get WordPiece vocabulary as {id: token} dict."""
        return self.tokenizer_wp.get_vocab()

def print_comparison(name: str, corpus: str, examples: list[str], vocab_size: int = 50):
    """
    Helper to train and compare tokenizers.
    
    Args:
        name: Name of this comparison (printed as header)
        corpus: Corpus to train on
        examples: List of text examples to encode and compare
        vocab_size: Vocabulary size for training
    """
    trainer = TokenizerTrainer(corpus, vocab_size=vocab_size)

    print("\n" + "-" * 80)
    print(f"--- {name} ---")
    print("-" * 80 + "\n")

    for example in examples:
        bpe_tokens, sp_tokens, wp_tokens = trainer.compare(example)
        print(f"Text: {example!r}")
        print(f"  BPE:           {bpe_tokens}")
        print(f"  SentencePiece: {sp_tokens}")
        print(f"  WordPiece:     {wp_tokens}")
        print()

    return trainer

In [None]:
corpus = """book booker booking booked bookshelf bookcase bookkeeper
reading reader readers reads reading bookshelf
teacher teaching teaches taught classroom
learning learner learns learned education
running runner runs run quickly running
""" * 10

examples = ["book", "booker", "booking", "booked", "bookshelf", "book shelf", "bookshélf"]

trainer = print_comparison("BPE vs SentencePiece Space Handling", corpus, examples, vocab_size=50)

In [None]:
display(HTML("<b>BPE vocabulary</b>"))
print(" ".join(sorted([k for k, v in trainer.get_vocab_bpe().items()])))

display(HTML("<b>WordPiece vocabulary</b>"))
print(" ".join(sorted([k for k, v in trainer.get_vocab_wp().items()])))

display(HTML("<b>SentencePiece vocabulary</b>"))
print(" ".join(sorted([k for k, v in trainer.get_vocab_sp().items()])))

**Exercise:** Try to find small datasets such that the WordPiece, BPE and/or SentencePiece behave differently, and try to explain why.

**Tips for exploration:**
- Try corpora with different characteristics (repetitive patterns, rare characters, different word lengths)
- Look at how each algorithm handles spaces and special characters
- Compare the vocabularies learned by each tokenizer
- Experiment with different vocabulary sizes to see how it affects tokenization

**Example questions to investigate:**
- How do the algorithms handle words not seen during training?
- What happens with accented characters or special symbols?
- How does vocabulary size affect the granularity of tokenization?

In [None]:
# [[STUDENT]]...

assert False, 'Not implemented yet'


# Exercise 2: Looking at (contextual) embeddings

The goal of this exercise is to look at how embeddings evolve within a transformer.
We will use a few sentences containing an ambiguous English word, "lead"

In [None]:
pd.DataFrame([
    ["guide", "The captain will lead the team to victory."], 
    ["guide", "She leads him by the hand through the dark corridor."],
], columns=["marker", "sentence"])

In [None]:
df_sentences = pd.DataFrame([
    # Sense 1: "lead" (verb) = to guide or direct
    ["guide", "The captain will lead the team to victory."], 
    ["guide", "They lead him by the hand through the dark corridor."],
    ["guide", "The roads lead to the village."],
    ["guide", "These evidences lead us to believe he is guilty."],
    ["guide", "Good teachers lead students to discover answers themselves."],

    # Sense 2: "lead" (noun) = the metal (pronounced "led")
    ["metal", "The old pipes contained lead and needed replacement."],
    ["metal", "Lead is a toxic heavy metal."],
    ["metal", "The artist used lead in the stained glass window."],
    ["metal", "Lead weights are used in fishing."],
    ["metal", "Gasoline with lead was banned decades ago."],

    # Sense 3: "lead" (noun) = main role or position
    ["main", "She played the lead in the school play."],
    ["main", "The lead singer of the band quit yesterday."],
    ["main", "He took the lead in the race."],
    ["main", "The company currently has a lead of 10 points."],
    ["main", "Our team has a commanding lead in the championship."],
    
    # Sense 4: "lead" (noun) = a clue or tip
    ["clue", "The detective followed a promising lead."],
    ["clue", "We got a new lead on the missing person."],
    ["clue", "The tip provided a solid lead for the investigation."],
    ["clue", "Sale lead were distributed to the team."],
], columns=['marker', 'sentence'])

In this exercise, we use [DistillBERT](https://huggingface.co/distilbert/distilbert-base-uncased), a small model.

In the next cell, we load an encoder model, `distilbert-base-uncased`, and its associated tokenizer.

Have a look at the configuration and PyTorch model, and try to understand the architecture and settings of the model.

In [None]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
model.to(device)
model.eval()

print("Configuration")
print(model.config)
print()

print("Model")
print(model)

## Extract Embeddings Helper

### Understanding the `process_sentences` function

The `process_sentences(df, model, tokenizer)` function processes a DataFrame of sentences through a transformer model and extracts embeddings at all layers.

**Input:**
- `df`: A DataFrame with a 'sentence' column (and potentially other metadata columns)
- `model`: A transformer model (must have `output_hidden_states=True`)
- `tokenizer`: The corresponding tokenizer

**Output:**
- `df_out`: An expanded DataFrame where each row represents one token, with columns:
  - Original columns from input df
  - `ix`: Global token index (unique across all sentences)
  - `sentence_ix`: Index of the source sentence
  - `token_ix`: Position of token within its sentence
  - `token`: The actual token string
- `embeddings`: A list indexed by `ix`, where each entry contains embeddings from all layers:
  - `embeddings[ix][0]`: Word embedding (layer 0, before any transformer processing)
  - `embeddings[ix][1]`: Output of 1st transformer layer
  - `embeddings[ix][n]`: Output of nth transformer layer

**Usage example:** After processing, you can select tokens of interest from `df_out` and visualize their embeddings across layers.

In [None]:
@torch.no_grad
def process_sentences(
    df: pd.DataFrame, model, tokenizer
) -> tuple[pd.DataFrame, list[torch.Tensor]]:
    """
    Process sentences through a transformer encoder.
    
    Returns:
        - DataFrame with columns: [original columns] + ix, sentence_ix, token_ix, token
        - List of embeddings indexed by ix: [word_embedding, layer_1, ..., layer_n]
    """
    device = next(model.parameters()).device
    model.eval()

    rows = []
    embeddings = []
    global_ix = 0

    for sentence_ix, row in tqdm(df.iterrows(), total=len(df)):
        sentence = row['sentence']

        # Tokenize
        encoded = tokenizer(sentence, return_tensors='pt').to(device)

        # Get all hidden states
        outputs = model(**encoded, output_hidden_states=True)

        # Extract token info
        tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])

        word_embeddings = model.embeddings.word_embeddings(encoded['input_ids'])[0].to('cpu')

        # Skip special tokens at start/end if needed
        for token_ix, token in enumerate(tokens):
            # Collect all layer embeddings for this token
            token_embeddings = [
                word_embeddings[token_ix]
            ] + [
                outputs.hidden_states[layer_idx][0, token_ix].cpu()
                for layer_idx in range(len(outputs.hidden_states))
            ]

            embeddings.append(token_embeddings)

            # Create row
            new_row = row.to_dict()
            new_row.update({
                'ix': global_ix,
                'sentence_ix': sentence_ix,
                'token_ix': token_ix,
                'token': token
            })
            rows.append(new_row)

            global_ix += 1

    df_out = pd.DataFrame(rows)
    return df_out, embeddings


df, embeddings = process_sentences(df_sentences, model, tokenizer)

### Understanding the Visualization Helper Functions

The following cells define several helper functions for analyzing and visualizing embeddings. These functions are hidden but documented here:

#### `compute_pca_projection(df_selection, embeddings, layer_ix)`
Reduces high-dimensional embeddings to 2D using PCA for visualization.
- **Returns:** DataFrame with added `x1`, `x2` columns (the 2D coordinates)

#### `show_pca_projection(df_selection, embeddings, layer_ix, group=None, grey=None, label=None)`
Visualizes embeddings in 2D space after PCA projection.
- `group`: pandas Series for coloring points by category (e.g., word sense)
- `grey`: pandas bool Series to mark points as grey (background context)
- `label`: pandas bool Series to show token labels on the plot
- **Use case:** See how tokens cluster based on context/meaning

#### `show_similarities(df_selection, embeddings, layer_ix, group=None, similarity='cosine')`
Creates a heatmap showing pairwise similarities between token embeddings.
- `similarity`: 'cosine' (normalized) or 'inner' (dot product)
- `group`: pandas Series to organize and group tokens in the heatmap
- **Use case:** Quantify how similar different token occurrences are

**How to use these functions:**
1. Select tokens from `df` (e.g., `df[df.token == "lead"]`)
2. Choose a layer to analyze (0 = word embeddings, 1-6 = transformer layers for DistilBERT)
3. Optionally specify grouping (e.g., by word sense) for better visualization

In [None]:
def compute_pca_projection(
    df_selection: pd.DataFrame, embeddings: list[torch.Tensor], layer_ix: int
) -> pd.DataFrame:
    """
    Compute 2D PCA projection for selected tokens at a given layer.
    
    Args:
        df_selection: Selected rows from df_out
        embeddings: List of embeddings from process_sentences
        layer_ix: Which layer to project (0 = token embeddings, 1+ = transformer layers)
    
    Returns:
        DataFrame with x1, x2 columns added
    """
    # Extract embeddings for selected tokens
    token_embeddings = []
    for ix in df_selection['ix'].values:
        token_embeddings.append(embeddings[ix][layer_ix].numpy())

    # Stack into matrix
    X = np.stack(token_embeddings)

    # Apply PCA
    pca = PCA(n_components=2)
    projections = pca.fit_transform(X)

    # Add to dataframe
    df_result = df_selection.copy()
    df_result['x1'] = projections[:, 0]
    df_result['x2'] = projections[:, 1]

    return df_result


def show_pca_projection(df_selection, embeddings, layer_ix, group=None, grey=None, label=None, figsize=(6, 5)):
    """
    Visualize PCA projection with optional grouping and labeling.
    
    Args:
        df_selection: Selected rows from df_out
        embeddings: List of embeddings from process_sentences
        layer_ix: Which layer to project
        group: Optional pandas Series (same length as df_selection) with group identifiers to color points
        grey: Optional pandas bool Series (same length as df_selection) marking points as grey
        label: Optional pandas bool Series (same length as df_selection) indicating which tokens to label
    """
    # Compute PCA
    df_pca = compute_pca_projection(df_selection, embeddings, layer_ix)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Get indices for proper alignment with Series
    indices = df_pca.index
    
    # Handle grouping and coloring
    if group is not None:
        groups = group.loc[indices].unique()
        colors = plt.cm.tab10(np.linspace(0, 1, len(groups)))
        group_colors = {g: colors[i] for i, g in enumerate(groups)}
        
        for group_val in groups:
            mask = (group.loc[indices] == group_val)
            if grey is not None:
                mask = mask & ~grey.loc[indices]
            
            ax.scatter(
                df_pca.loc[mask, 'x1'],
                df_pca.loc[mask, 'x2'],
                c=[group_colors[group_val]],
                label=str(group_val),
                alpha=0.7,
                s=50
            )
    else:
        ax.scatter(df_pca['x1'], df_pca['x2'], alpha=0.7, s=50, color='steelblue')
    
    # Grey points (not in any group)
    if grey is not None:
        grey_mask = grey.loc[indices]
        ax.scatter(
            df_pca.loc[grey_mask, 'x1'],
            df_pca.loc[grey_mask, 'x2'],
            c='lightgrey',
            alpha=0.5,
            s=50
        )
    
    # Add token labels where indicated
    if label is not None:
        label_mask = label.loc[indices]
        for idx in df_pca.loc[label_mask].index:
            row = df_pca.loc[idx]
            ax.annotate(row['token'], (row['x1'], row['x2']), fontsize=9)
    
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    if group is not None:
        ax.legend(title='Group')
    ax.set_title(f'PCA Projection (Layer {layer_ix})')
    
    plt.tight_layout()
    plt.show()

def show_similarities(df_selection, embeddings, layer_ix, group=None, similarity='cosine', figsize=(6,5)):
    """
    Visualize token similarities as a heatmap with optional group dividers.
    
    Args:
        df_selection: Selected rows from df_out
        embeddings: List of embeddings from process_sentences
        layer_ix: Which layer to use
        group: Optional pandas Series (same length as df_selection) to group and organize tokens
        similarity: 'cosine' or 'inner'
    """
    # Extract embeddings for selected tokens
    token_embeddings = []
    for ix in df_selection['ix'].values:
        token_embeddings.append(embeddings[ix][layer_ix].numpy())
    
    X = np.stack(token_embeddings)
    
    # Compute similarity matrix
    if similarity == 'cosine':
        sim_matrix = cosine_similarity(X)
    elif similarity == 'inner':
        sim_matrix = np.dot(X, X.T)
    else:
        raise ValueError("similarity must be 'cosine' or 'inner'")
    
    # Create labels and optionally reorder by group
    tokens = df_selection['token'].values
    sentence_ixs = df_selection['sentence_ix'].values
    indices = df_selection.index
    group_boundaries = []
    
    if group is not None:
        group_vals = group.loc[indices].values
        sort_idx = np.argsort(group_vals)
        
        sim_matrix = sim_matrix[np.ix_(sort_idx, sort_idx)]
        tokens = tokens[sort_idx]
        sentence_ixs = sentence_ixs[sort_idx]
        group_vals = group_vals[sort_idx]
        
        labels = [f"{token}(s{sentence_ixs[i]}, {group_vals[i]})" for i, token in enumerate(tokens)]
        
        # Find group boundaries
        for i in range(1, len(group_vals)):
            if group_vals[i] != group_vals[i-1]:
                group_boundaries.append(i - 0.5)
    else:
        labels = [f"{token}(s{sentence_ixs[i]})" for i, token in enumerate(tokens)]
    
    # Plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Set vmin/vmax only for cosine
    if similarity == 'cosine':
        im = ax.imshow(sim_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1)
    else:
        im = ax.imshow(sim_matrix, cmap='coolwarm', aspect='auto')
    
    ax.set_xticks(np.arange(len(labels)))
    ax.set_yticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9)
    ax.set_yticklabels(labels, fontsize=9)
    
    # Draw group dividers
    for boundary in group_boundaries:
        ax.axhline(y=boundary, color='black', linewidth=2)
        ax.axvline(x=boundary, color='black', linewidth=2)
    
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label(similarity)
    ax.set_title(f'Token Similarities - {similarity} (Layer {layer_ix})')
    
    plt.tight_layout()
    plt.show()

In [None]:
df[df.token == "lead"].head()

In [None]:
df_lead = df[df.token == "lead"]

show_pca_projection(df_lead, embeddings, layer_ix=6, group=df_lead.marker)

In [None]:
show_similarities(df_lead, embeddings, layer_ix=2, group=df.marker, similarity='inner')

**Question:** How can you interpret those results?

**Things to consider:**
- Do embeddings of "lead" with the same meaning cluster together?
- How do the embeddings change across different layers? (Try comparing layer 0, 3, and 6)
- What does high/low cosine similarity tell you about contextual understanding?
- Are the word embeddings (layer 0) context-aware, or does contextualization happen in later layers?

## Exercise: Coreference Resolution

Now, try to use the same kind of analysis to look into another linguistic phenomenon: **coreference** (anaphora resolution).

**Background:** Coreference occurs when different words refer to the same entity. For example:
- "John went to the store. **He** bought milk." — "He" refers to "John"
- "The cat chased **its** tail." — "its" refers to "The cat"

**Your task:**
1. Create sentences with clear coreference relationships (pronouns referring to nouns)
2. Process them through the model using `process_sentences`
3. Use `show_pca_projection` and `show_similarities` to analyze:
   - Do pronoun embeddings become similar to their referents in deeper layers?
   - At which layer does the model seem to "resolve" coreferences?
4. Try examples with ambiguous pronouns to see how the model handles them

**Starter code below** shows a basic example. Try modifying it to explore different scenarios!

In [None]:
# [[STUDENT]]...

assert False, 'Not implemented yet'


---

# Exercise 3: Attention Visualization with BertViz

In this exercise, we'll visualize the attention patterns learned by the transformer model. Attention mechanisms allow the model to focus on different parts of the input when processing each token.

**BertViz** provides interactive visualizations of attention weights across:
- **Heads**: The model uses multiple attention heads (DistilBERT has 12 heads per layer)
- **Layers**: DistilBERT has 6 layers, each learning different patterns

**What to look for:**
- Do certain heads specialize in specific patterns (e.g., attending to previous words, next words, or syntactic dependencies)?
- How do attention patterns change across layers?
- Can you identify heads that might be learning linguistic phenomena (syntax, semantics)?

## Step 3.1: Load Attention Model

We load the same model but with `output_attentions=True` to extract attention weights.

In [None]:
model_attn = AutoModel.from_pretrained(model_name, output_attentions=True)
model_attn.to(device)
model_attn.eval()
print("Model loaded")

## Step 3.2: Head View

The **head view** shows attention from each token to all other tokens, for each attention head.

**How to use:**
- Click on a token to see where it attends
- Compare different heads to see if they learn different patterns
- Darker lines = stronger attention

**Try these experiments:**
- Does "the" attend strongly to the noun it modifies?
- Do verbs attend to their subjects or objects?
- Try longer, more complex sentences to see clearer patterns

In [None]:
text = "The cat sat on the mat."
inputs = tokenizer.encode(text, return_tensors="pt").to(device)
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

with torch.no_grad():
    outputs = model_attn(inputs, output_attentions=True)
    attention = outputs[-1]

print(f"Text: {text}")
head_view(attention, tokens)

## Step 3.3: Model View

The **model view** shows attention across all layers and heads simultaneously, allowing you to see how attention patterns evolve through the network.

**How to use:**
- Rows = layers, Columns = heads
- Click on any head to see its attention pattern
- Compare early layers (bottom) vs. late layers (top)

**Questions to explore:**
- Do early layers capture local/syntactic patterns while later layers capture semantic relationships?
- Which layer seems most important for understanding the sentence meaning?

In [None]:
text = "Alice gave the book to Bob."
inputs = tokenizer.encode(text, return_tensors="pt").to(device)
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

with torch.no_grad():
    outputs = model_attn(inputs, output_attentions=True)
    attention = outputs[-1]

model_view(attention, tokens)

---

# Exercise 4: Self-Attention Implementation

In this exercise, we will implement multi-head self-attention. In order to test its behavior, we will use a simple toy dataset.

## Token Matching Task

For each token at position $i$ with value $t_i$, classify it as 1 if:
1. The value $(t_i - 1)$ appears **somewhere before** position $i$, AND
2. The value $(t_i + 1)$ appears **somewhere after** position $i$

Otherwise classify as 0.

**Example**: `5 3 4 1 5 6` → `0 0 1 0 1 0`

| Pos | Token | Need (t-1) | Need (t+1) | Before? | After? | Label |
|-----|-------|-----------|-----------|---------|--------|-------|
| 0   | 5     | 4         | 6         | ✗       | ✓      | 0     |
| 1   | 3     | 2         | 4         | ✗       | ✓      | 0     |
| 2   | 4     | 3         | 5         | ✓ (pos 1) | ✓ (pos 4) | **1** |
| 3   | 1     | 0         | 2         | ✗       | ✗      | 0     |
| 4   | 5     | 4         | 6         | ✓ (pos 2) | ✓ (pos 5) | **1** |
| 5   | 6     | 5         | 7         | ✓ (pos 4) | ✗      | 0     |

**Example**: `1 2 3 2 3` → `0 1 1 0 0`

| Pos | Token | Need (t-1) | Need (t+1) | Before? | After? | Label |
|-----|-------|-----------|-----------|---------|--------|-------|
| 0   | 1     | 0         | 2         | ✗       | ✓      | 0     |
| 1   | 2     | 1         | 3         | ✓ (pos 0) | ✓ (pos 2) | **1** |
| 2   | 3     | 2         | 4         | ✓ (pos 1) | ✗      | 0     |
| 3   | 2     | 1         | 3         | ✓ (pos 0) | ✓ (pos 4) | **1** |
| 4   | 3     | 2         | 4         | ✓ (pos 3) | ✗      | 0     |

**Why this task for attention?** The model must learn to:
1. Lookup what values need to be found (t-1 and t+1)
2. Search the left context and right context
3. Use attention patterns to gather evidence across positions

## The Toy Dataset

### Understanding the `TokenMatchingDataset` Class

This dataset generates sequences and labels for the token matching task described above.

**Constructor parameters:**
- `num_sequences`: Number of random sequences to generate
- `seq_length`: Length of each sequence
- `vocab_size`: Number of possible token values (0 to vocab_size-1)
- `seed`: Random seed for reproducibility

**Output format:**
- `sequence`: A tensor of token IDs, e.g., `[5, 3, 4, 1, 5, 6]`
- `labels`: Binary labels (0 or 1) for each position

The dataset automatically computes labels based on the rule: label=1 if (t_i-1) appears before position i AND (t_i+1) appears after position i.

In [None]:
class TokenMatchingDataset(Dataset):
    
    def __init__(self, num_sequences: int, seq_length: int, vocab_size: int, seed: int = None):
        if seed is not None:
            random.seed(seed)
            torch.manual_seed(seed)
        
        self.sequences = []
        self.labels = []
        
        for _ in range(num_sequences):
            seq = torch.randint(0, vocab_size, (seq_length,))
            labels = []
            
            for i in range(seq_length):
                t_i = seq[i].item()
                
                # Check if t_i - 1 appears before position i
                has_predecessor = any(seq[j].item() == t_i - 1 for j in range(i))
                
                # Check if t_i + 1 appears after position i
                has_successor = any(seq[j].item() == t_i + 1 for j in range(i + 1, seq_length))
                
                label = 1 if (has_predecessor and has_successor) else 0
                labels.append(label)
            
            self.sequences.append(seq)
            self.labels.append(torch.tensor(labels, dtype=torch.long))
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return {
            'sequence': self.sequences[idx],
            'labels': self.labels[idx],
        }

train_data = TokenMatchingDataset(200, 10, 6)
val_data = TokenMatchingDataset(50, 10, 6)
print(f"Dataset: {len(train_data)} train, {len(val_data)} val")

print(train_data[0])

## Model Components

### Positional Encoding

Since self-attention doesn't inherently capture word order (it's permutation-invariant), we add **positional encodings** to give the model information about token positions.

The `PositionalEncoding` class implements the sinusoidal positional encoding from the original Transformer paper:
- Even dimensions use sine: $PE_{pos, 2i} = \sin(pos / 10000^{2i/d})$
- Odd dimensions use cosine: $PE_{pos, 2i+1} = \cos(pos / 10000^{2i/d})$

These are added to the token embeddings before passing through the attention mechanism.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# -- Usage

# Create a position embedding module
pe = PositionalEncoding(128, 100)

# Apply it to the output of a transformer (batch x max length x dimension)
pe(torch.randn(2, 10, 128))

## Full Model Implementation

### Understanding the Model Architecture

Below we implement a complete self-attention model for the token matching task. The model consists of:

1. **`Attention` class**: Implements scaled dot-product attention
   - Computes attention scores: $\text{scores} = QK^T / \sqrt{d}$
   - Applies softmax to get attention weights
   - Returns weighted sum of values: $\text{output} = \text{softmax}(\text{scores}) \cdot V$

2. **`SelfAttention` class**: The complete model with:
   - **Token embeddings**: Convert token IDs to dense vectors
   - **Positional encodings**: Add position information
   - **Multi-head attention**: Learn multiple attention patterns simultaneously
   - **Layer normalization**: Stabilize training
   - **Classifier**: Final linear layer to predict 0 or 1 for each token

### Implementation Exercise

The `__init__` and `forward` methods have placeholders that you need to complete. The reference implementation is provided in the code but hidden from the notebook output.

**Key concepts to implement:**
- Multi-head attention: Split d-dimensional embeddings into `heads` separate attention computations
- Each head works with dimension `d // heads`
- Concatenate all head outputs and project with output layer

In [None]:
class Attention(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.scale = 1 / np.sqrt(d)
    
    def forward(self, Q, K, V):
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        weights = torch.softmax(scores, dim=-1)
        return torch.matmul(weights, V), weights

class SelfAttention(nn.Module):
    
    def __init__(self, vocab_size=10, d=64, heads=4, max_len=64):
        # Implement the initialization

        assert False, 'Not implemented yet'

    
    def forward(self, x):
        # Implement the forward

        assert False, 'Not implemented yet'


## Training the Model on the Toy Dataset

Now that the model has been defined, let's train it on our toy dataset.

**Training details:**
- **Loss function**: Binary Cross-Entropy with Logits (BCEWithLogitsLoss) - suitable for binary classification
- **Optimizer**: AdamW with learning rate 1e-3 and weight decay for regularization
- **Metrics**: Accuracy (percentage of correct predictions)
- **Epochs**: 50 training iterations through the dataset

**What to observe:**
- Does the model achieve high accuracy (>90%)?
- How quickly does it converge?
- Is there overfitting (train accuracy >> val accuracy)?

The training loop below will show progress and plot accuracy curves.

In [None]:
vocab_size=4
seq_length=10

model = SelfAttention(vocab_size=vocab_size, d=36, heads=4, max_len=seq_length)
model.to(device)

dataset = TokenMatchingDataset(num_sequences=3000, seq_length=seq_length, vocab_size=vocab_size, seed=42)
train_set, val_set = random_split(dataset, [0.8, 0.2])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)

num_epochs = 50
train_accs = []
val_accs = []
pb = tqdm(range(num_epochs))

for epoch in pb:
    # Training
    model.train()
    correct, total = 0, 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
        sequences = batch['sequence'].to(device)
        labels = batch['labels'].float().to(device)
        
        logits, _ = model(sequences)
        loss = criterion(logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        preds = (torch.sigmoid(logits) > 0.5).long()
        correct += (preds == labels.long()).sum().item()
        total += labels.numel()
    
    train_acc = correct / total
    train_accs.append(train_acc)
    
    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            sequences = batch['sequence'].to(device)
            labels = batch['labels'].float().to(device)
            logits, _ = model(sequences)
            preds = (torch.sigmoid(logits) > 0.5).long()
            correct += (preds == labels.long()).sum().item()
            total += labels.numel()
    
    val_acc = correct / total
    val_accs.append(val_acc)
    
    pb.set_description(f"Epoch {epoch+1}: Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")

# Plot
plt.figure(figsize=(10, 5))
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Interpreting the Results

**Question:** Can you explain what is happening by looking at the attention
heads?

**To investigate:**
1. Look at the attention weights from the trained model
2. Examine what patterns each head has learned:
   - Does a head attend to specific token values or positions?
   - Can you identify heads that look for predecessors (t-1) vs successors
     (t+1)?
3. Try visualizing attention for specific examples from the dataset
4. Consider: How does the model solve this task using only attention?

In [None]:
#
# Study the attention heads

assert False, 'Not implemented yet'
