# Multi-Head Attention (MHA) Transformer Training
## FYP: Comparison of Transformer Attention Mechanisms

This notebook trains a full encoder-decoder transformer with Multi-Head Attention on WikiText dataset.

**Author:** Your Name  
**Dataset:** WikiText-2  
**Architecture:** Encoder-Decoder Transformer with MHA

## 1. Setup Environment

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Mount Google Drive (for saving checkpoints)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone your repository
# Replace with your actual repo URL
!git clone https://github.com/YOUR_USERNAME/LLM-Journey.git
%cd LLM-Journey

In [None]:
# Install dependencies
!pip install -q transformers datasets tensorboard matplotlib seaborn tqdm

## 2. Verify Data

In [None]:
# Check if data exists
import os
data_path = "data_processed/wikitext2_processed"
if os.path.exists(data_path):
    print(f"✓ Data found at {data_path}")
    !ls -lh {data_path}
else:
    print(f"✗ Data not found at {data_path}")
    print("Please upload your processed data or run preprocessing script")

## 3. Configuration

In [None]:
# Load and view configuration
import json

with open('mha/config.json', 'r') as f:
    config = json.load(f)

print("Current Configuration:")
print(json.dumps(config, indent=2))

In [None]:
# Optional: Modify configuration for Colab
# (e.g., reduce batch size if running out of memory)

# config['training_config']['batch_size'] = 16  # Reduce if OOM
# config['training_config']['num_epochs'] = 5   # Adjust as needed

# Update checkpoint and log directories to save to Google Drive
config['logging_config']['checkpoint_dir'] = '/content/drive/MyDrive/LLM-Journey/checkpoints/mha'
config['logging_config']['log_dir'] = '/content/drive/MyDrive/LLM-Journey/logs/mha'

# Create directories
os.makedirs(config['logging_config']['checkpoint_dir'], exist_ok=True)
os.makedirs(config['logging_config']['log_dir'], exist_ok=True)

# Save modified config
with open('mha/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print("✓ Configuration updated for Colab")

## 4. Load Model and Data

In [None]:
# Import training modules
import sys
sys.path.insert(0, 'mha')

from transformer import Transformer
from data_loader import WikiTextDataModule
from utils import set_seed, count_parameters

print("✓ Modules imported successfully")

In [None]:
# Set random seed for reproducibility
set_seed(config['random_seed'])

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = Transformer(
    vocab_size=config['model_config']['vocab_size'],
    d_model=config['model_config']['d_model'],
    num_heads=config['model_config']['num_heads'],
    num_encoder_layers=config['model_config']['num_encoder_layers'],
    num_decoder_layers=config['model_config']['num_decoder_layers'],
    d_ff=config['model_config']['d_ff'],
    max_seq_length=config['model_config']['max_seq_length'],
    dropout=config['model_config']['dropout'],
    pe_type=config['positional_encoding']['type']
).to(device)

# Count parameters
total_params, trainable_params = count_parameters(model)
print(f"\nModel Parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

In [None]:
# Initialize data module
data_config = {
    'train_path': config['data_config']['train_path'],
    'val_path': config['data_config']['val_path'],
    'batch_size': config['training_config']['batch_size'],
    'max_seq_length': config['model_config']['max_seq_length'],
    'tokenizer': config['data_config']['tokenizer']
}

data_module = WikiTextDataModule(data_config)
data_module.setup()

print(f"\n✓ Data loaded successfully")
print(f"  Train samples: {len(data_module.train_dataset)}")
print(f"  Val samples: {len(data_module.val_dataset)}")

## 5. Training

In [None]:
# Load TensorBoard (optional)
%load_ext tensorboard
%tensorboard --logdir {config['logging_config']['log_dir']}

In [None]:
# Train the model
from train import Trainer

trainer = Trainer(config, device=device)
trainer.train(num_epochs=config['training_config']['num_epochs'])

## 6. Evaluation and Visualization

In [None]:
# Load best checkpoint
best_checkpoint = os.path.join(config['logging_config']['checkpoint_dir'], 'best_model.pt')
if os.path.exists(best_checkpoint):
    checkpoint = torch.load(best_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Best model loaded")
    print(f"  Val Loss: {checkpoint['metrics']['val_loss']:.4f}")
    print(f"  Val Perplexity: {checkpoint['metrics']['val_ppl']:.2f}")
else:
    print("Best checkpoint not found")

In [None]:
# Visualize attention weights on a sample
from utils import AttentionVisualizer
from attention import create_combined_mask, create_padding_mask

model.eval()

# Get a sample batch
val_loader = data_module.val_dataloader()
batch = next(iter(val_loader))

input_ids = batch['input_ids'].to(device)
src = input_ids[:, :-1]
tgt = input_ids[:, :-1]

# Create masks
src_mask = create_padding_mask(src, pad_token_id=0)
tgt_mask = create_combined_mask(tgt, pad_token_id=0, causal=True)

# Forward pass to get attention
with torch.no_grad():
    # You would need to modify the forward pass to return attention weights
    # For now, this is a placeholder
    output = model(src, tgt, src_mask, tgt_mask)

print("\nAttention visualization would go here")
print("(Requires modification to return attention weights from forward pass)")

In [None]:
# Text generation (simple greedy decoding)
def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'):
    model.eval()
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        for _ in range(max_length):
            # Get predictions
            src = input_ids
            tgt = input_ids
            
            output = model(src, tgt)
            
            # Get next token (greedy)
            next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
            
            # Append to sequence
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            # Stop if EOS token
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # Decode
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return generated_text

# Test generation
prompt = "The transformer architecture"
generated = generate_text(model, data_module.tokenizer, prompt, max_length=30, device=device)
print(f"Prompt: {prompt}")
print(f"Generated: {generated}")

## 7. Save Results

In [None]:
# Results are automatically saved to Google Drive
print("Training complete!")
print(f"Checkpoints saved to: {config['logging_config']['checkpoint_dir']}")
print(f"Logs saved to: {config['logging_config']['log_dir']}")
print("\nYou can access these files in your Google Drive.")