# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map
from mlde_analysis.ccs import compute_changes, plot_changes, plot_tp_fd, plot_hist_per_tp
from mlde_analysis.distribution import normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop_change, plot_threshold_exceedence_errors, THRESHOLDS, wd_mean, wd_mean_change, wd_mean_bias
from mlde_analysis.display import pretty_table
from mlde_analysis import display
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

MODELLABEL2SPEC = { model: {"source": source} | spec for source, models in MODELS.items() for model, spec in models.items() }

PRED_DAS = { var: xr.concat([ ds[f"pred_{var}"] for ds in EVAL_DS.values() ], dim="model") for var in eval_vars }

In [None]:
CCS_SEASONS = ["DJF", "MAM", "JJA", "SON"]

In [None]:
CCS_MODELS = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

CCS_PRED_DAS = {var: PRED_DAS[var].sel(model=list([model for source, models in CCS_MODELS.items() for model in models])) for var in eval_vars }

CCS_DS = xr.combine_by_coords([*list(CPM_DAS.values()), *list(CCS_PRED_DAS.values())]).sel(model=[m for ms in CCS_MODELS.values() for m in ms.keys()])

# Threshold proportion change

### CCS threshold exceedence frequency change: gridbox

In [None]:
change_stats = {    
    var: { 
        threshold: threshold_exceeded_prop_stats(CCS_DS[f"pred_{var}"], CPM_DAS[var], threshold=threshold, threshold_exceeded_prop_statistic=threshold_exceeded_prop_change) for threshold in thresholds
    } for var, thresholds in THRESHOLDS.items() if var in eval_vars
}

for var, thresholds in THRESHOLDS.items():
    if var in eval_vars:
        for threshold in thresholds:
            IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)
            
            plot_threshold_exceedence_errors(change_stats[var][threshold][f"change in % threshold exceeded"], style="change")
        
            plt.show()

### CCS threshold exceedence frequency change: domain mean

In [None]:
for var, thresholds in THRESHOLDS.items():
    if var not in eval_vars:
        continue
    for threshold in thresholds:
        pretty_table(
            change_stats[var][threshold].mean(dim=["grid_latitude", "grid_longitude"]),
            dim_order=["season", "model"],
            caption=f"{var} threshold: {threshold}{display.ATTRS[var]['units']}",
        )

### CCS wet-day mean intensity change: gridbox

In [None]:
model_das = {var: [
    EVAL_DS[source][f"pred_{var}"].sel(model=model)
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
] for var, thresholds in THRESHOLDS.items() if var in eval_vars}

change_stats = {    
    var: { 
        threshold: threshold_exceeded_prop_stats(CCS_DS[f"pred_{var}"], CPM_DAS[var], threshold=threshold, threshold_exceeded_prop_statistic=wd_mean_change) for threshold in thresholds
    } for var, thresholds in THRESHOLDS.items() if var in eval_vars
}


for var, thresholds in THRESHOLDS.items():
    if var in eval_vars:
        for threshold in thresholds:
            IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)
            
            plot_threshold_exceedence_errors(change_stats[var][threshold][f"Change in wd mean (mm/day)"], style="change")
        
            plt.show()

In [None]:
for var, thresholds in THRESHOLDS.items():
    if var not in eval_vars:
        continue
    for threshold in thresholds:
        pretty_table(
            change_stats[var][threshold].mean(dim=["grid_latitude", "grid_longitude"]).expand_dims({"threshold": [threshold], "variable": [var]}),
            dim_order=["variable", "threshold", "season", "model"],
            caption=f"{var} threshold: {threshold}{display.ATTRS[var]['units']}",
        )