# Samples from models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import IPython
import matplotlib.pyplot as plt

from mlde_notebooks.data import prep_eval_and_model_data
from mlde_notebooks.samples import show_samples, em_timestamps

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
split = "val"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "Diffusion (cCPM)",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "PSD": True,
            "color": "blue",
            "order": 10,
        },
        {
            "fq_model_id": "u-net/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "U-Net (cCPM)",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": True,
            "PSD": True,
            "color": "orange",
            "order": 1,
        },
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "Bilinear cCPM",
            "deterministic": True,
            "dataset": "bham_gcmx-4x_12em_linpr_eqvt_random-season",
            "color": "red",
            "order": 0,
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "pixelmmsstan",
            "label": "Diffusion (GCM)",
            "dataset": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "CCS": True,
            "color": "green",
            "order": 20,
            
        },
    ],
}

sample_percentiles={
    "CPM": [{"label": "Wet", "percentile": 0.8}, {"label": "Wettest", "percentile": 1}],
    "GCM": [{"label": "Wet", "percentile": 0.8}, {"label": "Wettest", "percentile": 1}],
}
sample_overrides={"CPM": {}, "GCM": {}}

desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_and_model_data(data_configs, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
seasons=["DJF", "JJA"]

examples_to_plot = {}

for source in EVAL_DS.keys():
    IPython.display.display_html(f"<h2>{source} Samples</h2>", raw=True)
    em_ts = em_timestamps(EVAL_DS[source], seasons=seasons, percentiles=sample_percentiles[source], overrides=sample_overrides[source])

    examples_to_plot[source] = em_ts

In [None]:
for source in EVAL_DS.keys():
    fig_width = min(2 + len(MODELS[source]) + 1, 5.5)
    fig_height = len(sample_percentiles[source])*len(seasons)
    
    fig = plt.figure(layout="constrained", figsize=(fig_width, fig_height))
    show_samples( EVAL_DS[source], examples_to_plot[source], models=MODELS[source], fig=fig, sim_title=source)
    plt.show()