# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_notebooks.data import prep_eval_and_model_data
from mlde_notebooks import plot_map, distribution_figure, scatter_plots
from mlde_notebooks.distribution import normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias
from mlde_notebooks.psd import plot_psd, pysteps_rapsd
from mlde_notebooks.uncertainty import plot_spread_error
from mlde_notebooks.wet_dry import wet_prop_stats, wet_day_prop, wet_day_prop_error, wet_day_prop_change, plot_wet_dry_errors
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
split = "val"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-stan",
            "label": "Diffusion-cCPM",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "PSD": True,
            "color": "blue",
            "order": 10,
        },
        {
            "fq_model_id": "u-net/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "U-Net-cCPM",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": True,
            "PSD": True,
            "color": "orange",
            "order": 1,
        },
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "cCPM Bilinear",
            "deterministic": True,
            "dataset": "bham_gcmx-4x_12em_linpr_eqvt_random-season",
            "color": "dimgrey",
            "order": 0,
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-pixelmmsstan",
            "label": "Diffusion-GCM",
            "dataset": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "CCS": True,
            "color": "green",
            "order": 20,
            
        },
    ],
}

desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds, MODELS = prep_eval_and_model_data(data_configs, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
merged_ds

In [None]:
cpm_pr = merged_ds["CPM"]["target_pr"]

## Figure: distribution

* Frequency Density Histogram of rainfall intensities
* Maps of Mean bias ($\frac{\mu_{sample}-\mu_{CPM}}{\mu_{CPM}}$) over all samples, time and ensemble members
* Std Dev Bias $\frac{\sigma_{sample}}{\sigma_{CPM}}$ over all samples, time and ensemble members

In [None]:
hist_data = [ dict(data=merged_ds[source]["pred_pr"].sel(model=model), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ]

mean_biases = [ dict(data=normalized_mean_bias(merged_ds[source]["pred_pr"].sel(model=model), cpm_pr), label=model)  for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ]

std_biases = [ dict(data=normalized_std_bias(merged_ds[source]["pred_pr"].sel(model=model), cpm_pr), label=model)  for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ]

In [None]:
fig = plt.figure(layout='constrained', figsize=(5.5, 6.5))

labels = [ label for mconfigs in MODELS.values() for label in mconfigs.keys() ]

meanb_axes_keys = list(map(lambda x: f"meanb {x}", labels))
meanb_spec = np.array(meanb_axes_keys).reshape(1,-1)

stddevb_axes_keys = list(map(lambda x: f"stddevb {x}", labels))
stddevb_spec = np.array(stddevb_axes_keys).reshape(1,-1)

dist_spec = np.array(["Density"] * meanb_spec.shape[1]).reshape(1,-1)

spec = np.concatenate([dist_spec, meanb_spec, stddevb_spec], axis=0)

axd = fig.subplot_mosaic(spec, gridspec_kw=dict(height_ratios=[3, 2, 2]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in meanb_axes_keys + stddevb_axes_keys})

ax = axd["Density"]

plot_freq_density(hist_data, ax=ax, target_da=cpm_pr)
ax.annotate("a.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

axes = plot_mean_biases(mean_biases, axd)
axes[0].annotate("b.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

axes = plot_std_biases(std_biases, axd)
axes[0].annotate("c.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

fig.savefig("cpm-gcm-distribution.pdf", format="pdf", bbox_inches='tight')

In [None]:
rms_mean_biases = xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(functools.partial(rms_mean_bias, cpm_pr=cpm_pr, normalize=False)) for source in merged_ds.keys() ], dim="model")
rms_std_biases = xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(functools.partial(rms_std_bias, cpm_pr=cpm_pr, normalize=False)) for source in merged_ds.keys() ], dim="model")

IPython.display.display_html(rms_mean_biases.rename("Root Mean Square Mean Bias (mm/day)").to_dataframe().round(2).to_html(), raw=True)
IPython.display.display_html(rms_std_biases.rename("Root Mean Square Std Dev Bias (mm/day)").to_dataframe().round(2).to_html(), raw=True)

rms_mean_biases = 100 * xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(functools.partial(rms_mean_bias, cpm_pr=cpm_pr, normalize=False)) for source in merged_ds.keys() ], dim="model")/cpm_pr.mean()
rms_std_biases = 100 * xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(functools.partial(rms_std_bias, cpm_pr=cpm_pr, normalize=False)) for source in merged_ds.keys() ], dim="model")/cpm_pr.std()

IPython.display.display_html(rms_mean_biases.rename("Relative Root Mean Square Mean Bias (%)").to_dataframe().round(2).to_html(), raw=True)
IPython.display.display_html(rms_std_biases.rename("Relative Root Mean Square Std Dev Bias (%)").to_dataframe().round(2).to_html(), raw=True)

rms_mean_biases = xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(functools.partial(rms_mean_bias, cpm_pr=cpm_pr, normalize=True)) for source in merged_ds.keys() ], dim="model")
rms_std_biases = xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(functools.partial(rms_std_bias, cpm_pr=cpm_pr, normalize=True)) for source in merged_ds.keys() ], dim="model")

IPython.display.display_html(rms_mean_biases.rename("Relative Root Mean Square Mean Bias (%)").to_dataframe().round(2).to_html(), raw=True)
IPython.display.display_html(rms_std_biases.rename("Relative Root Mean Square Std Dev Bias (%)").to_dataframe().round(2).to_html(), raw=True)

## Figure: spread

* Mean precip scatter: sample vs CPM
* Spread-error
  * https://journals.ametsoc.org/view/journals/hydr/15/4/jhm-d-14-0008_1.xml?tab_body=fulltext-display
  * https://journals.ametsoc.org/view/journals/aies/2/2/AIES-D-22-0061.1.xml
  * https://www.sciencedirect.com/science/article/pii/S0021999107000812

In [None]:
fig = plt.figure(layout='constrained', figsize=(5.5, 5.5*(2/3.0)))
scatter_fig, ss_fig = fig.subfigures(1, 2, width_ratios=[2,1.075])

source = "CPM"
domain_mean_cpm_ds = merged_ds[source].drop_sel(model="cCPM Bilinear", errors="ignore")[["target_pr", "pred_pr"]].mean(dim=["grid_latitude", "grid_longitude", "sample_id"])

axd = scatter_plots(domain_mean_cpm_ds, fig=scatter_fig, line_props=MODELS[source])

cpm_ds = merged_ds[source].sel(model=[ label for label, mconfig in MODELS[source].items() if not mconfig["deterministic"] ])

axd = ss_fig.subplot_mosaic([["Spread-Error"]])
ax = axd["Spread-Error"]
plot_spread_error(cpm_ds, ax, MODELS[source])
ax.annotate("c.", xy=(0, 1.05), xycoords=("axes fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

## Figure: structure

* PSD

In [None]:
cpm_hr_rapsd = pysteps_rapsd(merged_ds["CPM"]["target_pr"].stack(example=["ensemble_member", "time"]).transpose("example", "grid_latitude", "grid_longitude"), pixel_size=8.8).mean(dim="example").drop_sel(freq=0)

pred_rapsds = [
    {
        "label": model,
        "color": spec["color"],
        "data": pysteps_rapsd(merged_ds[source]["pred_pr"].sel(model=model).stack(example=["ensemble_member", "sample_id", "time"]).transpose("example", "grid_latitude", "grid_longitude"), pixel_size=8.8).mean(dim="example").drop_sel(freq=0)
    }
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
]

structure_fig = plt.figure(figsize=(5.5, 3.5), layout="constrained")

axd = structure_fig.subplot_mosaic([["PSD"]])
ax = axd["PSD"]

plot_psd(cpm_hr_rapsd, pred_rapsds, ax=ax)

plt.show()

## Figure: subdomain

In [None]:
fd_fig = plt.figure(figsize=(5.5, 4.5), layout="constrained")

subregions = {
    "SE": dict(grid_latitude=slice(10, 26), grid_longitude=slice(38, 54)),
    "NW": dict(grid_latitude=slice(44, 60), grid_longitude=slice(18, 34)),
}

human_names = {
    "DJF": "Winter",
    "JJA": "Summer",
}

axd = fd_fig.subplot_mosaic([["NW"], ["SE"]], sharex=True)

for i, (srname, season) in enumerate([("NW", "DJF"), ("SE", "JJA")]):

    season_mask = cpm_pr["time"]["time.season"] == season
    
    srseason_cpm_pr = cpm_pr.isel(**subregions[srname]).sel(time=season_mask)

    srseason_hist_data = [ 
        dict(data=merged_ds[source]["pred_pr"].sel(model=model).isel(**subregions[srname]).sel(time=season_mask), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() 
    ]

    ax = axd[srname]
    plot_freq_density(srseason_hist_data, ax=ax, target_da=srseason_cpm_pr)
    ax.set_title(f"{srname} {human_names[season]}")
        
    if i == 0:
        ax.set_xlabel("")
        ax.get_legend().remove()
        reg_ax = fd_fig.add_axes([0.8, 0.8, 0.2, 0.2], projection=cp_model_rotated_pole)
        # reg_ax = fd_fig.subplots(1, subplot_kw = dict(projection=cp_model_rotated_pole))
        nw_cpm_pr = cpm_pr.isel(**subregions["NW"])
        se_cpm_pr = cpm_pr.isel(**subregions["SE"])
        plot_map(10*xr.ones_like(nw_cpm_pr.isel(ensemble_member=0, time=0)), ax=reg_ax, style="precip", cl_kwargs=dict(alpha=0.2))
        plot_map(1*xr.ones_like(se_cpm_pr.isel(ensemble_member=0, time=0)), ax=reg_ax, style="precip", cl_kwargs=dict(alpha=0.2))
        reg_ax.annotate("NW", xy=(nw_cpm_pr.grid_longitude.mean().values.item(), nw_cpm_pr.grid_latitude.mean().values.item()), xycoords="data", fontsize="medium", ha="center", va="center")
        reg_ax.annotate("SE", xy=(se_cpm_pr.grid_longitude.mean().values.item(), se_cpm_pr.grid_latitude.mean().values.item()), xycoords="data", fontsize="medium", ha="center", va="center")
        reg_ax.set_extent([-2, 3, -2.5, 2.5], crs=cp_model_rotated_pole)
        

    ax.annotate(f"{string.ascii_lowercase[i]}.", xy=(0.02, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

## Figure: Wet/dry

In [None]:
WET_DAY_THRESHOLDS=[0.1]

model_pr_das = [
    merged_ds[source]["pred_pr"].sel(model=model)
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
]

wet_day_stats = { threshold: wet_prop_stats(model_pr_das, cpm_pr, threshold) for threshold in WET_DAY_THRESHOLDS }

### Wet day prop: domain mean

In [None]:
dfs = [wet_day_stats[threshold].mean(dim=["grid_latitude", "grid_longitude"]).to_dataframe().style.set_table_attributes("style='display:inline'").set_caption(f"Threshold: {threshold}mm/day").format(precision=1).to_html() for threshold in WET_DAY_THRESHOLDS]

IPython.display.display_html(functools.reduce(lambda v, e: v+e, dfs), raw=True)

### Wet day prop: grid box

In [None]:
for threshold in WET_DAY_THRESHOLDS:
    IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)
    
    plot_wet_dry_errors(wet_day_stats[threshold])

    plt.show()

### CCS wet day prop: gridbox

In [None]:
wet_day_change_stats = { threshold: wet_prop_stats(model_pr_das, cpm_pr, threshold=threshold, wet_prop_statistic=wet_day_prop_change) for threshold in WET_DAY_THRESHOLDS }

for threshold in WET_DAY_THRESHOLDS:
    IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)
    
    plot_wet_dry_errors(wet_day_change_stats[threshold]["change in % wet day"], style="change")

    plt.show()

### CCS Wet day prop: domain mean

In [None]:
dfs = [wet_day_change_stats[threshold].mean(dim=["grid_latitude", "grid_longitude"]).to_dataframe().style.set_table_attributes("style='display:inline'").set_caption(f"Threshold: {threshold}mm/day").format(precision=1).to_html() for threshold in WET_DAY_THRESHOLDS]

IPython.display.display_html(functools.reduce(lambda v, e: v+e, dfs), raw=True)