# Emulator Schematic

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv
    
import IPython
import matplotlib.pyplot as plt
import math

from mlde_utils import cp_model_rotated_pole
from mlde_notebooks.data import prep_eval_and_model_data
from mlde_notebooks import plot_map, sorted_em_time_by_mean_pr

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
split = "test"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "Diff",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
    ]
}
sample_percentiles={
    "CPM": [{"label": "Wet", "percentile": 0.8}, {"label": "Wettest", "percentile": 1}],
    "GCM": [{"label": "Wet", "percentile": 0.8}, {"label": "Wettest", "percentile": 1}],
}

desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_and_model_data(data_configs, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

## Figure: Emulator Schematic

In [None]:
mean_sorted_em_time = sorted_em_time_by_mean_pr(EVAL_DS["CPM"])
idx = math.ceil(len(mean_sorted_em_time) * (1 - sample_percentiles["CPM"][-1]["percentile"]))
em_time = mean_sorted_em_time[idx]

plot_ds = EVAL_DS["CPM"].sel(time=em_time[1], ensemble_member=em_time[0]).sel(model=list(MODELS["CPM"].keys())[-1])
subplot_kw = dict(projection=cp_model_rotated_pole)

nsamples = 3

thetas = [850, 700, 500, 250]

possible_variables = [("spechum", thetas, "Specific Humidity\n(multi-level)"), ("temp", thetas, "Temperature\n(multi-level)"), ("vorticity", thetas, "Vorticity\n(multi-level)"), ("psl", [""], "Sea-level pressure")]
variables = list(filter(lambda v: any([k.startswith(v[0]) for k in plot_ds.variables.keys()]), possible_variables))

full_variable_set = [f"{varclass}{level}" for varclass, levels, _title in variables for level in levels]


awidth = 0.12
offset_width = 0.009
stacked_awidth = (awidth + offset_width *3)
gap = (1 - stacked_awidth * 4)/(len(variables) - 1)
print(stacked_awidth*4 + gap*3)

fig = plt.figure(figsize=(5.5, 5.5), layout="constrained")

axd = fig.subplot_mosaic([full_variable_set + [f"pred_pr {i}" for i in range(nsamples)] + ["AI"]], subplot_kw=subplot_kw)

ax = axd["AI"]
# ax.axis("off")
ax.set_facecolor('black')
ax.text(0.5, 0.5, "Diffusion", 
        ha='center', va='center', color="white", weight='bold', transform=ax.transAxes)
ax.set_position([0.43, awidth+0.15, 0.15, 0.15])

output_arrows = [
    dict(
        xy=(1, 1),
        xytext=(0.5, 0),
        arrowprops=dict(facecolor='black', shrinkB=5, arrowstyle="fancy", connectionstyle="arc3,rad=-0.2"),
    ),
    dict(
        xy=(0.5, 1),
        xytext=(0.5, 0),
        arrowprops=dict(facecolor='black', shrinkB=5, arrowstyle="fancy"),
    ),
    dict(
        xy=(0, 1),
        xytext=(0.5, 0),
        arrowprops=dict(facecolor='black', shrinkB=5, arrowstyle="fancy", connectionstyle="arc3,rad=0.2"),
    ),
]
for sampleidx in range(nsamples):
    ax = axd[f"pred_pr {sampleidx}"]
    plot_map(plot_ds["pred_pr"].isel(sample_id=sampleidx), ax=ax, style="precip")
    ax.set_position([(0.505-awidth/2)+(sampleidx-1)*(awidth+0.03), awidth/2, awidth, awidth])

    if sampleidx == 1:
        ax.text(0.5, -0.15, "High-resolution precipitation", fontsize="small", ha='center', va='center',transform=ax.transAxes)

    axd["AI"].annotate(
            '',
            xycoords=ax.transAxes,
            textcoords=axd["AI"].transAxes,
            **output_arrows[sampleidx],
        )
arrows = [
    dict(
        xy=(0.5, 0.5),
        xytext=(0.5, 0.5),
        arrowprops=dict(facecolor='black', shrinkA=38, shrinkB=33, arrowstyle="simple"),#, connectionstyle="arc3,rad=0.2"),
    ),
    dict(
        xy=(0.5, 0.5),
        xytext=(0.85, -0.25),
        arrowprops=dict(facecolor='black', shrinkB=24, arrowstyle="simple"),#, connectionstyle="arc3,rad=0.2"),
    ),
    dict(
        xy=(0.5, 0.5),
        xytext=(0.5, -0.25),
        arrowprops=dict(facecolor='black', shrinkB=24, arrowstyle="simple"),#, connectionstyle="arc3,rad=-0.2"),
    ),
    dict(
        xy=(0.5, 0.5),
        xytext=(0.5, 0.5),
        arrowprops=dict(facecolor='black', shrinkA=27, shrinkB=33, arrowstyle="simple"),#, connectionstyle="arc3,rad=-0.2"),
    ),
]

for vi, (varclass, levels, vartitle) in enumerate(variables):
    variable_set = [f"{varclass}{level}" for level in levels]
    
    for i, var in enumerate(variable_set):
        ax = axd[var]
        var_plot_kwargs = {}
        if varclass in ["vorticity"]:
            var_plot_kwargs = {"center": 0}
        plot_map(plot_ds[var], ax=ax, style=None, **var_plot_kwargs)
        left = (stacked_awidth + gap) * vi + offset_width*i
        top = 0.5-offset_width*i
        if vi == 0 or vi == len(variables) - 1:
            top = top - 0.1
        ax.set_position([left, top, awidth, awidth]) 
        if i == 0:
            ax.set_title(vartitle, fontsize="small")
        if i == 0:#len(variable_set)-1:
            axd["AI"].annotate(
                '',
                xycoords=axd["AI"].transAxes,
                textcoords=ax.transAxes,
                **arrows[vi],
            )