# Train Improved Baseline Model (Memory-Efficient for 8GB GPU)

This notebook trains the **improved baseline model** with optimizations for limited GPU memory.

## Improved Baseline Features:
- **Deeper CNN encoder** (4 conv layers vs 3)
- **Bidirectional LSTM** for temporal modeling
- **Audio context concatenation** with word embeddings
- **Dropout** for regularization
- **Average pooling** instead of max pooling (preserves temporal info)

## Memory Optimizations for 8GB GPU:
- Smaller batch size (16 vs 32)
- Reduced embedding dimension (256)
- Reduced hidden dimension (512)
- 2 LSTM layers

**Estimated time**: 3-5 hours (30 epochs on GPU)

## Step 1: Setup and Imports

In [None]:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'  # Fix threading issue
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'  # Better memory allocation

import sys
from pathlib import Path

# Add parent directory to path
project_root = Path('..').absolute()
sys.path.insert(0, str(project_root))

import torch
torch.cuda.empty_cache()
import gc
gc.collect()

import json
from src.models import create_model
from src.dataset import create_dataloaders
from src.trainer import ModelTrainer
from src.utils import load_vocab, set_seed, get_device, count_parameters, make_json_serializable

print("‚úì Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Step 2: Set Random Seed and Check GPU

In [None]:
set_seed(42)
device = get_device()

# Check GPU memory
if torch.cuda.is_available():
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    free_memory = torch.cuda.mem_get_info()[0] / 1024**3
    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Total GPU memory: {total_memory:.2f} GB")
    print(f"Free GPU memory: {free_memory:.2f} GB")
else:
    print("\n‚ö†Ô∏è No GPU available, using CPU (will be slower)")

## Step 3: Load Vocabulary

In [None]:
print("Loading vocabulary...")
vocab = load_vocab('../vocab.json')
print(f"\n‚úì Vocabulary size: {len(vocab)}")
print(f"  <pad>: {vocab['<pad>']}")
print(f"  <sos>: {vocab['<sos>']}")
print(f"  <eos>: {vocab['<eos>']}")
print(f"  <unk>: {vocab['<unk>']}")

## Step 4: Create Dataloaders with Optimized Batch Size

In [None]:
print("Creating dataloaders...")
print("Using batch_size=16 (optimized for 8GB GPU)\n")

train_loader, val_loader, eval_dataset = create_dataloaders(
    train_captions='../data/train_captions.json',
    val_captions='../data/val_captions.json',
    eval_captions='../data/eval_captions.json',
    train_features_dir='../features/mel/',
    val_features_dir='../features/mel/',
    eval_features_dir='../features/mel_eval/',
    vocab=vocab,
    batch_size=16,      # Optimized for improved baseline
    num_workers=2
)

print(f"\n‚úì Dataloaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Eval samples: {len(eval_dataset)}")

## Step 5: Create Improved Baseline Model

The improved baseline adds:
- Deeper CNN (4 layers)
- Bidirectional LSTM for encoding
- Audio context concatenated with embeddings
- Dropout for regularization

In [None]:
print("Creating IMPROVED BASELINE model...\n")
print("Model configuration:")
print("  - Embedding dimension: 256")
print("  - Hidden dimension: 512")
print("  - LSTM layers: 2")
print("  - CNN layers: 4 (deeper than baseline)")
print("  - Bidirectional encoder LSTM")
print("  - Audio context concatenation")
print("  - Dropout: 0.3\n")

model = create_model(
    'improved_baseline', 
    vocab_size=len(vocab),
    embed_dim=256,
    hidden_dim=512,
    num_layers=2
)

print("\nModel architecture:")
print(model)

print("\nParameter count:")
total_params, trainable_params = count_parameters(model)

# Estimate memory
model_memory = total_params * 4 / (1024**3)  # 4 bytes per param (float32)
print(f"\nEstimated model memory: {model_memory:.2f} GB")

# Move to GPU and check actual usage
print("\nMoving model to GPU...")
model = model.to(device)

if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    free = torch.cuda.mem_get_info()[0] / 1024**3
    
    print(f"\nGPU Memory Status After Loading Model:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Free: {free:.2f} GB")
    
    if free > 3:
        print(f"\n‚úì Model fits comfortably in GPU memory! ({free:.2f} GB free)")
    elif free > 1:
        print(f"\n‚úì Model fits in GPU memory! ({free:.2f} GB free)")
    else:
        print(f"\n‚ö†Ô∏è Low memory warning: only {free:.2f} GB free")

## Step 6: Create Trainer

In [None]:
print("Initializing trainer...")

trainer = ModelTrainer(
    model=model,
    vocab=vocab,
    device=device,
    model_name='improved_baseline'
)

print("‚úì Trainer ready!")

## Step 7: Train Model

**This will take 3-5 hours on GPU**

Progress will be shown with:
- Training loss per batch (progress bar)
- Validation loss per epoch
- Sample generations every 5 epochs
- Learning rate changes
- Early stopping if no improvement

In [None]:
print("Starting training...\n")
print("="*80)
print("IMPROVED BASELINE MODEL TRAINING")
print("="*80)
print("\nTraining improved baseline with:")
print("  - Deeper CNN encoder")
print("  - Bidirectional LSTM")
print("  - Audio context concatenation")
print("  - Dropout regularization")
print("\nExpected training time: 3-5 hours")
print("\nMonitor GPU memory during training:")
print("  watch -n 1 nvidia-smi")
print("\n" + "="*80 + "\n")

history = trainer.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    eval_dataset=eval_dataset,
    num_epochs=35,           # More epochs for improved model
    learning_rate=1e-3,      # Standard learning rate
    weight_decay=1e-5,       # L2 regularization
    patience=5,              # Early stopping patience
    label_smoothing=0.0,     # No label smoothing for baseline
    save_dir='../checkpoints'
)

print("\n‚úì Training complete!")

## Step 8: Plot Training History

In [None]:
from src.utils import plot_training_history
import matplotlib.pyplot as plt

print("Plotting training history...")
plot_training_history(history)
plt.show()

# Save history
with open('../results/improved_baseline_history.json', 'w') as f:
    json.dump(history, f, indent=2)
print("\n‚úì History saved to results/improved_baseline_history.json")

## Step 9: Final Evaluation

In [None]:
from src.evaluation import evaluate_model

print("Running final evaluation...\n")

results, captions, refs = evaluate_model(
    trainer.model,
    eval_dataset,
    vocab,
    device=device,
    num_samples=100
)

# Save results
serializable_results = make_json_serializable(results)
with open('../results/improved_baseline_results.json', 'w') as f:
    json.dump(serializable_results, f, indent=2)

print("\n‚úì Results saved to results/improved_baseline_results.json")

## Step 10: Show Sample Predictions

In [None]:
from src.evaluation import get_sample_predictions, print_sample_predictions

print("Generating sample predictions...\n")

samples = get_sample_predictions(
    trainer.model,
    eval_dataset,
    vocab,
    device=device,
    num_samples=10
)

print_sample_predictions(samples, num_to_print=10)

## Step 11: Compare with Baseline (Optional)

If you trained the baseline model, let's compare the results.

In [None]:
import os.path as osp
import pandas as pd

# Try to load baseline results for comparison
baseline_path = '../results/baseline_results.json'

if osp.exists(baseline_path):
    print("Comparing with baseline model...\n")
    
    with open(baseline_path, 'r') as f:
        baseline_results = json.load(f)
    
    # Create comparison table
    comparison = pd.DataFrame({
        'Baseline': baseline_results,
        'Improved Baseline': serializable_results
    }).T
    
    print("="*80)
    print("MODEL COMPARISON")
    print("="*80)
    print(comparison.to_string())
    print("="*80)
    
    # Calculate improvements
    print("\nIMPROVEMENTS:")
    print("-"*80)
    
    # Repetition (lower is better)
    rep_improvement = ((baseline_results['avg_repetition_rate'] - serializable_results['avg_repetition_rate']) 
                      / baseline_results['avg_repetition_rate'] * 100)
    print(f"Repetition rate: {rep_improvement:+.1f}% (lower is better)")
    
    # Diversity (higher is better)
    div_improvement = ((serializable_results['vocabulary_diversity'] - baseline_results['vocabulary_diversity']) 
                      / baseline_results['vocabulary_diversity'] * 100)
    print(f"Vocabulary diversity: {div_improvement:+.1f}% (higher is better)")
    
    print("="*80)
else:
    print("Baseline results not found. Skipping comparison.")

## Summary

In [None]:
print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)

print(f"\nModel: Improved Baseline")
print(f"  - Embedding dim: 256")
print(f"  - Hidden dim: 512")
print(f"  - LSTM layers: 2")
print(f"  - Parameters: {total_params:,}")

print(f"\nArchitecture improvements:")
print(f"  ‚úì Deeper CNN encoder (4 layers)")
print(f"  ‚úì Bidirectional LSTM for temporal modeling")
print(f"  ‚úì Audio context concatenation")
print(f"  ‚úì Dropout regularization (0.3)")
print(f"  ‚úì Average pooling (preserves temporal info)")

print(f"\nBest validation loss: {min(history['val_loss']):.4f}")

print(f"\nEvaluation metrics:")
print(f"  - Repetition rate: {results['avg_repetition_rate']:.4f}")
print(f"  - Vocabulary diversity: {results['vocabulary_diversity']:.4f}")
print(f"  - Mean caption length: {results['mean_caption_length']:.2f} words")
print(f"  - Unique words used: {results['unique_words_used']}")

print(f"\nFiles saved:")
print(f"  ‚úì ../checkpoints/best_improved_baseline.pth")
print(f"  ‚úì ../results/improved_baseline_history.json")
print(f"  ‚úì ../results/improved_baseline_results.json")

print("\n" + "="*80)
print("Next steps:")
print("  - Train attention model: 03_train_attention_memory_efficient.ipynb")
print("  - Train transformer: 04_train_transformer_memory_efficient.ipynb")
print("  - Compare all models: 05_evaluate_all_memory_efficient.ipynb")
print("="*80)

## Expected Improvements Over Baseline

The improved baseline should show:

### Better Audio Understanding
- **Deeper CNN**: Extracts more complex audio features
- **Bidirectional LSTM**: Captures temporal context from both directions
- **Average pooling**: Preserves more temporal information than max pooling

### Better Caption Generation
- **Audio context concatenation**: Decoder has direct access to audio encoding at each step
- **Dropout**: Reduces overfitting, improves generalization
- **Larger capacity**: More parameters to learn complex patterns

### Typical Improvements
- **15-25%** reduction in repetition rate
- **10-20%** increase in vocabulary diversity
- **5-15%** increase in mean caption length
- **Better semantic accuracy** (validated with reference metrics)

### Trade-offs
- ‚ö° **Slightly slower**: ~1.5x training time vs baseline
- üíæ **More memory**: ~2-3x parameters vs baseline
- üéØ **Better quality**: Worth the computational cost