# Evaluation of distribution of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import itertools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy
import seaborn as sns
import xarray as xr
import xskillscore as xss

from mlde_utils import cp_model_rotated_pole

from mlde_analysis.data import prep_eval_data
from mlde_analysis.display import pretty_table, VAR_RANGES
from mlde_analysis.mv_distribution import compute_hist2d, plot_hist2d_figure

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *
eval_vars=["pr", "relhum150cm", "tmean150cm"]

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

## MV distribution distances

In [None]:
# for source, ds in EVAL_DS.items():
#     target_h, target_bins = np.histogramdd(
#         [ds[f"target_{var}"].values.reshape(-1) for var in eval_vars], 
#         bins=50,
#     )
#     for model, mds in ds.groupby("model"):
#         pred_h, pred_bins = np.histogramdd([mds[f"pred_{var}"].values.reshape(-1) for var in eval_vars], bins=target_bins)
#         print(model, "rmse", np.sqrt(np.mean(np.power(target_h - pred_h, 2))))
        
    # for model, mds in ds.groupby("model"):
    #     npt=np.stack([ds[f"target_{var}"].values.reshape(-1) for var in eval_vars], axis=1)
    #     npp=np.stack([mds[f"pred_{var}"].values.reshape(-1) for var in eval_vars], axis=1)
    #     print(model, "Wd samples", scipy.stats.wasserstein_distance_nd(npt, npp))
        
    # for model, mds in ds.groupby("model"):
    #     pred_h, pred_bins = np.histogramdd([mds[f"pred_{var}"].values.reshape(-1) for var in eval_vars], bins=target_bins)
    #     print(model, "Wd hist", scipy.stats.wasserstein_distance_nd(
    #         np.stack(np.meshgrid(target_bins[0][:-1], target_bins[1][:-1], target_bins[2][:-1], indexing="ij"), axis=-1).reshape(-1, 3), 
    #         np.stack(np.meshgrid(pred_bins[0][:-1], pred_bins[1][:-1], pred_bins[2][:-1], indexing="ij"), axis=-1).reshape(-1, 3),
    #         target_h.reshape(-1), pred_h.reshape(-1)
    #     ))

## Figure: joint distribution

* 2-D Frequency Density Histograms

In [None]:
def extract_and_compute_hist2d(ds, var_pair, xbins, ybins):
    x_pred = ds[f"pred_{var_pair[0]}"]
    y_pred = ds[f"pred_{var_pair[1]}"]
    x_target = ds[f"target_{var_pair[0]}"]
    y_target = ds[f"target_{var_pair[1]}"]
    
    return compute_hist2d(x_pred, y_pred, x_target, y_target, xbins=xbins, ybins=ybins)

### Annual

In [None]:
rmse_hist2ds = []
for source in EVAL_DS.keys():
    ds = EVAL_DS[source]
    for var_pair in itertools.combinations(eval_vars, 2):
        IPython.display.display_markdown(f"#### {var_pair}", raw=True)
        xbins = np.histogram_bin_edges([], bins=50, range=VAR_RANGES[var_pair[0]])
        ybins = np.histogram_bin_edges([], bins=50, range=VAR_RANGES[var_pair[1]])
        hist2d_ds = extract_and_compute_hist2d(ds, var_pair, xbins, ybins)
        
        fig, axd = plot_hist2d_figure(hist2d_ds, xbins, ybins)
        # fig.suptitle(f"Annual {var_pair}")
        if var_pair[0] == "relhum150cm":
            for ax in axd.values():
                ax.axvline(x=100, color='k', linestyle='--', linewidth=1)
        if var_pair[1] == "relhum150cm":
            for ax in axd.values():
                ax.axhline(y=100, color='k', linestyle='--', linewidth=1)
        plt.show()

        rmse_hist2ds.append(np.sqrt(((hist2d_ds["pred_2d_density"] - hist2d_ds["target_2d_density"])**2).mean(dim=["xbins", "ybins"])).rename("hist2d_rmse").expand_dims(vars=[f"{var_pair}"]))
pretty_table(xr.merge(rmse_hist2ds).transpose("vars", "model"), round=6)

### Seasonal

In [None]:
rmse_hist2ds = []
for source in EVAL_DS.keys():
    ds = EVAL_DS[source]
    for var_pair in itertools.combinations(eval_vars, 2):
        IPython.display.display_markdown(f"#### {var_pair}", raw=True)
        xbins = np.histogram_bin_edges([], bins=50, range=VAR_RANGES[var_pair[0]])
        ybins = np.histogram_bin_edges([], bins=50, range=VAR_RANGES[var_pair[1]])
        hist2d_ds = ds.groupby("time.season").map(extract_and_compute_hist2d, var_pair=var_pair, xbins=xbins, ybins=ybins)
        
        for season, season_hist2d_ds in hist2d_ds.groupby("season"):
            if season not in ["DJF", "JJA"]: continue
            IPython.display.display_markdown(f"##### {season}", raw=True)
            
            fig, axd = plot_hist2d_figure(season_hist2d_ds, xbins, ybins)
            # fig.suptitle(f"{season} {var_pair}")
            if var_pair[0] == "relhum150cm":
                for ax in axd.values():
                    ax.axvline(x=100, color='k', linestyle='--', linewidth=1)
            if var_pair[1] == "relhum150cm":
                for ax in axd.values():
                    ax.axhline(y=100, color='k', linestyle='--', linewidth=1)
            plt.show()
    
            rmse_hist2ds.append(np.sqrt(((hist2d_ds["pred_2d_density"] - hist2d_ds["target_2d_density"])**2).mean(dim=["xbins", "ybins"])).rename("hist2d_rmse").expand_dims(vars=[f"{var_pair}"]))
pretty_table(xr.merge(rmse_hist2ds).transpose("vars", "model", "season"), round=6)