# Attention Is All You Need - Interactive Implementation

**Paper**: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., 2017)

This notebook implements the Transformer architecture from scratch using a **progressive, experiment-first approach**. We'll build small, composable functions and gradually combine them into the full model.

## Key Innovation
The paper introduces the **Transformer** architecture, which relies entirely on **self-attention mechanisms** instead of recurrence or convolutions. This enables:
- Parallel processing of sequences
- Better long-range dependencies
- State-of-the-art performance on translation tasks

In [None]:
# Setup and imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 1. The Core: Scaled Dot-Product Attention

The fundamental building block is **scaled dot-product attention**. The idea is simple:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- **Q** (Query): What we're looking for
- **K** (Key): What each position offers
- **V** (Value): The actual content to retrieve
- **d_k**: Dimension of keys (used for scaling)

### Why scaling?
For large d_k, dot products grow large in magnitude, pushing softmax into regions with small gradients. Scaling by √d_k counteracts this.

In [None]:
# Let's experiment with the core attention mechanism
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Queries (batch_size, seq_len, d_k)
        K: Keys (batch_size, seq_len, d_k)
        V: Values (batch_size, seq_len, d_v)
        mask: Optional mask (batch_size, seq_len, seq_len)
    
    Returns:
        output: Attention output (batch_size, seq_len, d_v)
        attention_weights: Attention weights (batch_size, seq_len, seq_len)
    """
    d_k = Q.size(-1)
    
    # Compute attention scores: Q @ K^T / sqrt(d_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply mask if provided (set masked positions to large negative value)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

In [None]:
# Experiment: Let's see attention in action with a simple example
batch_size = 1
seq_len = 4
d_k = 8

# Create random Q, K, V tensors
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

# Compute attention
output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"Input shapes: Q{Q.shape}, K{K.shape}, V{V.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights (should sum to 1 per row):")
print(attn_weights[0].detach().numpy())
print(f"\nRow sums: {attn_weights[0].sum(dim=-1).detach().numpy()}")

In [None]:
# Visualize attention pattern
plt.figure(figsize=(6, 5))
plt.imshow(attn_weights[0].detach().numpy(), cmap='viridis')
plt.colorbar(label='Attention Weight')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Pattern')
plt.tight_layout()
plt.show()

## 2. Multi-Head Attention: Multiple Perspectives

Instead of performing a single attention operation, **multi-head attention** runs multiple attention operations in parallel (different "heads"), each with different learned projections:

1. Project Q, K, V into `h` different subspaces
2. Apply attention in each subspace independently
3. Concatenate results and project back

$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where $\text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)$

This allows the model to attend to information from different representation subspaces.

In [None]:
# Build multi-head attention from our composable attention function
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model: Dimension of model (must be divisible by num_heads)
            num_heads: Number of attention heads
        """
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)"""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Inverse of split_heads"""
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q, K, V: (batch_size, seq_len, d_model)
            mask: (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len)
        """
        # Linear projections
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # Split into multiple heads: (batch_size, num_heads, seq_len, d_k)
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Apply attention to each head
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads
        output = self.combine_heads(attn_output)
        
        # Final linear projection
        output = self.W_o(output)
        
        return output, attn_weights

In [None]:
# Experiment: Test multi-head attention
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)

# Create random input
x = torch.randn(batch_size, seq_len, d_model)

# Self-attention: Q, K, V all come from same input
output, attn_weights = mha(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Number of parameters: {sum(p.numel() for p in mha.parameters()):,}")

## 3. Positional Encoding: Giving the Model a Sense of Order

Since attention has no notion of sequence order (it's permutation-invariant), we need to inject **positional information**.

The paper uses sinusoidal functions:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

This allows the model to:
- Learn to attend by relative positions
- Extrapolate to longer sequences than seen during training

In [None]:
# Create positional encodings
def get_positional_encoding(seq_len, d_model):
    """
    Generate positional encoding matrix.
    
    Args:
        seq_len: Sequence length
        d_model: Model dimension
    
    Returns:
        pos_encoding: (seq_len, d_model)
    """
    position = torch.arange(seq_len).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    
    pos_encoding = torch.zeros(seq_len, d_model)
    pos_encoding[:, 0::2] = torch.sin(position * div_term)
    pos_encoding[:, 1::2] = torch.cos(position * div_term)
    
    return pos_encoding

In [None]:
# Experiment: Visualize positional encodings
seq_len = 100
d_model = 512

pos_enc = get_positional_encoding(seq_len, d_model)

plt.figure(figsize=(12, 4))
plt.imshow(pos_enc.T.numpy(), aspect='auto', cmap='RdBu')
plt.colorbar(label='Encoding Value')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Positional Encoding Pattern')
plt.tight_layout()
plt.show()

# Plot specific dimensions
plt.figure(figsize=(12, 4))
for i in [0, 1, 2, 3]:
    plt.plot(pos_enc[:, i].numpy(), label=f'Dimension {i}')
plt.xlabel('Position')
plt.ylabel('Encoding Value')
plt.title('Positional Encoding - First 4 Dimensions')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Wrap in a module for easy reuse
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len=5000):
        super().__init__()
        self.d_model = d_model
        
        # Precompute positional encodings
        pe = get_positional_encoding(max_seq_len, d_model)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_seq_len, d_model)
    
    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model)
        Returns:
            x with positional encoding added
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

## 4. Position-wise Feed-Forward Networks

After attention, each position passes through a **feed-forward network** independently:

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

This is a two-layer network with ReLU activation:
- First layer expands dimension (typically 4x)
- Second layer projects back to d_model

Same network applied to each position separately (position-wise).

In [None]:
# Simple feed-forward network
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension
            d_ff: Hidden dimension (typically 4 * d_model)
            dropout: Dropout rate
        """
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

In [None]:
# Experiment: Test feed-forward network
d_model = 512
d_ff = 2048
batch_size = 2
seq_len = 10

ffn = PositionWiseFeedForward(d_model, d_ff)
x = torch.randn(batch_size, seq_len, d_model)
output = ffn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in ffn.parameters()):,}")

## 5. Add & Norm: Residual Connections and Layer Normalization

Each sub-layer (attention, feed-forward) is wrapped with:
1. **Residual connection**: $\text{output} = \text{LayerNorm}(x + \text{Sublayer}(x))$
2. **Layer normalization**: Normalizes across feature dimension

This helps with:
- Training stability
- Gradient flow
- Faster convergence

In [None]:
# Helper function for residual connection with layer norm
class ResidualConnection(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer):
        """Apply residual connection to any sublayer with same size."""
        return x + self.dropout(sublayer(self.norm(x)))

## 6. Building the Encoder Layer

Now we compose our building blocks into an **Encoder Layer**:

1. Multi-head self-attention
2. Add & Norm
3. Feed-forward network
4. Add & Norm

The full encoder stacks N of these layers (paper uses N=6).

In [None]:
# Encoder layer: compose attention + FFN with residuals
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.residual_1 = ResidualConnection(d_model, dropout)
        self.residual_2 = ResidualConnection(d_model, dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual
        x = self.residual_1(x, lambda x: self.self_attn(x, x, x, mask)[0])
        # Feed-forward with residual
        x = self.residual_2(x, self.feed_forward)
        return x

In [None]:
# Stack multiple encoder layers
class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [None]:
# Experiment: Test encoder
num_layers = 6
d_model = 512
num_heads = 8
d_ff = 2048
batch_size = 2
seq_len = 10

encoder = Encoder(num_layers, d_model, num_heads, d_ff)
x = torch.randn(batch_size, seq_len, d_model)
output = encoder(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in encoder.parameters()):,}")

## 7. Building the Decoder Layer

The **Decoder Layer** is similar but has three sub-layers:

1. **Masked self-attention**: Prevents positions from attending to future positions
2. **Encoder-decoder attention**: Attends to encoder output (cross-attention)
3. **Feed-forward network**

Each with Add & Norm.

In [None]:
# Helper to create causal mask (for autoregressive decoding)
def create_causal_mask(seq_len):
    """
    Create mask to prevent attending to future positions.
    Returns lower triangular matrix.
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

In [None]:
# Visualize causal mask
mask = create_causal_mask(10)
plt.figure(figsize=(6, 5))
plt.imshow(mask[0, 0].numpy(), cmap='Blues')
plt.colorbar(label='Mask Value')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Mask (1=allowed, 0=blocked)')
plt.tight_layout()
plt.show()

In [None]:
# Decoder layer: masked self-attention + cross-attention + FFN
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.residual_1 = ResidualConnection(d_model, dropout)
        self.residual_2 = ResidualConnection(d_model, dropout)
        self.residual_3 = ResidualConnection(d_model, dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Masked self-attention
        x = self.residual_1(x, lambda x: self.self_attn(x, x, x, tgt_mask)[0])
        # Cross-attention (Q from decoder, K,V from encoder)
        x = self.residual_2(x, lambda x: self.cross_attn(x, encoder_output, encoder_output, src_mask)[0])
        # Feed-forward
        x = self.residual_3(x, self.feed_forward)
        return x

In [None]:
# Stack multiple decoder layers
class Decoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

In [None]:
# Experiment: Test decoder
num_layers = 6
d_model = 512
num_heads = 8
d_ff = 2048
batch_size = 2
src_seq_len = 10
tgt_seq_len = 8

decoder = Decoder(num_layers, d_model, num_heads, d_ff)

# Decoder input and encoder output
tgt = torch.randn(batch_size, tgt_seq_len, d_model)
encoder_output = torch.randn(batch_size, src_seq_len, d_model)

# Create causal mask for decoder
tgt_mask = create_causal_mask(tgt_seq_len)

output = decoder(tgt, encoder_output, tgt_mask=tgt_mask)

print(f"Decoder input shape: {tgt.shape}")
print(f"Encoder output shape: {encoder_output.shape}")
print(f"Decoder output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in decoder.parameters()):,}")

## 8. The Complete Transformer

Now we assemble everything into the full **Transformer** model:

1. **Input Embedding** + Positional Encoding
2. **Encoder** (stack of N=6 layers)
3. **Decoder** (stack of N=6 layers)
4. **Output Linear + Softmax**

This is the complete architecture from the paper!

In [None]:
# Complete Transformer model
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 num_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1):
        """
        Args:
            src_vocab_size: Source vocabulary size
            tgt_vocab_size: Target vocabulary size
            d_model: Model dimension (default 512)
            num_heads: Number of attention heads (default 8)
            num_layers: Number of encoder/decoder layers (default 6)
            d_ff: Feed-forward dimension (default 2048)
            max_seq_len: Maximum sequence length (default 5000)
            dropout: Dropout rate (default 0.1)
        """
        super().__init__()
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        
        # Encoder and Decoder
        self.encoder = Encoder(num_layers, d_model, num_heads, d_ff, dropout)
        self.decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Scale embeddings by sqrt(d_model) as in paper
        self.d_model = d_model
        
        # Initialize parameters
        self._init_parameters()
    
    def _init_parameters(self):
        """Initialize parameters with Xavier uniform."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, src, src_mask=None):
        """Encode source sequence."""
        # Embed and scale
        src = self.src_embedding(src) * math.sqrt(self.d_model)
        # Add positional encoding
        src = self.pos_encoding(src)
        src = self.dropout(src)
        # Encode
        return self.encoder(src, src_mask)
    
    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """Decode target sequence."""
        # Embed and scale
        tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        # Add positional encoding
        tgt = self.pos_encoding(tgt)
        tgt = self.dropout(tgt)
        # Decode
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Args:
            src: Source sequences (batch_size, src_seq_len)
            tgt: Target sequences (batch_size, tgt_seq_len)
            src_mask: Source mask (batch_size, 1, 1, src_seq_len)
            tgt_mask: Target mask (batch_size, 1, tgt_seq_len, tgt_seq_len)
        
        Returns:
            logits: (batch_size, tgt_seq_len, tgt_vocab_size)
        """
        # Encode
        encoder_output = self.encode(src, src_mask)
        # Decode
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        # Project to vocabulary
        logits = self.output_projection(decoder_output)
        return logits

In [None]:
# Experiment: Create a Transformer and test it
src_vocab_size = 10000
tgt_vocab_size = 10000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048

model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff)

# Create dummy data
batch_size = 2
src_seq_len = 10
tgt_seq_len = 8

src = torch.randint(0, src_vocab_size, (batch_size, src_seq_len))
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_seq_len))
tgt_mask = create_causal_mask(tgt_seq_len)

# Forward pass
logits = model(src, tgt, tgt_mask=tgt_mask)

print(f"Source shape: {src.shape}")
print(f"Target shape: {tgt.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 9. Training Example: Simple Copy Task

Let's train the Transformer on a simple task: **copy the input sequence**.

This demonstrates:
- How to prepare data
- Training loop
- Inference/generation

For real translation, you'd use parallel text corpora and proper tokenization.

In [None]:
# Create a simple copy dataset
class CopyDataset:
    def __init__(self, vocab_size, seq_len, num_samples):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Generate random sequence (excluding 0 for padding, 1 for SOS, 2 for EOS)
        seq = torch.randint(3, self.vocab_size, (self.seq_len,))
        # Target is same as source, but shifted right with SOS token
        src = seq
        tgt_input = torch.cat([torch.tensor([1]), seq[:-1]])  # Start with SOS token
        tgt_output = seq  # Target output is the original sequence
        return src, tgt_input, tgt_output

In [None]:
# Training setup
from torch.utils.data import DataLoader

# Hyperparameters for small model (faster training)
vocab_size = 100
seq_len = 10
d_model = 128
num_heads = 4
num_layers = 2
d_ff = 512
batch_size = 32
num_epochs = 20
learning_rate = 0.0001

# Create dataset and dataloader
train_dataset = CopyDataset(vocab_size, seq_len, num_samples=1000)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create model
model = Transformer(vocab_size, vocab_size, d_model, num_heads, num_layers, d_ff)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

print(f"Training on {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training loop
losses = []

model.train()
for epoch in range(num_epochs):
    epoch_loss = 0
    for src, tgt_input, tgt_output in train_loader:
        src = src.to(device)
        tgt_input = tgt_input.to(device)
        tgt_output = tgt_output.to(device)
        
        # Create target mask
        tgt_mask = create_causal_mask(tgt_input.size(1)).to(device)
        
        # Forward pass
        logits = model(src, tgt_input, tgt_mask=tgt_mask)
        
        # Compute loss
        loss = criterion(logits.reshape(-1, vocab_size), tgt_output.reshape(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

print("Training complete!")

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Inference: Generate output autoregressively
def generate(model, src, max_len=20, sos_token=1):
    """
    Generate output sequence autoregressively.
    
    Args:
        model: Trained Transformer model
        src: Source sequence (batch_size, src_len)
        max_len: Maximum generation length
        sos_token: Start-of-sequence token ID
    
    Returns:
        Generated sequence (batch_size, gen_len)
    """
    model.eval()
    batch_size = src.size(0)
    device = src.device
    
    # Encode source
    encoder_output = model.encode(src)
    
    # Start with SOS token
    generated = torch.ones(batch_size, 1, dtype=torch.long, device=device) * sos_token
    
    with torch.no_grad():
        for _ in range(max_len - 1):
            # Create target mask
            tgt_mask = create_causal_mask(generated.size(1)).to(device)
            
            # Decode
            decoder_output = model.decode(generated, encoder_output, tgt_mask=tgt_mask)
            
            # Project to vocabulary and get next token
            logits = model.output_projection(decoder_output[:, -1, :])
            next_token = logits.argmax(dim=-1, keepdim=True)
            
            # Append to generated sequence
            generated = torch.cat([generated, next_token], dim=1)
    
    return generated[:, 1:]  # Remove SOS token

In [None]:
# Test the model on a few examples
model.eval()
num_test = 5

print("Testing copy task:\n")
for i in range(num_test):
    # Generate test sequence
    src = torch.randint(3, vocab_size, (1, seq_len)).to(device)
    
    # Generate output
    output = generate(model, src, max_len=seq_len, sos_token=1)
    
    # Compare
    src_seq = src[0].cpu().numpy()
    out_seq = output[0].cpu().numpy()
    
    match = np.array_equal(src_seq, out_seq)
    
    print(f"Example {i+1}:")
    print(f"  Input:  {src_seq}")
    print(f"  Output: {out_seq}")
    print(f"  Match:  {'✓' if match else '✗'}")
    print()

## Summary and Key Takeaways

We've built the Transformer architecture from scratch using a **progressive, composable approach**:

### Core Components:
1. **Scaled Dot-Product Attention**: The fundamental mechanism
2. **Multi-Head Attention**: Multiple parallel attention operations
3. **Positional Encoding**: Injecting sequence order information
4. **Feed-Forward Networks**: Position-wise transformations
5. **Residual Connections & Layer Norm**: Training stability

### Architecture:
- **Encoder**: Processes input sequence
- **Decoder**: Generates output autoregressively
- **Full Transformer**: Combines encoder and decoder

### Key Innovations:
- **No recurrence**: Enables parallel processing
- **Self-attention**: Captures long-range dependencies
- **Multi-head**: Different representation subspaces
- **Positional encoding**: Order without recurrence

### What's Next?
This architecture forms the basis for:
- **BERT** (encoder-only): Bidirectional understanding
- **GPT** (decoder-only): Autoregressive generation
- **T5, BART** (encoder-decoder): Sequence-to-sequence tasks
- **Vision Transformers**: Applying to images

All modern LLMs build on these foundational concepts!

## Exercises to Explore Further

1. **Experiment with hyperparameters**: Try different numbers of heads, layers, dimensions
2. **Visualize attention patterns**: Plot attention weights from different layers
3. **Try different tasks**: Reverse sequence, sorting, arithmetic
4. **Add beam search**: Improve generation quality
5. **Implement label smoothing**: Better training regularization
6. **Add learning rate scheduling**: Warmup + decay as in paper
7. **Test on real data**: Apply to translation or text generation

The modular design makes it easy to swap components and experiment!