# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv
    
import math
import os

import cftime
import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pysteps
import scipy
import xarray as xr

from mlde_analysis import plot_map, plot_examples, distribution_figure, plot_mean_bias, plot_std_bias, scatter_plots, seasonal_distribution_figure, compute_gridspec, freq_density_plot
from mlde_analysis import plot_psd
from mlde_utils import cp_model_rotated_pole
from mlde_utils.utils import prep_eval_data
from mlde_analysis import create_map_fig, qq_plot, STYLES

In [None]:
def reasonable_quantiles(da):
    limit = int(np.log10(1/da.size))
    print(limit)
    return np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, limit-1, -1)] + [[1.0]])

In [None]:
time_slices = {
    "historic": (cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2000, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "present": (cftime.Datetime360Day(2020, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2040, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "future": (cftime.Datetime360Day(2060, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2080, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
}

In [None]:
# Parameters
split = "val"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "Bilinear",
            "deterministic": True,
            "dataset": "bham_gcmx-4x_12em_linpr_eqvt_random-season",
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "Diffusion",
            "dataset": "bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
    ],
    "GCM": [
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "Bilinear",
            "deterministic": True,
            "dataset": "bham_60km-4x_12em_linpr_eqvt_random-season",
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-20",
            "input_xfm": "pixelmmsstan",
            "label": "Diffusion",
            "dataset": "bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
    ],
}
highlighted_cpm_models = ["Diffusion 12em"]
gcm_lr_lin_pr_dataset = "bham_60km-4x_12em_linpr_eqvt_random-season"
cpm_hr_pr_dataset = "bham_gcmx-4x_12em_linpr_eqvt_random-season"
desc = "Multi-ensemble member models\nSplits are based on random choice of seasons with equal number of seasons from each time slice\n\nCompare:\n\n### Diffusion models\n* PslT4V4 IstanTsqrturrecen (without variables at 925 hPa)\n\n### Lo-res precip:\n* id-linpr\n\n## Diff models and U-net models\n\n8-channels loc-spec params (diff models only)\n\nInputs from: pressure at sea level and 5 levels of temp and vorticity\n\nTarget domain and resolution: 64x64 2.2km-4x England and Wales\n\nInput resolution: 60km/gcmx\n\nInput transforms are fitted on dataset in use (ie separate GCM and CPM versions) while target transform is fitted only at training on the CPM dataset\n"


In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds = { source: prep_eval_data(data_config, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run) for source, data_config in data_configs.items() }
merged_ds

## Pixel distribution

In [None]:
cpm_pr = merged_ds["CPM"]["target_pr"]
quantiles = reasonable_quantiles(cpm_pr)

### GCM

In [None]:
distribution_figure(merged_ds["CPM"], cpm_pr, quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])