# Evaluation of distribution of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import itertools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy
import seaborn as sns
import xarray as xr
import xskillscore as xss

from mlde_utils import cp_model_rotated_pole

from mlde_analysis.data import prep_eval_data
from mlde_analysis.display import pretty_table, VAR_RANGES
from mlde_analysis.distribution import plot_freq_density
from mlde_analysis.mv_distribution import compute_hist2d, plot_hist2d_figure

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *
eval_vars=["pr", "relhum150cm", "tmean150cm"]

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

## Figure: conditional distribution

* Frequency Density Histogram of one variable given another variable over a threshold

In [None]:
source = "CPM"
ds = EVAL_DS[source]
hot_ds = ds.where(ds["target_tmean150cm"] > 303)

var = "relhum150cm"

target_da = hot_ds[f"target_{var}"]

hist_data = {var: [ dict(data=hot_ds[f"pred_{var}"].sel(model=model), label=model, color=spec["color"]) for model, spec in MODELS[source].items() ] }

fig = plt.figure(layout='constrained', figsize=(5.5, 2.5))
axd = fig.subplot_mosaic([["Density"]])
ax = axd["Density"]
plot_freq_density(hist_data[var], ax=ax, target_da=target_da)

## Multivariate wet/dry

In [None]:
def extract_and_compute_hist2d(ds, var, threshold, xbins, ybins):
    wet_pred_da = ds[f"pred_{var}"].where(ds["pred_pr"] > threshold)
    wet_pred_pr = ds["pred_pr"].where(ds["pred_pr"] > threshold)
    wet_target_da = ds[f"target_{var}"].where(ds[ "target_pr"] > threshold)
    wet_target_pr = ds["target_pr"].where(ds["target_pr"] > threshold)
    
    return compute_hist2d(wet_pred_da, wet_pred_pr, wet_target_da, wet_target_pr, xbins=xbins, ybins=ybins)

source = "CPM"
ds = EVAL_DS[source]

THRESHOLDS = [0.1, 1, 10]

for threshold in THRESHOLDS:
    IPython.display.display_markdown(f"### Thresold: pr {threshold}mm/day", raw=True)
    wet_pred_pr = ds["pred_pr"].where(ds["pred_pr"] > threshold)
    wet_target_pr = CPM_DAS["pr"].where(CPM_DAS["pr"] > threshold)
    
    fig = plt.figure(layout='constrained', figsize=(3.5, 2.5))
    axd = fig.subplot_mosaic([["wet pr"]])
    hist_data = [ dict(data=wet_pred_pr.sel(model=model), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ] 
    ax = axd["wet pr"]
    plot_freq_density(hist_data, ax=ax, target_da=wet_target_pr, linewidth=1, hrange=VAR_RANGES["pr"], legend=True)
    ax.set_title(f"pr wet days")
    plt.show()
    
    for var in eval_vars:
        if var == "pr": continue 
        IPython.display.display_markdown(f"#### {var}", raw=True)
        
        fig = plt.figure(layout='constrained', figsize=(6, 3.5))
    
        spec = [ ["linear dry", "log dry"], ["linear wet", "log wet"]]
        
        axd = fig.subplot_mosaic(spec, sharex=True, sharey=False)#, gridspec_kw=dict(height_ratios=[3, 2, 2]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in meanb_axes_keys + stddevb_axes_keys})
        
        wet_pred_da = ds[f"pred_{var}"].where(ds["pred_pr"] > threshold)
        wet_target_da = ds[f"target_{var}"].where(ds["target_pr"] > threshold)
    
        dry_pred_da = ds[f"pred_{var}"].where(ds["pred_pr"] <= threshold)
        dry_target_da = ds[f"target_{var}"].where(ds["target_pr"] <= threshold)
        
        hrange=VAR_RANGES[var]

        for yscale in ["log", "linear"]:
            hist_data = [ dict(data=dry_pred_da.sel(model=model), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ] 
            ax = axd[f"{yscale} dry"]
            plot_freq_density(hist_data, ax=ax, target_da=dry_target_da, linewidth=1, hrange=hrange, legend=(yscale=="linear"), yscale=yscale)
            ax.set_title(f"{var} dry days")
                         
            hist_data = [ dict(data=wet_pred_da.sel(model=model), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ] 
            ax = axd[f"{yscale} wet"]
            plot_freq_density(hist_data, ax=ax, target_da=wet_target_da, linewidth=1, hrange=hrange, legend=False)
            ax.set_title(f"{var} wet days")


        xbins = np.histogram_bin_edges([], bins=50, range=VAR_RANGES[var])
        ybins = np.histogram_bin_edges([], bins=50, range=VAR_RANGES["pr"])
        hist2d_ds = ds.groupby("time.season").map(extract_and_compute_hist2d, threshold=threshold, var=var, xbins=xbins, ybins=ybins)
        
        for season, season_hist2d_ds in hist2d_ds.groupby("season"):
            if season not in ["DJF", "JJA"]: continue
            IPython.display.display_markdown(f"##### {season}", raw=True)
    
            fig2d, axd2d = plot_hist2d_figure(season_hist2d_ds, xbins, ybins)
            fig2d.suptitle(f"{season} Wet day ({var}, pr)")

            if var == "relhum150cm":
                for ax in axd2d.values():
                    ax.axvline(x=100, color='k', linestyle='--', linewidth=1)
        
        plt.show()

        rmse_hist2d = np.sqrt(((hist2d_ds["pred_2d_density"] - hist2d_ds["target_2d_density"])**2).mean(dim=["xbins", "ybins"])).rename("hist2d_rmse")
        pretty_table(rmse_hist2d, round=6)