# Understanding Transformer Architecture

This notebook provides a hands-on exploration of transformer architecture - the foundation of modern LLMs like GPT, Llama, and others. Through visualization and code exploration, we'll build an intuition for why transformers have revolutionized NLP.

## Learning Objectives

By the end of this notebook, you'll understand:
- The key components of transformer architecture
- How self-attention works and why it's revolutionary
- How positional encoding enables sequence processing
- The role of multi-head attention
- How to inspect and visualize these components in real models
- The differences between encoder-only, decoder-only, and encoder-decoder architectures

Let's start by setting up our environment.

In [None]:
# Install required packages if not already installed
%pip install transformers torch matplotlib pandas seaborn numpy plotly

Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting plotly
  Downloading plotly-6.1.0-py3-none-any.whl.metadata (6.9 kB)
Collecting narwhals>=1.15.1 (from plotly)
  Downloading narwhals-1.39.1-py3-none-any.whl.metadata (11 kB)
Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
Downloading plotly-6.1.0-py3-none-any.whl (16.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.1/16.1 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading narwhals-1.39.1-py3-none-any.whl (355 kB)
Installing collected packages: narwhals, plotly, seaborn
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [seaborn]m2/3[0m [seaborn]
[1A[2KSuccessfully installed narwhals-1.39.1 plotly-6.1.0 seaborn-0.13.2


In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from transformers import AutoModel, AutoModelForCausalLM, AutoConfig, AutoTokenizer
from torch.nn import functional as F

# Set styles for better visualization
plt.style.use('ggplot')
sns.set(style="whitegrid")

# Ensure plots appear in the notebook
%matplotlib inline

# Enable interactive widgets
from IPython.display import display, HTML

## 1. The Evolution to Transformers

Before diving into transformers, let's understand why they were such a breakthrough:

1. **RNNs/LSTMs (Pre-2017)**: 
   - Processed text sequentially (one word at a time)
   - Suffered from vanishing gradients with long sequences
   - Limited parallel processing capabilities

2. **Convolutional Networks for NLP**:
   - Better parallelization than RNNs
   - Limited receptive field (context window)

3. **Transformers (2017+)**:
   - Process entire sequences simultaneously
   - Capture long-range dependencies through attention
   - Highly parallelizable training

The key innovation of transformers was the **attention mechanism**, which allowed each word to directly "attend" to all other words in a sequence, regardless of their positions.


## 2. Exploring Pre-trained Transformer Models

Let's start by loading some pre-trained transformer models and examining their architecture. We'll look at a few different types:

1. **GPT-2**: A decoder-only transformer (for text generation)
2. **BERT**: An encoder-only transformer (for understanding text)
3. **T5**: An encoder-decoder transformer (for translation/summarization)

In [None]:
# Function to explore a model's configuration and properties
def explore_model(model_name):
    """
    Load and explore a pre-trained model's configuration and architecture.
    
    Args:
        model_name (str): Name of the HuggingFace model to load
    
    Returns:
        tuple: The loaded model and its configuration
    """
    print(f"Exploring model: {model_name}")
    print("-" * 50)
    
    # Load model configuration first to examine parameters
    config = AutoConfig.from_pretrained(model_name)
    print(f"Model type: {config.model_type}")
    print(f"Model architecture: {config.architectures if hasattr(config, 'architectures') else 'Not specified'}")
    
    # Print key model dimensions
    if hasattr(config, "hidden_size"):
        print(f"Hidden size: {config.hidden_size}")
    if hasattr(config, "num_hidden_layers"):
        print(f"Number of layers: {config.num_hidden_layers}")
    if hasattr(config, "num_attention_heads"):
        print(f"Number of attention heads: {config.num_attention_heads}")
    if hasattr(config, "intermediate_size"):
        print(f"Intermediate (feed-forward) size: {config.intermediate_size}")
    
    # Print model's vocabulary size
    if hasattr(config, "vocab_size"):
        print(f"Vocabulary size: {config.vocab_size}")
    
    # Calculate theoretical parameter count from config
    try:
        if config.model_type == "gpt2":
            # Rough parameter calculation for GPT-2 style models
            embed_params = config.vocab_size * config.hidden_size
            pos_embed_params = config.max_position_embeddings * config.hidden_size
            layer_params = 12 * config.hidden_size * config.hidden_size + \
                           4 * config.hidden_size * config.intermediate_size + \
                           config.intermediate_size * config.hidden_size + \
                           2 * config.hidden_size + config.intermediate_size
            total_layer_params = layer_params * config.num_hidden_layers
            approx_total = embed_params + pos_embed_params + total_layer_params
            print(f"Approximate parameter count from config: {approx_total:,}")
        else:
            print("Parameter calculation from config not implemented for this model type")
    except:
        print("Could not estimate parameters from config")
    
    # Load the actual model
    print("\nLoading model weights to count actual parameters...")
    model = AutoModel.from_pretrained(model_name)
    
    # Count actual parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Actual parameter count: {total_params:,}")
    
    # Print trainable vs non-trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {total_params - trainable_params:,}")
    
    print("\nModel architecture summary:")
    # Print the first level of modules
    for name, module in model.named_children():
        print(f"  {name}: {module.__class__.__name__}")
        
        # For the first layer, print more details
        if "encoder" in name or "decoder" in name or name == "transformer":
            # Try to find the first layer
            first_layer = None
            if hasattr(module, "layer") and hasattr(module.layer, "0"):
                first_layer = module.layer[0]
            elif hasattr(module, "h") and hasattr(module.h, "0"):
                first_layer = module.h[0]
            elif hasattr(module, "layers") and len(module.layers) > 0:
                first_layer = module.layers[0]
                
            if first_layer is not None:
                print(f"\nExamining the first layer:")
                for sub_name, sub_module in first_layer.named_children():
                    print(f"    {sub_name}: {sub_module.__class__.__name__}")
                    # For attention module, print even more details
                    if "attn" in sub_name or "attention" in sub_name:
                        for attn_name, attn_module in sub_module.named_children():
                            print(f"      {attn_name}: {attn_module.__class__.__name__}")
    
    return model, config

In [None]:
# Explore a decoder-only transformer (GPT-2)
gpt2_model, gpt2_config = explore_model("gpt2")

In [None]:
# Explore an encoder-only transformer (BERT)
bert_model, bert_config = explore_model("bert-base-uncased")

In [None]:
# Explore an encoder-decoder transformer (T5)
t5_model, t5_config = explore_model("t5-small")

## 3. Understanding Self-Attention

The heart of the transformer is the **self-attention mechanism**. This is what allows transformers to weigh the importance of words in relation to each other.

Let's visualize and understand how self-attention works:


In [None]:
# Function to visualize self-attention
def visualize_self_attention(text, model_name="gpt2"):
    """
    Tokenize a text, run it through a model, and visualize the self-attention patterns.
    
    Args:
        text (str): Input text to analyze
        model_name (str): Model to use for analysis
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, output_attentions=True)
    
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt")
    
    # Get the token IDs and convert them back to tokens for display
    input_ids = inputs["input_ids"][0]
    tokens = [tokenizer.decode([token_id]) for token_id in input_ids]
    
    # Run through the model and get attention weights
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get attention weights (shape: [layers, heads, seq_len, seq_len])
    attention = outputs.attentions
    
    # Print model information
    print(f"Model: {model_name}")
    print(f"Number of layers: {len(attention)}")
    print(f"Number of attention heads: {attention[0].shape[1]}")
    print(f"Input sequence length: {len(tokens)}")
    
    # Select a specific layer and head to visualize
    layer_idx = 0  # First layer
    head_idx = 0   # First attention head
    
    # Create attention heatmap for the selected layer and head
    attention_weights = attention[layer_idx][0, head_idx].numpy()
    
    # Create a DataFrame for the heatmap
    df = pd.DataFrame(attention_weights, index=tokens, columns=tokens)
    
    # Plot the heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(df, annot=True, cmap="YlGnBu", fmt=".2f")
    plt.title(f"Self-Attention Weights (Layer {layer_idx+1}, Head {head_idx+1})")
    plt.ylabel("Query Tokens")
    plt.xlabel("Key Tokens")
    plt.tight_layout()
    plt.show()
    
    # Let's also create an interactive plot with Plotly
    fig = px.imshow(
        attention_weights,
        x=tokens,
        y=tokens,
        color_continuous_scale='Viridis',
        labels=dict(x="Key Tokens", y="Query Tokens", color="Attention Weight"),
        title=f"Self-Attention Weights (Layer {layer_idx+1}, Head {head_idx+1})"
    )
    
    # Customize layout
    fig.update_layout(
        width=700,
        height=600,
        xaxis=dict(side="top"),
    )
    
    # Show the interactive plot
    fig.show()
    
    # Create a function to explore different layers and heads
    def explore_attention(layer=0, head=0):
        attention_weights = attention[layer][0, head].numpy()
        df = pd.DataFrame(attention_weights, index=tokens, columns=tokens)
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(df, annot=True, cmap="YlGnBu", fmt=".2f")
        plt.title(f"Self-Attention Weights (Layer {layer+1}, Head {head+1})")
        plt.ylabel("Query Tokens")
        plt.xlabel("Key Tokens")
        plt.tight_layout()
        plt.show()
    
    # Return the explore function for further exploration
    return explore_attention, attention, tokens

In [None]:
# Let's visualize self-attention patterns for a simple sentence
explore_fn, attention_weights, tokens = visualize_self_attention(
    "The transformer architecture revolutionized natural language processing."
)

In [None]:
# Look at different layers and heads
explore_fn(layer=5, head=3)  # Middle layer, different head

In [None]:
# Let's also look at the last layer
explore_fn(layer=11, head=0)  # Last layer (for GPT-2), first head

### Understanding Self-Attention Computation

Now let's break down the mathematical computation behind self-attention. Self-attention involves three main steps:

1. **Computing Query, Key, and Value vectors** from input embeddings
2. **Calculating attention scores** between all tokens
3. **Creating weighted representations** by aggregating values according to attention scores

Let's implement these steps manually to understand the process:


In [None]:
# Implement self-attention from scratch
def self_attention_from_scratch(input_embeddings, d_k=64):
    """
    Implement vanilla self-attention from scratch.
    
    Args:
        input_embeddings: Input token embeddings [batch_size, seq_len, embedding_dim]
        d_k: Dimensionality of query and key vectors
    
    Returns:
        context_vectors: Attention-weighted outputs
        attention_weights: Attention weight matrix
    """
    # For simplicity, we'll use random weight matrices
    batch_size, seq_len, d_model = input_embeddings.shape
    
    # 1. Create random weight matrices for Q, K, V projections
    W_Q = torch.randn(d_model, d_k)
    W_K = torch.randn(d_model, d_k)
    W_V = torch.randn(d_model, d_model)  # V typically projects to the same dimension
    
    # 2. Create Query, Key, Value projections
    Q = torch.matmul(input_embeddings, W_Q)  # [batch_size, seq_len, d_k]
    K = torch.matmul(input_embeddings, W_K)  # [batch_size, seq_len, d_k]
    V = torch.matmul(input_embeddings, W_V)  # [batch_size, seq_len, d_model]
    
    # 3. Calculate attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch_size, seq_len, seq_len]
    
    # 4. Scale the scores (to prevent softmax saturation)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # 5. Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)  # [batch_size, seq_len, seq_len]
    
    # 6. Apply attention weights to values
    context_vectors = torch.matmul(attention_weights, V)  # [batch_size, seq_len, d_model]
    
    return context_vectors, attention_weights


In [None]:
# Create some toy token embeddings and visualize self-attention
def visualize_self_attention_computation():
    # Create toy embeddings for a 5-token sequence
    seq_len = 5
    d_model = 64
    batch_size = 1
    
    # Random embeddings for demonstration
    embeddings = torch.randn(batch_size, seq_len, d_model)
    
    # Run our self-attention implementation
    context_vectors, attention_weights = self_attention_from_scratch(embeddings)
    
    # Visualize the attention weights
    attention_matrix = attention_weights[0].detach().numpy()  # Remove batch dimension
    
    # Create a heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(attention_matrix, annot=True, cmap="YlGnBu", fmt=".2f")
    plt.title("Self-Attention Weights (From Scratch Implementation)")
    plt.xlabel("Key Positions")
    plt.ylabel("Query Positions")
    plt.tight_layout()
    plt.show()
    
    # Visualize the query-key-value computation
    fig = go.Figure()
    
    # Add steps for the self-attention computation
    steps = [
        "Token Embeddings [seq_len, d_model]",
        "Query Projection (Q) [seq_len, d_k]",
        "Key Projection (K) [seq_len, d_k]",
        "Value Projection (V) [seq_len, d_model]",
        "Attention Scores = Q·K^T [seq_len, seq_len]",
        "Scaled Attention = Scores/√d_k",
        "Attention Weights = softmax(Scaled Attention)",
        "Output = Weights·V [seq_len, d_model]"
    ]
    
    # Create a visualization of the self-attention process
    fig.add_trace(go.Scatter(
        x=[1, 2, 3, 4, 5, 6, 7, 8],
        y=[1, 1, 1, 1, 1, 1, 1, 1],
        mode="markers+text",
        marker=dict(size=15, color="blue"),
        text=steps,
        textposition="top center"
    ))
    
    # Add arrows to show the flow
    for i in range(len(steps)-1):
        fig.add_annotation(
            x=i+1, y=1,
            ax=i+2, ay=1,
            xref="x", yref="y",
            axref="x", ayref="y",
            showarrow=True,
            arrowhead=3,
            arrowsize=1.5,
            arrowwidth=2,
            arrowcolor="red"
        )
    
    # Update layout
    fig.update_layout(
        title="Self-Attention Computation Flow",
        xaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
        yaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
        showlegend=False,
        width=900,
        height=300,
        plot_bgcolor="white"
    )
    
    fig.show()
    
    return attention_weights


In [None]:
# Run the visualization
attention_weights = visualize_self_attention_computation()

## 4. Positional Encoding: How Transformers Understand Order

Unlike RNNs, transformers process all tokens simultaneously, losing the natural order of the sequence. **Positional encoding** solves this by adding position information to token embeddings.

Let's visualize how positional encodings work:

In [None]:
# Implement and visualize positional encoding
def positional_encoding(seq_len, d_model):
    """
    Compute sinusoidal positional encodings as used in the original transformer paper.
    
    Args:
        seq_len: Maximum sequence length
        d_model: Dimensionality of the model embeddings
    
    Returns:
        pos_encoding: Positional encodings [seq_len, d_model]
    """
    # Create empty positional encoding matrix
    pos_encoding = torch.zeros(seq_len, d_model)
    
    # Create position indices
    positions = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    
    # Calculate frequencies for sine/cosine functions
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
    
    # Apply sine to even indices
    pos_encoding[:, 0::2] = torch.sin(positions * div_term)
    
    # Apply cosine to odd indices
    pos_encoding[:, 1::2] = torch.cos(positions * div_term)
    
    return pos_encoding


In [None]:
# Visualize positional encodings
def visualize_positional_encoding():
    # Generate positional encodings
    seq_len = 30
    d_model = 64
    pos_enc = positional_encoding(seq_len, d_model)
    
    # Plot as a heatmap
    plt.figure(figsize=(10, 6))
    sns.heatmap(pos_enc, cmap="coolwarm")
    plt.title("Positional Encodings")
    plt.xlabel("Embedding Dimension")
    plt.ylabel("Position in Sequence")
    plt.tight_layout()
    plt.show()
    
    # Plot a few dimensions to show the wavelength patterns
    plt.figure(figsize=(12, 6))
    
    # Plot several dimensions
    dims_to_plot = [0, 1, 2, 3, 15, 31, 47, 63]
    for i, dim in enumerate(dims_to_plot):
        plt.subplot(2, 4, i+1)
        plt.plot(pos_enc[:, dim])
        plt.title(f"Dimension {dim}")
        plt.xlabel("Position")
        plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Explain why this works
    print("Why Sinusoidal Positional Encoding Works:")
    print("1. Each position has a unique encoding pattern")
    print("2. Relative positions have similar patterns at different frequencies")
    print("3. The model can learn to attend to similar positions through these patterns")
    print("4. Lower dimensions (higher frequencies) help distinguish nearby positions")
    print("5. Higher dimensions (lower frequencies) capture long-range patterns")

In [None]:
# Run the visualization
visualize_positional_encoding()


## 5. Multi-Head Attention: Looking from Different Perspectives

Instead of a single attention mechanism, transformers use **multi-head attention**. Each head can focus on different aspects of the relationships between tokens.

Let's visualize how different attention heads focus on different patterns:


In [None]:
# Function to extract and visualize multi-head attention
def visualize_multi_head_attention(text, model_name="gpt2", layer_idx=0):
    """
    Visualize how different attention heads focus on different patterns.
    
    Args:
        text (str): Input text to analyze
        model_name (str): Model to use for analysis
        layer_idx (int): Which layer to analyze
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, output_attentions=True)
    
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt")
    
    # Get the token IDs and convert them back to tokens for display
    input_ids = inputs["input_ids"][0]
    tokens = [tokenizer.decode([token_id]) for token_id in input_ids]
    
    # Run through the model and get attention weights
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get attention weights for the selected layer
    attention = outputs.attentions[layer_idx][0]  # [heads, seq_len, seq_len]
    
    # Get number of heads
    num_heads = attention.shape[0]
    
    # Create a figure for all attention heads
    fig, axes = plt.subplots(2, (num_heads+1)//2, figsize=(15, 6))
    axes = axes.flatten()
    
    # Plot each attention head
    for head_idx in range(num_heads):
        attention_weights = attention[head_idx].numpy()
        sns.heatmap(attention_weights, annot=False, cmap="YlGnBu", 
                   ax=axes[head_idx], xticklabels=False, yticklabels=False)
        axes[head_idx].set_title(f"Head {head_idx+1}")
    
    # Remove any unused subplots
    for i in range(num_heads, len(axes)):
        fig.delaxes(axes[i])
    
    plt.tight_layout()
    plt.suptitle(f"Multi-Head Attention Patterns (Layer {layer_idx+1})", fontsize=16)
    plt.subplots_adjust(top=0.88)
    plt.show()
    
    # Now create a more detailed visualization of a few selected heads
    heads_to_show = min(4, num_heads)
    
    fig, axes = plt.subplots(heads_to_show, 1, figsize=(10, 12))
    if heads_to_show == 1:
        axes = [axes]
    
    for i in range(heads_to_show):
        head_idx = i
        attention_weights = attention[head_idx].numpy()
        
        # Create a DataFrame for better labeling
        df = pd.DataFrame(attention_weights, index=tokens, columns=tokens)
        
        # Plot the heatmap
        sns.heatmap(df, annot=True, cmap="YlGnBu", fmt=".2f", ax=axes[i])
        axes[i].set_title(f"Head {head_idx+1} Attention Pattern")
        axes[i].set_ylabel("Query Tokens")
        if i == heads_to_show - 1:
            axes[i].set_xlabel("Key Tokens")
        else:
            axes[i].set_xlabel("")
    
    plt.tight_layout()
    plt.suptitle(f"Detailed Multi-Head Attention (Layer {layer_idx+1})", fontsize=16)
    plt.subplots_adjust(top=0.95)
    plt.show()
    
    # Analyze patterns - let's examine what each head might be focusing on
    print("Potential attention patterns:")
    
    # Compute some metrics for each head
    for head_idx in range(num_heads):
        weights = attention[head_idx].numpy()
        
        # Calculate diagonal attention (self-attention)
        diagonal_attn = np.mean(np.diag(weights))
        
        # Calculate local attention (attention to neighboring tokens)
        local_attn = 0
        for i in range(len(tokens)):
            for j in range(max(0, i-1), min(len(tokens), i+2)):
                if i != j:  # Exclude the diagonal
                    local_attn += weights[i, j]
        local_attn /= (len(tokens) * 2 - 2)  # Normalize
        
        # Check for attention to the first token (often special tokens)
        first_token_attn = np.mean(weights[:, 0])
        
        print(f"Head {head_idx+1}:")
        print(f"  - Self-attention strength: {diagonal_attn:.3f}")
        print(f"  - Local attention strength: {local_attn:.3f}")
        print(f"  - First token attention: {first_token_attn:.3f}")
        
        # Analyze potential patterns based on metrics
        if diagonal_attn > 0.5:
            print("  - This head seems to focus on the token itself (identity/content)")
        elif local_attn > 0.3:
            print("  - This head seems to focus on local relationships (syntax/phrases)")
        elif first_token_attn > 0.3:
            print("  - This head pays significant attention to the first token (global context)")
        else:
            max_col = np.argmax(np.mean(weights, axis=0))
            if max_col > 0:
                print(f"  - This head focuses on token '{tokens[max_col]}'")
        
        print()
    
    return attention, tokens

In [None]:
# Visualize multi-head attention for a more complex sentence
attention_weights, tokens = visualize_multi_head_attention(
    "The quick brown fox jumps over the lazy dog while the cat watches from a distance.",
    layer_idx=2  # Using a middle layer
)


## 6. The Feed-Forward Network: Processing Token Representations

After attention, each token's representation goes through a **feed-forward network** (FFN). This is a simple two-layer neural network applied to each position separately.

Let's examine the feed-forward network in transformers:

In [None]:
# Exploring the feed-forward network
def examine_feed_forward_network(model_name="gpt2"):
    """
    Explore the structure and role of feed-forward networks in transformers.
    
    Args:
        model_name (str): Model to examine
    """
    # Load model configuration
    config = AutoConfig.from_pretrained(model_name)
    
    # Print feed-forward network dimensions
    print(f"Model: {model_name}")
    print(f"Hidden size (embedding dimension): {config.hidden_size}")
    
    # The intermediate size is typically the FFN's expanded dimension
    if hasattr(config, "intermediate_size"):
        ffn_dim = config.intermediate_size
    else:
        # For some models like GPT-2, it's typically 4x the hidden size
        ffn_dim = 4 * config.hidden_size
    
    print(f"Feed-forward intermediate dimension: {ffn_dim}")
    print(f"Expansion factor: {ffn_dim / config.hidden_size}x")
    
    # Visualize the feed-forward network structure
    plt.figure(figsize=(10, 6))
    
    # Define the layers
    layers = [
        {"name": "Input", "size": config.hidden_size},
        {"name": "Linear 1", "size": ffn_dim},
        {"name": "GELU", "size": ffn_dim},
        {"name": "Linear 2", "size": config.hidden_size},
        {"name": "Output", "size": config.hidden_size}
    ]
    
    # Draw the network
    max_size = max(ffn_dim, config.hidden_size)
    
    # Draw nodes
    for i, layer in enumerate(layers):
        size_factor = layer["size"] / max_size
        width = 2 + 8 * size_factor
        
        plt.plot([i, i], [0, width], 'b-', linewidth=3)
        plt.text(i, width/2, layer["name"], ha='center', va='center',
                rotation=90, fontsize=12, bbox=dict(facecolor='white', alpha=0.8))
        
        # Add size label
        plt.text(i, width + 0.5, f"{layer['size']}", ha='center')
    
    # Draw connections
    for i in range(len(layers)-1):
        from_size = layers[i]["size"] / max_size
        to_size = layers[i+1]["size"] / max_size
        
        from_width = 2 + 8 * from_size
        to_width = 2 + 8 * to_size
        
        # Draw multiple lines to show the connection width
        num_lines = 10
        for j in range(num_lines):
            from_y = j * from_width / (num_lines-1)
            to_y = j * to_width / (num_lines-1)
            plt.plot([i, i+1], [from_y, to_y], 'k-', alpha=0.1)
    
    # Customize the plot
    plt.title("Feed-Forward Network Structure in Transformer", fontsize=16)
    plt.xlabel("Layer")
    plt.ylabel("Width (proportional to dimension)")
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.show()
    
    # Explain the role of FFN
    print("\nRole of the Feed-Forward Network:")
    print("1. Processes each token's representation independently")
    print("2. Expands the representation to a higher dimension, allowing more expressivity")
    print("3. Applies non-linear transformation (GELU/ReLU) to capture complex patterns")
    print("4. Projects back to the original dimension for residual connections")
    print("5. Acts as a position-wise fully connected layer")
    print("\nA study by Google Research suggests feed-forward layers act as 'key-value memories'")
    print("that store knowledge gained during training, functioning like a large associative memory.")

In [None]:
# Examine the feed-forward network
examine_feed_forward_network()

## 7. Putting It All Together: The Complete Transformer Architecture

Now let's understand how all these components fit together in the transformer architecture. We'll look at:
1. **Encoder-only models** (like BERT)
2. **Decoder-only models** (like GPT)
3. **Encoder-decoder models** (like T5)

In [None]:
# Function to visualize the transformer architecture
def visualize_transformer_architecture():
    """
    Create visualizations of the three main transformer architecture types.
    """
    # Create a figure with three subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 8))
    
    # Colors for different components
    colors = {
        "embedding": "lightblue",
        "positional": "lightyellow",
        "attention": "lightgreen",
        "ffn": "lightpink",
        "norm": "lightgray",
        "output": "lightcoral"
    }
    
    # 1. Encoder-only (BERT-like) architecture
    encoder_components = [
        {"name": "Token Embedding", "color": colors["embedding"]},
        {"name": "Position Embedding", "color": colors["positional"]},
        {"name": "Layer Norm", "color": colors["norm"]},
        {"name": "Self-Attention", "color": colors["attention"]},
        {"name": "Layer Norm", "color": colors["norm"]},
        {"name": "Feed-Forward", "color": colors["ffn"]},
        {"name": "Layer Norm", "color": colors["norm"]},
        {"name": "Output", "color": colors["output"]}
    ]
    
    ax1.set_xlim(0, 1)
    ax1.set_ylim(0, len(encoder_components) + 1)
    
    # Draw blocks for encoder
    for i, comp in enumerate(encoder_components):
        y_pos = len(encoder_components) - i
        rect = plt.Rectangle((0.2, y_pos - 0.4), 0.6, 0.8, facecolor=comp["color"])
        ax1.add_patch(rect)
        ax1.text(0.5, y_pos, comp["name"], ha='center', va='center', fontsize=10)
        
        # Add arrows
        if i < len(encoder_components) - 1:
            ax1.arrow(0.5, y_pos - 0.5, 0, -0.5, head_width=0.05, head_length=0.1, fc='black', ec='black')
    
    # Draw residual connections
    ax1.arrow(0.7, len(encoder_components) - 3.5, 0, -2, head_width=0.05, head_length=0.1, 
             fc='blue', ec='blue', linestyle='--')
    ax1.arrow(0.7, len(encoder_components) - 5.5, 0, -1, head_width=0.05, head_length=0.1, 
             fc='blue', ec='blue', linestyle='--')
    
    ax1.set_title("Encoder-Only Architecture\n(BERT, RoBERTa)", fontsize=14)
    ax1.axis('off')
    
    # 2. Decoder-only (GPT-like) architecture
    decoder_components = [
        {"name": "Token Embedding", "color": colors["embedding"]},
        {"name": "Position Embedding", "color": colors["positional"]},
        {"name": "Layer Norm", "color": colors["norm"]},
        {"name": "Masked Self-Attention", "color": colors["attention"]},
        {"name": "Layer Norm", "color": colors["norm"]},
        {"name": "Feed-Forward", "color": colors["ffn"]},
        {"name": "Layer Norm", "color": colors["norm"]},
        {"name": "Output", "color": colors["output"]}
    ]
    
    ax2.set_xlim(0, 1)
    ax2.set_ylim(0, len(decoder_components) + 1)
    
    # Draw blocks for decoder
    for i, comp in enumerate(decoder_components):
        y_pos = len(decoder_components) - i
        rect = plt.Rectangle((0.2, y_pos - 0.4), 0.6, 0.8, facecolor=comp["color"])
        ax2.add_patch(rect)
        ax2.text(0.5, y_pos, comp["name"], ha='center', va='center', fontsize=10)
        
        # Add arrows
        if i < len(decoder_components) - 1:
            ax2.arrow(0.5, y_pos - 0.5, 0, -0.5, head_width=0.05, head_length=0.1, fc='black', ec='black')
    
    # Draw residual connections
    ax2.arrow(0.7, len(decoder_components) - 3.5, 0, -2, head_width=0.05, head_length=0.1, 
             fc='blue', ec='blue', linestyle='--')
    ax2.arrow(0.7, len(decoder_components) - 5.5, 0, -1, head_width=0.05, head_length=0.1, 
             fc='blue', ec='blue', linestyle='--')
    
    ax2.set_title("Decoder-Only Architecture\n(GPT, Llama)", fontsize=14)
    ax2.axis('off')
    
    # 3. Encoder-decoder (T5-like) architecture
    encoder_decoder_components = [
        {"side": "encoder", "name": "Token Embedding", "color": colors["embedding"]},
        {"side": "encoder", "name": "Position Embedding", "color": colors["positional"]},
        {"side": "encoder", "name": "Self-Attention", "color": colors["attention"]},
        {"side": "encoder", "name": "Feed-Forward", "color": colors["ffn"]},
        {"side": "encoder", "name": "Output", "color": colors["norm"]},
        {"side": "decoder", "name": "Token Embedding", "color": colors["embedding"]},
        {"side": "decoder", "name": "Position Embedding", "color": colors["positional"]},
        {"side": "decoder", "name": "Masked Self-Attention", "color": colors["attention"]},
        {"side": "decoder", "name": "Cross-Attention", "color": "lightsalmon"},
        {"side": "decoder", "name": "Feed-Forward", "color": colors["ffn"]},
        {"side": "decoder", "name": "Output", "color": colors["output"]}
    ]
    
    ax3.set_xlim(0, 1)
    ax3.set_ylim(0, len(encoder_decoder_components) + 1)
    
    # Track the last positions for each side
    last_encoder_pos = None
    last_decoder_pos = None
    
    # Draw blocks for encoder-decoder
    encoder_count = 0
    decoder_count = 0
    
    for i, comp in enumerate(encoder_decoder_components):
        y_pos = len(encoder_decoder_components) - i
        
        if comp["side"] == "encoder":
            encoder_count += 1
            x_start = 0.1
            rect = plt.Rectangle((x_start, y_pos - 0.4), 0.3, 0.8, facecolor=comp["color"])
            ax3.add_patch(rect)
            ax3.text(x_start + 0.15, y_pos, comp["name"], ha='center', va='center', fontsize=9)
            
            # Track position for cross-attention connection
            if comp["name"] == "Output":
                last_encoder_pos = (x_start + 0.15, y_pos)
            
            # Add arrows for encoder
            if encoder_count > 1:
                ax3.arrow(x_start + 0.15, y_pos + 0.5, 0, -0.5, head_width=0.05, head_length=0.1, 
                         fc='black', ec='black')
        else:
            decoder_count += 1
            x_start = 0.6
            rect = plt.Rectangle((x_start, y_pos - 0.4), 0.3, 0.8, facecolor=comp["color"])
            ax3.add_patch(rect)
            ax3.text(x_start + 0.15, y_pos, comp["name"], ha='center', va='center', fontsize=9)
            
            # Track position for cross-attention connection
            if comp["name"] == "Cross-Attention":
                last_decoder_pos = (x_start + 0.15, y_pos)
            
            # Add arrows for decoder
            if decoder_count > 1:
                ax3.arrow(x_start + 0.15, y_pos + 0.5, 0, -0.5, head_width=0.05, head_length=0.1, 
                         fc='black', ec='black')
    
    # Draw the cross-attention connection
    if last_encoder_pos and last_decoder_pos:
        ax3.arrow(last_encoder_pos[0], last_encoder_pos[1], 
                 last_decoder_pos[0] - last_encoder_pos[0], 
                 last_decoder_pos[1] - last_encoder_pos[1], 
                 head_width=0.05, head_length=0.1, fc='red', ec='red', linestyle='-.')
    
    ax3.set_title("Encoder-Decoder Architecture\n(T5, BART)", fontsize=14)
    ax3.axis('off')
    
    plt.tight_layout()
    plt.suptitle("Three Main Transformer Architecture Types", fontsize=16)
    plt.subplots_adjust(top=0.85)
    plt.show()
    
    # Create table comparing the three architectures
    architectures = ["Encoder-Only", "Decoder-Only", "Encoder-Decoder"]
    features = [
        "Primary Use Cases", 
        "Context Processing", 
        "Attention Type", 
        "Input/Output", 
        "Training Objective", 
        "Example Models"
    ]
    
    data = [
        ["Understanding tasks (classification, NER, etc.)", "Text generation", "Translation, summarization, QA"],
        ["Bidirectional (sees full context)", "Unidirectional (only previous tokens)", "Full context in encoder, unidirectional in decoder"],
        ["Full self-attention", "Masked self-attention", "Both + cross-attention"],
        ["Fixed-length inputs", "Variable length outputs", "Input → different output"],
        ["Masked language modeling", "Next token prediction", "Sequence-to-sequence"],
        ["BERT, RoBERTa, DeBERTa", "GPT, Llama, Falcon", "T5, BART, Pegasus"]
    ]
    
    # Create table
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.axis('tight')
    ax.axis('off')
    table = ax.table(cellText=data, rowLabels=features, colLabels=architectures,
                    loc='center', cellLoc='center', colWidths=[0.3, 0.3, 0.3])
    
    # Style the table
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 1.5)
    
    # Color the header row
    for j, cell in enumerate(table._cells[(0, j)] for j in range(3)):
        cell.set_facecolor('#4472C4')
        cell.set_text_props(color='white')
    
    # Color the row labels
    for i, cell in enumerate(table._cells[(i, -1)] for i in range(1, 7)):
        cell.set_facecolor('#8EA9DB')
    
    plt.title("Comparison of Transformer Architectures", fontsize=16)
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize the complete transformer architecture
visualize_transformer_architecture()

## 8. Exercises: Hands-On Transformer Exploration

Now that we've explored the transformer architecture, let's consolidate our learning with some hands-on exercises:

### Exercise 1: Compare Attention Patterns Across Different Models

Choose a sentence and compare how different models (e.g., GPT-2, BERT, T5) attend to it. What differences do you notice?

### Exercise 2: Visualize Attention for Specific Language Phenomena

Try inputs with specific language phenomena (e.g., coreference resolution, subject-verb agreement) and examine how attention heads capture these relationships.

### Exercise 3: Modify and Observe Self-Attention

Implement a custom self-attention layer with modifications (e.g., different scaling, adding a bias) and observe how it changes the attention patterns.

### Exercise 4: Examine Model Scaling Properties

Look at how model parameters scale with different sizes of models (e.g., GPT-2 small vs. medium). Plot the relationship between model size and parameter count.

Let's start with Exercise 1:


In [None]:
# Exercise 1: Compare attention patterns across different models
def compare_model_attention(text):
    """
    Compare attention patterns across different model types.
    
    Args:
        text (str): Input text to analyze
    """
    models = ["gpt2", "bert-base-uncased", "t5-small"]
    model_types = ["Decoder-only (GPT-2)", "Encoder-only (BERT)", "Encoder-decoder (T5)"]
    
    # Create a figure for comparison
    fig, axes = plt.subplots(len(models), 1, figsize=(12, 6*len(models)))
    
    for i, (model_name, model_type) in enumerate(zip(models, model_types)):
        print(f"Processing {model_type}...")
        
        # Initialize tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # For encoder models, use AutoModel
        # For decoder models, use AutoModelForCausalLM
        # This handles different model types correctly
        if "t5" in model_name:
            model = AutoModel.from_pretrained(model_name, output_attentions=True)
        elif "gpt" in model_name:
            model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True)
        else:
            model = AutoModel.from_pretrained(model_name, output_attentions=True)
        
        # Tokenize the input text
        inputs = tokenizer(text, return_tensors="pt")
        
        # Get tokens for labeling
        input_ids = inputs["input_ids"][0]
        tokens = [tokenizer.decode([token_id]) for token_id in input_ids]
        
        # Run through the model and get attention weights
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get attention weights from last layer, first head
        # (This will work for most models, but might need adjustment for some)
        attentions = outputs.attentions
        
        if attentions is not None:
            # Get the last layer's attention (usually most informative)
            last_layer_attention = attentions[-1]
            
            # Get attention from first head (for simplicity)
            attention_weights = last_layer_attention[0, 0].numpy()
            
            # Create a DataFrame for the heatmap
            df = pd.DataFrame(attention_weights, index=tokens, columns=tokens)
            
            # Plot the heatmap
            sns.heatmap(df, annot=False, cmap="YlGnBu", ax=axes[i])
            axes[i].set_title(f"{model_type} Attention Pattern (Last Layer, First Head)")
            axes[i].set_ylabel("Query Tokens")
            
            if i == len(models) - 1:
                axes[i].set_xlabel("Key Tokens")
            
            # Print model-specific notes
            print(f"  Number of layers: {len(attentions)}")
            print(f"  Number of attention heads: {last_layer_attention.shape[1]}")
            print(f"  Sequence length: {len(tokens)}")
            print(f"  Tokens: {tokens}")
            print()
        else:
            axes[i].text(0.5, 0.5, "Attention weights not available for this model", 
                        ha='center', va='center')
            axes[i].set_title(f"{model_type}")
    
    plt.tight_layout()
    plt.show()
    
    print("Key Observations:")
    print("1. GPT-2 (decoder) uses causal masking - each token only attends to previous tokens")
    print("2. BERT (encoder) has full bidirectional attention - each token can attend to all tokens")
    print("3. T5 (encoder-decoder) has different attention patterns in encoder and decoder")
    print("4. Different models tokenize text differently, affecting the attention patterns")
    print("5. Special tokens like [CLS], [SEP], <s>, </s> often receive or distribute a lot of attention")

In [None]:
# Run Exercise 1
compare_model_attention("The transformer architecture revolutionized natural language processing by enabling parallel computation and better handling of long-range dependencies.")


## 9. Additional Exercises

Continue with the remaining exercises on your own. Here are some starting points:

### Exercise 2: Visualize Attention for Specific Language Phenomena

In [None]:
# Try inputs with coreference resolution
explore_fn_coref, _, _ = visualize_self_attention(
    "John said he was tired. Mary thought he needed rest."
)

# Look at different layers and heads to see if any capture the coreference
explore_fn_coref(layer=5, head=2)

In [None]:
# Try inputs with subject-verb agreement
explore_fn_agreement, _, _ = visualize_self_attention(
    "The keys to the cabinet are on the table."
)

# Look at different layers and heads
explore_fn_agreement(layer=8, head=3)

### Exercise 3: Modify and Observe Self-Attention

Let's modify our self-attention implementation with different scaling approaches:

In [None]:
# Modify self-attention with different scaling approaches
def modified_self_attention(input_embeddings, d_k=64, scaling_factor=None, add_bias=False):
    """
    Implement modified self-attention with optional scaling and bias.
    
    Args:
        input_embeddings: Input token embeddings [batch_size, seq_len, embedding_dim]
        d_k: Dimensionality of query and key vectors
        scaling_factor: Custom scaling factor (None for no scaling)
        add_bias: Whether to add a bias term to attention scores
    
    Returns:
        context_vectors: Attention-weighted outputs
        attention_weights: Attention weight matrix
    """
    # For simplicity, we'll use random weight matrices
    batch_size, seq_len, d_model = input_embeddings.shape
    
    # Create random weight matrices for Q, K, V projections
    W_Q = torch.randn(d_model, d_k)
    W_K = torch.randn(d_model, d_k)
    W_V = torch.randn(d_model, d_model)
    
    # Create Query, Key, Value projections
    Q = torch.matmul(input_embeddings, W_Q)
    K = torch.matmul(input_embeddings, W_K)
    V = torch.matmul(input_embeddings, W_V)
    
    # Calculate attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Apply custom scaling if provided
    if scaling_factor is not None:
        scores = scores / scaling_factor
    
    # Add bias if requested
    if add_bias:
        # Create a bias that encourages local attention
        bias = torch.zeros_like(scores)
        for i in range(seq_len):
            for j in range(seq_len):
                bias[0, i, j] = -0.1 * abs(i - j)  # Penalize distant tokens
        scores = scores + bias
    
    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    context_vectors = torch.matmul(attention_weights, V)
    
    return context_vectors, attention_weights

# Function to compare different attention variants
def compare_attention_variants():
    # Create toy embeddings for demonstration
    seq_len = 5
    d_model = 64
    batch_size = 1
    
    # Same random embeddings for fair comparison
    torch.manual_seed(42)
    embeddings = torch.randn(batch_size, seq_len, d_model)
    
    # Test different variants
    variants = [
        ("Standard (scaled by √d_k)", lambda e: self_attention_from_scratch(e)),
        ("No scaling", lambda e: modified_self_attention(e, scaling_factor=None)),
        ("Aggressive scaling (÷10)", lambda e: modified_self_attention(e, scaling_factor=10)),
        ("With local bias", lambda e: modified_self_attention(e, add_bias=True))
    ]
    
    # Create a figure for comparison
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes = axes.flatten()
    
    # Run each variant and visualize
    for i, (name, func) in enumerate(variants):
        _, attention_weights = func(embeddings)
        
        # Visualize the attention weights
        attention_matrix = attention_weights[0].detach().numpy()
        
        # Create a heatmap
        sns.heatmap(attention_matrix, annot=True, cmap="YlGnBu", fmt=".2f", ax=axes[i])
        axes[i].set_title(f"Self-Attention Variant: {name}")
        axes[i].set_xlabel("Key Positions")
        axes[i].set_ylabel("Query Positions")
    
    plt.tight_layout()
    plt.show()
    
    # Analyze the differences
    print("Observations about attention variants:")
    print("1. Standard scaling (÷√d_k) balances the attention distribution")
    print("2. No scaling leads to more extreme softmax values (near 0 or 1)")
    print("3. Aggressive scaling makes attention more uniform")
    print("4. Local bias encourages attention to nearby tokens")

In [None]:
# Compare different attention variants
compare_attention_variants()

### Exercise 4: Examine Model Scaling Properties

Let's examine how model parameters scale with different model sizes:

In [None]:
# Examine model scaling properties
def examine_model_scaling():
    """
    Compare parameter counts across different model sizes and visualize scaling relationships.
    """
    # List of models with increasing sizes
    models = [
        "distilgpt2",           # ~82M parameters
        "gpt2",                 # ~124M parameters
        "gpt2-medium",          # ~355M parameters
        "gpt2-large",           # ~774M parameters
        "gpt2-xl"               # ~1.5B parameters
    ]
    
    # Initialize lists to store data
    hidden_sizes = []
    layer_counts = []
    head_counts = []
    param_counts = []
    
    print("Analyzing model scaling properties...")
    
    # Collect data for each model
    for model_name in models:
        try:
            print(f"Loading configuration for {model_name}...")
            config = AutoConfig.from_pretrained(model_name)
            
            # Get model dimensions
            hidden_size = config.hidden_size
            n_layers = config.num_hidden_layers
            n_heads = config.num_attention_heads
            
            # Calculate parameter count (approximate formula for GPT-2 style models)
            vocab_size = config.vocab_size
            embed_params = vocab_size * hidden_size
            pos_embed_params = config.max_position_embeddings * hidden_size
            
            # Per-layer parameters
            layer_params = (
                # Self-attention
                4 * hidden_size * hidden_size +  # Q, K, V, and output projections
                # Feed-forward
                4 * hidden_size * hidden_size * 4 +  # Expansion and projection (4x hidden size)
                # Layer norms
                4 * hidden_size  # 2 layer norms with gain and bias
            )
            
            total_params = embed_params + pos_embed_params + (layer_params * n_layers)
            
            # Store the data
            hidden_sizes.append(hidden_size)
            layer_counts.append(n_layers)
            head_counts.append(n_heads)
            param_counts.append(total_params)
            
            print(f"  Hidden size: {hidden_size}")
            print(f"  Layers: {n_layers}")
            print(f"  Attention heads: {n_heads}")
            print(f"  Parameter count (approx): {total_params:,}")
            print()
            
        except Exception as e:
            print(f"Error processing {model_name}: {e}")
    
    # Convert to more readable units (millions of parameters)
    param_counts_m = [p / 1_000_000 for p in param_counts]
    
    # Plot the relationships
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Hidden size vs. Parameters
    axes[0, 0].plot(hidden_sizes, param_counts_m, 'o-', linewidth=2, markersize=10)
    axes[0, 0].set_title("Model Size vs. Parameter Count")
    axes[0, 0].set_xlabel("Hidden Size")
    axes[0, 0].set_ylabel("Parameters (millions)")
    axes[0, 0].grid(True)
    
    # 2. Layer count vs. Parameters
    axes[0, 1].plot(layer_counts, param_counts_m, 'o-', linewidth=2, markersize=10)
    axes[0, 1].set_title("Layer Count vs. Parameter Count")
    axes[0, 1].set_xlabel("Number of Layers")
    axes[0, 1].set_ylabel("Parameters (millions)")
    axes[0, 1].grid(True)
    
    # 3. Head count vs. Parameters
    axes[1, 0].plot(head_counts, param_counts_m, 'o-', linewidth=2, markersize=10)
    axes[1, 0].set_title("Attention Head Count vs. Parameter Count")
    axes[1, 0].set_xlabel("Number of Attention Heads")
    axes[1, 0].set_ylabel("Parameters (millions)")
    axes[1, 0].grid(True)
    
    # 4. Scatter plot of hidden size vs layers with size indicating parameters
    scatter = axes[1, 1].scatter(hidden_sizes, layer_counts, s=[p/5_000_000 for p in param_counts], 
                              alpha=0.7, c=param_counts_m, cmap='viridis')
    axes[1, 1].set_title("Model Architecture Comparison")
    axes[1, 1].set_xlabel("Hidden Size")
    axes[1, 1].set_ylabel("Number of Layers")
    axes[1, 1].grid(True)
    
    # Add colorbar for parameter count
    cbar = fig.colorbar(scatter, ax=axes[1, 1])
    cbar.set_label('Parameters (millions)')
    
    # Add model names as annotations
    for i, model_name in enumerate(models):
        axes[1, 1].annotate(model_name, (hidden_sizes[i], layer_counts[i]),
                         xytext=(5, 5), textcoords='offset points')
    
    plt.tight_layout()
    plt.show()
    
    # Print scaling observations
    print("Observations about model scaling:")
    print("1. Parameter count scales quadratically with hidden size (d_model)")
    print("2. Parameter count scales linearly with the number of layers")
    print("3. Most parameters are in attention and feed-forward layers")
    print("4. Doubling the hidden size approximately quadruples the parameter count")
    print("5. Larger models can use fewer heads per dimension of hidden size")

In [None]:
examine_model_scaling()


## 10. Key Takeaways

After completing this notebook, you should understand:

1. **The core components of transformers**:
   - Self-attention mechanisms for capturing relationships between tokens
   - Positional encoding for sequence order awareness
   - Multi-head attention for capturing different types of relationships
   - Feed-forward networks for processing token representations

2. **Different transformer architectures**:
   - Encoder-only (BERT): Best for understanding tasks
   - Decoder-only (GPT): Best for generation tasks
   - Encoder-decoder (T5): Best for translation/transformation tasks

3. **How attention works**:
   - Computes relationships between all pairs of tokens
   - Uses query, key, and value projections
   - Different heads can focus on different linguistic patterns

4. **Why transformers revolutionized NLP**:
   - Parallelizable (no sequential processing like RNNs)
   - Capable of handling long-range dependencies
   - Highly scalable architecture

5. **Model scaling properties**:
   - Parameter count grows quadratically with model width
   - Deeper models capture more complex relationships

In the next notebook, we'll explore how transformers are trained, fine-tuned, and optimized for specific tasks.