# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
from importlib.resources import files
import math
import os
import string

import cftime
import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pysteps
import scipy
import xarray as xr

import mlde_utils
from mlde_notebooks import plot_map, distribution_figure, plot_mean_bias, plot_std_bias, scatter_plots, freq_density_plot
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS, platecarree, dataset_split_path
from mlde_notebooks.data import prep_eval_data, open_concat_sample_dataarrays, si_to_mmday
from mlde_notebooks import create_map_fig, STYLES
from mlde_notebooks.distribution import normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases
# from mlde_notebooks.wet_dry import wet_dry_dataframe, wet_dry_ratio, wet_dry_ratio_error, plot_wet_dry_errors

xr.set_options(display_style="html")

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
split = "test"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 6

precomp_60km_samples_data_configs = [
    {
        "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec/postprocess/gcm-grid",
        "checkpoint": "epoch-20",
        "input_xfm": "pixelmmsstan",
        "label": "Emul(GCM)@60km",
        "dataset": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
        "deterministic": False,
        "CCS": True,
        "color": "lightgreen",
        "order": 20,
        
    },
]

In [None]:
precomp_60km_da = open_concat_sample_dataarrays(precomp_60km_samples_data_configs, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run).isel(model=0)
precomp_60km_da

In [None]:
raw_gcm_pr = si_to_mmday(xr.open_dataset(dataset_split_path("bham_60km-60km_12em_rawpr_eqvt_random-season", split))["pr"])
raw_gcm_pr

In [None]:
cpm_pr_on_gcm = si_to_mmday(xr.open_dataset(dataset_split_path("bham_gcmx-60km_12em_pr_eqvt_random-season", split))["pr"])
cpm_pr_on_gcm

### frequency distribution on the coarse grid, bias in mean and std. dev.

In [None]:
hist_data = [ 
    dict(data=precomp_60km_da, label="Diffusion-GCM@60km", color="green"),
    dict(data=raw_gcm_pr, label="GCM", color="magenta"),]

mean_biases = [ dict(data=normalized_mean_bias(hd["data"], cpm_pr_on_gcm), label=hd["label"]) for hd in hist_data ]

std_biases = [ dict(data=normalized_std_bias(hd["data"], cpm_pr_on_gcm), label=hd["label"]) for hd in hist_data ]

In [None]:
fig = plt.figure(layout='constrained', figsize=(5.5, 6.5))

meanb_axes_keys = list([f"meanb {mb['label']}" for mb in mean_biases])
meanb_spec = np.array(meanb_axes_keys).reshape(1,-1)

stddevb_axes_keys = list([f"stddevb {sb['label']}" for sb in std_biases])
stddevb_spec = np.array(stddevb_axes_keys).reshape(1,-1)

density_axes_keys = ["density"]
density_spec = np.array(density_axes_keys*meanb_spec.shape[1]).reshape(1,-1)

spec = np.concatenate([density_spec, meanb_spec, stddevb_spec], axis=0)

axd = fig.subplot_mosaic(spec, gridspec_kw=dict(height_ratios=[3, 2, 2]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in meanb_axes_keys + stddevb_axes_keys})

ax = axd["density"]
plot_freq_density(hist_data, target_da=cpm_pr_on_gcm, ax=ax, target_label="CPM@60km")
ax.annotate("a.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

axes = plot_mean_biases(mean_biases, axd, plot_map_kwargs={"transform": platecarree})
axes[0].annotate("b.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

axes = plot_std_biases(std_biases, axd, plot_map_kwargs={"transform": platecarree})
axes[0].annotate("c.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

In [None]:
density_axes_keys = list(map(lambda x: f"density {x}", [ hd["label"] for hd in hist_data ]))
if len(density_axes_keys) % 2 == 1:
    density_axes_keys = density_axes_keys + ["."]
density_spec = np.array(density_axes_keys).reshape(-1,2)

fig = plt.figure(figsize=(10, 5*density_spec.shape[0]))

axd = fig.subplot_mosaic(density_spec)

for hd in hist_data:
    ax = axd[f"density {hd['label']}"]
    plot_freq_density([hd], target_da=cpm_pr_on_gcm, ax=ax, target_label="CPM@60km")