# Solver Comparison

This notebook demonstrates how to run the same pixel structure through multiple EM solvers
and compare their QE predictions.

COMPASS supports the following solvers:
- **RCWA**: torcwa, grcwa, meent
- **FDTD**: flaport

All solvers consume the same `PixelStack` and return a `SimulationResult`,
making cross-solver comparison straightforward.

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from compass.analysis.energy_balance import check_energy_balance
from compass.analysis.qe_calculator import compute_qe
from compass.analysis.solver_comparison import compare_results
from compass.core.config_schema import load_config
from compass.geometry.builder import GeometryBuilder
from compass.materials.database import MaterialDB
from compass.solvers.base import SolverFactory
from compass.sources.planewave import PlaneWave

## 1. Setup: Build Pixel Stack

We use a common pixel configuration for all solvers.

In [None]:
config = load_config(Path("../configs/pixel/default_bsi_1um.yaml"))
mat_db = MaterialDB()

builder = GeometryBuilder(config.pixel, mat_db)
pixel_stack = builder.build()

print(f"Pixel: {config.pixel.pitch_x} x {config.pixel.pitch_y} um")
print(f"Stack: {pixel_stack.total_thickness:.2f} um, {len(pixel_stack.layer_slices)} layers")

## 2. Create Solver Instances

The `SolverFactory` creates solver instances by name. Each solver has
its own configuration (harmonic orders, grid resolution, etc.) but they
all implement the `SolverBase` interface.

In [None]:
solver_names = ["torcwa", "grcwa", "meent"]

solvers = {}
for name in solver_names:
    solvers[name] = SolverFactory.create(name, config.solver)
    print(f"Created solver: {name}")
    print(f"  Backend: {solvers[name].backend}")
    print(f"  Harmonics: {solvers[name].num_harmonics}")

## 3. Wavelength Sweep

Run each solver across a range of wavelengths to compute spectral QE.
Normal incidence, unpolarized light.

In [None]:
wavelengths = np.arange(0.400, 0.701, 0.010)  # 400-700 nm in um

# Store results: solver_name -> {wavelength -> SimulationResult}
all_results = {}

for solver_name, solver in solvers.items():
    print(f"Running {solver_name}...")
    results = {}
    for wl in wavelengths:
        source = PlaneWave(wavelength=wl, theta=0.0, phi=0.0, polarization="unpolarized")
        results[wl] = solver.solve(pixel_stack, source)
    all_results[solver_name] = results
    print(f"  Completed {len(wavelengths)} wavelengths")

## 4. Energy Balance Check

Before comparing QE, verify that each solver conserves energy (R + T + A = 1)
within the expected tolerance (< 1%).

In [None]:
for solver_name, results in all_results.items():
    max_error = 0.0
    for _wl, result in results.items():
        energy_error = check_energy_balance(result)
        max_error = max(max_error, abs(energy_error))
    status = "PASS" if max_error < 0.01 else "FAIL"
    print(f"{solver_name}: max energy error = {max_error:.6f} [{status}]")

## 5. Compare QE Spectra

Extract QE per pixel for each solver and plot the comparison.
For a single-pixel unit cell, we compare the total QE across solvers.

In [None]:
# Compute QE spectra for pixel (0,0)
qe_spectra = {}

for solver_name, results in all_results.items():
    qe_values = []
    for wl in wavelengths:
        qe = compute_qe(results[wl], pixel_stack)
        qe_values.append(qe[0, 0])  # pixel (0,0)
    qe_spectra[solver_name] = np.array(qe_values)

In [None]:
# Plot QE comparison
fig, axes = plt.subplots(2, 1, figsize=(10, 8), gridspec_kw={"height_ratios": [3, 1]})

# Top: QE spectra
colors = {"torcwa": "tab:blue", "grcwa": "tab:orange", "meent": "tab:green"}
for solver_name, qe in qe_spectra.items():
    axes[0].plot(
        wavelengths * 1000,  # convert um to nm for display
        qe,
        label=solver_name,
        color=colors[solver_name],
        linewidth=2,
    )

axes[0].set_ylabel("Quantum Efficiency")
axes[0].set_title("Solver Comparison: QE Spectrum (normal incidence, unpolarized)")
axes[0].legend()
axes[0].set_ylim(0, 1)
axes[0].grid(True, alpha=0.3)

# Bottom: pairwise difference relative to torcwa
reference = qe_spectra["torcwa"]
for solver_name, qe in qe_spectra.items():
    if solver_name == "torcwa":
        continue
    diff = qe - reference
    axes[1].plot(
        wavelengths * 1000,
        diff,
        label=f"{solver_name} - torcwa",
        color=colors[solver_name],
        linewidth=2,
    )

axes[1].set_xlabel("Wavelength (nm)")
axes[1].set_ylabel("QE Difference")
axes[1].axhline(0, color="gray", linestyle="--", linewidth=0.5)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

fig.tight_layout()

## 6. Reflectance, Transmittance, Absorption Comparison

Compare the full R/T/A breakdown across solvers.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

quantities = [
    ("Reflectance", lambda r: r.reflectance),
    ("Transmittance", lambda r: r.transmittance),
    ("Absorption", lambda r: r.absorption),
]

for ax, (qty_name, getter) in zip(axes, quantities):
    for solver_name, results in all_results.items():
        values = [getter(results[wl]) for wl in wavelengths]
        ax.plot(wavelengths * 1000, values, label=solver_name,
                color=colors[solver_name], linewidth=2)
    ax.set_xlabel("Wavelength (nm)")
    ax.set_ylabel(qty_name)
    ax.set_title(qty_name)
    ax.legend()
    ax.grid(True, alpha=0.3)

fig.suptitle("R / T / A Comparison Across Solvers", fontsize=14)
fig.tight_layout()

## 7. Statistical Comparison

Use the built-in `compare_results` utility to compute agreement metrics
between solver pairs.

In [None]:
# Compare all solver pairs
comparison = compare_results(all_results, wavelengths, pixel_stack)

print("Pairwise QE comparison (pixel 0,0):")
print(f"{'Pair':<25} {'Max Abs Diff':>15} {'RMS Diff':>15} {'R-squared':>12}")
print("-" * 70)

for pair_name, metrics in comparison.items():
    print(f"{pair_name:<25} {metrics['max_abs_diff']:>15.6f} "
          f"{metrics['rms_diff']:>15.6f} {metrics['r_squared']:>12.8f}")

## 8. Convergence Study: Harmonic Orders

For RCWA solvers, the number of harmonic orders (Fourier terms) controls
the accuracy-performance tradeoff. We sweep harmonics to find convergence.

In [None]:
from compass.core.config_schema import override_config

harmonic_orders = [3, 5, 7, 9, 11, 15]
test_wavelength = 0.550  # um
source = PlaneWave(wavelength=test_wavelength, theta=0.0, phi=0.0, polarization="unpolarized")

convergence = {name: [] for name in ["torcwa", "grcwa", "meent"]}

for n_harm in harmonic_orders:
    cfg_override = override_config(config, {"solver.num_harmonics": n_harm})
    for solver_name in convergence:
        solver = SolverFactory.create(solver_name, cfg_override.solver)
        result = solver.solve(pixel_stack, source)
        qe = compute_qe(result, pixel_stack)
        convergence[solver_name].append(qe[0, 0])
        print(f"{solver_name} N={n_harm}: QE={qe[0, 0]:.6f}")

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

for solver_name, qe_values in convergence.items():
    ax.plot(harmonic_orders, qe_values, "o-", label=solver_name,
            color=colors[solver_name], linewidth=2, markersize=6)

ax.set_xlabel("Number of Harmonic Orders")
ax.set_ylabel("QE at 550 nm")
ax.set_title("RCWA Convergence: QE vs Harmonic Orders")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()

## Summary

This notebook demonstrated the COMPASS multi-solver comparison workflow:

1. **Same pixel stack** fed to multiple solvers via `SolverFactory`
2. **Wavelength sweep** with `PlaneWave` source for spectral QE
3. **Energy balance** verification (R + T + A = 1 within 1%)
4. **QE comparison** plots with pairwise difference analysis
5. **R/T/A breakdown** per solver
6. **Statistical metrics** via `compare_results()`
7. **Convergence study** for RCWA harmonic orders

For automated batch comparisons, see `scripts/compare_solvers.py` which uses
the `ComparisonRunner` to orchestrate multi-solver sweeps from a single config.