# üìù GPT: Generative Pre-trained Transformer

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/transformer/transformer_architectures/03_gpt/demo.ipynb)

![Architecture](architecture.png)

### Key Features
- **Decoder-only**: Autoregressive generation
- **Causal mask**: Only see past tokens
- **Next token prediction**: P(next | previous)

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## GPT Architecture

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, max_len=512, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.d_model = d_model
        
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Causal mask (lower triangular)
        self.register_buffer('mask', torch.tril(torch.ones(max_len, max_len)).view(1, 1, max_len, max_len))
    
    def forward(self, x):
        B, T, C = x.shape
        
        # Q, K, V projection
        qkv = self.c_attn(x).reshape(B, T, 3, self.n_heads, self.d_k).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        # Causal attention
        attn = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = attn.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = (attn @ V).transpose(1, 2).reshape(B, T, C)
        return self.c_proj(out), attn

class GPTBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        attn_out, _ = self.attn(self.ln1(x))
        x = x + attn_out
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=4, max_len=512, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([GPTBlock(d_model, n_heads, dropout) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying
        self.token_embed.weight = self.head.weight
    
    def forward(self, idx):
        B, T = idx.shape
        pos = torch.arange(T, device=idx.device).unsqueeze(0)
        
        x = self.drop(self.token_embed(idx) + self.pos_embed(pos))
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -512:]  # Crop to max_len
            logits = self(idx_cond)[:, -1, :] / temperature
            
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = float('-inf')
            
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_idx], dim=1)
        return idx

model = GPT(vocab_size=5000, d_model=128, n_heads=4, n_layers=3).to(device)
print(f'GPT Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Visualize Causal Mask

In [None]:
# Visualize causal mask
seq_len = 10
causal_mask = torch.tril(torch.ones(seq_len, seq_len))

plt.figure(figsize=(8, 6))
plt.imshow(causal_mask, cmap='Greens')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Attention Mask\n(Green = Can Attend)')
plt.colorbar()
for i in range(seq_len):
    for j in range(seq_len):
        plt.text(j, i, int(causal_mask[i, j].item()), ha='center', va='center')
plt.show()

print('Each position can only attend to itself and previous positions.')

## Training: Character-Level Language Model

In [None]:
# Create tiny character dataset
text = '''
the quick brown fox jumps over the lazy dog
a fast red fox leaps across the sleeping hound
quick foxes jump high over slow dogs
the brown dog runs after the red fox
lazy dogs sleep while foxes jump around
''' * 100  # Repeat for more data

# Character-level tokenization
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}

def encode(s):
    return [char_to_idx[c] for c in s]

def decode(l):
    return ''.join([idx_to_char[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
print(f'Vocabulary size: {vocab_size}')
print(f'Data length: {len(data)}')
print(f'Characters: {"".join(chars)}')

In [None]:
# Data loader
def get_batch(data, batch_size, block_size):
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# Training
model = GPT(vocab_size=vocab_size, d_model=64, n_heads=4, n_layers=3, max_len=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

losses = []
n_steps = 500
batch_size = 32
block_size = 32

print('Training GPT (Character-Level LM)...')
for step in range(n_steps):
    x, y = get_batch(data, batch_size, block_size)
    
    optimizer.zero_grad()
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (step + 1) % 100 == 0:
        print(f'Step {step+1}: Loss = {loss.item():.4f}')

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('GPT Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Generate text!
model.eval()

prompts = ['the ', 'fox ', 'dog ', 'quick ']

print('Text Generation:')
print('='*50)
for prompt in prompts:
    context = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
    generated = model.generate(context, max_new_tokens=50, temperature=0.8)
    print(f'Prompt: "{prompt}"')
    print(f'Generated: "{decode(generated[0].tolist())}"')
    print()

In [None]:
# Visualize attention patterns
model.eval()
test_input = torch.tensor([encode('the fox')], dtype=torch.long).to(device)

with torch.no_grad():
    x = model.token_embed(test_input) + model.pos_embed(torch.arange(test_input.size(1), device=device))
    _, attn = model.blocks[0].attn(model.blocks[0].ln1(x))

plt.figure(figsize=(8, 6))
plt.imshow(attn[0, 0].cpu(), cmap='Blues')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('GPT Causal Attention (Head 0)')
tokens = list('the fox')
plt.xticks(range(len(tokens)), tokens)
plt.yticks(range(len(tokens)), tokens)
plt.colorbar()
plt.show()

print('\nüéØ Key Takeaways:')
print('1. GPT is decoder-only, autoregressive')
print('2. Causal mask: Each token only sees past tokens')
print('3. Next token prediction: Predict P(next | previous)')
print('4. Great for text generation, completion, chatbots')