# SPIDS Result Analysis Tutorial

Learn how to analyze and compare SPIDS experiment results.

This tutorial covers:
1. Loading saved experiments
2. Visualizing reconstruction quality
3. Analyzing convergence metrics
4. Comparing multiple experiments

**Estimated time**: 20-25 minutes

**Prerequisites**: Run at least one SPIDS experiment first

## 1. Setup

In [None]:
import json
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch


sys.path.insert(0, "../..")

%matplotlib inline
plt.rcParams["figure.figsize"] = (16, 10)
plt.rcParams["figure.dpi"] = 100

print("✓ Setup complete")

## 2. Finding and Loading Experiments

First, let's see what experiments are available:

In [None]:
# Find all experiment directories
runs_dir = Path("../../runs")

if runs_dir.exists():
    experiments = sorted([d for d in runs_dir.iterdir() if d.is_dir()])
    print(f"Found {len(experiments)} experiments:\n")
    for exp in experiments:
        print(f"  - {exp.name}")

        # Check what files are available
        files = list(exp.glob("*.pt")) + list(exp.glob("*.json")) + list(exp.glob("*.png"))
        if files:
            print(f"    Files: {len(files)} items")
else:
    print("⚠️  No runs/ directory found. Run an experiment first!")
    print("\nExample:")
    print("  cd ../..")
    print("  uv run python main.py --obj_name europa --n_samples 50 --fermat --name tutorial_test")

## 3. Loading Experiment Data

Load a specific experiment for analysis:

In [None]:
def load_experiment(exp_name: str) -> dict:
    """
    Load experiment results from the runs directory.

    Parameters
    ----------
    exp_name : str
        Name of the experiment (folder name in runs/)

    Returns
    -------
    dict
        Dictionary containing:
        - model_state: Final model checkpoint
        - metrics: Training metrics (if available)
        - config: Experiment configuration (if available)
    """
    exp_dir = Path("../../runs") / exp_name

    if not exp_dir.exists():
        raise ValueError(f"Experiment '{exp_name}' not found in runs/")

    results = {
        "name": exp_name,
        "path": exp_dir,
    }

    # Load model checkpoint
    model_file = exp_dir / "final_model.pt"
    if model_file.exists():
        results["model_state"] = torch.load(model_file, map_location="cpu")
        print("✓ Loaded model checkpoint")

    # Load metrics
    metrics_file = exp_dir / "metrics.json"
    if metrics_file.exists():
        with open(metrics_file) as f:
            results["metrics"] = json.load(f)
        print("✓ Loaded metrics")

    # Load configuration
    config_file = exp_dir / "config.json"
    if config_file.exists():
        with open(config_file) as f:
            results["config"] = json.load(f)
        print("✓ Loaded configuration")

    # Find saved images
    results["images"] = list(exp_dir.glob("*.png"))
    print(f"✓ Found {len(results['images'])} saved images")

    return results


# Select an experiment to analyze
# Replace with your experiment name:
EXPERIMENT_NAME = experiments[0].name if "experiments" in dir() and experiments else "tutorial_test"

print(f"Loading experiment: {EXPERIMENT_NAME}\n")
exp_data = load_experiment(EXPERIMENT_NAME)

## 4. Visualizing Reconstruction

Display the final reconstruction:

In [None]:
if "model_state" in exp_data:
    # Load and visualize the reconstruction
    from prism.models.networks import GenCropSpidsNet

    # Reconstruct model (you need to know obj_size and image_size)
    # These should ideally come from config
    if "config" in exp_data:
        obj_size = exp_data["config"].get("obj_size", 128)
        image_size = exp_data["config"].get("image_size", 512)
    else:
        obj_size = 128
        image_size = 512

    model = GenCropSpidsNet(obj_size=obj_size, image_size=image_size)
    model.load_state_dict(exp_data["model_state"])
    model.eval()

    # Generate reconstruction
    with torch.no_grad():
        reconstruction = model().cpu()

    # Visualize
    if reconstruction.shape[1] == 2:  # Complex-valued
        magnitude = torch.sqrt(reconstruction[:, 0] ** 2 + reconstruction[:, 1] ** 2)
        phase = torch.atan2(reconstruction[:, 1], reconstruction[:, 0])

        fig, axes = plt.subplots(1, 2, figsize=(14, 6))

        # Magnitude
        im1 = axes[0].imshow(magnitude[0], cmap="viridis")
        axes[0].set_title("Magnitude", fontsize=14, fontweight="bold")
        axes[0].axis("off")
        plt.colorbar(im1, ax=axes[0], fraction=0.046)

        # Phase
        im2 = axes[1].imshow(phase[0], cmap="twilight", vmin=-np.pi, vmax=np.pi)
        axes[1].set_title("Phase", fontsize=14, fontweight="bold")
        axes[1].axis("off")
        plt.colorbar(im2, ax=axes[1], fraction=0.046)

        plt.suptitle(f"Final Reconstruction: {EXPERIMENT_NAME}", fontsize=16, fontweight="bold")
        plt.tight_layout()
        plt.show()
    else:
        plt.figure(figsize=(10, 10))
        plt.imshow(reconstruction[0, 0], cmap="viridis")
        plt.title(f"Reconstruction: {EXPERIMENT_NAME}")
        plt.colorbar(fraction=0.046)
        plt.axis("off")
        plt.show()
else:
    print("⚠️  No model checkpoint found in this experiment")

## 5. Training Metrics Analysis

Analyze convergence and training dynamics:

In [None]:
if "metrics" in exp_data:
    metrics = exp_data["metrics"]

    # Extract common metrics
    losses = metrics.get("losses", [])
    coverage = metrics.get("coverage", None)
    failed_samples = metrics.get("failed_samples", [])
    n_samples = metrics.get("n_samples", 0)

    # Create comprehensive metrics plot
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Loss curve
    if losses:
        axes[0, 0].plot(losses, linewidth=2)
        axes[0, 0].set_xlabel("Sample Number")
        axes[0, 0].set_ylabel("Loss")
        axes[0, 0].set_title("Training Loss Curve", fontweight="bold")
        axes[0, 0].grid(True, alpha=0.3)
        axes[0, 0].set_yscale("log")  # Log scale often better for losses

    # 2. Loss distribution
    if losses:
        axes[0, 1].hist(losses, bins=30, edgecolor="black", alpha=0.7)
        axes[0, 1].set_xlabel("Loss Value")
        axes[0, 1].set_ylabel("Frequency")
        axes[0, 1].set_title("Loss Distribution", fontweight="bold")
        axes[0, 1].axvline(
            np.mean(losses), color="r", linestyle="--", label=f"Mean: {np.mean(losses):.4f}"
        )
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

    # 3. Success rate
    if n_samples > 0:
        success_rate = (n_samples - len(failed_samples)) / n_samples * 100
        axes[1, 0].bar(
            ["Success", "Failed"],
            [n_samples - len(failed_samples), len(failed_samples)],
            color=["green", "red"],
            alpha=0.7,
            edgecolor="black",
        )
        axes[1, 0].set_ylabel("Number of Samples")
        axes[1, 0].set_title(f"Sample Success Rate: {success_rate:.1f}%", fontweight="bold")
        axes[1, 0].grid(True, alpha=0.3, axis="y")

    # 4. Coverage information
    if coverage is not None:
        axes[1, 1].text(
            0.5,
            0.6,
            f"Coverage: {coverage:.1f}%",
            ha="center",
            va="center",
            fontsize=24,
            fontweight="bold",
        )
        axes[1, 1].text(
            0.5, 0.4, f"Total Samples: {n_samples}", ha="center", va="center", fontsize=16
        )
        if losses:
            axes[1, 1].text(
                0.5, 0.3, f"Final Loss: {losses[-1]:.4f}", ha="center", va="center", fontsize=16
            )
            axes[1, 1].text(
                0.5, 0.2, f"Mean Loss: {np.mean(losses):.4f}", ha="center", va="center", fontsize=16
            )
        axes[1, 1].set_xlim(0, 1)
        axes[1, 1].set_ylim(0, 1)
        axes[1, 1].axis("off")
        axes[1, 1].set_title("Summary Statistics", fontweight="bold")

    plt.suptitle(f"Metrics Analysis: {EXPERIMENT_NAME}", fontsize=18, fontweight="bold", y=1.00)
    plt.tight_layout()
    plt.show()

    # Print detailed statistics
    print("\n" + "=" * 60)
    print("DETAILED METRICS")
    print("=" * 60)
    if losses:
        print("Loss Statistics:")
        print(f"  Mean:     {np.mean(losses):.6f}")
        print(f"  Std Dev:  {np.std(losses):.6f}")
        print(f"  Min:      {np.min(losses):.6f}")
        print(f"  Max:      {np.max(losses):.6f}")
        print(f"  Final:    {losses[-1]:.6f}")
    if coverage is not None:
        print(f"\nCoverage: {coverage:.2f}%")
    if n_samples > 0:
        print("\nSamples:")
        print(f"  Total:    {n_samples}")
        print(f"  Success:  {n_samples - len(failed_samples)}")
        print(f"  Failed:   {len(failed_samples)}")
    print("=" * 60)
else:
    print("⚠️  No metrics found in this experiment")

## 6. Comparing Multiple Experiments

Compare results from different experiments:

In [None]:
def compare_experiments(exp_names: list[str]):
    """
    Compare metrics across multiple experiments.
    """
    comparison_data = []

    for name in exp_names:
        try:
            exp = load_experiment(name)
            if "metrics" in exp:
                metrics = exp["metrics"]
                comparison_data.append(
                    {
                        "name": name,
                        "mean_loss": np.mean(metrics.get("losses", [])),
                        "coverage": metrics.get("coverage", 0),
                        "n_samples": metrics.get("n_samples", 0),
                        "failed": len(metrics.get("failed_samples", [])),
                    }
                )
        except Exception as e:
            print(f"⚠️  Could not load {name}: {e}")

    if not comparison_data:
        print("No experiments to compare")
        return

    # Create comparison plots
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    names = [d["name"] for d in comparison_data]

    # Mean loss comparison
    axes[0].bar(
        range(len(names)),
        [d["mean_loss"] for d in comparison_data],
        color="steelblue",
        edgecolor="black",
        alpha=0.7,
    )
    axes[0].set_xticks(range(len(names)))
    axes[0].set_xticklabels(names, rotation=45, ha="right")
    axes[0].set_ylabel("Mean Loss")
    axes[0].set_title("Mean Loss Comparison", fontweight="bold")
    axes[0].grid(True, alpha=0.3, axis="y")

    # Coverage comparison
    axes[1].bar(
        range(len(names)),
        [d["coverage"] for d in comparison_data],
        color="forestgreen",
        edgecolor="black",
        alpha=0.7,
    )
    axes[1].set_xticks(range(len(names)))
    axes[1].set_xticklabels(names, rotation=45, ha="right")
    axes[1].set_ylabel("Coverage (%)")
    axes[1].set_title("Coverage Comparison", fontweight="bold")
    axes[1].grid(True, alpha=0.3, axis="y")

    # Success rate comparison
    success_rates = [
        (d["n_samples"] - d["failed"]) / d["n_samples"] * 100 if d["n_samples"] > 0 else 0
        for d in comparison_data
    ]
    axes[2].bar(range(len(names)), success_rates, color="coral", edgecolor="black", alpha=0.7)
    axes[2].set_xticks(range(len(names)))
    axes[2].set_xticklabels(names, rotation=45, ha="right")
    axes[2].set_ylabel("Success Rate (%)")
    axes[2].set_title("Success Rate Comparison", fontweight="bold")
    axes[2].grid(True, alpha=0.3, axis="y")
    axes[2].set_ylim(0, 100)

    plt.tight_layout()
    plt.show()

    # Print comparison table
    print("\n" + "=" * 80)
    print("EXPERIMENT COMPARISON")
    print("=" * 80)
    print(f"{'Experiment':<25} {'Mean Loss':>12} {'Coverage':>12} {'Success Rate':>15}")
    print("-" * 80)
    for d in comparison_data:
        success_rate = (
            (d["n_samples"] - d["failed"]) / d["n_samples"] * 100 if d["n_samples"] > 0 else 0
        )
        print(
            f"{d['name']:<25} {d['mean_loss']:>12.6f} {d['coverage']:>11.1f}% {success_rate:>14.1f}%"
        )
    print("=" * 80)


# Example: Compare first 3 experiments (if available)
if "experiments" in dir() and len(experiments) > 0:
    exp_to_compare = [exp.name for exp in experiments[: min(3, len(experiments))]]
    print(f"Comparing {len(exp_to_compare)} experiments\n")
    compare_experiments(exp_to_compare)
else:
    print("⚠️  Need at least one experiment to compare")

## 7. Quality Metrics

Calculate additional quality metrics:

In [None]:
def calculate_quality_metrics(reconstruction: torch.Tensor) -> dict:
    """
    Calculate quality metrics for a reconstruction.
    """
    if reconstruction.shape[1] == 2:  # Complex
        magnitude = torch.sqrt(reconstruction[:, 0] ** 2 + reconstruction[:, 1] ** 2)
        phase = torch.atan2(reconstruction[:, 1], reconstruction[:, 0])
    else:
        magnitude = reconstruction[:, 0]
        phase = None

    # Calculate metrics
    metrics = {
        "mean_intensity": float(magnitude.mean()),
        "std_intensity": float(magnitude.std()),
        "min_intensity": float(magnitude.min()),
        "max_intensity": float(magnitude.max()),
        "dynamic_range": float(magnitude.max() - magnitude.min()),
    }

    if phase is not None:
        metrics.update(
            {
                "phase_mean": float(phase.mean()),
                "phase_std": float(phase.std()),
            }
        )

    return metrics


# Calculate for loaded experiment
if "model_state" in exp_data:
    quality = calculate_quality_metrics(reconstruction)

    print("\n" + "=" * 60)
    print("QUALITY METRICS")
    print("=" * 60)
    for key, value in quality.items():
        print(f"{key:<20}: {value:>15.6f}")
    print("=" * 60)

## 8. Export Results

Save analysis results for reporting:

In [None]:
def export_analysis_report(exp_data: dict, output_file: str = "analysis_report.json"):
    """
    Export comprehensive analysis report.
    """
    report = {
        "experiment_name": exp_data["name"],
        "timestamp": str(Path(exp_data["path"]).stat().st_mtime),
    }

    if "metrics" in exp_data:
        metrics = exp_data["metrics"]
        report["metrics"] = {
            "mean_loss": float(np.mean(metrics.get("losses", []))),
            "final_loss": float(metrics.get("losses", [0])[-1]),
            "coverage": metrics.get("coverage"),
            "n_samples": metrics.get("n_samples"),
            "n_failed": len(metrics.get("failed_samples", [])),
        }

    # Save report
    output_path = Path(exp_data["path"]) / output_file
    with open(output_path, "w") as f:
        json.dump(report, f, indent=2)

    print(f"✓ Analysis report saved to: {output_path}")
    return report


# Export current experiment analysis
if exp_data:
    report = export_analysis_report(exp_data)
    print("\nReport contents:")
    print(json.dumps(report, indent=2))

## Summary

You've learned:
- ✓ How to load and explore saved experiments
- ✓ How to visualize reconstruction results
- ✓ How to analyze training metrics and convergence
- ✓ How to compare multiple experiments
- ✓ How to export analysis reports

## Next Steps

1. **Run systematic experiments** - Vary parameters (n_samples, patterns, objects)
2. **Create comparison datasets** - Build a library of experiments
3. **Optimize parameters** - Find optimal settings for your use case
4. **Share results** - Export reports and visualizations for papers/presentations

## Further Resources

- **Python API Examples**: `examples/python_api/`
- **Documentation**: Project README and docs/
- **Advanced patterns**: `examples/patterns/`
- **Baselines**: `examples/baselines/` for algorithm comparisons