# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham in a low data scenario

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import cftime
import IPython
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_notebooks.data import prep_eval_and_model_data
from mlde_notebooks import plot_map, precip_cmap, precip_norm
from mlde_notebooks.psd import plot_psd, pysteps_rapsd
from mlde_notebooks.distribution import normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias
from mlde_utils import cp_model_rotated_pole

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
split = "test"
ensemble_members = [
    "01",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "Diffusion-cCPM",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic",
            "deterministic": False,
            "PSD": True,
            "CCS": False,
            "color": "blue",
            "order": 10,
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_1em_PslS4T4V4_random-season-historic-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-300",
            "input_xfm": "stan",
            "label": "Diffusion_ld-cCPM",
            "dataset": "bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic",
            "deterministic": False,
            "PSD": True,
            "CCS": False,
            "color": "purple",
            "order": 11,
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_1em_PslS4T4V4_random-season-historic-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-300",
            "input_xfm": "pixelmmsstan",
            "label": "Diffusion_ld-GCM",
            "dataset": "bham_60km-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic",
            "deterministic": False,
            "PSD": True,
            "CCS": False,
            "color": "darkgreen",
            "order": 21,
        },
    ],
}

desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds, MODELS = prep_eval_and_model_data(data_configs, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
merged_ds

In [None]:
cpm_pr = merged_ds["CPM"]["target_pr"]

## Figure: distribution

* Frequency Density Histogram of rainfall intensities
* Maps of Mean bias ($\frac{\mu_{sample}-\mu_{CPM}}{\mu_{CPM}}$) over all samples, time and ensemble members
* Std Dev Bias $\frac{\sigma_{sample}}{\sigma_{CPM}}$ over all samples, time and ensemble members

In [None]:
hist_data = [ dict(data=merged_ds[spec["source"]]["pred_pr"].sel(model=model), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ]

mean_biases = [ dict(data=normalized_mean_bias(sample_pr.squeeze("model"), cpm_pr), label=model) for source in merged_ds.keys() for model, sample_pr in merged_ds[source]["pred_pr"].groupby("model", squeeze=False) ]

std_biases = [ dict(data=normalized_std_bias(sample_pr.squeeze("model"), cpm_pr), label=model) for source in merged_ds.keys() for model, sample_pr in merged_ds[source]["pred_pr"].groupby("model", squeeze=False) ]

cpm_hr_rapsd = pysteps_rapsd(merged_ds["CPM"]["target_pr"].stack(example=["ensemble_member", "time"]).transpose("example", "grid_latitude", "grid_longitude"), pixel_size=8.8).mean(dim="example").drop_sel(freq=0)

pred_rapsds = [
    {
        "label": model,
        "color": spec["color"],
        "data": pysteps_rapsd(merged_ds[spec["source"]]["pred_pr"].sel(model=model).stack(example=["ensemble_member", "sample_id", "time"]).transpose("example", "grid_latitude", "grid_longitude"), pixel_size=8.8).mean(dim="example").drop_sel(freq=0)
    }
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
]

In [None]:
fig = plt.figure(layout='constrained', figsize=(5.5, 6.5))

samples_to_show = 2

em_ts = {
    # ("01", "2029-12-04"),
    "": ("01", cftime.Datetime360Day(1993, 8, 1, 12)),
    # "JJA Wettest": ("01", "2073-07-30")
}

models = ["Diffusion_ld-cCPM"]
sample_axes_keys = np.array([[f"CPM {em_ts_key}"] + [f"{model} {em_ts_key} sample {sidx}" for sidx in range(samples_to_show)] for model in models for em_ts_key in (em_ts.keys())])
sample_spec = np.array(sample_axes_keys).reshape(len(models) * len(em_ts),-1)

dist_spec = np.array(["Density"] * sample_spec.shape[1]).reshape(1,-1)

rapsd_spec = np.array(["RAPSD"] * sample_spec.shape[1]).reshape(1,-1)

spec = np.concatenate([dist_spec, sample_spec, rapsd_spec], axis=0)

axd = fig.subplot_mosaic(spec, gridspec_kw=dict(height_ratios=[3, 2, 3]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in sample_axes_keys.flat})

ax = axd["Density"]

plot_freq_density(hist_data, target_da=cpm_pr, ax=ax)
ax.annotate("a.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

for model in models:
    for tsi, (desc, ts) in enumerate(em_ts.items()):
        pr_da = merged_ds["CPM"].sel(ensemble_member=ts[0]).sel(time=ts[1], method="nearest")["target_pr"]
        
        ax = axd[f"CPM {desc}"]
        plot_map(
            pr_da,
            ax,
            cmap=precip_cmap,
            norm=precip_norm,
            add_colorbar=False,
        )
        ax.set_title("CPM", fontsize="medium")
        # label row
        ax.text(
            -0.1,
            0.5,
            desc,
            transform=ax.transAxes,
            ha="right",
            va="center",
            fontsize="medium",
            rotation=90,
        )
        # annotate row with identifier
        ax.annotate("b.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")
        for sidx in range(2):
            print(f"Precip from EM{ts[0]} on {ts[1]}")
        
            ts_ds = merged_ds["CPM"].sel(ensemble_member=ts[0]).sel(time=ts[1], method="nearest")
    
            pr_da = ts_ds.sel(model=model).isel(sample_id=sidx)["pred_pr"]
            ax = axd[f"{model} {desc} sample {sidx}"]
            
            plot_map(
                pr_da,
                ax,
                cmap=precip_cmap,
                norm=precip_norm,
                add_colorbar=False,
            )
            ax.set_title(f"Sample {sidx+1}", fontsize="medium")

ax = axd["RAPSD"]
plot_psd(cpm_hr_rapsd, pred_rapsds, ax=ax, legend_kwargs={"fontsize": "small"})
ax.annotate("c.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

In [None]:
def rmmb(pred_pr):
    return rms_mean_bias(pred_pr, cpm_pr)
def rmsb(pred_pr):
    return rms_std_bias(pred_pr, cpm_pr)

rms_mean_biases = xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(rmmb) for source in merged_ds.keys() ], dim="model")
rms_std_biases = xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(rmsb) for source in merged_ds.keys() ], dim="model")

IPython.display.display_html(rms_mean_biases.rename("Root Mean Square Mean Bias (mm/day)").to_dataframe().round(2).to_html(), raw=True)
IPython.display.display_html(rms_std_biases.rename("Root Mean Square Std Dev Bias (mm/day)").to_dataframe().round(2).to_html(), raw=True)