# Multi-Head Attention (MHA) Transformer Training - Simplified
## Train Transformer on WikiText-2

**Simple, self-contained notebook - no complex modules needed!**

This notebook:
- Loads **pre-processed WikiText-2** data (from Datasets.ipynb)
- Trains encoder-decoder transformer with Multi-Head Attention
- Optimized for Google Colab

**Prerequisites:**
1. Run `Datasets.ipynb` first to create processed data
2. Download `data_processed.zip` from that notebook
3. Upload it to Colab and extract: `!unzip data_processed.zip`

---

## 1Ô∏è‚É£ Check GPU

In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected - training will be slow!")

print(f"\nUsing device: {device}")

## 2Ô∏è‚É£ Mount Google Drive (to save checkpoints)

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

# Create checkpoint directory
CHECKPOINT_DIR = '/content/drive/MyDrive/mha_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"‚úì Checkpoints will be saved to: {CHECKPOINT_DIR}")

## 3Ô∏è‚É£ Clone Repository

In [None]:
# Clone your repository
!git clone https://github.com/YOUR_USERNAME/LLM-Journey.git
%cd LLM-Journey

print("‚úì Repository cloned!")

## 4Ô∏è‚É£ Install Dependencies

In [None]:
# Install dependencies AND the mha package properly!
!pip install -q datasets transformers tqdm

# Install YOUR mha package in editable mode (industry standard!)
!pip install -q -e .

print("‚úì Packages installed!")
print("‚úì MHA package installed in editable mode (proper way!)")

## 5Ô∏è‚É£ Import Everything

**Note:** We're using **proper Python package imports** (industry standard)!

No more `sys.path.insert(0, 'mha')` hacks. The package is properly installed with `pip install -e .` in the previous cell.

This means:
- ‚úÖ Imports work from any directory
- ‚úÖ Works reliably in Colab
- ‚úÖ IDE autocomplete works
- ‚úÖ Same as PyTorch, Transformers, etc.
- ‚úÖ Professional & reproducible

In [None]:
# Standard imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
from datasets import load_from_disk
from tqdm import tqdm
import math

# Import from YOUR package (properly installed!)
from mha import Transformer
from mha import create_combined_mask, create_padding_mask

print("‚úì All imports successful!")
print("‚úì Using properly installed mha package (no sys.path hacks!)")

## 6Ô∏è‚É£ Configuration (Simple Dictionary!)

In [None]:
# Simple configuration - adjust as needed
config = {
    # Model architecture
    'vocab_size': 50257,         # GPT-2 tokenizer vocabulary
    'd_model': 512,              # Model dimension
    'num_heads': 8,              # Number of attention heads
    'num_encoder_layers': 6,     # Encoder depth
    'num_decoder_layers': 6,     # Decoder depth  
    'd_ff': 2048,                # Feed-forward dimension
    'max_seq_length': 512,       # Max sequence length (matches pre-processed data!)
    'dropout': 0.1,              # Dropout probability
    
    # Training
    'batch_size': 8,             # Batch size (small for Colab memory)
    'num_epochs': 3,             # Number of epochs
    'learning_rate': 0.0001,     # Peak learning rate
    'warmup_steps': 2000,        # LR warmup steps
    'gradient_clip': 1.0,        # Gradient clipping
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 7Ô∏è‚É£ Load Pre-Processed Data (From Datasets.ipynb)

**Note:** This loads the data you already prepared in `Datasets.ipynb`!

Before running this cell:
1. Make sure you ran `Datasets.ipynb` and downloaded `data_processed.zip`
2. Upload `data_processed.zip` to Colab
3. Extract it: `!unzip data_processed.zip`

This will load the **already tokenized** WikiText-2 dataset (512 tokens, GPT-2 tokenizer).

In [None]:
from datasets import load_from_disk

# Load pre-processed datasets (created from Datasets.ipynb)
print("Loading pre-processed WikiText-2 dataset...")
print("(Make sure you've uploaded data_processed.zip and extracted it!)\n")

# Path to your pre-processed data
DATA_PATH = './data/wikitext2_processed'

try:
    # Load from disk
    dataset = load_from_disk(DATA_PATH)
    
    train_dataset = dataset['train']
    val_dataset = dataset['validation']
    
    # Initialize tokenizer (still needed for pad_token_id)
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    
    # Set PyTorch format
    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    
    print(f"‚úì Dataset loaded successfully!")
    print(f"  Tokenizer vocab size: {len(tokenizer)}")
    print(f"  Train samples: {len(train_dataset):,}")
    print(f"  Val samples: {len(val_dataset):,}")
    print(f"  Sequence length: {len(train_dataset[0]['input_ids'])} tokens")
    
except FileNotFoundError:
    print("‚ùå Error: Pre-processed data not found!")
    print(f"\nExpected path: {DATA_PATH}")
    print("\nPlease:")
    print("  1. Run Datasets.ipynb to create the processed data")
    print("  2. Download data_processed.zip from Datasets.ipynb")
    print("  3. Upload and extract it to Colab:")
    print("     !unzip data_processed.zip")
    raise

## 8Ô∏è‚É£ Create DataLoaders (Simple!)

In [None]:
# Simple collate function (data is already in PyTorch format!)
def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': input_ids.clone()  # For language modeling
    }

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_fn
)

print(f"‚úì DataLoaders created!")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## 9Ô∏è‚É£ Initialize Model

In [None]:
# Create transformer model
model = Transformer(
    vocab_size=config['vocab_size'],
    d_model=config['d_model'],
    num_heads=config['num_heads'],
    num_encoder_layers=config['num_encoder_layers'],
    num_decoder_layers=config['num_decoder_layers'],
    d_ff=config['d_ff'],
    max_seq_length=config['max_seq_length'],
    dropout=config['dropout'],
    pe_type='sinusoidal'  # Sinusoidal positional encoding
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì Model created!")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1e6:.1f} MB")

## üîü Setup Optimizer, Scheduler & Loss

In [None]:
# Adam optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=config['learning_rate'],
    betas=(0.9, 0.98),
    eps=1e-9
)

# Learning rate scheduler with warmup (from original paper)
def lr_lambda(step):
    if step == 0:
        return 0
    d_model = config['d_model']
    warmup = config['warmup_steps']
    return (d_model ** -0.5) * min(step ** -0.5, step * warmup ** -1.5)

scheduler = LambdaLR(optimizer, lr_lambda)

# Loss function (ignore padding tokens)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

print("‚úì Optimizer, scheduler, and loss function ready!")

## 1Ô∏è‚É£1Ô∏è‚É£ Training Function

In [None]:
def train_epoch(model, train_loader, optimizer, scheduler, criterion, device, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_tokens = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch_idx, batch in enumerate(pbar):
        # Get data
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        # Prepare src and tgt (shift by 1 for next-token prediction)
        src = input_ids[:, :-1]          # All tokens except last
        tgt_input = input_ids[:, :-1]    # Same (decoder input)
        tgt_output = labels[:, 1:]       # All tokens except first (target)
        
        # Create masks
        src_mask = create_padding_mask(src, pad_token_id=tokenizer.pad_token_id)
        tgt_mask = create_combined_mask(tgt_input, pad_token_id=tokenizer.pad_token_id, causal=True)
        
        # Forward pass
        output = model(src, tgt_input, src_mask, tgt_mask)
        
        # Compute loss
        loss = criterion(
            output.reshape(-1, config['vocab_size']),
            tgt_output.reshape(-1)
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip'])
        optimizer.step()
        scheduler.step()
        
        # Track metrics
        num_tokens = (tgt_output != tokenizer.pad_token_id).sum().item()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens
        
        # Update progress bar
        current_lr = scheduler.get_last_lr()[0]
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'lr': f"{current_lr:.2e}"
        })
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity

## 1Ô∏è‚É£2Ô∏è‚É£ Validation Function

In [None]:
@torch.no_grad()
def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    for batch in tqdm(val_loader, desc="Validation"):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        src = input_ids[:, :-1]
        tgt_input = input_ids[:, :-1]
        tgt_output = labels[:, 1:]
        
        src_mask = create_padding_mask(src, pad_token_id=tokenizer.pad_token_id)
        tgt_mask = create_combined_mask(tgt_input, pad_token_id=tokenizer.pad_token_id, causal=True)
        
        output = model(src, tgt_input, src_mask, tgt_mask)
        
        loss = criterion(
            output.reshape(-1, config['vocab_size']),
            tgt_output.reshape(-1)
        )
        
        num_tokens = (tgt_output != tokenizer.pad_token_id).sum().item()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity

## 1Ô∏è‚É£3Ô∏è‚É£ Main Training Loop üöÄ

In [None]:
# Training loop
best_val_loss = float('inf')
history = {'train_loss': [], 'train_ppl': [], 'val_loss': [], 'val_ppl': []}

print("\n" + "="*60)
print("Starting Training!")
print("="*60 + "\n")

for epoch in range(1, config['num_epochs'] + 1):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch}/{config['num_epochs']}")
    print(f"{'='*60}")
    
    # Train
    train_loss, train_ppl = train_epoch(
        model, train_loader, optimizer, scheduler, criterion, device, epoch
    )
    print(f"\nüìä Train Loss: {train_loss:.4f} | Perplexity: {train_ppl:.2f}")
    
    # Validate
    val_loss, val_ppl = validate(model, val_loader, criterion, device)
    print(f"üìä Val Loss: {val_loss:.4f} | Perplexity: {val_ppl:.2f}")
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_ppl'].append(train_ppl)
    history['val_loss'].append(val_loss)
    history['val_ppl'].append(val_ppl)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint_path = f"{CHECKPOINT_DIR}/best_model_epoch{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_ppl': val_ppl,
            'config': config,
        }, checkpoint_path)
        print(f"‚úÖ Best model saved! (Val Loss: {val_loss:.4f})")
        print(f"   Saved to: {checkpoint_path}")

print("\n" + "="*60)
print("‚úÖ Training Complete!")
print("="*60)
print(f"\nBest Validation Loss: {best_val_loss:.4f}")
print(f"Final Train PPL: {history['train_ppl'][-1]:.2f}")
print(f"Final Val PPL: {history['val_ppl'][-1]:.2f}")

## 1Ô∏è‚É£4Ô∏è‚É£ Text Generation Function

In [None]:
def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'):
    """Generate text using greedy decoding"""
    model.eval()
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        for _ in range(max_length):
            src = input_ids
            tgt = input_ids
            
            # Create masks
            src_mask = create_padding_mask(src, pad_token_id=tokenizer.pad_token_id)
            tgt_mask = create_combined_mask(tgt, pad_token_id=tokenizer.pad_token_id, causal=True)
            
            # Forward pass
            output = model(src, tgt, src_mask, tgt_mask)
            
            # Get next token (greedy)
            next_token_logits = output[0, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0).unsqueeze(0)
            
            # Append to sequence
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            # Stop if EOS
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # Decode
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return generated_text

# Test generation
print("Testing text generation...\n")

prompts = [
    "The transformer architecture",
    "In the field of artificial intelligence",
    "Machine learning is"
]

for prompt in prompts:
    generated = generate_text(model, tokenizer, prompt, max_length=30, device=device)
    print(f"Prompt: {prompt}")
    print(f"Generated: {generated}")
    print("-" * 60)
    print()

## 1Ô∏è‚É£5Ô∏è‚É£ Load Best Model & Evaluate

In [None]:
# Load best checkpoint
import glob

checkpoint_files = glob.glob(f"{CHECKPOINT_DIR}/best_model_epoch*.pt")

if checkpoint_files:
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
    print(f"Loading best model: {latest_checkpoint}")
    
    checkpoint = torch.load(latest_checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"\n‚úÖ Best model loaded!")
    print(f"  Epoch: {checkpoint['epoch']}")
    print(f"  Train Loss: {checkpoint['train_loss']:.4f}")
    print(f"  Val Loss: {checkpoint['val_loss']:.4f}")
    print(f"  Val Perplexity: {checkpoint['val_ppl']:.2f}")
    
    # Test generation with best model
    print("\n" + "="*60)
    print("Testing generation with best model:")
    print("="*60 + "\n")
    
    prompt = "The attention mechanism allows"
    generated = generate_text(model, tokenizer, prompt, max_length=40, device=device)
    print(f"Prompt: {prompt}")
    print(f"Generated: {generated}")
    
else:
    print("‚ùå No checkpoint found!")

print("\n" + "="*60)
print("‚úÖ All Done!")
print("="*60)
print(f"\nCheckpoints saved at: {CHECKPOINT_DIR}")
print("You can find them in your Google Drive!")