# ðŸ”„ Vanilla Transformer: Complete Implementation & Training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/transformer_architectures/01_vanilla_transformer/demo.ipynb)

## "Attention Is All You Need" (Vaswani et al., 2017)

![Architecture](architecture.png)

### What You'll Learn
- Multi-Head Self-Attention from scratch
- Positional Encoding
- Encoder-Decoder architecture
- Training on machine translation task

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Part 1: Positional Encoding

Since transformers have no recurrence, we need to inject position information:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        # Compute div_term: 10000^(2i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Visualize positional encoding
pe = PositionalEncoding(128, 100, dropout=0)
pe_values = pe.pe[0, :50, :64].numpy()

plt.figure(figsize=(12, 4))
plt.imshow(pe_values.T, cmap='RdBu', aspect='auto')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Positional Encoding Visualization')
plt.colorbar()
plt.show()

## Part 2: Multi-Head Attention

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Multi-head allows attending to different positions:

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections and reshape to (batch, n_heads, seq_len, d_k)
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Attention scores: (batch, n_heads, seq_q, seq_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        context = torch.matmul(attn_weights, V)
        
        # Reshape and project
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(context)
        
        return output, attn_weights

# Test attention
mha = MultiHeadAttention(d_model=64, n_heads=4)
x = torch.randn(2, 10, 64)  # (batch, seq_len, d_model)
out, attn = mha(x, x, x)
print(f'Input: {x.shape}')
print(f'Output: {out.shape}')
print(f'Attention weights: {attn.shape}')

## Part 3: Feed-Forward Network & Encoder Layer

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual
        attn_out, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # FFN with residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Masked self-attention
        attn_out, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # Cross-attention
        attn_out, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_out))
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_out))
        
        return x

print('Encoder and Decoder layers defined!')

## Part 4: Complete Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=256, n_heads=8, 
                 n_encoder_layers=3, n_decoder_layers=3, d_ff=512, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Embeddings
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
        
        # Encoder
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout) 
            for _ in range(n_encoder_layers)
        ])
        
        # Decoder
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout) 
            for _ in range(n_decoder_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(d_model, tgt_vocab)
        
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, src, src_mask=None):
        x = self.src_embed(src) * math.sqrt(self.d_model)
        x = self.pos_enc(x)
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x
    
    def decode(self, tgt, enc_output, src_mask=None, tgt_mask=None):
        x = self.tgt_embed(tgt) * math.sqrt(self.d_model)
        x = self.pos_enc(x)
        for layer in self.decoder_layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return x
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc_output = self.encode(src, src_mask)
        dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
        return self.output_proj(dec_output)

def create_causal_mask(size):
    """Create causal mask for decoder."""
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return ~mask  # True where attention is allowed

# Create model
model = Transformer(src_vocab=1000, tgt_vocab=1000, d_model=128, 
                    n_heads=4, n_encoder_layers=2, n_decoder_layers=2)
print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')

## Part 5: Training on Tiny Translation Task

In [None]:
# Create tiny synthetic dataset (number to word translation)
# E.g., [1, 2, 3] -> ["one", "two", "three"]

class TinyTranslationDataset:
    def __init__(self, size=1000, max_len=8):
        self.size = size
        self.max_len = max_len
        
        # Simple vocab: numbers 0-99 -> words
        self.src_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
        self.tgt_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
        
        # Add numbers and word representations
        for i in range(100):
            self.src_vocab[str(i)] = i + 3
            self.tgt_vocab[f'w{i}'] = i + 3  # w0, w1, w2, ...
        
        self.src_vocab_size = len(self.src_vocab)
        self.tgt_vocab_size = len(self.tgt_vocab)
        
        # Generate data
        self.data = self._generate_data()
    
    def _generate_data(self):
        data = []
        for _ in range(self.size):
            length = np.random.randint(2, self.max_len)
            numbers = np.random.randint(0, 100, size=length)
            
            # Source: number tokens
            src = [self.src_vocab['<sos>']] + [self.src_vocab[str(n)] for n in numbers] + [self.src_vocab['<eos>']]
            # Target: word tokens (same sequence, different vocab)
            tgt = [self.tgt_vocab['<sos>']] + [self.tgt_vocab[f'w{n}'] for n in numbers] + [self.tgt_vocab['<eos>']]
            
            data.append((src, tgt))
        return data
    
    def get_batch(self, batch_size):
        indices = np.random.choice(len(self.data), batch_size)
        batch = [self.data[i] for i in indices]
        
        # Pad sequences
        max_src = max(len(x[0]) for x in batch)
        max_tgt = max(len(x[1]) for x in batch)
        
        src_batch = torch.zeros(batch_size, max_src, dtype=torch.long)
        tgt_batch = torch.zeros(batch_size, max_tgt, dtype=torch.long)
        
        for i, (src, tgt) in enumerate(batch):
            src_batch[i, :len(src)] = torch.tensor(src)
            tgt_batch[i, :len(tgt)] = torch.tensor(tgt)
        
        return src_batch, tgt_batch

dataset = TinyTranslationDataset(size=2000)
print(f'Dataset size: {len(dataset.data)}')
print(f'Source vocab: {dataset.src_vocab_size}')
print(f'Target vocab: {dataset.tgt_vocab_size}')

# Show example
src, tgt = dataset.get_batch(1)
print(f'\nExample batch:')
print(f'Source: {src[0].tolist()}')
print(f'Target: {tgt[0].tolist()}')

In [None]:
# Training loop
model = Transformer(
    src_vocab=dataset.src_vocab_size, 
    tgt_vocab=dataset.tgt_vocab_size,
    d_model=128, n_heads=4, 
    n_encoder_layers=2, n_decoder_layers=2,
    d_ff=256, dropout=0.1
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding

losses = []
n_epochs = 100
batch_size = 32

print('Training Transformer...')
for epoch in range(n_epochs):
    model.train()
    
    src, tgt = dataset.get_batch(batch_size)
    src, tgt = src.to(device), tgt.to(device)
    
    # Teacher forcing: input is tgt[:-1], target is tgt[1:]
    tgt_input = tgt[:, :-1]
    tgt_output = tgt[:, 1:]
    
    # Create causal mask
    tgt_mask = create_causal_mask(tgt_input.size(1)).to(device)
    
    # Forward
    optimizer.zero_grad()
    output = model(src, tgt_input, tgt_mask=tgt_mask)
    
    # Loss
    loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 20 == 0:
        print(f'Epoch {epoch+1}: Loss = {loss.item():.4f}')

# Plot training
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Transformer Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Inference (greedy decoding)
def greedy_decode(model, src, max_len=20):
    model.eval()
    with torch.no_grad():
        enc_output = model.encode(src)
        
        # Start with <sos>
        tgt = torch.tensor([[1]]).to(device)  # <sos> token
        
        for _ in range(max_len):
            tgt_mask = create_causal_mask(tgt.size(1)).to(device)
            output = model.decode(tgt, enc_output, tgt_mask=tgt_mask)
            logits = model.output_proj(output[:, -1, :])
            next_token = logits.argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token], dim=1)
            
            if next_token.item() == 2:  # <eos>
                break
        
        return tgt[0].tolist()

# Test inference
print('Testing Inference:')
print('='*50)

for _ in range(5):
    src, tgt = dataset.get_batch(1)
    src = src.to(device)
    
    predicted = greedy_decode(model, src)
    
    print(f'Source:    {src[0].tolist()}')
    print(f'Target:    {tgt[0].tolist()}')
    print(f'Predicted: {predicted}')
    print()

In [None]:
# Visualize attention patterns
def visualize_attention(model, src, tgt):
    model.eval()
    with torch.no_grad():
        # Get encoder output and attention
        x = model.src_embed(src) * math.sqrt(model.d_model)
        x = model.pos_enc(x)
        
        # Get attention from first encoder layer
        _, attn = model.encoder_layers[0].self_attn(x, x, x)
        
    return attn[0, 0].cpu().numpy()  # First batch, first head

src, tgt = dataset.get_batch(1)
src = src.to(device)
attn = visualize_attention(model, src, tgt)

plt.figure(figsize=(8, 6))
plt.imshow(attn, cmap='Blues')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Self-Attention Pattern (Encoder, Head 0)')
plt.colorbar()
plt.show()

print('\nðŸŽ¯ Key Takeaways:')
print('1. Transformer uses self-attention instead of recurrence')
print('2. Multi-head attention captures different relationships')
print('3. Encoder processes source, decoder generates target autoregressively')
print('4. Positional encoding provides position information')