# Samples from models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import IPython
import matplotlib
import matplotlib.pyplot as plt

from mlde_analysis.data import prep_eval_data
from mlde_analysis.examples import em_timestamps
from mlde_analysis.perspective_paper import pp_plot_examples

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
# Parameters
split = "test"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
dataset_configs = {
    "CPM": "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
    "GCM": "bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
}
samples_per_run = 6
sample_configs = {
    "CPM": [
        {
            "label": "Diffusion (cCPM)",
            "sample_specs": [
                {
                    "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
                    "checkpoint": "epoch_20",
                    "input_xfm": "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-stan",
                    "dataset": "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
                    "variables": ["pr"],
                },
            ],
            "deterministic": False,
            "CCS": True,
            "color": "tab:blue",
            "order": 10,
        }
    ],
    "GCM": [
        {
            "label": "Diffusion (GCM)",
            "sample_specs": [
                {
                    "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
                    "checkpoint": "epoch_20",
                    "input_xfm": "bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-pixelmmsstan",
                    "dataset": "bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
                    "variables": ["pr"],
                }
            ],
            "deterministic": False,
            "CCS": True,
            "PSD": True,
            "UQ": False,
            "color": "tab:cyan",
            "order": 100,
        }
    ],
}


example_percentiles = {
    "CPM": {
        "DJF Median": {"percentile": 0.5, "variable": "pr", "season": "DJF"},
        "DJF Annual max": {"percentile": 0.9974, "variable": "pr", "season": "DJF"},
        "JJA Median": {"percentile": 0.5, "variable": "pr", "season": "JJA"},
        "JJA Annual max": {"percentile": 0.9974, "variable": "pr", "season": "JJA"},
    },
    "GCM": {
        "DJF Median": {"percentile": 0.5, "variable": "pr", "season": "DJF"},
        "DJF Annual max": {"percentile": 0.9974, "variable": "pr", "season": "DJF"},
        "JJA Median": {"percentile": 0.5, "variable": "pr", "season": "JJA"},
        "JJA Annual max": {"percentile": 0.9974, "variable": "pr", "season": "JJA"},
    },
}
example_overrides = {"CPM": {}, "GCM": {}}
n_samples_per_example=2
derived_variables_config={}
eval_vars=["pr"]

## Data

* Using all 12 ensemble members on 1981-2000, 2021-2040 and 2061-2080 periods for initial UKCP Local release (but using data post graupel bug fix)
* Splits are based on random choice of seasons with equal number of seasons from each time slice
* Target domain and resolution: 64x64@8.8km (4x 2.2km) England and Wales
* Input resolution: 60km (cCPM is CPM coarsened to GCM 60km grid)

## CPMGEM models

Compare:

* cCPM input source
* GCM with bias correction input source

### Shared specs

* Inputs variable (unless otherwise stated): pSTV (pressure at sea level and 4 levels of specific humidity, air temp and relative vorticity)
* Input transforms are fitted on dataset in use (ie separate GCM and CPM versions) while target transform is fitted only at training on the CPM dataset
* No loc-spec params
* 6 samples per example

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
SOURCES = ["CPM"] # example_percentiles.keys()

examples_to_plot = { source: em_timestamps(EVAL_DS[source], percentiles=example_percentiles[source], overrides=example_overrides[source]) for source in SOURCES }

In [None]:
for source, examples in examples_to_plot.items():
    IPython.display.display_html(f"<h2>{source} Samples</h2>", raw=True)
    fig_width = min(2 + len(MODELS[source]) + 1, 5.5)
    fig_width = 2 + 2*len(MODELS[source])
    fig_height = 1.1*len(examples) + 1
    
    fig = plt.figure(layout="constrained", figsize=(fig_width, fig_height))
    pp_plot_examples( EVAL_DS[source], examples, vars=eval_vars, models=MODELS[source], fig=fig, sim_title=source, n_samples_per_example=n_samples_per_example)
    plt.show()