# Emulate full spectrum using PCA

GP emulation of 6S for a complete spectrum using PCA.

**Author:** Brian Schubert &lt;<schubert.b@northeastern.edu>&gt;

**Date:** 28 August 2023


In [None]:
import copy
import itertools
import math
import pathlib
import random
from typing import Final

import alive_progress
import matplotlib.pyplot as plt
import numpy as np
import rtm_wrapper.parameters as rtm_param
import scipy.stats.qmc as sci_qmc
import sklearn.base as skl_base
import sklearn.decomposition as skl_decomp
import sklearn.gaussian_process as skl_gp
import sklearn.pipeline
import sklearn.preprocessing as skl_pre
import xarray as xr 
from rtm_wrapper.engines.sixs import PySixSEngine, pysixs_default_inputs
from rtm_wrapper.execution import ConcurrentExecutor
from rtm_wrapper.simulation import SweepSimulation

from scratch_emulator import sweep_hash, unit2range, brute_maximin

MAX_SWEEP_WORKERS: Final = 24

## Set wavelength and input parameter ranges

In [None]:
# Fixed spectrum to simulate.
# Note: even though 6S's lower bound is 0.2 um, it generates a handful of nan values
# below 0.25 um.
WAVELENGTHS: Final = np.arange(0.25, 4, 0.0025)  # micrometers

# Atmosphere parameter ranges to simulate.
OZONE_RANGE: Final = (0.25, 0.45)  # cm-atm
WATER_RANGE: Final = (1, 4)  # g/cm^2
AOT_RANGE: Final = (0.05, 0.5)  # 1
ZENITH_COSINE_RANGE: Final = (0.5, 1)  # 1

INPUT_RANGES: Final = {
    "atmosphere.ozone": OZONE_RANGE,
    "atmosphere.water": WATER_RANGE,
    "aerosol_profile.aot": AOT_RANGE,
    "geometry.solar_zenith.cosine": ZENITH_COSINE_RANGE,
}

# Model output to emulate.
target_output: Final = "total_transmission"

## Define base 6S inputs

In [None]:
base_inputs = pysixs_default_inputs().replace(
    atmosphere=rtm_param.AtmosphereWaterOzone(),
    aerosol_profile=rtm_param.AerosolAOTSingleLayer(profile="Maritime", height=100),
    geometry__solar_zenith=rtm_param.AngleCosineParameter(),
)


def param_rich_name(param_name: str) -> str:
    meta = base_inputs.get_metadata(param_name)
    return f"{meta.get('title', param_name)} (${meta.get('unit', '?')}$)"

# Run true 6S simulation

## Sample atmosphere input ranges

In [None]:
# Number of LHS samples to draw.
NUM_SAMPLES: Final = 100

# Draw LHS samples.
rng = np.random.default_rng(2023_09_01)
lhs_sampler = sci_qmc.LatinHypercube(d=len(INPUT_RANGES), seed=rng)
raw_samples = lhs_sampler.random(NUM_SAMPLES)

# raw_samples = brute_maximin(
#         NUM_SAMPLES, len(INPUT_RANGES), iterations=10_000, pick="min", metric="euclidean", rng=rng
# )  # metric=lambda u,v: np.min(np.abs(u-v))

# Rescale LHS samples to parameter ranges.
input_samples = {
    input_name: unit2range(raw_samples[:, sample_column], *input_range)
    for sample_column, (input_name, input_range) in enumerate(INPUT_RANGES.items())
}

## Plot atmosphere input samples

In [None]:
param_combos = list(itertools.combinations(INPUT_RANGES.keys(), r=2))
ncols = math.floor(math.sqrt(len(param_combos)))
nrows = math.ceil(len(param_combos) / ncols)

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 8))

for ax, (param_x, param_y) in zip(axs.flat, param_combos):
    ax.scatter(input_samples[param_x], input_samples[param_y], s=15)
    ax.set_xlim(INPUT_RANGES[param_x])
    ax.set_ylim(INPUT_RANGES[param_y])
    ax.set_xlabel(param_rich_name(param_x))
    ax.set_ylabel(param_rich_name(param_y))

fig.suptitle("Atmosphere Input LHS Samples")
fig.tight_layout()

## Perform simulation

In [None]:
train_sweep = SweepSimulation(
    {
        "lhs": input_samples,
        "wavelength.value": WAVELENGTHS,
    },
    base=base_inputs,
)

train_sweep_path = pathlib.Path(f"sweep_{sweep_hash(train_sweep)[:10]}.nc")

if train_sweep_path.exists():
    print(f"Loading sweep results from '{train_sweep_path}'")
    train_results = xr.load_dataset(train_sweep_path)
else:
    engine = PySixSEngine()
    runner = ConcurrentExecutor(max_workers=MAX_SWEEP_WORKERS)

    with alive_progress.alive_bar(train_sweep.sweep_size, force_tty=True) as bar:
        runner.run(train_sweep, engine, step_callback=lambda _: bar())

    train_results = runner.collect_results()
    train_results.to_netcdf(train_sweep_path)
    print(f"Saved sweep results to '{train_sweep_path}'")

train_output = train_results.data_vars[target_output]
display(train_results)

## Asses performance vs number of PCs

In [None]:
pca_pipe = sklearn.pipeline.Pipeline(
    [
        ("scale", skl_pre.StandardScaler(with_std=False)),
        # white=True - scale down PC components by singular values so that the output features are isotropic.
        ("pca", skl_decomp.PCA(n_components=None, whiten=True)),
    ]
)


rmse_vs_comps = []

test_components = list(range(1, min(len(WAVELENGTHS) // 4, NUM_SAMPLES, 20)))
for num_components in test_components:
    pca_pipe.set_params(pca__n_components=num_components)
    pca_pipe.fit(train_output)

    proj = pca_pipe.transform(train_output)
    round_trip = pca_pipe.inverse_transform(proj)
    err = train_output - round_trip

    rmse_vs_comps.append(np.sqrt(np.mean(err**2)))

fig, axs = plt.subplots(ncols=2, figsize=(10, 6))

ax = axs[0]
ax.semilogy(test_components, rmse_vs_comps, "x-")
ax.set_ylabel("RMSE")
ax.set_xlabel("Number of components")

ax = axs[1]
ax.plot(
    test_components,
    np.cumsum(pca_pipe.named_steps["pca"].explained_variance_ratio_),
    "x-",
)
ax.set_ylabel("Cumulative fraction of explained variance")
ax.set_xlabel("Number of components")
ax.set_ylim(0, 1)

fig.suptitle("PCA Performance vs Number of Components")

## Fix number of PCs

In [None]:
NUM_TRAIN_COMPONENTS: Final = 8
pca_pipe.set_params(pca__n_components=NUM_TRAIN_COMPONENTS)
pca_pipe.fit(train_output)

print(pca_pipe.named_steps["pca"].singular_values_)
print(pca_pipe.named_steps["pca"].explained_variance_)

for wavelength, pc_contrib in zip(
    WAVELENGTHS, pca_pipe.named_steps["pca"].components_.T
):
    print(f" {wavelength*1e3:6.1f}nm: {' '.join(f'{c:7.4f}' for c in pc_contrib)}")

# Train Emulator

## Extract training arrays

In [None]:
# Shape: (examples) x (features)
x_train = np.stack(
    [train_output.coords[parameter].values for parameter in INPUT_RANGES.keys()],
    axis=-1,
)

# Shape: (examples) x (pcs)
y_train = pca_pipe.transform(train_output)
print(f"{x_train.shape=}, {y_train.shape=}")

## Create GP model

In [None]:
kernel = 1.0 * skl_gp.kernels.RBF()  # + sklearn_gp.kernels.WhiteKernel()
gaussian_process = skl_gp.GaussianProcessRegressor(
    kernel=kernel,
    n_restarts_optimizer=20,
    alpha=1e-1,
    # alpha=1,
    # Normalize targets to zero means, unit variance.
    normalize_y=True,
)

pipeline = sklearn.pipeline.Pipeline(
    [
        # Rescale input features to [0, 1].
        # ("scale", sklearn_pre.MinMaxScaler()),
        # Rescale to zero mean.
        # ("normalize", skl_pre.StandardScaler(with_std=False)),
        ("gp", gaussian_process),
    ]
)
display(pipeline)
display(pipeline.named_steps["gp"].kernel.hyperparameters)

pc_models = [skl_base.clone(pipeline) for _ in range(NUM_TRAIN_COMPONENTS)]

## Fit model

In [None]:
for pc_idx, model in enumerate(pc_models):
    model.fit(x_train, y_train[:, pc_idx])
    print(f"PC {pc_idx}: {model.named_steps['gp'].kernel_}")

## Plot marginal likelihood surface

In [None]:
for pc_idx, model in enumerate(pc_models):
    fig = plt.figure(figsize=(10, 5), layout="constrained")
    # Extract fit hyperparameter values.
    gp = model.named_steps["gp"]
    fit_theta = gp.kernel_.theta

    # Indices of the two kernel hyperparameters to vary and plot MLL over.
    plot_hyper_idx = [0, 1]
    plot_hyper_names = [
        gaussian_process.kernel.hyperparameters[idx].name for idx in plot_hyper_idx
    ]

    # Hyperparameter ranges to compute marginal likelihood over.
    # Natural log scaled, and centered about fit hyperparameter values found above.
    log_sweep_0 = np.log(10) * np.linspace(-5, 5, 60) + fit_theta[plot_hyper_idx[0]]
    log_sweep_1 = np.log(10) * np.linspace(-5, 5, 60) + fit_theta[plot_hyper_idx[1]]

    mesh_hyper_0, mesh_hyper_1 = np.meshgrid(log_sweep_0, log_sweep_1)
    # Preallocate array for likelihood at each hyperparameter combination.
    log_marginal_likelihoods = np.zeros(mesh_hyper_0.shape)

    # Compute MLL for each hyperparameter combination.
    for hyper_0, hyper_1, out in np.nditer(
        [mesh_hyper_0, mesh_hyper_1, log_marginal_likelihoods],
        op_flags=[["readonly"], ["readonly"], ["writeonly"]],
    ):
        theta = fit_theta.copy()
        theta[plot_hyper_idx[0]] = hyper_0
        theta[plot_hyper_idx[1]] = hyper_1
        out[...] = gp.log_marginal_likelihood(theta)

    # Plot MLL contours.
    ax = fig.add_subplot(1, 2, 1)
    ax.set_xscale("log")
    ax.set_yscale("log")
    # Pick contour levels. Increase level density near max to better show peaks.
    peak_switch = np.percentile(log_marginal_likelihoods, 85)
    levels = np.hstack(
        (
            np.linspace(log_marginal_likelihoods.min(), peak_switch, 40)[:-1],
            np.linspace(peak_switch, log_marginal_likelihoods.max(), 5),
        )
    )
    # levels = 30
    art = ax.contour(
        np.exp(mesh_hyper_0), np.exp(mesh_hyper_1), log_marginal_likelihoods, levels
    )
    ax.plot(*np.exp(fit_theta), "x")
    ax.set_xlabel("magnitude scale")
    ax.set_ylabel("length scale")

    # Plot 3D MLL surface.
    ax = fig.add_subplot(1, 2, 2, projection="3d")
    ax.computed_zorder = False  # Prevent surface from hiding point, https://stackoverflow.com/q/51241367/11082165
    ax.view_init(elev=30, azim=-135)
    zlims = ax.get_zlim()
    ax.scatter(
        [fit_theta[0] / np.log(10)],
        [fit_theta[1] / np.log(10)],
        [gp.log_marginal_likelihood(fit_theta)],
        c="r",
        s=5,
        zorder=2,
    )
    ax.plot_surface(
        mesh_hyper_0 / np.log(10),
        mesh_hyper_1 / np.log(10),
        log_marginal_likelihoods,
        # cmap="coolwarm",
        zorder=1,
    )
    ax.contour(
        mesh_hyper_0 / np.log(10),
        mesh_hyper_1 / np.log(10),
        log_marginal_likelihoods,
        levels=levels,
        zorder=3,
    )

    ax.set_xlabel(f"log10(magnitude scale)")
    ax.set_ylabel(f"log10(length scale)")
    # ax.set_zlabel("log mll")
    fig.suptitle(f"PC {pc_idx}: Marginal Likelihood vs Hyperparameters")

# Asses Emulator

## Generate test data

In [None]:
grid_size = 8

dense_input_test = {
    param_name: np.linspace(*param_range, grid_size)
    for param_name, param_range in INPUT_RANGES.items()
}

## Obtain actual sim results for test data

In [None]:
test_sweep = SweepSimulation(
    {
        **dense_input_test,
        "wavelength.value": WAVELENGTHS,
    },
    base=base_inputs,
)

test_sweep_path = pathlib.Path(f"sweep_{sweep_hash(test_sweep)[:10]}.nc")

if test_sweep_path.exists():
    print(f"Loading sweep results from '{test_sweep_path}'")
    test_results = xr.load_dataset(test_sweep_path)
else:
    engine = PySixSEngine()
    runner = ConcurrentExecutor(max_workers=MAX_SWEEP_WORKERS)

    with alive_progress.alive_bar(test_sweep.sweep_size, force_tty=True) as bar:
        runner.run(test_sweep, engine, step_callback=lambda _: bar())

    test_results = runner.collect_results()
    test_results.to_netcdf(test_sweep_path)
    print(f"Saved sweep results to '{test_sweep_path}'")

test_output = test_results.data_vars[target_output]
display(test_results)

## Extract test arrays

In [None]:
dense_input_meshes = np.meshgrid(
    *dense_input_test.values(),
    indexing="ij",
)
assert test_output.dims[-1] == "wavelength.value"

x_test = np.hstack([mesh.reshape(-1, 1) for mesh in dense_input_meshes])
y_test_wl = test_output.values.reshape(-1, len(WAVELENGTHS))
y_test_pc = pca_pipe.transform(y_test_wl)

print(f"{x_test.shape=}, {y_test_wl.shape=} {y_test_pc.shape=}")

## Evaluate model on test data

In [None]:
pc_pred_means = []
pc_pred_stds = []
pc_pred_errors = []
pc_y_shaped = []

grid_shape = dense_input_meshes[0].shape

for pc_index, model in enumerate(pc_models):
    mean, std = model.predict(x_test, return_std=True)
    pc_pred_means.append(mean.reshape(grid_shape))
    pc_pred_stds.append(std.reshape(grid_shape))
    pc_pred_errors.append((y_test_pc[:, pc_index] - mean).reshape(grid_shape))
    pc_y_shaped.append(y_test_pc[:, pc_index].reshape(grid_shape))

## Compute metrics on PC output feature

In [None]:
for pc_idx in range(NUM_TRAIN_COMPONENTS):
    print(f"PC {pc_idx}:")
    rmse = np.sqrt(np.mean(pc_pred_errors[pc_idx] ** 2))

    abs_error = np.abs(pc_pred_errors[pc_idx])

    print(f"  RMSE: {rmse:0.2f}")
    print(f"  Avg abs err: {np.mean(abs_error):0.2f}")
    print(f"  Max abs err: {np.max(abs_error):0.2f}")
    print(f"  Avg rel err: {np.mean(abs_error/np.abs(pc_y_shaped[pc_idx])):0.2%}")
    print(f"  Max rel err: {np.max(abs_error/np.abs(pc_y_shaped[pc_idx])):0.2%}")

## Plot posterior mean, std, error for PC output feature

In [None]:
for pc_index in range(NUM_TRAIN_COMPONENTS):
    param_idx_combos = list(itertools.combinations(range(len(INPUT_RANGES)), r=2))
    param_names = list(INPUT_RANGES.keys())
    fig, axs = plt.subplots(
        nrows=len(param_idx_combos),
        ncols=5,
        figsize=(16, 3 * len(param_idx_combos)),
        # sharex="row",
        # sharey="row",
        layout="constrained",
    )

    for ax_row, (param_x_idx, param_y_idx) in zip(axs, param_idx_combos):
        local_mesh_x, local_mesh_y = np.meshgrid(
            dense_input_test[param_names[param_x_idx]],
            dense_input_test[param_names[param_y_idx]],
            indexing="ij",
        )

        other_dims = tuple(
            i for i in range(len(INPUT_RANGES)) if i not in (param_x_idx, param_y_idx)
        )

        pred_mean_only = pc_pred_means[pc_index].mean(axis=other_dims)
        y_test_only = pc_y_shaped[pc_index].mean(axis=other_dims)
        pred_std_only = pc_pred_stds[pc_index].mean(axis=other_dims)
        pred_error_only = pc_pred_errors[pc_index].mean(axis=other_dims)

        vmin = min(pred_mean_only.min(), y_test_only.min())
        vmax = max(pred_mean_only.max(), y_test_only.max())

        # Plot predicted mean surface.
        ax = ax_row[0]
        art = ax.pcolormesh(
            local_mesh_x, local_mesh_y, pred_mean_only, vmin=vmin, vmax=vmax
        )
        ax.plot(
            input_samples[param_names[param_x_idx]],
            input_samples[param_names[param_y_idx]],
            "o",
            color="k",
            markerfacecolor="none",
        )

        cbar = fig.colorbar(art)

        # Plot true output surface.
        ax = ax_row[1]
        art = ax.pcolormesh(
            local_mesh_x, local_mesh_y, y_test_only, vmin=vmin, vmax=vmax
        )
        ax.plot(
            input_samples[param_names[param_x_idx]],
            input_samples[param_names[param_y_idx]],
            "o",
            color="k",
            markerfacecolor="none",
        )

        fig.colorbar(art)

        # Plot predicted variance surface.
        ax = ax_row[2]
        art = ax.pcolormesh(local_mesh_x, local_mesh_y, pred_std_only)
        ax.plot(
            input_samples[param_names[param_x_idx]],
            input_samples[param_names[param_y_idx]],
            "o",
            color="k",
            markerfacecolor="none",
        )

        fig.colorbar(art)

        # Plot error surface.
        ax = ax_row[3]
        art = ax.pcolormesh(local_mesh_x, local_mesh_y, pred_error_only)
        ax.plot(
            input_samples[param_names[param_x_idx]],
            input_samples[param_names[param_y_idx]],
            "o",
            color="k",
            markerfacecolor="none",
        )

        fig.colorbar(art)

        ax_row[0].set_ylabel(
            f"{param_rich_name(param_names[param_x_idx])}\nvs\n{param_rich_name(param_names[param_y_idx])}"
        )
        
        # Plot histogram of residues
        ax = ax_row[4]
        residue_norm = pred_error_only / pred_std_only
        counts, bins = np.histogram(residue_norm.flat, bins=20)
        ax.stairs(counts, bins)
        ax.set_xlim(-4, 4)
        
        

    axs[0, 0].set_title("Posterior Mean")
    axs[0, 1].set_title("True Output")
    axs[0, 2].set_title("Posterior Std")
    axs[0, 3].set_title("Error")
    axs[0, 4].set_title("Residue/Std Hist")
    fig.suptitle(f"PC {pc_index} Emulator Performance")

In [None]:
import scipy.stats as sci_stats
fig, axs = plt.subplots(
        nrows=NUM_TRAIN_COMPONENTS,
        ncols=3,
        figsize=(8.5, 3 * NUM_TRAIN_COMPONENTS),
        # sharex="row",
        # sharey="row",
        layout="constrained",
    )
r = sci_stats.norm()

num_bins = 40
    
for ax_row, pc_index in zip(axs, range(NUM_TRAIN_COMPONENTS)):
    residue_norm =  pc_pred_errors[pc_index] / pc_pred_stds[pc_index]
    
    ax = ax_row[0]
    hist_count, hist_bins = np.histogram(residue_norm, bins=num_bins, density=True)
    ax.stairs(hist_count, hist_bins)
    hist_range = np.linspace(hist_bins[0], hist_bins[-1], 1000)
    ax.plot(hist_range, r.pdf(hist_range), '--', color="tab:orange")
    
    ax = ax_row[1]
    pit_counts, pit_bins = np.histogram(r.cdf(residue_norm), bins=num_bins, density=True)
    ax.stairs(pit_counts, pit_bins, label="Actual")
    ax.axhline([1], linestyle='--', color="tab:orange", label="Expected")
    ax.legend(loc="upper center")
    
    ax = ax_row[2]
    osm, osr = sci_stats.probplot(residue_norm.flat, dist=r, fit=False)
    ax.plot(osm, osr)
    ax.plot(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100), '--', color="tab:orange")
    
    ax_row[0].set_title("Residue Histogram")
    ax_row[1].set_title(f"PC {pc_index}\nPIT")
    ax_row[2].set_title(f"Q-Q")
    ax_row[0].set_xlabel("residue / std")
    ax_row[1].set_xlabel("x")
    ax_row[2].set_xlabel("Theoretical Quantiles")
    ax_row[0].set_ylabel("Relative Frequency")
    ax_row[1].set_ylabel("Relative Frequency")
    ax_row[2].set_ylabel("Ordered Responses")
    ax_row[0].set_xlim(-5, 5)
    ax_row[1].set_xlim(0, 1)
    ax_row[2].set_xlim(-5, 5)
    ax_row[0].set_ylim(0, 0.6)
    ax_row[1].set_ylim(0, pit_counts.max()*1.2)

## Compute model output spectrum

In [None]:
 # r = sci_stats.norm()
# chisq = sci_stats.chisquare(counts, r.pdf(bins[:-1]), axis=None)

In [None]:
import copy

pred_mean_wl = pca_pipe.inverse_transform(
    np.hstack([m.reshape(-1, 1) for m in pc_pred_means])
)

std_inverse = copy.deepcopy(pca_pipe.named_steps["pca"])
std_inverse.components_ = np.abs(std_inverse.components_)

pred_std_wl = np.sqrt(
    std_inverse.inverse_transform(
        np.hstack([m.reshape(-1, 1)**2 for m in pc_pred_stds])
    )
)
pred_error_wl = y_test_wl - pred_mean_wl

pred_mean_wl_shaped = pred_mean_wl.reshape(*grid_shape, len(WAVELENGTHS))
pred_std_wl_shaped = pred_std_wl.reshape(*grid_shape, len(WAVELENGTHS))
y_test_wl_shaped = y_test_wl.reshape(*grid_shape, len(WAVELENGTHS))
pred_error_wl_shaped = pred_error_wl.reshape(*grid_shape, len(WAVELENGTHS))

## Evaluate metrics on output spectrum

In [None]:
print(f"Avg abs err {np.mean(np.abs(pred_error_wl)):0.5f}")
print(f"Max abs err {np.max(np.abs(pred_error_wl)):0.5f}")

In [None]:
# Pick wavelength to plot mean, error surfaces for.
wavelength_idx = 130

param_idx_combos = list(itertools.combinations(range(len(INPUT_RANGES)), r=2))
param_names = list(INPUT_RANGES.keys())

fig, axs = plt.subplots(
    nrows=len(param_idx_combos),
    ncols=4,
    figsize=(12, 3 * len(param_idx_combos)),
    sharex="row",
    sharey="row",
    layout="constrained",
)

for ax_row, (param_x_idx, param_y_idx) in zip(axs, param_idx_combos):
    local_mesh_x, local_mesh_y = np.meshgrid(
        dense_input_test[param_names[param_x_idx]],
        dense_input_test[param_names[param_y_idx]],
        indexing="ij",
    )

    other_dims = tuple(
        i for i in range(len(INPUT_RANGES)) if i not in (param_x_idx, param_y_idx)
    )

    pred_mean_only = pred_mean_wl_shaped[..., wavelength_idx].mean(axis=other_dims)
    pred_std_only = pred_std_wl_shaped[..., wavelength_idx].mean(axis=other_dims)
    y_test_only = y_test_wl_shaped[..., wavelength_idx].mean(axis=other_dims)
    pred_error_only = pred_error_wl_shaped[..., wavelength_idx].mean(axis=other_dims)

    vmin = min(pred_mean_only.min(), y_test_only.min())
    vmax = max(pred_mean_only.max(), y_test_only.max())

    # Plot predicted mean surface.
    ax = ax_row[0]
    art = ax.pcolormesh(
        local_mesh_x, local_mesh_y, pred_mean_only, vmin=vmin, vmax=vmax
    )
    ax.plot(
        input_samples[param_names[param_x_idx]],
        input_samples[param_names[param_y_idx]],
        "o",
        color="k",
        markerfacecolor="none",
    )

    cbar = fig.colorbar(art)

    # Plot true output surface.
    ax = ax_row[1]
    art = ax.pcolormesh(local_mesh_x, local_mesh_y, y_test_only, vmin=vmin, vmax=vmax)
    ax.plot(
        input_samples[param_names[param_x_idx]],
        input_samples[param_names[param_y_idx]],
        "o",
        color="k",
        markerfacecolor="none",
    )

    fig.colorbar(art)
    
    # Plot predicted variance surface.
    ax = ax_row[2]
    art = ax.pcolormesh(local_mesh_x, local_mesh_y, pred_std_only)
    ax.plot(
        input_samples[param_names[param_x_idx]],
        input_samples[param_names[param_y_idx]],
        "o",
        color="k",
        markerfacecolor="none",
    )

    fig.colorbar(art)

    # Plot error surface.
    ax = ax_row[3]
    art = ax.pcolormesh(
        local_mesh_x, local_mesh_y, pred_error_only
    )
    ax.plot(
        input_samples[param_names[param_x_idx]],
        input_samples[param_names[param_y_idx]],
        "o",
        color="k",
        markerfacecolor="none",
    )

    fig.colorbar(art)

    ax_row[0].set_ylabel(
        f"{param_rich_name(param_names[param_x_idx])}\nvs\n{param_rich_name(param_names[param_y_idx])}"
    )

axs[0, 0].set_title("Posterior Mean")
axs[0, 1].set_title("True Output")
axs[0, 2].set_title("Std")
axs[0, 3].set_title("Error")
fig.suptitle(f"{WAVELENGTHS[wavelength_idx]*1e3:0.1f}nm Emulator Performance")

In [None]:
rand_input = tuple(
    random.randint(1, len(input_choices) - 1)
    for input_choices in dense_input_test.values()
)

# rand_input = (rand_input[0], 0, rand_input[2])

fig, ax = plt.subplots(figsize=(12, 8))

ax.plot(
    WAVELENGTHS, y_test_wl_shaped[rand_input], label="6S", linewidth=1.5, linestyle="--"
)
ax.plot(
    WAVELENGTHS,
    pred_mean_wl_shaped[rand_input], 
    label="Emulator",
    linewidth=1.5,
    linestyle="-.",
)
pca_temp = pca_pipe.inverse_transform(pca_pipe.transform(y_test_wl_shaped[rand_input].reshape(1, -1))).reshape(-1, 1)
ax.plot(
    WAVELENGTHS,
    pca_temp.flat,
    label="PCA",
    linewidth=1.5,
    linestyle=":",
)
ax.set_ylim(0, 1)
ax.set_xlim(WAVELENGTHS[0], WAVELENGTHS[-1])
ax.set_xlabel(param_rich_name("wavelength.value"))
ax.set_ylabel(test_output.attrs.get("title", "Output"))
ax.set_title(
    f"{', '.join(f'{param}={value_choice[value_index]:0.2f}' for param, value_choice, value_index in zip(test_output.dims, dense_input_test.values(), rand_input)) }",
    fontsize=10,
)
ax.fill_between(WAVELENGTHS, np.zeros_like(WAVELENGTHS), pred_std_wl_shaped[rand_input], alpha=0.3, color="tab:blue")
ax.fill_between(
    WAVELENGTHS, 
    pred_mean_wl_shaped[rand_input] - pred_std_wl_shaped[rand_input], 
    pred_mean_wl_shaped[rand_input] + pred_std_wl_shaped[rand_input], 
    alpha=0.2, 
    color="tab:blue",
    label="$\pm 1$ std.",
)

ax.legend(loc="upper left")

ignore_idx = (y_test_wl_shaped < 0.2).any(axis=0).any(axis=0).any(axis=0).any(axis=0)
temp = pred_error_wl_shaped.copy()
# temp[rand_input + (ignore_idx,)] = float("nan")
ax2 = ax.twinx()
ax2.plot(
    WAVELENGTHS,
    temp[rand_input],
    color="black",
    linestyle=":",
    label="Emulator Error",
)
ax2.plot(
    WAVELENGTHS,
    (-pca_temp.reshape(1, -1)+ y_test_wl_shaped[rand_input]).flat,
    color="red",
    linestyle="--",
    label="PCA Error",
)
ax2.set_ylabel("Error")
max_error = np.max(np.abs(pred_error_wl_shaped))
ax2.set_ylim(-2 * max_error, 2 * max_error)
ax2.legend(loc="upper right")

fig.suptitle("True 6S Output vs Emulator Output");

In [None]:
fig, axs = plt.subplots(nrows=2, figsize=(16,16), sharex="col", layout="constrained")

plot_components = 15 # NUM_TRAIN_COMPONENTS -5

for ax in axs:
    ax.set_prop_cycle(
        color=plt.cm.viridis(np.linspace(0, 1, plot_components))
    )
    ax.set_xlim(WAVELENGTHS[0], WAVELENGTHS[-1]*1.1)

    ax.set_xlabel(param_rich_name("wavelength.value"))
    ax.set_ylabel(f'Error in {test_output.attrs.get("title", "Output")}')
    ax.axhline(0, color="k")
                  



for n_components in range(1, plot_components + 1):
    remove_components = NUM_TRAIN_COMPONENTS - n_components
    x = np.hstack([m.reshape(-1, 1) for m in pc_pred_means])
    x[:, -remove_components:] = 0
    
    pred_mean_wl_part = pca_pipe.inverse_transform(x)
    pred_error_wl_part = y_test_wl - pred_mean_wl_part
    
    axs[0].plot(WAVELENGTHS, np.abs(pred_error_wl_part).mean(axis=0), label=f"#PC={n_components}", linewidth=0.5)
    axs[1].plot(WAVELENGTHS, np.abs(pred_error_wl_part).max(axis=0), label=f"#PC={n_components}", linewidth=0.5)


axs[0].legend(ncols=2)
axs[0].set_title("Mean Abs. Error")
axs[1].set_title("Max Abs. Error")
fig.suptitle("Posterior Error for Varying #PCs")

# Wavelength-space Residues



In [None]:
import scipy.stats as sci_stats

# wavelength_idx = random.randint(0, len(WAVELENGTHS)) # 370 345 934 577

display_wavelengths = sorted(np.random.choice(len(WAVELENGTHS), 10))

for wavelength_idx in display_wavelengths:
    fig, axs = plt.subplots(
            nrows=1,
            ncols=3,
            figsize=(12, 4),
            # sharex="row",
            # sharey="row",
            layout="constrained",
        )
    
    pc_inverse_map = np.sqrt(pca_pipe.named_steps["pca"].explained_variance_[:, np.newaxis]) * pca_pipe.named_steps["pca"].components_
    
    wavelength_pc_weights = pc_inverse_map[:, wavelength_idx]
    
    expected_var = (wavelength_pc_weights**2).sum()/wavelength_pc_weights.sum()
    # scale=std
    r = sci_stats.norm(scale=expected_var**0.5) 
    
    num_bins = 40
        
    residue_norm =  pred_error_wl_shaped / pred_std_wl_shaped
    
    residue_norm_single = residue_norm[..., wavelength_idx]
    residue_norm_single = residue_norm_single[~np.isnan(residue_norm_single)]
    print(residue_norm_single.shape)
    
    ax = axs[0]
    hist_count, hist_bins = np.histogram(residue_norm_single, bins=num_bins, density=True)
    ax.stairs(hist_count, hist_bins)
    hist_range = np.linspace(hist_bins[0], hist_bins[-1], 1000)
    ax.plot(hist_range, r.pdf(hist_range), '--', color="tab:orange")
    
    ax = axs[1]
    pit_counts, pit_bins = np.histogram(r.cdf(residue_norm_single), bins=num_bins, density=True)
    ax.stairs(pit_counts, pit_bins, label="Actual")
    ax.axhline([1], linestyle='--', color="tab:orange", label="Expected")
    ax.legend(loc="upper center")
    
    ax = axs[2]
    osm, osr = sci_stats.probplot(residue_norm_single.flat, dist=r, fit=False)
    ax.plot(osm, osr)
    ax.plot(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100), '--', color="tab:orange")
    
    axs[0].set_title("Residue Histogram")
    axs[2].set_title(f"Q-Q")
    axs[1].set_title(f"{WAVELENGTHS[wavelength_idx]*1e3:.1f}nm ({wavelength_idx})\nPIT")
    axs[0].set_xlabel("residue / std")
    axs[1].set_xlabel("x")
    axs[2].set_xlabel("Theoretical Quantiles")
    axs[0].set_ylabel("Relative Frequency")
    axs[1].set_ylabel("Relative Frequency")
    axs[2].set_ylabel("Ordered Responses")
    axs[0].set_xlim(-2, 2)
    axs[1].set_xlim(0, 1)
    axs[2].set_xlim(-2, 2)
    axs[0].set_ylim(0, 3.5)
    axs[1].set_ylim(0, pit_counts.max()*1.2);

In [None]:
(pca_pipe.named_steps["pca"].components_**2).sum(axis=1)

In [None]:
pca_pipe.named_steps["pca"].components_[:, wavelength_idx]

In [None]:
np.sqrt(pca_pipe.named_steps["pca"].explained_variance_[:, np.newaxis]) * pca_pipe.named_steps["pca"].components_

In [None]:
wavelength_idx

In [None]:
pred_std_wl_shaped[..., wavelength_idx]