# Stability Demo

This notebook demonstrates COMPASS's numerical stability features for RCWA simulations:

1. Float32 vs float64 precision comparison
2. Mixed precision eigendecomposition
3. Adaptive fallback mechanism
4. `StabilityDiagnostics` pre/post simulation checks

RCWA is sensitive to numerical precision because it involves eigendecomposition
of large complex matrices. Small floating-point errors can compound through the
S-matrix recursion and produce unphysical results (QE > 1 or R + T + A != 1).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time

from compass.solvers.base import SolverFactory
from compass.solvers.rcwa.stability import (
    PrecisionManager,
    AdaptivePrecisionRunner,
    StabilityDiagnostics,
    EigenvalueStabilizer,
)
from compass.runners.single_run import SingleRunner
from compass.analysis.energy_balance import EnergyBalance

## 1. Setup: Common Pixel Configuration

We use a standard 2x2 BSI pixel. Stability issues are more likely
with high Fourier orders and structures containing metallic layers.

In [None]:
base_config = {
    "pixel": {
        "pitch": 1.0,
        "unit_cell": [2, 2],
        "bayer_map": [["R", "G"], ["G", "B"]],
        "layers": {
            "air": {"thickness": 1.0, "material": "air"},
            "microlens": {
                "enabled": True, "height": 0.6,
                "radius_x": 0.48, "radius_y": 0.48,
                "material": "polymer_n1p56",
                "profile": {"type": "superellipse", "n": 2.5, "alpha": 1.0},
            },
            "planarization": {"thickness": 0.3, "material": "sio2"},
            "color_filter": {
                "thickness": 0.6,
                "materials": {"R": "cf_red", "G": "cf_green", "B": "cf_blue"},
                "grid": {"enabled": True, "width": 0.05, "material": "tungsten"},
            },
            "barl": {"layers": [
                {"thickness": 0.010, "material": "sio2"},
                {"thickness": 0.025, "material": "hfo2"},
            ]},
            "silicon": {
                "thickness": 3.0, "material": "silicon",
                "photodiode": {"position": [0, 0, 0.5], "size": [0.7, 0.7, 2.0]},
                "dti": {"enabled": True, "width": 0.1, "material": "sio2"},
            },
        },
    },
    "source": {
        "wavelength": {
            "mode": "sweep",
            "sweep": {"start": 0.40, "stop": 0.70, "step": 0.01},
        },
        "polarization": "unpolarized",
    },
    "compute": {"backend": "auto"},
}

print("Base config: 2x2 BSI with tungsten grid (high-contrast)")

## 2. Float32 vs Float64 Comparison

Run the same simulation in pure float32 and pure float64 to see how
precision affects QE accuracy and energy conservation.

In [None]:
import copy

# Float32 configuration
config_f32 = copy.deepcopy(base_config)
config_f32["solver"] = {
    "name": "torcwa", "type": "rcwa",
    "params": {"fourier_order": [9, 9], "dtype": "complex64"},
    "stability": {
        "precision_strategy": "float32",
        "allow_tf32": False,
        "fourier_factorization": "li_inverse",
    },
}

# Float64 configuration
config_f64 = copy.deepcopy(base_config)
config_f64["solver"] = {
    "name": "torcwa", "type": "rcwa",
    "params": {"fourier_order": [9, 9], "dtype": "complex128"},
    "stability": {
        "precision_strategy": "float64",
        "allow_tf32": False,
        "fourier_factorization": "li_inverse",
    },
}

print("Running float32...")
t0 = time.perf_counter()
result_f32 = SingleRunner.run(config_f32)
time_f32 = time.perf_counter() - t0

print("Running float64...")
t0 = time.perf_counter()
result_f64 = SingleRunner.run(config_f64)
time_f64 = time.perf_counter() - t0

print(f"\nFloat32: {time_f32:.2f}s")
print(f"Float64: {time_f64:.2f}s")
print(f"Speedup: {time_f64/time_f32:.1f}x")

In [None]:
# Compare QE spectra
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), height_ratios=[3, 1])

wl_nm = result_f64.wavelengths * 1000

# Extract green pixel QE from both
for name, qe in result_f32.qe_per_pixel.items():
    if name.startswith("G"):
        ax1.plot(wl_nm, qe, "--", color="tab:blue", alpha=0.7,
                 label=f"float32 ({name})")
        qe_f32_green = qe
        break

for name, qe in result_f64.qe_per_pixel.items():
    if name.startswith("G"):
        ax1.plot(wl_nm, qe, "-", color="tab:orange", linewidth=2,
                 label=f"float64 ({name})")
        qe_f64_green = qe
        break

ax1.set_ylabel("Green QE")
ax1.set_title("Float32 vs Float64: QE Comparison")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Difference plot
diff = np.abs(np.array(qe_f32_green) - np.array(qe_f64_green))
ax2.plot(wl_nm, diff, "r-", linewidth=1.5)
ax2.axhline(0.01, color="gray", linestyle="--", alpha=0.5, label="1% threshold")
ax2.set_xlabel("Wavelength (nm)")
ax2.set_ylabel("|QE difference|")
ax2.set_title("Absolute QE Difference (f32 vs f64)")
ax2.legend()
ax2.grid(True, alpha=0.3)

fig.tight_layout()

print(f"Max |QE diff|: {np.max(diff):.5f}")
print(f"Mean |QE diff|: {np.mean(diff):.5f}")

## 3. Mixed Precision Eigendecomposition

The `mixed` precision strategy runs most of the computation in float32 for speed,
but promotes eigendecomposition to float64 (on CPU) for stability.

This is the recommended default. Let's demonstrate it directly.

In [None]:
# Demonstrate mixed precision eigendecomp on a synthetic matrix
np.random.seed(42)

# Simulate a Fourier-domain matrix at order [9,9] -> 361 modes -> 722x722
n = 722
matrix_f32 = (np.random.randn(n, n) + 1j * np.random.randn(n, n)).astype(np.complex64)

# Direct float32 eigendecomp
t0 = time.perf_counter()
evals_f32, evecs_f32 = np.linalg.eig(matrix_f32.astype(np.complex64))
time_direct = time.perf_counter() - t0

# Mixed precision eigendecomp (promotes to f64 internally)
t0 = time.perf_counter()
evals_mixed, evecs_mixed = PrecisionManager.mixed_precision_eigen(matrix_f32)
time_mixed = time.perf_counter() - t0

# Full float64 eigendecomp
matrix_f64 = matrix_f32.astype(np.complex128)
t0 = time.perf_counter()
evals_f64, evecs_f64 = np.linalg.eig(matrix_f64)
time_full = time.perf_counter() - t0

print(f"Direct float32: {time_direct:.3f}s")
print(f"Mixed precision: {time_mixed:.3f}s")
print(f"Full float64:   {time_full:.3f}s")

# Compare reconstruction error: ||A*v - lambda*v||
def reconstruction_error(matrix, evals, evecs):
    """Compute eigendecomposition reconstruction error."""
    errors = []
    for i in range(min(20, len(evals))):  # Check first 20
        residual = matrix @ evecs[:, i] - evals[i] * evecs[:, i]
        errors.append(np.linalg.norm(residual))
    return np.mean(errors)

err_f32 = reconstruction_error(matrix_f64, evals_f32.astype(np.complex128),
                               evecs_f32.astype(np.complex128))
err_mixed = reconstruction_error(matrix_f64, evals_mixed.astype(np.complex128),
                                 evecs_mixed.astype(np.complex128))
err_f64 = reconstruction_error(matrix_f64, evals_f64, evecs_f64)

print(f"\nReconstruction error (||Av - lv||):")
print(f"  Float32:         {err_f32:.6e}")
print(f"  Mixed precision: {err_mixed:.6e}")
print(f"  Float64:         {err_f64:.6e}")

## 4. Adaptive Fallback Demonstration

The `AdaptivePrecisionRunner` automatically escalates precision when
energy conservation fails. The fallback chain is:

1. GPU float32 (fastest)
2. GPU float64 (if f32 fails energy check)
3. CPU float64 (most stable, slowest)

Here we demonstrate the mechanism by running it with a strict tolerance.

In [None]:
# Create an AdaptivePrecisionRunner with tight tolerance
adaptive_runner = AdaptivePrecisionRunner(tolerance=0.005)  # 0.5% tolerance

print(f"Adaptive runner tolerance: {adaptive_runner.tolerance}")
print(f"Fallback chain: GPU-f32 -> GPU-f64 -> CPU-f64")
print()

# Simulate the fallback for a challenging wavelength
# In practice, short wavelengths with metallic structures are hardest
config_mixed = copy.deepcopy(base_config)
config_mixed["solver"] = {
    "name": "torcwa", "type": "rcwa",
    "params": {"fourier_order": [9, 9]},
    "stability": {
        "precision_strategy": "mixed",
        "allow_tf32": False,
        "fourier_factorization": "li_inverse",
        "energy_check": {
            "enabled": True,
            "tolerance": 0.005,
            "auto_retry_float64": True,
        },
    },
}

print("Running with mixed precision and auto-retry...")
result_mixed = SingleRunner.run(config_mixed)

# Check energy balance
if result_mixed.reflection is not None and result_mixed.transmission is not None:
    total = np.array(result_mixed.reflection) + np.array(result_mixed.transmission)
    if result_mixed.absorption is not None:
        total += np.array(result_mixed.absorption)
    max_error = np.max(np.abs(total - 1.0))
    print(f"Max energy error: {max_error:.6f}")
    print(f"Energy check: {'PASS' if max_error < 0.01 else 'FAIL'}")

## 5. StabilityDiagnostics: Pre-Simulation Check

Before running a simulation, `StabilityDiagnostics.pre_simulation_check()`
inspects the pixel stack and solver config for potential stability issues.

It checks for:
- Large Fourier order with insufficient precision
- Thick layers that may cause S-matrix overflow
- TF32 being enabled (catastrophic for RCWA)
- Patterned layers with naive Fourier factorization

In [None]:
from compass.materials.database import MaterialDB
from compass.geometry.builder import GeometryBuilder

# Build pixel stack for diagnostics
mat_db = MaterialDB()
builder = GeometryBuilder(base_config["pixel"], mat_db)
pixel_stack = builder.build()

# Test 1: Good config (should have no warnings)
good_config = {
    "name": "torcwa", "type": "rcwa",
    "params": {"fourier_order": [9, 9]},
    "stability": {
        "precision_strategy": "mixed",
        "allow_tf32": False,
        "fourier_factorization": "li_inverse",
    },
}

print("=== Pre-simulation check: Good config ===")
warnings = StabilityDiagnostics.pre_simulation_check(pixel_stack, good_config)
if warnings:
    for w in warnings:
        print(f"  WARNING: {w}")
else:
    print("  No warnings. Config looks stable.")

# Test 2: Risky config (high order with float32, naive factorization)
risky_config = {
    "name": "torcwa", "type": "rcwa",
    "params": {"fourier_order": [17, 17]},
    "stability": {
        "precision_strategy": "float32",
        "fourier_factorization": "naive",
    },
}

print("\n=== Pre-simulation check: Risky config ===")
warnings = StabilityDiagnostics.pre_simulation_check(pixel_stack, risky_config)
for w in warnings:
    print(f"  WARNING: {w}")

## 6. StabilityDiagnostics: Post-Simulation Check

After the simulation, `post_simulation_check()` validates that:
- QE values are in the physical range [0, 1]
- No NaN or Inf values in R, T, A
- Energy conservation holds (R + T + A = 1 within tolerance)

In [None]:
# Post-simulation check on our float64 result
print("=== Post-simulation check: float64 result ===")
report_f64 = StabilityDiagnostics.post_simulation_check(result_f64)
if report_f64:
    for key, info in report_f64.items():
        print(f"  {key}: {info['status']} - {info['issue']}")
else:
    print("  All checks passed. No issues detected.")

# Post-simulation check on our float32 result
print("\n=== Post-simulation check: float32 result ===")
report_f32 = StabilityDiagnostics.post_simulation_check(result_f32)
if report_f32:
    for key, info in report_f32.items():
        print(f"  {key}: {info['status']} - {info['issue']}")
else:
    print("  All checks passed. No issues detected.")

## 7. Energy Conservation vs Fourier Order

Higher Fourier orders create larger matrices that are harder to solve
stably. This plot shows how energy conservation error grows with order
when using float32 vs mixed precision.

In [None]:
orders = [5, 7, 9, 11, 13, 15]
energy_err_f32 = []
energy_err_mixed = []

test_config = copy.deepcopy(base_config)
test_config["source"] = {
    "wavelength": {"mode": "single", "value": 0.45},  # Blue: harder for stability
    "polarization": "unpolarized",
}

for N in orders:
    # Float32
    cfg = copy.deepcopy(test_config)
    cfg["solver"] = {
        "name": "torcwa", "type": "rcwa",
        "params": {"fourier_order": [N, N], "dtype": "complex64"},
        "stability": {"precision_strategy": "float32", "allow_tf32": False,
                       "fourier_factorization": "li_inverse"},
    }
    r = SingleRunner.run(cfg)
    if r.reflection is not None and r.transmission is not None:
        total = float(r.reflection) + float(r.transmission)
        if r.absorption is not None:
            total += float(r.absorption)
        energy_err_f32.append(abs(total - 1.0))
    else:
        energy_err_f32.append(np.nan)

    # Mixed precision
    cfg["solver"]["stability"]["precision_strategy"] = "mixed"
    r = SingleRunner.run(cfg)
    if r.reflection is not None and r.transmission is not None:
        total = float(r.reflection) + float(r.transmission)
        if r.absorption is not None:
            total += float(r.absorption)
        energy_err_mixed.append(abs(total - 1.0))
    else:
        energy_err_mixed.append(np.nan)

    print(f"Order [{N},{N}]: f32 err={energy_err_f32[-1]:.6f}, "
          f"mixed err={energy_err_mixed[-1]:.6f}")

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

ax.semilogy(orders, energy_err_f32, "o-", label="Float32",
            color="tab:red", linewidth=2, markersize=8)
ax.semilogy(orders, energy_err_mixed, "s-", label="Mixed precision",
            color="tab:blue", linewidth=2, markersize=8)

ax.axhline(0.01, color="gray", linestyle="--", alpha=0.5, label="1% threshold")
ax.axhline(0.02, color="gray", linestyle=":", alpha=0.5, label="2% threshold")

ax.set_xlabel("Fourier Order N (matrix size = (2N+1)^2)")
ax.set_ylabel("|R + T + A - 1| (energy error)")
ax.set_title("Energy Conservation Error vs Fourier Order")
ax.legend()
ax.grid(True, alpha=0.3, which="both")

# Add matrix size on secondary x-axis
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks(orders)
ax2.set_xticklabels([(2*N+1)**2 for N in orders])
ax2.set_xlabel("Matrix size (modes)")

fig.tight_layout()

## Summary

This notebook demonstrated the COMPASS numerical stability toolkit:

1. **Float32 vs float64**: Pure float32 is faster but can produce QE errors
   of several percent. Float64 is the ground truth but 2x slower.

2. **Mixed precision eigendecomp**: Promotes only the eigenvalue problem to
   float64, achieving float64-level accuracy at near-float32 speed.
   This is the recommended default (`precision_strategy: "mixed"`).

3. **Adaptive fallback**: `AdaptivePrecisionRunner` automatically escalates
   precision (f32 -> f64 -> CPU-f64) when energy conservation fails.

4. **StabilityDiagnostics**: Pre-simulation checks catch risky configurations
   (large matrices with low precision, TF32 enabled, naive factorization).
   Post-simulation checks verify QE range and energy conservation.

**Best practices:**
- Always use `precision_strategy: "mixed"` (default)
- Always set `allow_tf32: false` for RCWA
- Use `fourier_factorization: "li_inverse"` for structures with metals
- Run `StabilityDiagnostics.pre_simulation_check()` before production sweeps
- Validate energy conservation on every result