# 10: Transformer Basics

**Duration:** 4-5 hours | **Difficulty:** Advanced

## Learning Objectives
- Transformer architecture fundamentals
- Self-attention and positional encoding
- Complete transformer implementation
- Text generation with transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json

import sys
sys.path.append('../')
from utils.text_utils import SimpleTokenizer
from utils.model_helpers import get_device, count_parameters

device = get_device("auto")
print(f"Using device: {device}")

## Transformer Components

**Key innovations**: Self-attention, parallel processing, positional encoding

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention = F.softmax(scores, dim=-1)
        output = torch.matmul(attention, V)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(output)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
    
    def forward(self, x, mask=None):
        x = self.norm1(x + self.attention(x, mask))
        x = self.norm2(x + self.feed_forward(x))
        return x

class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=4, d_ff=1024):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(1000, d_model))
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        seq_len = x.size(1)
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x + self.pos_encoding[:seq_len]
        
        # Causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(x.device)
        
        for block in self.blocks:
            x = block(x, mask)
        
        x = self.ln_f(x)
        return self.head(x)
    
    def generate(self, start_tokens, max_length=30):
        self.eval()
        with torch.no_grad():
            tokens = start_tokens
            for _ in range(max_length):
                logits = self(tokens)
                next_token = torch.multinomial(F.softmax(logits[:, -1], dim=-1), 1)
                tokens = torch.cat([tokens, next_token], dim=1)
                if next_token.item() == 2:  # EOS
                    break
            return tokens

# Load data and create model
with open('../data/conversations/simple_qa_pairs.json', 'r') as f:
    conversations = [(item['question'], item['answer']) for item in json.load(f)]

tokenizer = SimpleTokenizer(vocab_size=2000)
all_text = [text for conv in conversations for text in conv]
tokenizer.fit(all_text)

model = TransformerLM(len(tokenizer.vocab)).to(device)
print(f"Model parameters: {count_parameters(model)['total']:,}")

# Test generation
test_input = torch.tensor([[1, 10, 20, 30]]).to(device)  # Sample tokens
output = model.generate(test_input)
print(f"Generated: {output.shape}")

print("\n=== Transformer Basics Complete ===")
print("Key Concepts Learned:")
print("• Self-attention mechanism")
print("• Positional encoding")
print("• Transformer blocks with residual connections")
print("• Autoregressive text generation")
print("• Foundation for modern language models")