# Overcast Evaluation

In [None]:
import json

import numpy as np

from scipy import stats

from pathlib import Path

from overcast import models
from overcast import datasets
from overcast.models import ensembles
from overcast.visualization import plotting

from sklearn.preprocessing import MinMaxScaler

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
TARGET_KEYS = {
    "Nd": r"$N_d$", 
    "re": r"$r_e$", 
    "COD": r"$\tau$",
    "CWP": "CWP", 
    "LPC": r"$CF$",
}

In [None]:
project_dir = Path("MR-MLforACI")
data_dir = project_dir / "data"
output_dir = project_dir / "output"

In [None]:
class Experiment:
    def __init__(self, experiment_path):
        self.experiment_dir = Path(experiment_path)
        self.transformer = True if "daily" in experiment_path else False
        config_path = self.experiment_dir / "config.json"
        self.checkpoint_dir = self.experiment_dir / "checkpoints"
        self.ensemble_dir = self.experiment_dir

        with open(config_path) as cp:
            config = json.load(cp)

        config["ds_test"]["data_dir"] = data_dir
        config["ds_valid"]["data_dir"] = data_dir
        config["ds_train"]["data_dir"] = data_dir

        self.dataset_name = config.get("dataset_name")
        self.num_components_outcome = config.get("num_components_outcome")
        self.num_components_treatment = config.get("num_components_treatment")
        self.dim_hidden = config.get("dim_hidden")
        self.depth = config.get("depth")
        self.negative_slope = config.get("negative_slope")
        self.beta = config.get("beta")
        self.layer_norm = config.get("layer_norm")
        self.dropout_rate = config.get("dropout_rate")
        self.spectral_norm = config.get("spectral_norm")
        self.learning_rate = config.get("learning_rate")
        self.batch_size = config.get("batch_size")
        self.epochs = config.get("epochs")
        self.ensemble_size = config.get("ensemble_size")
        self.num_heads = config.get("num_heads") if self.transformer is True else None

        self.ds = {
            "test": datasets.DATASETS.get(self.dataset_name)(**config.get("ds_test")),
            "valid": datasets.DATASETS.get(self.dataset_name)(**config.get("ds_valid")),
            "train": datasets.DATASETS.get(self.dataset_name)(**config.get("ds_train")),
        }

        self.target_keys = dict(
            (k, v) for (k, v) in enumerate(self.ds["test"].target_names)
        )

        ensemble = self.load_ensemble()
        treatments = self.load_treatments()
        outcomes = self.load_outcomes()

        self.ensemble = ensemble
        self.treatments = treatments
        self.outcomes = outcomes

        apos_ensemble = self.load_apos_ensemble()
        self.apos_ensemble = apos_ensemble

        means_ensemble = self.get_means_ensemble()
        self.means_ensemble = means_ensemble

        self.apo_limits = {}

    def load_ensemble(self):
        if self.transformer:
            return self.load_transformer_ensemble()
        else:
            return self.load_nn_ensemble()

    def load_transformer_ensemble(self):
        ensemble = []
        for ensemble_id in range(self.ensemble_size):
            model_dir = self.checkpoint_dir / f"model-{ensemble_id}" / "mu"
            model = models.AppendedTreatmentAttentionNetwork(
                job_dir=model_dir,
                dim_input=self.ds["train"].dim_input,
                dim_treatment=self.ds["train"].dim_treatments,
                dim_output=self.ds["train"].dim_targets,
                num_components_outcome=self.num_components_outcome,
                num_components_treatment=self.num_components_treatment,
                dim_hidden=self.dim_hidden,
                depth=self.depth,
                num_heads=self.num_heads,
                negative_slope=self.negative_slope,
                beta=self.beta,
                layer_norm=self.layer_norm,
                spectral_norm=self.spectral_norm,
                dropout_rate=self.dropout_rate,
                num_examples=len(self.ds["train"]),
                learning_rate=self.learning_rate,
                batch_size=self.batch_size,
                epochs=self.epochs,
                patience=50,
                num_workers=0,
                seed=ensemble_id,
            )
            model.load()
            ensemble.append(model)
        return ensemble

    def load_nn_ensemble(self):
        ensemble = []
        for ensemble_id in range(self.ensemble_size):
            model_dir = self.checkpoint_dir / f"model-{ensemble_id}" / "mu"
            model = models.AppendedTreatmentNeuralNetwork(
                job_dir=model_dir,
                architecture="resnet",
                dim_input=self.ds["train"].dim_input,
                dim_treatment=self.ds["train"].dim_treatments,
                dim_output=self.ds["train"].dim_targets,
                num_components_outcome=self.num_components_outcome,
                num_components_treatment=self.num_components_treatment,
                dim_hidden=self.dim_hidden,
                depth=self.depth,
                negative_slope=self.negative_slope,
                beta=self.beta,
                layer_norm=self.layer_norm,
                spectral_norm=self.spectral_norm,
                dropout_rate=self.dropout_rate,
                num_examples=len(self.ds["train"]),
                learning_rate=self.learning_rate,
                batch_size=self.batch_size,
                epochs=self.epochs,
                patience=self.epochs,
                num_workers=0,
                seed=ensemble_id,
            )
            model.load()
            ensemble.append(model)
        return ensemble

    def load_treatments(self):
        if self.transformer:
            treatments = np.concatenate(self.ds["train"].treatments, axis=0)
            treatments = self.ds["train"].treatments_xfm.inverse_transform(treatments)
            treatments = np.quantile(treatments, q=np.arange(0, 1 + 1 / 32, 1 / 32),)[:-1]
        else:
            treatments = np.quantile(
                self.ds["train"].treatments_xfm.inverse_transform(
                    self.ds["train"].treatments
                ),
                q=np.arange(0, 1 + 1 / 32, 1 / 32),
            )[:-1]        
        return treatments

    def load_outcomes(self):
        if self.transformer:
            df_test = self.ds["test"].data_frame
            observed_outcomes = df_test.to_numpy()[:, -4:]
        else: 
            observed_outcomes = self.ds["test"].targets_xfm.inverse_transform(self.ds["test"].targets)
        return observed_outcomes
    
    def load_apos_ensemble(self, from_scratch=False):
        apos_ensemble_path = self.ensemble_dir / "apos_ensemble.npy"
        if not apos_ensemble_path.exists() or from_scratch:
            capos_ensemble = ensembles.predict_capos(
                ensemble=self.ensemble,
                dataset=self.ds["test"],
                treatments=self.treatments,
                batch_size=1 if self.transformer else 20000,
            )
            apos_ensemble = capos_ensemble.mean(2)
            np.save(apos_ensemble_path, apos_ensemble)
        else:
            apos_ensemble = np.load(apos_ensemble_path)
        return apos_ensemble

    def get_apo_limits(self, log_lambda, from_scratch=False):
        apo_limits_path = self.ensemble_dir / f"apo_limits_{log_lambda}.npy"
        if (not apo_limits_path.exists()) or from_scratch:
            lower_capos, upper_capos = ensembles.predict_intervals(
                ensemble=self.ensemble,
                dataset=self.ds["test"],
                treatments=self.treatments,
                log_lambda=log_lambda,
                num_samples=100,
                batch_size=1 if self.transformer else 10000,
            )
            lower_apos = np.expand_dims(lower_capos.mean(2), 0)
            upper_apos = np.expand_dims(upper_capos.mean(2), 0)
            apo_limits = np.concatenate([lower_apos, upper_apos], axis=0)
            np.save(apo_limits_path, apo_limits)
        else:
            apo_limits = np.load(apo_limits_path)
        self.apo_limits[log_lambda] = apo_limits

    def get_means_ensemble(self): 
        return ensembles.predict_mean(self.ensemble, self.ds["test"], batch_size=None)

## Load Experiments

In [None]:
tr_lrp = Experiment(
    f"{output_dir}/jasmin-daily-four_outputs_liqcf_pacific_treatment-AOD_covariates-RH900-RH850-RH700-LTS-EIS-W500-SST_outcomes-re-COD-CWP-LPC_bins-1/appended-treatment-transformer/dh-128_nco-22_nct-27_dp-3_nh-8_ns-0.28_bt-0.0_ln-False_dr-0.42_sn-0.0_lr-0.0001_bs-128_ep-500",
)
nn_lrp = Experiment(
    f"{output_dir}/nice/jasmin-four_outputs_liqcf_pacific_treatment-AOD_covariates-RH900-RH850-RH700-LTS-EIS-W500-SST_outcomes-re-COD-CWP-LPC_bins-1/appended-treatment-nn/dh-256_nco-5_nct-2_dp-2_ns-0.1_bt-0.0_ln-False_dr-0.09_sn-0.0_lr-0.0002_bs-224_ep-9", 
)
tr_lra = Experiment(
    f"{output_dir}/jasmin-daily-four_outputs_liqcf_atlantic_treatment-AOD_covariates-RH900-RH850-RH700-LTS-EIS-W500-SST_outcomes-re-COD-CWP-LPC_bins-1/appended-treatment-transformer/dh-128_nco-24_nct-7_dp-4_nh-8_ns-0.19_bt-0.0_ln-True_dr-0.16_sn-0.0_lr-0.0001_bs-160_ep-500",
)
tr_lrprh = Experiment(
    f"{output_dir}/jasmin-daily-four_outputs_liqcf_pacific_treatment-AOD_covariates-LTS-EIS-W500-SST_outcomes-re-COD-CWP-LPC_bins-1/appended-treatment-transformer/dh-128_nco-22_nct-27_dp-3_nh-8_ns-0.28_bt-0.0_ln-False_dr-0.42_sn-0.0_lr-0.0001_bs-128_ep-500", 
)
tr_lrpw500 = Experiment(
    f"{output_dir}/jasmin-daily-four_outputs_liqcf_pacific_treatment-AOD_covariates-RH900-RH850-RH700-LTS-EIS-SST_outcomes-re_bins-1/appended-treatment-transformer/dh-256_nco-24_nct-24_dp-3_nh-4_ns-0.01_bt-0.0_ln-False_dr-0.5_sn-0.0_lr-0.0002_bs-32_ep-500",
)

In [None]:
def plot_all_scatterplot(experiment, savepath=None):
    exp, label, color = experiment 
    fig, axs = plt.subplots(1, len(exp.target_keys), figsize=(len(exp.target_keys)*6, 6))
    for idx, key in enumerate(exp.target_keys):
        target = TARGET_KEYS[exp.target_keys[idx]]
        qs = np.quantile(exp.outcomes[:, idx], [0.01, 0.99])
        domain = np.arange(qs[0], qs[1], 0.01)
        exp_means_ensemble = exp.means_ensemble
        slope, intercept, r, p, stderr = stats.linregress(
            exp.outcomes[:, idx], exp_means_ensemble.mean(0)[:, idx]
        )
        _ = axs[idx].scatter(
            x=exp.outcomes[:, idx],
            y=exp_means_ensemble.mean(0)[:, idx],
            s=0.01,
            c=color,
        )
        _ = axs[idx].plot(domain, domain, c="C2")
        _ = axs[idx].plot(domain, domain * slope + intercept, c=color, label=f"$r^2$={r**2:.03f}")
        _ = axs[idx].set_xlim(qs)
        _ = axs[idx].set_ylim(qs)
        _ = axs[idx].set_xlabel(f"{target} true")
        _ = axs[idx].set_ylabel(f"{target} predicted")
        _ = axs[idx].legend(loc="upper left")
    if savepath is not None:
        plt.savefig(f'{savepath}.png', format="png", bbox_inches='tight')
    plt.show()

In [None]:
def plot_single_scatterplot(experiment, idx_outcome, savepath=None):
    exp, label, color = experiment 
    target = TARGET_KEYS[exp.target_keys[idx_outcome]]
    qs = np.quantile(exp.outcomes[:, idx_outcome], [0.01, 0.99])
    domain = np.arange(qs[0], qs[1], 0.01)
    exp_means_ensemble = exp.means_ensemble
    slope, intercept, r, p, stderr = stats.linregress(
        exp.outcomes[:, idx_outcome], exp_means_ensemble.mean(0)[:, idx_outcome]
    )
    plt.figure(figsize=(6, 6))
    plt.scatter(
        x=exp.outcomes[:, idx_outcome],
        y=exp_means_ensemble.mean(0)[:, idx_outcome],
        s=0.01,
        c=color,
    )
    plt.plot(domain, domain, c="C2")
    plt.plot(domain, domain * slope + intercept, c=color, label=f"$r^2$={r**2:.03f}")
    plt.xlim(qs)
    plt.ylim(qs)
    plt.xlabel(f"{target} true")
    plt.ylabel(f"{target} predicted")
    plt.legend(loc="upper left")
    if savepath is not None:
        plt.savefig(f'{savepath}.png', format="png", bbox_inches='tight')
    plt.show()

In [None]:
def plot_all_scatterplot_comp(experiment1, experiment2, savepath=None):
    exp1, label1, color1 = experiment1
    exp2, label2, color2 = experiment2
    assert(exp1.target_keys == exp2.target_keys)    
    fig, axs = plt.subplots(1, len(exp1.target_keys), figsize=(len(exp1.target_keys)*6, 6))
    for idx, key in enumerate(exp1.target_keys):
        target = TARGET_KEYS[exp1.target_keys[idx]]
        qs1 = np.quantile(exp1.outcomes[:, idx], [0.01, 0.99])
        qs2 = np.quantile(exp2.outcomes[:, idx], [0.01, 0.99])
        qs = min(qs1[0], qs2[0]), max(qs1[1], qs2[1])
        domain = np.arange(qs[0], qs[1], 0.01)
        exp1_means_ensemble = exp1.means_ensemble
        exp2_means_ensemble = exp2.means_ensemble
        slope1, intercept1, r1, p1, stderr1 = stats.linregress(
            exp1.outcomes[:, idx], exp1_means_ensemble.mean(0)[:, idx]
        )
        slope2, intercept2, r2, p2, stderr2 = stats.linregress(
            exp2.outcomes[:, idx], exp2_means_ensemble.mean(0)[:, idx]
        )
        _ = axs[idx].scatter(
            x=exp1.outcomes[:, idx],
            y=exp1_means_ensemble.mean(0)[:, idx],
            s=0.01,
            c=color1,
        )
        _ = axs[idx].scatter(
            x=exp2.outcomes[:, idx],
            y=exp2_means_ensemble.mean(0)[:, idx],
            s=0.01,
            c=color2,
        )
        _ = axs[idx].plot(domain, domain * slope1 + intercept1, c=color1, label=f"{label1} $r^2$={r1**2:.03f}")
        _ = axs[idx].plot(domain, domain * slope2 + intercept2, c=color2, label=f"{label2} $r^2$={r2**2:.03f}")
        _ = axs[idx].set_xlim(qs)
        _ = axs[idx].set_ylim(qs)
        _ = axs[idx].set_xlabel(f"{target} true")
        _ = axs[idx].set_ylabel(f"{target} predicted")
        _ = axs[idx].legend(loc="upper left")
    if savepath is not None:
        plt.savefig(f'{savepath}.png', format="png", bbox_inches='tight')
    plt.show()

In [None]:
def plot_single_scatterplot_comp(experiment1, experiment2, idx_outcome, savepath=None):
    exp1, label1, color1 = experiment1
    exp2, label2, color2 = experiment2
    assert(exp1.target_keys[idx_outcome] == exp2.target_keys[idx_outcome])
    target = TARGET_KEYS[exp1.target_keys[idx_outcome]]
    qs1 = np.quantile(exp1.outcomes[:, idx_outcome], [0.01, 0.99])
    qs2 = np.quantile(exp2.outcomes[:, idx_outcome], [0.01, 0.99])
    qs = min(qs1[0], qs2[0]), max(qs1[1], qs2[1])
    domain = np.arange(qs[0], qs[1], 0.01)
    plt.figure(figsize=(6, 6))
    plt.plot(domain, domain, c="C2")
    exp1_means_ensemble = exp1.means_ensemble
    exp2_means_ensemble = exp2.means_ensemble
    slope1, intercept1, r1, p1, stderr1 = stats.linregress(
        exp1.outcomes[:, idx_outcome], exp1_means_ensemble.mean(0)[:, idx_outcome]
    )
    slope2, intercept2, r2, p2, stderr2 = stats.linregress(
        exp2.outcomes[:, idx_outcome], exp2_means_ensemble.mean(0)[:, idx_outcome]
    )
    plt.scatter(
        x=exp1.outcomes[:, idx_outcome],
        y=exp1_means_ensemble.mean(0)[:, idx_outcome],
        s=0.01,
        c=color1,
    )
    plt.scatter(
        x=exp2.outcomes[:, idx_outcome],
        y=exp2_means_ensemble.mean(0)[:, idx_outcome],
        s=0.01,
        c=color2,
    )
    plt.plot(domain, domain * slope1 + intercept1, c=color1, label=f"{label1} $r^2$={r1**2:.03f}")
    plt.plot(domain, domain * slope2 + intercept2, c=color2, label=f"{label2} $r^2$={r2**2:.03f}")
    plt.xlim(qs)
    plt.ylim(qs)
    plt.xlabel(f"{target} true")
    plt.ylabel(f"{target} predicted")
    plt.legend(loc="upper left", fontsize=12)
    if savepath is not None:
        plt.savefig(f'{savepath}.png', format="png", bbox_inches='tight')

In [None]:
def plot_single_apo_1(experiment, idx_outcome, savepath=None):
    exp, label, color = experiment
    alpha = 0.05
    plt.figure()
    _ = plt.plot(exp.treatments, exp.apos_ensemble[idx_outcome].mean(0), c=color)
    _ = plt.fill_between(
        x=exp.treatments,
        y1=np.quantile(exp.apos_ensemble[idx_outcome], 1 - alpha / 2, axis=0),
        y2=np.quantile(exp.apos_ensemble[idx_outcome], alpha / 2, axis=0),
        alpha=0.2,
        label=r"$\Lambda \to 1.0 $",
        color=color,
    )
    _ = plt.legend(
        title=r"$\alpha=$" + f"{alpha}",
        loc="upper right",
    )
    _ = plt.xlim([0.03, 0.3])
    _ = plt.ylabel(TARGET_KEYS[exp.target_keys[idx_outcome]])
    _ = plt.xlabel("AOD")
    if savepath is not None:
        plt.savefig(f'{savepath}.pdf', format="pdf", bbox_inches='tight')

In [None]:
def plot_all_apo_1(experiment, savepath=None):
    exp, label, color = experiment
    alpha = 0.05
    fig, axs = plt.subplots(1, len(exp.target_keys), figsize=(len(exp.target_keys)*6, 6))
    for idx, key in enumerate(exp.target_keys):
        target = TARGET_KEYS[exp.target_keys[idx]]
        _ = axs[idx].plot(exp.treatments, exp.apos_ensemble[idx].mean(0), c=color)
        _ = axs[idx].fill_between(
            x=exp.treatments,
            y1=np.quantile(exp.apos_ensemble[idx], 1 - alpha / 2, axis=0),
            y2=np.quantile(exp.apos_ensemble[idx], alpha / 2, axis=0),
            alpha=0.2,
            label=r"$\Lambda \to 1.0 $",
            color=color,
        )
        _ = axs[idx].set_xlim([0.03, 0.3])
        _ = axs[idx].set_xlabel("AOD")
        _ = axs[idx].set_ylabel(f"{target}")
        if idx == 0: 
            _ = axs[idx].legend(        
                title=r"$\alpha=$" + f"{alpha}",
                loc="upper right",
            )
        else:
            _ = axs[idx].legend(        
                title=r"$\alpha=$" + f"{alpha}",
                loc="upper left",
            )
        
    if savepath is not None:
        plt.savefig(f'{savepath}.pdf', format="pdf", bbox_inches='tight')
    plt.show()

In [None]:
def plot_single_apo_1_comp(experiment1, experiment2, idx_outcome, savepath=None):
    exp1, label1, color1 = experiment1
    exp2, label2, color2 = experiment2
    assert(exp1.target_keys[idx_outcome] == exp2.target_keys[idx_outcome])
    target = TARGET_KEYS[exp1.target_keys[idx_outcome]]
    alpha = 0.05
    plt.figure()
    scaler1 = MinMaxScaler().fit(
        exp1.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
    )
    scaler2 = MinMaxScaler().fit(
        exp2.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
    )
    _ = plt.plot(
        exp1.treatments,
        scaler1.transform(
            exp1.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
        ),
        color=color1, 
        label=label1,
    )
    _ = plt.fill_between(
        x=exp1.treatments,
        y1=scaler1.transform(
            np.quantile(exp1.apos_ensemble[idx_outcome], 1 - alpha / 2, axis=0).reshape(-1, 1),
        ).flatten(),
        y2=scaler1.transform(
            np.quantile(exp1.apos_ensemble[idx_outcome], alpha / 2, axis=0).reshape(-1, 1),
        ).flatten(),
        alpha=0.2,
        label=r"$\Lambda \to 1.0 $",
        color=color1,
    )
    _ = plt.plot(
        exp2.treatments,
        scaler2.transform(
            exp2.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
        ),
        color=color2, 
        label=label2,
    )
    _ = plt.fill_between(
        x=exp2.treatments,
        y1=scaler2.transform(
            np.quantile(exp2.apos_ensemble[idx_outcome], 1 - alpha / 2, axis=0).reshape(-1, 1),
        ).flatten(),
        y2=scaler2.transform(
            np.quantile(exp2.apos_ensemble[idx_outcome], alpha / 2, axis=0).reshape(-1, 1),
        ).flatten(),
        alpha=0.2,
        label=r"$\Lambda \to 1.0 $",
        color=color2,
    )
    _ = plt.legend(
        title=r"$\alpha=$" + f"{alpha}",
        loc="upper right",
    )
    _ = plt.xlim([0.03, 0.3])
    _ = plt.ylabel(TARGET_KEYS[exp1.target_keys[idx_outcome]])
    _ = plt.xlabel("AOD")
    if savepath is not None:
        plt.savefig(f'{savepath}.pdf', format="pdf", bbox_inches='tight')

In [None]:
def plot_single_apo_1_comp_noscale(experiment1, experiment2, idx_outcome, savepath=None):
    exp1, label1, color1 = experiment1
    exp2, label2, color2 = experiment2
    assert(exp1.target_keys[idx_outcome] == exp2.target_keys[idx_outcome])
    target = TARGET_KEYS[exp1.target_keys[idx_outcome]]
    alpha = 0.05
    plt.figure()
    _ = plt.plot(
        exp1.treatments,
        exp1.apos_ensemble[idx_outcome].mean(0),
        color=color1, 
        label=label1,
    )
    _ = plt.fill_between(
        x=exp1.treatments,
        y1=np.quantile(exp1.apos_ensemble[idx_outcome], 1 - alpha / 2, axis=0),
        y2=np.quantile(exp1.apos_ensemble[idx_outcome], alpha / 2, axis=0),
        alpha=0.2,
        label=r"$\Lambda \to 1.0 $",
        color=color1,
    )
    _ = plt.plot(
        exp2.treatments,
        exp2.apos_ensemble[idx_outcome].mean(0),
        color=color2, 
        label=label2,
    )
    _ = plt.fill_between(
        x=exp2.treatments,
        y1=np.quantile(exp2.apos_ensemble[idx_outcome], 1 - alpha / 2, axis=0),
        y2=np.quantile(exp2.apos_ensemble[idx_outcome], alpha / 2, axis=0),
        alpha=0.2,
        label=r"$\Lambda \to 1.0 $",
        color=color2,
    )
    _ = plt.legend(
        title=r"$\alpha=$" + f"{alpha}",
        loc="upper right",
    )
    _ = plt.xlim([0.03, 0.3])
    _ = plt.ylabel(TARGET_KEYS[exp1.target_keys[idx_outcome]])
    _ = plt.xlabel("AOD")
    if savepath is not None:
        plt.savefig(f'{savepath}.pdf', format="pdf", bbox_inches='tight')

In [None]:
def plot_single_apo_lambda_cov(experiment1, experiment2, idx_outcome, log_lambda, savepath=None):
    exp1, label1, color1 = experiment1
    exp2, label2, color2 = experiment2
    assert(exp1.target_keys[idx_outcome] == exp2.target_keys[idx_outcome])
    target = TARGET_KEYS[exp1.target_keys[idx_outcome]]
    alpha = 0.05
    plt.figure()
    scaler1 = MinMaxScaler().fit(
        exp1.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
    )
    scaler2 = MinMaxScaler().fit(
        exp2.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
    )
    _ = plt.plot(
        exp1.treatments,
        scaler1.transform(
            exp1.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
        ),
        color=color1, 
        label=label1,
    )
    _ = plt.plot(
        exp2.treatments,
        scaler2.transform(
            exp2.apos_ensemble[idx_outcome].mean(0).reshape(-1, 1)
        ),
        color=color2, 
        label=label2,
    )
    _ = plt.fill_between(
        x=exp1.treatments,
        y1=scaler1.transform(
            np.quantile(exp1.apos_ensemble[idx_outcome], 1 - alpha / 2, axis=0).reshape(-1, 1),
        ).flatten(),
        y2=scaler1.transform(
            np.quantile(exp1.apos_ensemble[idx_outcome], alpha / 2, axis=0).reshape(-1, 1),
        ).flatten(),
        alpha=0.2,
        label=r"$\Lambda \to 1.0 $ for " + f"{label1}",
        color=color1,
    )
    if log_lambda not in exp2.apo_limits: 
        exp2.get_apo_limits(log_lambda)
    _ = plt.fill_between(
        x=exp1.treatments,
        y1=scaler2.transform(
            np.quantile(
                exp2.apo_limits[log_lambda][1][idx_outcome],
                1 - alpha / 2,
                axis=0,
            ).reshape(-1, 1)
        ).flatten(),
        y2=scaler1.transform(
            np.quantile(
                exp1.apos_ensemble[idx_outcome],
                1 - alpha / 2,
                axis=0,
            ).reshape(-1, 1)
        ).flatten(),
        alpha=0.2,
        color=color2,
        label=r"$\Lambda=$" + f"{np.exp(log_lambda):.2f} for {label2}",
    )
    _ = plt.fill_between(
        x=exp1.treatments,
        y1=scaler1.transform(
            np.quantile(
                exp1.apos_ensemble[idx_outcome],
                alpha / 2,
                axis=0
            ).reshape(-1, 1)    
        ).flatten(), 
        y2=scaler2.transform(
            np.quantile(
                exp2.apo_limits[log_lambda][0][idx_outcome],
                alpha / 2,
                axis=0
            ).reshape(-1, 1)
        ).flatten(),
        alpha=0.2,
        color=color2,
    )
    _ = plt.legend(
        title=r"$\alpha=$" + f"{alpha}",
        loc="upper right",
    )
    _ = plt.xlim([0.03, 0.3])
    _ = plt.ylabel(TARGET_KEYS[exp1.target_keys[idx_outcome]])
    _ = plt.xlabel("AOD")
    if savepath is not None:
        plt.savefig(f'{savepath}.pdf', format="pdf", bbox_inches='tight')

## Plots

In [None]:
rc = {
    "figure.constrained_layout.use": True,
    "figure.facecolor": "white",
    "axes.labelsize": 22,
    "axes.titlesize": 22,
    "legend.frameon": True,
    "figure.figsize": (6, 6),
    "legend.fontsize": 22,
    "legend.title_fontsize": 22,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
}
_ = sns.set(style="whitegrid", palette="colorblind", rc=rc)

In [None]:
plot_all_scatterplot(
    [tr_lrp, 'Transformer Pacific', 'C0'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr_lrp',
)

plot_all_scatterplot(
    [nn_lrp, 'Neural Network Pacific', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_nn_lrp',
)

plot_all_scatterplot(
    [tr_lra, 'Transformer Atlantic', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr_lra',
)

plot_all_scatterplot(
    [tr_lrprh, 'Pacific without RH', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr_lrprh',
)

plot_all_apo_1(
    [tr_lrp, 'Transformer Pacific', 'C0'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/overcast/apo_tr_lrp',
)

plot_all_apo_1(
    [nn_lrp, 'Neural Network Pacific', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/overcast/apo_nn_lrp',
)

plot_all_apo_1(
    [tr_lra, 'Transformer Atlantic', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/overcast/apo_tr_lra',
)

In [None]:
plot_all_scatterplot(
    [tr_hrp, 'Transformer HR Pacific', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr_hrp',
)

plot_all_apo_1(
    [tr_hrp, 'Transformer HR Pacific', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/overcast/apo_tr_hrp',
)

In [None]:
rc = {
    "figure.constrained_layout.use": True,
    "figure.facecolor": "white",
    "axes.labelsize": 20,
    "axes.titlesize": 20,
    "legend.frameon": True,
    "figure.figsize": (6, 6),
    "legend.fontsize": 20,
    "legend.title_fontsize": 20,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
}
_ = sns.set(style="whitegrid", palette="colorblind", rc=rc)

In [None]:
plot_all_scatterplot_comp(
    [tr_lrp, 'Transformer', 'C0'],
    [nn_lrp, 'Neural Network', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr-vs-nn_lrp',
)

plot_all_scatterplot_comp(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lra, 'Atlantic', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr_lrp-vs-lra',
)

In [None]:
plot_all_scatterplot_comp(
    [tr_lrp, 'Low Resolution', 'C0'],
    [tr_hrp, 'High Resolution', 'C1'],
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr_lrp-vs-hrp',
)

In [None]:
rc = {
    "figure.constrained_layout.use": True,
    "figure.facecolor": "white",
    "axes.labelsize": 18,
    "axes.titlesize": 18,
    "legend.frameon": True,
    "figure.figsize": (6, 6),
    "legend.fontsize": 18,
    "legend.title_fontsize": 18,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
}
_ = sns.set(style="whitegrid", palette="colorblind", rc=rc)

In [None]:
plot_single_scatterplot_comp(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lra, 'Atlantic', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr_lrp-vs-lra_re',
)

plot_single_scatterplot_comp(
    [tr_lrp, 'Transformer', 'C0'],
    [nn_lrp, 'Neural Network', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/scatter_tr-vs-nn_lrp_re',
)

plot_single_apo_1(
    [tr_lrp, 'Transformer Pacific', 'C0'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp_re',
)

plot_single_apo_1(
    [nn_lrp, 'Neural Network Pacific', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_nn_lrp_re',
)

plot_single_apo_1(
    [tr_lra, 'Transformer Atlantic', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lra_re',
)

plot_single_apo_1(
    [tr_lrprh, 'Pacific without RH', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrprh_re',
)

plot_single_apo_1(
    [tr_lrpw500, r'Pacific without $\omega500$', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrpw500_re',
)

In [None]:
rc = {
    "figure.constrained_layout.use": True,
    "figure.facecolor": "white",
    "axes.labelsize": 14,
    "axes.titlesize": 14,
    "legend.frameon": True,
    "figure.figsize": (6, 6),
    "legend.fontsize": 14,
    "legend.title_fontsize": 14,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
}
_ = sns.set(style="whitegrid", palette="colorblind", rc=rc)

In [None]:
plot_single_apo_1_comp_noscale(
    [tr_lrp, 'Transformer', 'C0'],
    [nn_lrp, 'Neural Network', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr-vs-nn_lrp_re',
)

plot_single_apo_1_comp(
    [tr_lrp, 'Transformer', 'C0'],
    [nn_lrp, 'Neural Network', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo-scaled_tr-vs-nn_lrp_re',
)

plot_single_apo_1_comp_noscale(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lra, 'Atlantic', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp-vs-lra_re',
)

plot_single_apo_1_comp(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lra, 'Atlantic', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo-scaled_tr_lrp-vs-lra_re',
)

plot_single_apo_1_comp_noscale(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lrpw500, r'Pacific without $\omega500$', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp-vs-lrpw500_re',
)

plot_single_apo_1_comp(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lrpw500, r'Pacific without $\omega500$', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo-scaled_tr_lrp-vs-lrpw500_re',
)

plot_single_apo_1_comp_noscale(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lrprh, 'Pacific without RH', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp-vs-lrprh_re',
)

plot_single_apo_1_comp(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lrprh, 'Pacific without RH', 'C1'],
    0,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo-scaled_tr_lrp-vs-lrprh_re',
)

plot_single_apo_lambda_cov(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lra, 'Atlantic', 'C1'],
    0,
    0.05,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp-vs-lra_lra_re',
)

plot_single_apo_lambda_cov(
    [tr_lra, 'Atlantic', 'C1'],
    [tr_lrp, 'Pacific', 'C0'],
    0,
    0.07,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp-vs-lra_lrp_re',
)

plot_single_apo_lambda_cov(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lrpw500, r'Pacific without $\omega500$', 'C1'],
    0,
    0.01,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp-vs-lrpw500_lrpw500_re',
)

plot_single_apo_lambda_cov(
    [tr_lrp, 'Pacific', 'C0'],
    [tr_lrprh, 'Pacific without RH', 'C1'],
    0,
    0.04,
    savepath=f'/users/ms21mmso/msc-project/msc-project-report/figures/apo_tr_lrp-vs-lrprh_lrprh_re',
)