# Evaluate and Compare All Models (Memory-Efficient Versions)

This notebook loads all **memory-efficient** trained models and compares their performance.

**Models**: Baseline (small), Attention (small), Transformer (small)

**Comparison metrics**:
- Validation loss
- Repetition rate
- Vocabulary diversity
- Caption length
- Sample predictions

## Step 1: Setup and Imports

In [None]:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
from pathlib import Path

project_root = Path('..').absolute()
sys.path.insert(0, str(project_root))

import torch
torch.cuda.empty_cache()
import gc
gc.collect()

import json
import matplotlib.pyplot as plt
import pandas as pd
from src.models import create_model
from src.dataset import ClothoEvalDataset
from src.evaluation import compare_models, get_sample_predictions, print_sample_predictions
from src.utils import load_vocab, get_device, plot_evaluation_metrics

print("✓ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Step 2: Setup Device and Load Vocabulary

In [None]:
device = get_device()

# Check GPU memory
if torch.cuda.is_available():
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    free_memory = torch.cuda.mem_get_info()[0] / 1024**3
    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Total GPU memory: {total_memory:.2f} GB")
    print(f"Free GPU memory: {free_memory:.2f} GB")

vocab = load_vocab('../vocab.json')
print(f"\n✓ Vocabulary size: {len(vocab)}")

## Step 3: Load Evaluation Dataset

In [None]:
print("Loading evaluation dataset...")
eval_dataset = ClothoEvalDataset(
    captions_file='../data/eval_captions.json',
    features_dir='../features/mel_eval/',
    vocab=vocab
)
print(f"✓ Loaded {len(eval_dataset)} samples")

## Step 4: Load All Memory-Efficient Trained Models

In [None]:
print("Loading memory-efficient trained models...\n")
models_dict = {}

# Model configurations (memory-efficient versions)
model_configs = {
    'baseline_small': {
        'type': 'baseline',
        'params': {
            'vocab_size': len(vocab),
            'embed_dim': 128,
            'hidden_dim': 256,
            'num_layers': 1
        }
    },
    'attention_small': {
        'type': 'attention',
        'params': {
            'vocab_size': len(vocab),
            'embed_dim': 128,
            'hidden_dim': 256,
            'num_layers': 1
        }
    },
    'transformer_small': {
        'type': 'transformer',
        'params': {
            'vocab_size': len(vocab),
            'd_model': 256,
            'nhead': 4,
            'num_encoder_layers': 2,
            'num_decoder_layers': 2,
            'dim_feedforward': 512
        }
    }
}

for model_name, config in model_configs.items():
    checkpoint_path = Path(f'../checkpoints/best_{model_name}.pth')
    
    if checkpoint_path.exists():
        print(f"Loading {model_name}...")
        
        # Create model with specific configuration
        model = create_model(config['type'], **config['params'])
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        
        models_dict[model_name] = model
        print(f"  ✓ {model_name} loaded")
    else:
        print(f"  ⚠ Checkpoint not found: {checkpoint_path}")

print(f"\n✓ Loaded {len(models_dict)} memory-efficient models")

# Check GPU memory after loading all models
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1024**3
    free = torch.cuda.mem_get_info()[0] / 1024**3
    print(f"\nGPU Memory Status:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Free: {free:.2f} GB")

## Step 5: Compare All Models

This will evaluate each model on the evaluation set and compare metrics.

In [None]:
print("="*80)
print("COMPARING ALL MEMORY-EFFICIENT MODELS")
print("="*80)

comparison = compare_models(
    models_dict=models_dict,
    eval_dataset=eval_dataset,
    vocab=vocab,
    device=device,
    num_samples=100
)

## Step 6: Save Comparison Results

In [None]:
# Save comparison results
results_to_save = {}
for model_name, data in comparison.items():
    results_to_save[model_name] = data['metrics']

with open('../results/comparison_results_small.json', 'w') as f:
    json.dump(results_to_save, f, indent=2)

print("✓ Comparison results saved to results/comparison_results_small.json")

## Step 7: Plot Comparison

In [None]:
print("Generating comparison plots...")
plot_evaluation_metrics(comparison)
plt.savefig('../results/comparison_plot_small.png', dpi=300, bbox_inches='tight')
plt.show()
print("✓ Plot saved to results/comparison_plot_small.png")

## Step 8: Sample Predictions from Each Model

In [None]:
print("\n" + "="*80)
print("SAMPLE PREDICTIONS")
print("="*80)

for model_name, model in models_dict.items():
    print(f"\n{model_name.upper().replace('_', ' ')}:")
    print("-"*80)
    samples = get_sample_predictions(model, eval_dataset, vocab, device, num_samples=5)
    print_sample_predictions(samples, num_to_print=5)

## Step 9: Load Training Histories

In [None]:
# Load training histories
histories = {}
for model_name in models_dict.keys():
    history_path = Path(f'../results/{model_name}_history.json')
    if history_path.exists():
        with open(history_path, 'r') as f:
            histories[model_name] = json.load(f)

print(f"✓ Loaded {len(histories)} training histories")

## Step 10: Plot Training Curves Comparison

In [None]:
from src.utils import plot_model_comparison

# Compare validation losses
if len(histories) > 0:
    print("Plotting training curves comparison...")
    plot_model_comparison(histories, metric='val_loss')
    plt.savefig('../results/training_comparison_small.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("✓ Training comparison saved to results/training_comparison_small.png")
else:
    print("⚠ No training histories found")

## Step 11: Summary Table

In [None]:
# Create summary dataframe
summary_data = []
for model_name in models_dict.keys():
    metrics = comparison[model_name]['metrics']
    history = histories.get(model_name, {})
    
    # Clean up model name for display
    display_name = model_name.replace('_small', '').replace('_', ' ').title()
    
    summary_data.append({
        'Model': display_name + ' (Small)',
        'Best Val Loss': min(history.get('val_loss', [float('inf')])),
        'Repetition Rate': metrics['avg_repetition_rate'],
        'Vocabulary Diversity': metrics['vocabulary_diversity'],
        'Avg Caption Length': metrics['mean_caption_length'],
        'Unique Words Used': metrics['unique_words_used']
    })

summary_df = pd.DataFrame(summary_data)
print("\n" + "="*80)
print("MODEL COMPARISON SUMMARY (Memory-Efficient Versions)")
print("="*80)
print(summary_df.to_string(index=False))
print("="*80)

# Save to CSV
summary_df.to_csv('../results/model_comparison_small.csv', index=False)
print("\n✓ Summary saved to results/model_comparison_small.csv")

## Step 12: Best Model Analysis

In [None]:
# Find best model based on validation loss
best_model = min(summary_data, key=lambda x: x['Best Val Loss'])

print("\n" + "="*80)
print("BEST MODEL (Among Memory-Efficient Versions)")
print("="*80)
print(f"\nModel: {best_model['Model']}")
print(f"\nPerformance:")
for key, value in best_model.items():
    if key != 'Model':
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
print("="*80)

## Summary

All memory-efficient models have been evaluated and compared!

**Files created:**
- `results/comparison_results_small.json` - Detailed metrics for all models
- `results/comparison_plot_small.png` - Visual comparison
- `results/training_comparison_small.png` - Training curves comparison
- `results/model_comparison_small.csv` - Summary table

**Memory-efficient configurations used:**
- Baseline: embed_dim=128, hidden_dim=256, num_layers=1
- Attention: embed_dim=128, hidden_dim=256, num_layers=1
- Transformer: d_model=256, nhead=4, layers=2