# ALME Optimizer Analysis

This notebook analyzes the performance of the Adaptive Local Minima Escape (ALME) optimizer compared to baseline optimizers (SGD, Adam, AdamW).

## Overview

ALME combines Adam-based gradient descent with a local minima escape mechanism:
1. Detects stagnation via gradient norm and loss plateaus
2. Samples perturbed weight candidates
3. Evaluates candidates with mini-optimization runs
4. Continues from the best candidate

We'll test ALME on:
- MNIST classification
- Smooth vs jagged loss landscapes
- Synthetic optimization problems

In [None]:
import json
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

# Add project to path
sys.path.insert(0, str(Path.cwd().parent / "src"))

from styx.datasets.loaders import get_mnist_loaders
from styx.models.simple_nets import MLP
from styx.optimizers import ALME
from styx.visualization.plots import (
    plot_escape_events,
    plot_optimizer_comparison_detailed,
    plot_population_diversity,
)

%matplotlib inline
plt.style.use("seaborn-v0_8-darkgrid")

## 1. Load Experiment Results

First, let's load the results from the benchmark experiments.

In [None]:
results_dir = Path("../experiments/results")

# Load baseline results
baseline_sgd = json.load(open(results_dir / "baseline_sgd_results.json"))
baseline_adam = json.load(open(results_dir / "baseline_adam_results.json"))
baseline_adamw = json.load(open(results_dir / "baseline_adamw_results.json"))

# Load ALME results
alme_shallow = json.load(open(results_dir / "alme_mnist_shallow_results.json"))
alme_aggressive = json.load(open(results_dir / "alme_mnist_aggressive_results.json"))

print("✓ Results loaded successfully")

## 2. Compare Optimizer Performance

Let's create comprehensive comparison plots.

In [None]:
# Prepare results for comparison
comparison_results = {
    "SGD": baseline_sgd,
    "Adam": baseline_adam,
    "AdamW": baseline_adamw,
    "ALME (shallow)": alme_shallow,
    "ALME (aggressive)": alme_aggressive,
}

# Create detailed comparison plot
plot_optimizer_comparison_detailed(
    comparison_results,
    save_path="../experiments/results/optimizer_comparison.png",
)

## 3. Analyze ALME Escape Events

Let's examine when and how ALME escapes local minima.

In [None]:
# Plot escape events for shallow ALME
print("ALME (Shallow Configuration):")
plot_escape_events(
    alme_shallow["val_loss_history"],
    alme_shallow["alme_stats"],
    save_path="../experiments/results/alme_shallow_escapes.png",
)

# Print statistics
final_stats = alme_shallow["final_alme_stats"]
print(f"Total escapes: {final_stats['escape_count']}")
print(f"Average escape distance: {final_stats['avg_escape_distance']:.6f}")
print(f"Best validation loss: {final_stats['best_val_loss']:.4f}")

In [None]:
# Plot escape events for aggressive ALME
print("\nALME (Aggressive Configuration):")
plot_escape_events(
    alme_aggressive["val_loss_history"],
    alme_aggressive["alme_stats"],
    save_path="../experiments/results/alme_aggressive_escapes.png",
)

# Print statistics
final_stats = alme_aggressive["final_alme_stats"]
print(f"Total escapes: {final_stats['escape_count']}")
print(f"Average escape distance: {final_stats['avg_escape_distance']:.6f}")
print(f"Best validation loss: {final_stats['best_val_loss']:.4f}")

## 4. Population Diversity Analysis

Analyze how ALME explores the parameter space during escapes.

In [None]:
# Plot population diversity for both configurations
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

plt.sca(axes[0])
plot_population_diversity(alme_shallow["alme_stats"])
axes[0].set_title("ALME Shallow - Population Diversity")

plt.sca(axes[1])
plot_population_diversity(alme_aggressive["alme_stats"])
axes[1].set_title("ALME Aggressive - Population Diversity")

plt.tight_layout()
plt.savefig("../experiments/results/population_diversity_comparison.png", dpi=300, bbox_inches="tight")
plt.show()

## 5. Final Performance Metrics

Compare final performance metrics across all optimizers.

In [None]:
import pandas as pd

# Create summary table
summary_data = []

for name, results in comparison_results.items():
    row = {
        "Optimizer": name,
        "Final Train Loss": results["final_train_loss"],
        "Final Val Loss": results["final_val_loss"],
        "Best Val Loss": results["best_val_loss"],
        "Final Train Acc": results["train_accuracy_history"][-1],
        "Final Val Acc": results["val_accuracy_history"][-1],
    }
    
    if "final_alme_stats" in results:
        row["Escape Count"] = results["final_alme_stats"]["escape_count"]
        row["Avg Escape Dist"] = results["final_alme_stats"]["avg_escape_distance"]
    else:
        row["Escape Count"] = "-"
        row["Avg Escape Dist"] = "-"
    
    summary_data.append(row)

summary_df = pd.DataFrame(summary_data)
print("\nOptimizer Performance Summary:")
print(summary_df.to_string(index=False))

# Save to CSV
summary_df.to_csv("../experiments/results/optimizer_summary.csv", index=False)
print("\n✓ Summary saved to optimizer_summary.csv")

## 6. Loss Landscape Analysis

If landscape analysis results are available, let's examine them.

In [None]:
landscape_results_path = Path("../experiments/results/landscape/landscape_analysis_results.json")

if landscape_results_path.exists():
    landscape_results = json.load(open(landscape_results_path))
    
    print("Landscape Analysis Results:")
    print("=" * 60)
    
    for landscape, optimizers in landscape_results.items():
        print(f"\n{landscape}:")
        for opt_name, opt_results in optimizers.items():
            if "final_loss" in opt_results:
                print(f"  {opt_name}: {opt_results['final_loss']:.6f}")
            elif "final_val_loss" in opt_results:
                print(f"  {opt_name}: Val Loss={opt_results['final_val_loss']:.4f}, "
                      f"Val Acc={opt_results['final_val_accuracy']:.4f}")
else:
    print("Landscape analysis results not found. Run landscape_analysis.py first.")

## 7. Training Curves Comparison

Plot training and validation curves side by side for detailed analysis.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Training Loss
ax = axes[0, 0]
for name, results in comparison_results.items():
    epochs = range(1, len(results["train_loss_history"]) + 1)
    ax.plot(epochs, results["train_loss_history"], label=name, marker="o", markersize=3)
ax.set_xlabel("Epoch")
ax.set_ylabel("Training Loss")
ax.set_title("Training Loss Comparison")
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Validation Loss
ax = axes[0, 1]
for name, results in comparison_results.items():
    epochs = range(1, len(results["val_loss_history"]) + 1)
    ax.plot(epochs, results["val_loss_history"], label=name, marker="s", markersize=3)
ax.set_xlabel("Epoch")
ax.set_ylabel("Validation Loss")
ax.set_title("Validation Loss Comparison")
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Training Accuracy
ax = axes[1, 0]
for name, results in comparison_results.items():
    epochs = range(1, len(results["train_accuracy_history"]) + 1)
    ax.plot(epochs, results["train_accuracy_history"], label=name, marker="o", markersize=3)
ax.set_xlabel("Epoch")
ax.set_ylabel("Training Accuracy")
ax.set_title("Training Accuracy Comparison")
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Validation Accuracy
ax = axes[1, 1]
for name, results in comparison_results.items():
    epochs = range(1, len(results["val_accuracy_history"]) + 1)
    ax.plot(epochs, results["val_accuracy_history"], label=name, marker="s", markersize=3)
ax.set_xlabel("Epoch")
ax.set_ylabel("Validation Accuracy")
ax.set_title("Validation Accuracy Comparison")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("../experiments/results/training_curves_detailed.png", dpi=300, bbox_inches="tight")
plt.show()

## 8. Insights and Conclusions

### Key Findings:

1. **Escape Mechanism Effectiveness**: Analyze whether ALME successfully escapes local minima and improves performance
2. **Hyperparameter Sensitivity**: Compare shallow vs aggressive configurations
3. **Computational Cost**: Consider the overhead of candidate sampling and evaluation
4. **Landscape Dependence**: Examine performance on smooth vs jagged surfaces

### Next Steps:

- Test on more challenging datasets (CIFAR-10, Fashion-MNIST)
- Experiment with different scale distributions
- Tune stagnation detection parameters
- Compare with other advanced optimizers (RAdam, Lookahead, etc.)

In [None]:
# Print key insights
print("Key Performance Metrics:")
print("=" * 60)

for name, results in comparison_results.items():
    print(f"\n{name}:")
    print(f"  Best Val Loss: {results['best_val_loss']:.4f}")
    print(f"  Final Val Acc: {results['val_accuracy_history'][-1]:.4f}")
    
    if "final_alme_stats" in results:
        stats = results["final_alme_stats"]
        print(f"  Total Escapes: {stats['escape_count']}")
        if stats['escape_count'] > 0:
            print(f"  Avg Escape Distance: {stats['avg_escape_distance']:.6f}")