# Model Comparison: RNN vs LSTM

This notebook implements a rigorous hyperparameter search and model comparison following proper scientific methodology:

1. **Train/Validation/Test split** - Hyperparameters selected by validation loss, test set only touched once
2. **Reproducible seeds** - All random operations are seeded
3. **Statistical robustness** - Best configs evaluated with multiple seeds, reporting mean +/- std
4. **Fair comparison** - Same data split and search budget for all models

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import matplotlib.pyplot as plt

from ml_rnn_names.model import CharRNN, CharLSTM
from ml_rnn_names.processing import n_letters
from ml_rnn_names.training import train, set_seed
from ml_rnn_names.evaluation import compute_accuracy, evaluate_model, compute_confusion_matrix, plot_confusion_matrix
from ml_rnn_names.utils import torch_device_setup
from ml_rnn_names.data import load_data

In [None]:
# Configuration
CONFIG = {
    "n_trials": 10,           # Number of random hyperparameter configurations to try
    "seeds_per_config": 3,    # Number of seeds for statistical robustness
    "data_seed": 2024,        # Fixed seed for data split (ensures same split across runs)
    "validation_split": 0.15, # 70% train, 15% val, 15% test
    "param_ranges": {
        "hidden_size": [64, 128, 256],
        "learning_rate": (0.01, 0.5),
        "n_batch_size": [32, 64, 128],
        "n_epoch": [10, 15, 20],
    },
    "lstm_extra_params": {
        "num_layers": [2, 3, 4],  # LSTM-specific parameter
    },
}

In [None]:
# Load data with 3-way split (same split for all models)
device = torch_device_setup()
train_set, val_set, test_set, labels_uniq = load_data(
    device,
    validation_split=CONFIG["validation_split"],
    seed=CONFIG["data_seed"]
)

print(f"\nClasses: {labels_uniq}")

In [None]:
# Model factory
def create_model(model_type, hidden_size, num_layers=2):
    """Create a model instance based on type and hyperparameters."""
    if model_type == "RNN":
        return CharRNN(
            input_size=n_letters,
            hidden_size=hidden_size,
            output_size=len(labels_uniq)
        )
    elif model_type == "LSTM":
        return CharLSTM(
            input_size=n_letters,
            hidden_size=hidden_size,
            num_layers=num_layers,
            output_size=len(labels_uniq)
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")

In [None]:
def random_search(model_type, n_trials, param_ranges, train_data, val_data, device, base_seed=0):
    """
    Run random hyperparameter search, selecting by VALIDATION loss.
    
    Args:
        model_type: "RNN" or "LSTM"
        n_trials: number of configurations to try
        param_ranges: dict of hyperparameter ranges
        train_data: training dataset
        val_data: validation dataset
        device: torch device
        base_seed: base seed for reproducibility
    
    Returns:
        results: list of all trial results
        best: best trial (by validation loss)
    """
    set_seed(base_seed, device)
    results = []
    
    for trial in range(n_trials):
        # Sample hyperparameters
        hidden_size = int(np.random.choice(param_ranges["hidden_size"]))
        learning_rate = np.random.uniform(*param_ranges["learning_rate"])
        n_batch_size = int(np.random.choice(param_ranges["n_batch_size"]))
        n_epoch = int(np.random.choice(param_ranges["n_epoch"]))
        
        # LSTM-specific: sample num_layers
        num_layers = 2
        if model_type == "LSTM" and "num_layers" in CONFIG.get("lstm_extra_params", {}):
            num_layers = int(np.random.choice(CONFIG["lstm_extra_params"]["num_layers"]))
        
        print(f"Trial {trial + 1}/{n_trials}: hidden={hidden_size}, lr={learning_rate:.4f}, "
              f"batch={n_batch_size}, epochs={n_epoch}" + 
              (f", layers={num_layers}" if model_type == "LSTM" else ""))
        
        model = create_model(model_type, hidden_size, num_layers)
        
        result = train(
            model=model,
            training_data=train_data,
            validation_data=val_data,
            n_epoch=n_epoch,
            n_batch_size=n_batch_size,
            learning_rate=learning_rate,
            device=device,
            seed=base_seed + trial,
            report_every=100,  # Reduce verbosity
        )
        
        results.append({
            "model_type": model_type,
            "hidden_size": hidden_size,
            "num_layers": num_layers,
            "learning_rate": learning_rate,
            "n_batch_size": n_batch_size,
            "n_epoch": n_epoch,
            "final_train_loss": result["train_losses"][-1],
            "final_val_loss": result["val_losses"][-1],
            "train_losses": result["train_losses"],
            "val_losses": result["val_losses"],
        })
        
        print(f"  -> train_loss={result['train_losses'][-1]:.4f}, val_loss={result['val_losses'][-1]:.4f}\n")
    
    # Select by VALIDATION loss (not training loss!)
    best = min(results, key=lambda x: x["final_val_loss"])
    return results, best

## Hyperparameter Search

Run random search for each model type, selecting the best configuration by **validation loss**.

In [None]:
all_results = {}
best_configs = {}

for model_type in ["RNN", "LSTM"]:
    print(f"\n{'='*60}")
    print(f"Searching hyperparameters for {model_type}")
    print('='*60 + "\n")
    
    results, best = random_search(
        model_type=model_type,
        n_trials=CONFIG["n_trials"],
        param_ranges=CONFIG["param_ranges"],
        train_data=train_set,
        val_data=val_set,
        device=device,
        base_seed=42 if model_type == "RNN" else 142,  # Different base seeds per model
    )
    
    all_results[model_type] = results
    best_configs[model_type] = best
    
    print(f"\nBest {model_type} config (selected by validation loss):")
    for k, v in best.items():
        if k not in ["train_losses", "val_losses"]:
            print(f"  {k}: {v}")

## Statistical Robustness

Retrain the best configuration for each model with multiple random seeds to assess variance.

In [None]:
def evaluate_with_multiple_seeds(config, train_data, val_data, test_data, seeds, device):
    """
    Train the same config with multiple seeds and report statistics.
    
    Returns dict with mean and std of test metrics.
    """
    test_accuracies = []
    test_losses = []
    trained_models = []
    
    for seed in seeds:
        print(f"  Training with seed {seed}...")
        model = create_model(
            config["model_type"],
            config["hidden_size"],
            config.get("num_layers", 2)
        )
        
        train(
            model=model,
            training_data=train_data,
            validation_data=val_data,
            n_epoch=config["n_epoch"],
            n_batch_size=config["n_batch_size"],
            learning_rate=config["learning_rate"],
            device=device,
            seed=seed,
            report_every=100,
        )
        
        # Evaluate on TEST set (only now do we touch test data)
        metrics = evaluate_model(model, train_data, val_data, test_data)
        test_accuracies.append(metrics["test_accuracy"])
        test_losses.append(metrics["test_loss"])
        trained_models.append(model)
        
        print(f"    test_accuracy={metrics['test_accuracy']:.4f}, test_loss={metrics['test_loss']:.4f}")
    
    return {
        "test_accuracy_mean": np.mean(test_accuracies),
        "test_accuracy_std": np.std(test_accuracies),
        "test_loss_mean": np.mean(test_losses),
        "test_loss_std": np.std(test_losses),
        "models": trained_models,
    }

In [None]:
seeds = [42, 123, 456]
final_results = {}

for model_type, config in best_configs.items():
    print(f"\nEvaluating {model_type} with {len(seeds)} seeds...")
    stats = evaluate_with_multiple_seeds(
        config, train_set, val_set, test_set, seeds, device
    )
    final_results[model_type] = stats

## Final Results

In [None]:
print("\n" + "="*70)
print("FINAL RESULTS (Test Set)")
print("="*70)
print(f"{'Model':<10} {'Accuracy':<25} {'Loss':<25}")
print("-"*70)

for model_type, stats in final_results.items():
    acc = f"{stats['test_accuracy_mean']*100:.1f}% +/- {stats['test_accuracy_std']*100:.1f}%"
    loss = f"{stats['test_loss_mean']:.4f} +/- {stats['test_loss_std']:.4f}"
    print(f"{model_type:<10} {acc:<25} {loss:<25}")

print("\nBest hyperparameters used:")
for model_type, config in best_configs.items():
    print(f"\n{model_type}:")
    for k in ["hidden_size", "num_layers", "learning_rate", "n_batch_size", "n_epoch"]:
        if k in config:
            print(f"  {k}: {config[k]}")

## Visualization

In [None]:
# Plot validation loss across all trials for each model
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, model_type in enumerate(["RNN", "LSTM"]):
    ax = axes[idx]
    for result in all_results[model_type]:
        ax.plot(result["val_losses"], alpha=0.5, 
                label=f"h={result['hidden_size']}, lr={result['learning_rate']:.2f}")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Validation Loss")
    ax.set_title(f"{model_type} - Hyperparameter Search")
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
# Confusion matrices for best models (using first seed's model)
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

for idx, model_type in enumerate(["RNN", "LSTM"]):
    model = final_results[model_type]["models"][0]  # First seed's model
    confusion = compute_confusion_matrix(model, test_set, len(labels_uniq))
    
    # Normalize
    for i in range(len(labels_uniq)):
        denom = confusion[i].sum()
        if denom > 0:
            confusion[i] = confusion[i] / denom
    
    ax = axes[idx]
    cax = ax.matshow(confusion.cpu().numpy())
    fig.colorbar(cax, ax=ax)
    ax.set_xticks(np.arange(len(labels_uniq)))
    ax.set_yticks(np.arange(len(labels_uniq)))
    ax.set_xticklabels(labels_uniq, rotation=90)
    ax.set_yticklabels(labels_uniq)
    ax.set_title(f"{model_type} Confusion Matrix (Test Set)")

plt.tight_layout()
plt.show()