# Figures for perspective paper
# Evaluation of a UoB models on 60km -> 2.2km-4x over Birmingham)

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_notebooks.data import prep_eval_and_model_data
from mlde_notebooks import plot_map
from mlde_notebooks.distribution import normalized_mean_bias, compute_fractional_contribution, plot_freq_density, plot_fractional_contribution, plot_mean_biases, rms_mean_bias
from mlde_notebooks.ccs import compute_changes, plot_changes, plot_tp_fd, bootstrap_seasonal_mean_pr_change_samples
from mlde_utils import cp_model_rotated_pole

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

## Without humidity

In [None]:
split = "test"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 6
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch_20",
            "input_xfm": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-stan",
            "label": "Diffusion (cCPM)",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "PSD": True,
            "color": "blue",
            "order": 10,
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch_20",
            "input_xfm": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-pixelmmsstan",
            "label": "Diffusion (GCM)",
            "dataset": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "CCS": True,
            "color": "green",
            "order": 20,
            
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch_20",
            "input_xfm": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-stan",
            "label": "Diff no-bc (GCM)",
            "dataset": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "CCS": True,
            "color": "red",
            "order": 21,
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch_20",
            "input_xfm": "bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season-pixelmmsstan",
            "label": "Diff no-hum (GCM)",
            "dataset": "bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "CCS": True,
            "color": "orange",
            "order": 22,
            
        },
    ],
}

In [None]:
merged_ds, MODELS = prep_eval_and_model_data(data_configs, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
merged_ds

In [None]:
cpm_pr = merged_ds["CPM"]["target_pr"]

In [None]:
hist_data = [ dict(data=merged_ds[spec["source"]]["pred_pr"].sel(model=model), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() ]

mean_biases = [ dict(data=normalized_mean_bias(merged_ds["GCM"]["pred_pr"].sel(model=model), cpm_pr), label=model) for model, spec in MODELS["GCM"].items() ]

ccs_seasons = ["DJF", "JJA"]
ccs_models = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

historical_cpm_pr = cpm_pr.where(cpm_pr["time_period"]== "historic", drop=True)

hrange=(0,250)
bins = np.histogram_bin_edges([], bins=50, range=hrange)

def frac_contrib_change(pr_da, bins, range):
    fpr = pr_da.where(pr_da["time_period"]=="future", drop=True)
    ffraccontrib = compute_fractional_contribution(fpr, bins=bins, range=range)
    
    hpr = pr_da.where(pr_da["time_period"]=="historic", drop=True)
    hfraccontrib = compute_fractional_contribution(hpr, bins=bins, range=range)
    
    return ffraccontrib - hfraccontrib

tp_key = "future"
tp_cpm_pr = cpm_pr.where(cpm_pr["time_period"]==tp_key, drop=True)
tp_frcontrib_data = [ 
        dict(data=frac_contrib_change(cpm_pr, bins, hrange), label="CPM", color="black", source="CPM")
    ] + [ 
        dict(data=frac_contrib_change(merged_ds[source]["pred_pr"].sel(model=model), bins, hrange), label=model, color=spec["color"], source=source) for source, mconfigs in ccs_models.items() for model, spec in mconfigs.items() 

    ]

In [None]:
fig = plt.figure(layout='constrained', figsize=(4.5, 6.5))

meanb_axes_keys = list(map(lambda x: f"meanb {x['label']}", mean_biases))
meanb_spec = np.array(meanb_axes_keys).reshape(1,-1)

dist_spec = np.array(["Density"] * meanb_spec.shape[1]).reshape(1,-1)
ccs_spec = np.array([tp_key] * meanb_spec.shape[1]).reshape(1,-1)

spec = np.concatenate([dist_spec, meanb_spec, ccs_spec], axis=0)

axd = fig.subplot_mosaic(spec, gridspec_kw=dict(height_ratios=[3, 2, 3]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in meanb_axes_keys})

ax = axd["Density"]

plot_fractional_contribution(hist_data, ax=ax, target_da=cpm_pr, title="All periods", linewidth=1,)
ax.annotate("a.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

axes = plot_mean_biases(mean_biases, axd, colorbar=True)
axes[0].annotate("b.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

ax = axd[tp_key]
for pred in tp_frcontrib_data:
    ax.stairs(
        pred["data"],
        bins,
        baseline=None,
        fill=False,
        color=pred["color"],
        alpha=pred.get("alpha", 0.95),
        linestyle=pred.get("linestyle", "-"),
        linewidth=1,
        label=f"{pred['label']}",
    )
ax.set_title(f"Change from Historic to Future")
ax.set_xlabel("Precip (mm/day)")
ax.set_ylabel("Change in frac. contrib.")

# # linthresh based on minimum value
# linthresh = (min((map(lambda h: np.min(np.fabs(h["data"][h["data"].nonzero()])), tp_frcontrib_data))))
# print(linthresh)
# linthresh = (10 ** math.floor(math.log10(linthresh)))/2
# print(linthresh)

# # linthresh based on minimum value from CPM
# linthresh = np.min(np.fabs(tp_frcontrib_data[0]["data"][tp_frcontrib_data[0]["data"].nonzero()]))
# print(linthresh)
# linthresh = (10 ** math.floor(math.log10(linthresh)))/2
# print(linthresh)

# linthreshold based on single observation at reasonably high precip
mindensity = 1 / (np.product(cpm_pr.shape)/3) # divide by 3 as considering single time periods
print(mindensity)
linthresh = 10 ** (math.floor(math.log10(100*mindensity))) / 2 # multiply by 100 as frac contrib is density times intensity
print(linthresh)

ax.set_yscale("symlog", linthresh=linthresh)
ax.tick_params(axis="both", which="major")
ax.legend(ncols=2, fontsize="small")
ax.annotate("c.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

plt.show()

rms_mean_biases = xr.concat([ merged_ds[source]["pred_pr"].groupby("model", squeeze=False).map(lambda x: rms_mean_bias(x, cpm_pr)) for source in merged_ds.keys() ], dim="model")

IPython.display.display_html(rms_mean_biases.rename("Root Mean Square Mean Bias (mm/day)").to_dataframe().round(2).to_html(), raw=True)

### CCS Figures

In [None]:
ccs_seasons = ["DJF", "JJA"]
ccs_models = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False) } for source, mconfigs in MODELS.items() }

ccs_pred_pr_das = [ merged_ds[source]["pred_pr"].sel(model=model) for source, models in ccs_models.items() for model in models ]

### Mean change

In [None]:
changes = compute_changes(ccs_pred_pr_das, cpm_pr, ccs_seasons, stat_func=xr.DataArray.mean)
    
mean_change_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
plot_changes(changes, ccs_seasons, mean_change_fig, show_change=[ccs_pred_pr_das[0]["model"].data.item()])

plt.show()

### Quantile changes

In [None]:
from functools import partial

for q in [0.90, 0.95, 0.99]:
    IPython.display.display_markdown(f"#### Quantile: {q}", raw=True)

    changes = compute_changes(ccs_pred_pr_das, cpm_pr, ccs_seasons, stat_func=partial(xr.DataArray.quantile, q=q))
    
    mean_change_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
    plot_changes(changes, ccs_seasons, mean_change_fig, show_change=[ccs_pred_pr_das[0]["model"].data.item()])
        
    plt.show()