# Figures for perspective paper
# Evaluation of a UoB models on 60km -> 2.2km-4x over Birmingham)

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map
from mlde_analysis.distribution import normalized_mean_bias, plot_freq_density, plot_mean_biases, rms_mean_bias
from mlde_analysis.fractional_contribution import compute_fractional_contribution, plot_fractional_contribution, frac_contrib_change, fc_bins
from mlde_analysis.ccs import compute_changes, plot_changes, plot_tp_fd, bootstrap_seasonal_mean_pr_change_samples
from mlde_analysis.display import pretty_table
from mlde_utils import cp_model_rotated_pole

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
split = "test"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
dataset_configs = {
    "CPM": "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
    "GCM": "bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
}
samples_per_run = 6
sample_configs = {
    "CPM": [
        {
            "label": "Diffusion (cCPM)",
            "sample_specs": [
                {
                    "fq_model_id": "score-sde/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous/bham-4x_12em_pSTV",
                    "checkpoint": "epoch_20",
                    "input_xfm": "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-stan",
                    "dataset": "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
                    "variables": ["pr"],
                },
            ],
            "deterministic": False,
            "CCS": False,
            "color": "tab:blue",
            "order": 10,
        }
    ],
    "GCM": [
        {
            "label": "Diffusion (GCM)",
            "sample_specs": [
                {
                    "fq_model_id": "score-sde/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous/bham-4x_12em_pSTV",
                    "checkpoint": "epoch_20",
                    "input_xfm": "bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-pixelmmsstan",
                    "dataset": "bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
                    "variables": ["pr"],
                }
            ],
            "deterministic": False,
            "CCS": True,
            "PSD": True,
            "UQ": False,
            "color": "tab:green",
            "order": 100,
        },
        {
            "label": "Diff no-bc (GCM)",
            "sample_specs": [
                {
                    "fq_model_id": "score-sde/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous/bham-4x_12em_pSTV",
                    "checkpoint": "epoch_20",
                    "input_xfm": "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-stan",
                    "dataset": "bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr",
                    "variables": ["pr"],
                }
            ],
            "deterministic": False,
            "CCS": True,
            "PSD": True,
            "UQ": False,
            "color": "tab:red",
            "order": 100,
        },
        {
            "label": "Diff no-hum (GCM)",
            "sample_specs": [
                {
                    "fq_model_id": "score-sde/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous/bham-4x_12em_pTV-original",
                    "checkpoint": "epoch_20",
                    "input_xfm": "bham64_gcm-4x_12em_psl-temp4th-vort4th_pr-pixelmmsstan",
                    "dataset": "bham64_gcm-4x_12em_psl-temp4th-vort4th_pr",
                    "variables": ["pr"],
                }
            ],
            "deterministic": False,
            "CCS": True,
            "PSD": True,
            "UQ": False,
            "color": "tab:orange",
            "order": 100,
        },
    ],
}
derived_variables_config={}
eval_vars = ["pr"]

## Data

* Using all 12 ensemble members on 1981-2000, 2021-2040 and 2061-2080 periods for initial UKCP Local release (but using data post graupel bug fix)
* Splits are based on random choice of seasons with equal number of seasons from each time slice
* Target domain and resolution: 64x64@8.8km (4x 2.2km) England and Wales
* Input resolution: 60km (cCPM is CPM coarsened to GCM 60km grid)

## CPMGEM models

Compare:

* cCPM input source
* GCM with bias correction input source
* GCM without bias correction
* GCM input source without humidity (pTV)

### Shared model specs

* Inputs variable (unless otherwise stated): pSTV (pressure at sea level and 4 levels of specific humidity, air temp and relative vorticity)
* Input transforms are fitted on dataset in use (ie separate GCM and CPM versions) while target transform is fitted only at training on the CPM dataset
* No loc-spec params
* 6 samples per example

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

## Fractional contribution (including change from Historic to Future) and mean bias 

In [None]:
ccs_seasons = ["DJF", "JJA"]
ccs_models = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

historical_cpm_pr = CPM_DAS["pr"].where(CPM_DAS["pr"]["time_period"]== "historic", drop=True)

def fig_data(eval_ds, cpm_pr):
    mean_biases = [ dict(data=normalized_mean_bias(eval_ds["GCM"]["pred_pr"].sel(model=model), cpm_pr), label=model) for model, spec in MODELS["GCM"].items() ]

    cpm_fc = compute_fractional_contribution(cpm_pr, bins=fc_bins())

    fraccontrib_data = [
        dict(data=cpm_fc, label="CPM", color="black", source="CPM")
    ]

    fraccontrib_err_data = [
        dict(data=(compute_fractional_contribution(eval_ds[spec["source"]]["pred_pr"].sel(model=model), bins=fc_bins()) - cpm_fc), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
    ]
    
    tp_frcontrib_data = [
                dict(data=frac_contrib_change(cpm_pr, bins=fc_bins()), label="CPM", color="black", source="CPM")
        ] + [
            dict(data=frac_contrib_change(eval_ds[source]["pred_pr"].sel(model=model), bins=fc_bins()), label=model, color=spec["color"], source=source) for source, mconfigs in ccs_models.items() for model, spec in mconfigs.items()
        ]

    return fraccontrib_data, fraccontrib_err_data, tp_frcontrib_data, mean_biases

In [None]:
def plot_figure(fraccontrib_data, fraccontrib_err_data, frcontrib_change_data, mean_biases):
    fig = plt.figure(layout='constrained', figsize=(4.5, 6.5))

    meanb_axes_keys = list(map(lambda x: f"meanb {x['label']}", mean_biases))
    meanb_spec = np.array(meanb_axes_keys).reshape(1,-1)

    # dist_spec = np.array(["Density"] * meanb_spec.shape[1]).reshape(1,-1)
    dist_spec = np.array(["FCErr"] * (meanb_spec.shape[1])).reshape(1,-1)

    ccs_spec = np.array(["Change"] * meanb_spec.shape[1]).reshape(1,-1)

    spec = np.concatenate([dist_spec, meanb_spec, ccs_spec], axis=0)

    axd = fig.subplot_mosaic(spec, gridspec_kw=dict(height_ratios=[3, 2, 3]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in meanb_axes_keys})
    
    ax = axd["FCErr"]
    plot_fractional_contribution(fraccontrib_err_data, ax=ax, title="Error", linewidth=1, ylim=[-0.2, 0.2])
    ax.annotate("a.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

    axes = plot_mean_biases(mean_biases, axd, colorbar=True)
    axes[0].annotate("b.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

    ax = axd["Change"]
    plot_fractional_contribution(frcontrib_change_data, ax=ax, title="Change from Historic to Future", linewidth=1, legend=False, ylim=[-0.3, 0.3])
    ax.annotate("c.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")

    # add inset of CPM frac contrib to error plot for context
    ax = fig.add_axes([0.79, 0.82, 0.18, 0.18])
    plot_fractional_contribution(fraccontrib_data, ax=ax, title="", linewidth=1, legend=False,)
    ax.set_title("CPM frac. contrib.", pad=1, fontsize="xx-small")
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.tick_params(axis='both', which='major', labelsize="xx-small")
    ax.tick_params(axis='both', which='minor', labelsize="xx-small")
    ax.set_aspect(1)
    
    return fig

### Annual Figure

In [None]:
fraccontrib_data, fraccontrib_err_data, tp_frcontrib_data, mean_biases = fig_data(EVAL_DS, CPM_DAS["pr"])

rms_mean_biases = xr.concat(
        [ ds["pred_pr"].groupby("model", squeeze=False).map(lambda x: rms_mean_bias(x, CPM_DAS["pr"])) for ds in EVAL_DS.values() ],
        dim="model",
    )

plot_figure(fraccontrib_data, fraccontrib_err_data, tp_frcontrib_data, mean_biases)
plt.show()

pretty_table(rms_mean_biases.rename("Root Mean Square Mean Bias (mm/day)"), round=2)

### Seasonal figures

In [None]:
for season in ["DJF", "JJA"]:
    IPython.display.display_markdown(f"#### {season}", raw=True)
    seasonal_cpm_pr = CPM_DAS["pr"].where(CPM_DAS["pr"]["time.season"] == season)
    seasonal_eval_ds = { source: ds.where(ds["time.season"] == season) for source, ds in EVAL_DS.items() }
    seasonal_fig = plot_figure(*fig_data(seasonal_eval_ds, seasonal_cpm_pr))
    seasonal_fig.suptitle(season)
    plt.show()

    rms_mean_biases = xr.concat(
            [ ds["pred_pr"].groupby("model", squeeze=False).map(lambda x: rms_mean_bias(x, seasonal_cpm_pr)) for ds in seasonal_eval_ds.values() ],
            dim="model",
        )
    pretty_table(rms_mean_biases.rename("Root Mean Square Mean Bias (mm/day)"), round=2)

## CCS Figures

In [None]:
ccs_seasons = ["DJF", "JJA"]
ccs_models = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False) } for source, mconfigs in MODELS.items() }

ccs_pred_pr_das = [ EVAL_DS[source]["pred_pr"].sel(model=model) for source, models in ccs_models.items() for model in models ]

### 99th percentile changes

In [None]:
from functools import partial

q = 0.99
IPython.display.display_markdown(f"#### Quantile: {q}", raw=True)

for ccs_pred_da in ccs_pred_pr_das:
    changes = compute_changes([ccs_pred_da], CPM_DAS["pr"], ccs_seasons, stat_func=partial(xr.DataArray.quantile, q=q))
    
    change_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
    plot_changes(changes, ccs_seasons, change_fig, show_change=[ccs_pred_da.model.values.item()])

plt.show()

## Rough: wet-day mean

In [None]:
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop_change, plot_threshold_exceedence_errors

THRESHOLDS = {"pr": [0.1, 1.0]}

def wd_mean(da, threshold):
    dims = set(da.dims) - set(["grid_latitude", "grid_longitude"])
    return da.where(da > threshold).mean(dim=dims).rename("wd mean (mm/day)")

def wd_mean_change(da, threshold):
    from_da = da.where(da["time_period"] == "historic", drop=True)
    to_da = da.where(da["time_period"] == "future", drop=True)

    from_wd_mean = wd_mean(from_da, threshold=threshold).rename("Historic wd mean (mm/day)")
    to_wd_mean = wd_mean(to_da, threshold=threshold).rename("Future wd mean (mm/day)")

    change = (to_wd_mean - from_wd_mean).rename("Change in wd mean (mm/day)")

    return xr.merge([from_wd_mean, to_wd_mean, change])

def wd_mean_bias(pred_da, cpm_da, threshold):
    pred_wd_mean = wd_mean(pred_da, threshold)
    cpm_wd_mean = wd_mean(cpm_da, threshold)

    return (pred_wd_mean - cpm_wd_mean)

### Wet-day mean

In [None]:
model_das = xr.concat([
    EVAL_DS[source]["pred_pr"].sel(model=[model])
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
], dim="model")

change_stats = { var: { threshold: threshold_exceeded_prop_stats(model_das, CPM_DAS["pr"], threshold=threshold, threshold_exceeded_prop_statistic=wd_mean) for threshold in thresholds } for var, thresholds in THRESHOLDS.items() if var in eval_vars }

for var, thresholds in THRESHOLDS.items():
    if var in eval_vars:
        for threshold in thresholds:
            IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)

            plot_threshold_exceedence_errors(change_stats[var][threshold], style="change")

            plt.show()

### Change from historic to future

In [None]:
model_das = {var: xr.concat([
    EVAL_DS[source][f"pred_{var}"].sel(model=[model])
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
], dim="model") for var, thresholds in THRESHOLDS.items() if var in eval_vars}

change_stats = {
    var: {
        threshold: threshold_exceeded_prop_stats(model_das[var], CPM_DAS["pr"], threshold=threshold, threshold_exceeded_prop_statistic=wd_mean_change) for threshold in thresholds
    } for var, thresholds in THRESHOLDS.items() if var in eval_vars
}


for var, thresholds in THRESHOLDS.items():
    if var in eval_vars:
        for threshold in thresholds:
            IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)

            plot_threshold_exceedence_errors(change_stats[var][threshold][f"Change in wd mean (mm/day)"], style="change")

            plt.show()

pretty_table(xr.merge([
    change_stats[var][threshold].mean(dim=["grid_latitude", "grid_longitude"]).expand_dims({"threshold": [threshold]})
    for var, thresholds in THRESHOLDS.items() if var in eval_vars for threshold in thresholds
]), dim_order=["threshold", "season", "model"])

## Rough: Wet day frequency

In [None]:
model_das = {var: xr.concat([
    EVAL_DS[source][f"pred_{var}"].sel(model=[model])
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
], dim="model") for var, thresholds in THRESHOLDS.items() if var in eval_vars}

change_stats = {
    var: {
        threshold: threshold_exceeded_prop_stats(model_das[var], CPM_DAS["pr"], threshold=threshold, threshold_exceeded_prop_statistic=threshold_exceeded_prop_change) for threshold in thresholds
    } for var, thresholds in THRESHOLDS.items() if var in eval_vars
}


for var, thresholds in THRESHOLDS.items():
    if var in eval_vars:
        for threshold in thresholds:
            IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)

            plot_threshold_exceedence_errors(change_stats[var][threshold]["change in % threshold exceeded"], style="change")

            plt.show()

pretty_table(xr.merge([
    change_stats[var][threshold].mean(dim=["grid_latitude", "grid_longitude"]).expand_dims({"threshold": [threshold]})
    for var, thresholds in THRESHOLDS.items() if var in eval_vars for threshold in thresholds
]), dim_order=["threshold", "season", "model"])