# Evaluation of distribution of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_notebooks.data import prep_eval_data
from mlde_notebooks import plot_map
from mlde_notebooks.display import pretty_table, VAR_RANGES
from mlde_notebooks.distribution import mean_bias, std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias, xr_hist, hist_dist, plot_distribution_figure, compute_metrics, DIST_THRESHOLDS
from mlde_notebooks.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop, threshold_exceeded_prop_error, threshold_exceeded_prop_change, plot_threshold_exceedence_errors, THRESHOLDS
from mlde_utils import cp_model_rotated_pole
from mlde_notebooks import qq_plot, reasonable_quantiles

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_notebooks.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"].rename(f"cpm_{var}") for var in eval_vars }

VAR_DAS = {
    var: xr.merge([ EVAL_DS[source][f"pred_{var}"] for source, models in MODELS.items() for model, spec in models.items() ] + [CPM_DAS[var]])
for var in eval_vars }

MODELLABEL2SPEC = { model: {"source": source} | spec for source, models in MODELS.items() for model, spec in models.items() }

## Figures: Seasonal distribution

* Frequency Density Histogram of rainfall intensities
* Maps of Mean bias ($\frac{\mu_{sample}-\mu_{CPM}}{\mu_{CPM}}$) over all samples, time and ensemble members
* Std Dev Bias $\frac{\sigma_{sample}}{\sigma_{CPM}}$ over all samples, time and ensemble members

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)
    for season, season_ds in VAR_DAS[var].groupby("time.season"):
        IPython.display.display_markdown(f"#### {season}", raw=True)
        hist_das = season_ds[f"pred_{var}"]
        cpm_da = season_ds[f"cpm_{var}"]
        normalize=(var == "pr")
        mean_biases = season_ds[f"pred_{var}"].groupby("model").map(mean_bias, cpm_da=season_ds[f"cpm_{var}"], normalize=normalize)
        std_biases = season_ds[f"pred_{var}"].groupby("model").map(std_bias, cpm_da=season_ds[f"cpm_{var}"], normalize=normalize)

        bias_kwargs = {"style": f"{var}Bias"}
        for fd_kwargs in [{"yscale": "log"}, {"yscale": "linear"}]:
            fig = plt.figure(layout="constrained", figsize=(5.5, 6.5))
            axd = plot_distribution_figure(fig, hist_das, cpm_da, mean_biases, std_biases, MODELLABEL2SPEC, hrange=VAR_RANGES[var], fd_kwargs=fd_kwargs, bias_kwargs=bias_kwargs)
            if var == "relhum150cm":
                axd["Density"].axvline(x=100, color='k', linestyle='--', linewidth=1)
            
        plt.show()

## Seasonal RMS biases and J-S Distances

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)
    for season, season_ds in VAR_DAS[var].groupby("time.season"):
        IPython.display.display_markdown(f"#### {season}", raw=True)

        metrics_ds, thshd_exceedence_ds = compute_metrics(season_ds[f"pred_{var}"], season_ds[f"cpm_{var}"], thresholds=DIST_THRESHOLDS[var])
    
        pretty_table(thshd_exceedence_ds, round=8)
            
        pretty_table(metrics_ds, round=4)

## Seasonal QQ plots

In [None]:
quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"]

for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)

    fig = plt.figure(layout='constrained', figsize=(5.5, 5.5))
    axd = fig.subplot_mosaic(np.array(["DJF", "MAM", "JJA", "SON"]).reshape(2,2))

    for season, season_ds in VAR_DAS[var].groupby("time.season"):
        season_cpm_da = season_ds[f"cpm_{var}"]

        quantiles = reasonable_quantiles(season_cpm_da)
        season_cpm_quantiles = season_cpm_da.quantile(quantiles, dim=quantile_dims).rename("target_q")

        season_pred_da = season_ds[f"pred_{var}"]
        season_pred_quantiles = season_pred_da.quantile(quantiles, dim=quantile_dims).rename("pred_q")

        xlabel = f"CPM \n{xr.plot.utils.label_from_attrs(da=season_cpm_da)}"
        ylabel = f"Predicted \n{xr.plot.utils.label_from_attrs(da=season_pred_da)}"

        qq_plot(axd[season], season_cpm_quantiles, season_pred_quantiles, title=season, xlabel=xlabel, ylabel=ylabel)

    plt.show()