# Train Transformer Model (Memory-Efficient for 8GB GPU)

This notebook trains a **memory-optimized** transformer model.

**Optimizations for 8GB GPU**:
- Smaller d_model (256 vs 512)
- Fewer attention heads (4 vs 8)
- Fewer encoder/decoder layers (2 vs 3)
- Smaller batch size (8 vs 32)

**Model**: CNN + Transformer encoder-decoder with positional encoding

**Estimated time**: 4-6 hours (40 epochs on GPU)

## Step 1: Setup and Imports

In [None]:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'  # Fix threading issue
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'  # Better memory allocation

import sys
from pathlib import Path

# Add parent directory to path
project_root = Path('..').absolute()
sys.path.insert(0, str(project_root))

import torch
torch.cuda.empty_cache()
import gc
gc.collect()

import json
from src.models import create_model
from src.dataset import create_dataloaders
from src.trainer import ModelTrainer
from src.utils import load_vocab, set_seed, get_device, count_parameters, make_json_serializable


print("✓ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Step 2: Set Random Seed and Check GPU

In [None]:
set_seed(42)
device = get_device()

# Check GPU memory
if torch.cuda.is_available():
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    free_memory = torch.cuda.mem_get_info()[0] / 1024**3
    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Total GPU memory: {total_memory:.2f} GB")
    print(f"Free GPU memory: {free_memory:.2f} GB")
else:
    print("\n⚠️ No GPU available, using CPU (will be slower)")

## Step 3: Load Vocabulary

In [None]:
print("Loading vocabulary...")
vocab = load_vocab('../vocab.json')
print(f"\n✓ Vocabulary size: {len(vocab)}")
print(f"  <pad>: {vocab['<pad>']}")
print(f"  <sos>: {vocab['<sos>']}")
print(f"  <eos>: {vocab['<eos>']}")
print(f"  <unk>: {vocab['<unk>']}")

## Step 4: Create Dataloaders with Small Batch Size

In [None]:
print("Creating dataloaders...")
print("Using batch_size=8 (transformers use most memory!)\n")

train_loader, val_loader, eval_dataset = create_dataloaders(
    train_captions='../data/train_captions.json',
    val_captions='../data/val_captions.json',
    eval_captions='../data/eval_captions.json',
    train_features_dir='../features/mel/',
    val_features_dir='../features/mel/',
    eval_features_dir='../features/mel_eval/',
    vocab=vocab,
    batch_size=8,      # Smallest batch size - transformers are memory hungry
    num_workers=2
)

print(f"\n✓ Dataloaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Eval samples: {len(eval_dataset)}")

## Step 5: Create Memory-Efficient Transformer Model

In [None]:
print("Creating MEMORY-EFFICIENT transformer model...\n")
print("Model configuration (optimized for 8GB GPU):")
print("  - d_model: 256 (reduced from 512)")
print("  - num_heads: 4 (reduced from 8)")
print("  - num_encoder_layers: 2 (reduced from 3)")
print("  - num_decoder_layers: 2 (reduced from 3)")
print("  - dim_feedforward: 512 (reduced from 2048)")
print("\nThis significantly reduces memory usage while maintaining transformer architecture\n")

model = create_model(
    'transformer', 
    vocab_size=len(vocab),
    d_model=256,              # Reduced from 512
    nhead=4,                  # Reduced from 8
    num_encoder_layers=2,     # Reduced from 3
    num_decoder_layers=2,     # Reduced from 3
    dim_feedforward=512       # Reduced from 2048
)

print("\nModel architecture:")
print(model)

print("\nParameter count:")
total_params, trainable_params = count_parameters(model)

# Estimate memory
model_memory = total_params * 4 / (1024**3)  # 4 bytes per param (float32)
print(f"\nEstimated model memory: {model_memory:.2f} GB")

# Move to GPU and check actual usage
print("\nMoving model to GPU...")
model = model.to(device)

if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    free = torch.cuda.mem_get_info()[0] / 1024**3
    
    print(f"\nGPU Memory Status After Loading Model:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Free: {free:.2f} GB")
    
    if free > 3:
        print(f"\n✓ Model fits comfortably in GPU memory! ({free:.2f} GB free)")
    elif free > 1:
        print(f"\n✓ Model fits in GPU memory! ({free:.2f} GB free)")
    else:
        print(f"\n⚠️ Low memory warning: only {free:.2f} GB free")

## Step 6: Create Trainer

In [None]:
print("Initializing trainer...")

trainer = ModelTrainer(
    model=model,
    vocab=vocab,
    device=device,
    model_name='transformer_small'  # Different name to distinguish from full-size model
)

print("✓ Trainer ready!")

## Step 7: Train Model

**This will take 4-6 hours on GPU**

Progress will be shown with:
- Training loss per batch (progress bar)
- Validation loss per epoch
- Sample generations every 5 epochs
- Learning rate changes
- Early stopping if no improvement

In [None]:
print("Starting training...\n")
print("="*80)
print("TRANSFORMER MODEL TRAINING (Memory-Efficient Version)")
print("="*80)
print("\nTraining smaller transformer model optimized for 8GB GPU")
print("Expected training time: 4-6 hours")
print("\nMonitor GPU memory during training:")
print("  watch -n 1 nvidia-smi")
print("\n" + "="*80 + "\n")

history = trainer.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    eval_dataset=eval_dataset,
    num_epochs=40,
    learning_rate=1e-4,
    weight_decay=1e-4,
    patience=5,
    label_smoothing=0.1,
    save_dir='../checkpoints'
)

print("\n✓ Training complete!")

## Step 8: Plot Training History

In [None]:
from src.utils import plot_training_history
import matplotlib.pyplot as plt

print("Plotting training history...")
plot_training_history(history)
plt.show()

# Save history
with open('../results/transformer_small_history.json', 'w') as f:
    json.dump(history, f, indent=2)
print("\n✓ History saved to results/transformer_small_history.json")

## Step 9: Final Evaluation

In [None]:
from src.evaluation import evaluate_model

print("Running final evaluation...\n")

results, captions, refs = evaluate_model(
    trainer.model,
    eval_dataset,
    vocab,
    device=device,
    num_samples=100
)

# Save results
serializable_results = make_json_serializable(results)
with open('../results/transformer_small_results.json', 'w') as f:
    json.dump(serializable_results, f, indent=2)


print("\n✓ Results saved to results/transformer_small_results.json")

## Step 10: Show Sample Predictions

In [None]:
from src.evaluation import get_sample_predictions, print_sample_predictions

print("Generating sample predictions...\n")

samples = get_sample_predictions(
    trainer.model,
    eval_dataset,
    vocab,
    device=device,
    num_samples=10
)

print_sample_predictions(samples, num_to_print=10)

## Summary

In [None]:
print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)

print(f"\nModel: Transformer (Memory-Efficient)")
print(f"  - d_model: 256")
print(f"  - num_heads: 4")
print(f"  - encoder_layers: 2")
print(f"  - decoder_layers: 2")
print(f"  - Parameters: {total_params:,}")

print(f"\nBest validation loss: {min(history['val_loss']):.4f}")

print(f"\nEvaluation metrics:")
print(f"  - Repetition rate: {results['avg_repetition_rate']:.4f}")
print(f"  - Vocabulary diversity: {results['vocabulary_diversity']:.4f}")
print(f"  - Mean caption length: {results['mean_caption_length']:.2f} words")

print(f"\nFiles saved:")
print(f"  ✓ ../checkpoints/best_transformer_small.pth")
print(f"  ✓ ../results/transformer_small_history.json")
print(f"  ✓ ../results/transformer_small_results.json")

print("\n" + "="*80)
print("Next: Compare all models with 05_evaluate_all.ipynb")
print("="*80)