# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

import math
import os

import cftime
import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pysteps
import xarray as xr

from mlde_utils.utils import cp_model_rotated_pole, plot_grid, prep_eval_data, plot_examples, distribution_figure, plot_mean_bias, plot_std_bias, plot_psd, scatter_plots, seasonal_distribution_figure, compute_gridspec
from mlde_utils.plotting import create_map_fig, qq_plot

In [None]:
time_slices = {
    "TS1": (cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2000, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "TS2": (cftime.Datetime360Day(2020, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2040, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "TS3": (cftime.Datetime360Day(2060, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2080, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
}

In [None]:
split = "val"
data_config = [
    {
        "datasets": {
            "CPM": "bham_gcmx-4x_psl-temp-vort_random-season",
            "GCM": "bham_60km-4x_psl-temp-vort_random-season",
        },
        "runs": [
                ("score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-season-IstanTsqrturrecen-shuffle-fix",
                "epoch-100",
                "stan",
                "PslTV shuffle",)
            ],
    },
    {
        "datasets": {
            "CPM": "bham_gcmx-4x_pr_random",
            "GCM": "bham_60km-4x_pr_random",
        },
        "runs": [
            ("id-pr", "epoch-0", "", "LR precip"),
        ],
    }
    
]
desc = """
Describe in more detail the models being compared
"""
# the datasets to use for comparisons like PSD which need default datasets with CPM-based hi-res precip and GCM-based lo-res precip respectively
gcm_lr_lin_pr_dataset = "bham_60km-4x_linpr_random"
cpm_hr_pr_dataset = "bham_60km-4x_linpr_random"

In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds = xr.merge([ prep_eval_data(c["datasets"], c["runs"], split) for c in data_config ])
merged_ds

## Spread

In [None]:
ensemble_spread = np.power((merged_ds["pred_pr"] - merged_ds["pred_pr"].mean(dim="sample_id")), 2).mean(dim="sample_id")
ensemble_mean_error = np.power((merged_ds["target_pr"] - merged_ds["pred_pr"].mean(dim="sample_id")), 2)


# rms_spread 

for model in ensemble_spread["model"].values:
    plt.scatter(ensemble_mean_error.sel(source="CPM").sel(model=model).mean(dim="time"), ensemble_spread.sel(source="CPM").sel(model=model).mean(dim="time"))
plt.show()



# plt.scatter(ensemble_mean_error.sel(source="GCM").isel(model=0).mean(dim="time"), ensemble_spread.sel(source="GCM").isel(model=0).mean(dim="time"))
# plt.show()

## Pixel distribution

In [None]:
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -8, -1)])

cpm_pr = merged_ds.sel(source="CPM")["target_pr"]

In [None]:
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    distribution_figure(merged_ds.sel(source=source), cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

### Seasonal

In [None]:
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    seasonal_distribution_figure(merged_ds.sel(source=source), cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

### Subregions

In [None]:
subregions = {
    "SE": dict(grid_latitude=slice(10, 26), grid_longitude=slice(38, 54)),
    "NW": dict(grid_latitude=slice(44, 60), grid_longitude=slice(18, 34)),
}

fig, axd = create_map_fig([["subregions"]])

plot_grid(10*xr.ones_like(merged_ds["pred_pr"].sel(source="CPM").isel(sample_id=0, model=0, time=0)).isel(**subregions["NW"]), ax=axd["subregions"], style="precip")
plot_grid(1*xr.ones_like(merged_ds["pred_pr"].sel(source="CPM").isel(sample_id=0, model=0, time=0)).isel(**subregions["SE"]), ax=axd["subregions"], style="precip")
    
axd["subregions"].set_extent([-2, 3, -2.5, 2.5], crs=cp_model_rotated_pole)

quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -7, -1)])

#### NW

In [None]:
srname = "NW"
srbnds = subregions[srname]

for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h6>{source}</h6>", raw=True)
    distribution_figure(merged_ds.sel(source=source).isel(**srbnds), cpm_pr.isel(**srbnds), quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

##### NW Winter

In [None]:
srseason="DJF"
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
    srseason_mask = merged_ds.sel(source=source).isel(**subregions["NW"])["time.season"] == srseason
    srseason_mask_sample_ds = merged_ds.sel(source=source).isel(**subregions["NW"]).sel(time=srseason_mask)
    srseason_mask_cpm_pr = cpm_pr.isel(srbnds).sel(time=srseason_mask)
    distribution_figure(srseason_mask_sample_ds, srseason_mask_cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

#### SE

In [None]:
srname = "SE"
srbnds = subregions[srname]
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h6>{source}</h6>", raw=True)
    distribution_figure(merged_ds.sel(source=source).isel(**srbnds), cpm_pr.isel(**srbnds), quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

##### SE Summer

In [None]:
srseason = "JJA"
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
    srseason_mask = merged_ds.sel(source=source).isel(**subregions["NW"])["time.season"] == srseason
    srseason_mask_sample_ds = merged_ds.sel(source=source).isel(**subregions["NW"]).sel(time=srseason_mask)
    srseason_mask_cpm_pr = cpm_pr.isel(srbnds).sel(time=srseason_mask)
    distribution_figure(srseason_mask_sample_ds, srseason_mask_cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

### Scatter plots

In [None]:
# for source in ds["source"].values:
#     IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
#     scatter_plots(merged_ds.sel(source=source), cpm_pr)

## Mean distrubution

In [None]:
quantiles = np.concatenate([np.linspace(0.1,0.8,8), np.linspace(0.9,0.99,10), np.linspace(0.991,0.999,9)])

mean_ds = merged_ds[["target_pr", "pred_pr"]].mean(dim=["grid_longitude", "grid_latitude"])
cpm_mean_pr = mean_ds.sel(source="CPM")["target_pr"]

In [None]:
for source in mean_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    distribution_figure(mean_ds.sel(source=source), cpm_mean_pr, quantiles, quantile_dims=["time"])

In [None]:
for source in mean_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    seasonal_distribution_figure(mean_ds.sel(source=source), cpm_mean_pr, quantiles, quantile_dims=["time"])

### Scatter plots

In [None]:
# for source in ds["source"].values:
#     IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
#     scatter_plots(mean_ds.sel(source=source), target_mean_pr)

## Climate change signal

### Pixel quantiles

In [None]:
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -7, -1)])

historical_ts_mask = (merged_ds["time"] >= time_slices["TS1"][0]) & (merged_ds["time"] <= time_slices["TS1"][1])

historical_cpm_pr_quantiles = cpm_pr.sel(time=historical_ts_mask).quantile(quantiles, dim=["time", "grid_latitude", "grid_longitude"])

for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    
    fig, axd = plt.subplot_mosaic([time_slices.keys()], figsize=(18, 6))
    for ts_key, (ts_start, ts_end) in time_slices.items():
        # IPython.display.display_html(f"<h2>{ts_start} to {ts_end}</h2>", raw=True)
        ts_mask = (merged_ds["time"] >= ts_start) & (merged_ds["time"] <= ts_end)
        ts_pred_pr = merged_ds.sel(source=source).sel(time=ts_mask)["pred_pr"]
        ts_cpm_pr = cpm_pr.sel(time=ts_mask)
        # .expand_dims(model=[human_name])
        
        ts_sample_quantiles = ts_pred_pr.quantile(quantiles, dim=["time", "grid_latitude", "grid_longitude", "sample_id"])
        ts_cpm_quantiles = ts_cpm_pr.quantile(quantiles, dim=["time", "grid_latitude", "grid_longitude"])
        ts_quantiles = xr.concat([ts_sample_quantiles, ts_cpm_quantiles.expand_dims(model=["CPM"])], dim="model")
        ax = axd[ts_key]
        qq_plot(ax, historical_cpm_pr_quantiles, ts_quantiles, title=f"{ts_key} sample quantiles vs TS1 CPM quantiles", xlabel="TS1 CPM precip (mm day$^{-1}$)", tr=200, guide_label=None)
    plt.show()

### Mean change

In [None]:
cpm_pr_historical_mean = cpm_pr.sel(time=historical_ts_mask).mean(dim=["time"])

for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    for ts_key, (ts_start, ts_end) in time_slices.items():
        IPython.display.display_html(f"<h2>{ts_key}</h2>", raw=True)
        
        target_name = "$\\mu_{CPM}$"
        models = merged_ds.sel(source=source)["model"].values
        grid_spec = compute_gridspec(models, target_name)
        fig, axd = plt.subplot_mosaic(
            grid_spec,
            figsize=(grid_spec.shape[1] * 5.5, grid_spec.shape[0] * 5.5),
            subplot_kw=dict(projection=cp_model_rotated_pole),
            constrained_layout=True,
        )
        
        ts_mask = (merged_ds["time"] >= ts_start) & (merged_ds["time"] <= ts_end)
        
        ax = axd[target_name]
        cpm_pr_ts_mean = cpm_pr.sel(time=ts_mask).mean(dim=["time"])
        cpm_ts_change = (cpm_pr_ts_mean - cpm_pr_historical_mean)/cpm_pr_historical_mean
        plot_grid(
            cpm_ts_change,
            ax,
            title=f"{ts_key} CPM mean change vs TS1 CPM",
            cmap="BrBG",
            norm=None,
            center=0,
            vmax=0.2,
            add_colorbar=False,
        )
        
        for model in models:
            ax = axd[model]
            sample_ts_mean = merged_ds.sel(source=source).sel(model=model).sel(time=ts_mask)["pred_pr"].mean(dim=["time", "sample_id"])
            sample_ts_change = (sample_ts_mean - cpm_pr_historical_mean)/cpm_pr_historical_mean
            pcm = plot_grid(
                sample_ts_change,
                ax,
                title=f"{ts_key} {model} mean change vs TS1 CPM",
                cmap="BrBG",
                norm=None,
                center=0,
                vmax=0.2,
            )
        
        fig.colorbar(pcm, ax=list(axd.values()), location="left", shrink=0.8, extend="both")
        plt.show()
        
        
        # ts_ds = merged_ds.sel(source=source).sel(time=ts_mask)
        # ts_target_pr = cpm_pr.sel(time=ts_mask)
        # plot_grid()

## Bias $\frac{\mu_{sample}-\mu_{CPM}}{\mu_{CPM}}$

### All

In [None]:
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h3>{source}</h3>", raw=True)
    plot_mean_bias(merged_ds.sel(source=source), cpm_pr)

### Seasonal

In [None]:
for season, seasonal_ds in merged_ds.groupby("time.season"):
    IPython.display.display_html(f"<h3>{season}</h3>", raw=True)
    seasonal_cpm_pr = cpm_pr.sel(time=(cpm_pr["time.season"] == season))
    for source in merged_ds["source"].values:
        IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
        plot_mean_bias(seasonal_ds.sel(source=source), seasonal_cpm_pr)

## Standard deviation $\sigma_{sample}$/$\sigma_{CPM}$

### All

In [None]:
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h3>{source}</h3>", raw=True)
    plot_std_bias(merged_ds.sel(source=source), cpm_pr)

### Seasonal

In [None]:
for season, seasonal_ds in merged_ds.groupby("time.season"):
    IPython.display.display_html(f"<h3>{season}</h3>", raw=True)
    seasonal_cpm_pr = cpm_pr.sel(time=(cpm_pr["time.season"] == season))
    for source in merged_ds["source"].values:
        IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
        plot_std_bias(seasonal_ds.sel(source=source), seasonal_cpm_pr)

## FSS

In [None]:
source = "CPM"
for model in merged_ds["model"].values:
    ds = merged_ds.sel(source=source, model=model)
    fss = pysteps.verification.spatialscores.fss_init(0.1, 3)
    for t in range(len(ds["time"])):
        for sample_id in range(len(ds["sample_id"])):
            pysteps.verification.spatialscores.fss_accum(fss, ds["pred_pr"].isel(time=t, sample_id=sample_id).values,  ds["target_pr"].isel(time=t).values)
    fss_score = pysteps.verification.spatialscores.fss_compute(fss)
    print(f"{source}\t{model}\t{fss_score:.3}")

In [None]:
source = "CPM"
for model in merged_ds["model"].values:
    ds = merged_ds.sel(source=source, model=model)
    fss_obj = pysteps.verification.spatialscores.fss_init(0.1, 3)
    def wrap_fss_accum(Xf, Xo):
        return pysteps.verification.spatialscores.fss_accum(fss_obj, Xf, Xo)

    xr.apply_ufunc(
        wrap_fss_accum,  # first the function
        ds["pred_pr"],  # now arguments in the order expected by 'fss'
        ds["target_pr"],
        input_core_dims=[["grid_latitude", "grid_longitude"], ["grid_latitude", "grid_longitude"]],  # list with one entry per arg
        output_core_dims=[[]],
        # exclude_dims=set(("grid_latitude", "grid_longitude",)),  # dimensions allowed to change size. Must be set!
        vectorize=True,
    )

    pysteps.verification.spatialscores.fss_compute(fss_obj)
    
    fss_score = pysteps.verification.spatialscores.fss_compute(fss_obj)
    print(f"{source}\t{model}\t{fss_score:.3}")

In [None]:
ds = merged_ds.sel(source="CPM")
xr.apply_ufunc(
    pysteps.verification.spatialscores.fss,  # first the function
    ds["pred_pr"],  # now arguments in the order expected by 'fss'
    ds["target_pr"],
    0.1,
    3,
    input_core_dims=[["grid_latitude", "grid_longitude"], ["grid_latitude", "grid_longitude"], [], []],  # list with one entry per arg
    output_core_dims=[[]],
    # exclude_dims=set(("grid_latitude", "grid_longitude",)),  # dimensions allowed to change size. Must be set!
    vectorize=True,
)

## PSD

In [None]:
gcm_lr_lin_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("MOOSE_DERIVED_DATA"), "nc-datasets", gcm_lr_lin_pr_dataset, f"{split}.nc"
    )
)["linpr"]*3600*24).assign_attrs({"units": "mm day-1"})

cpm_hr_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("MOOSE_DERIVED_DATA"), "nc-datasets", cpm_hr_pr_dataset, f"{split}.nc"
    )
)["target_pr"]*3600*24).assign_attrs({"units": "mm day-1"})

In [None]:
simulation_data = {"CPM pr": cpm_hr_pr, "GCM pr": gcm_lr_lin_pr}
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    ml_data = { f"{model} Sample": merged_ds.sel(source=source, model=model)["pred_pr"] for model in merged_ds["model"].values }
    plot_psd(ml_data | simulation_data)

## Correlation

## Samples

In [None]:
for source, sourced_ds in merged_ds.groupby("source"):
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    for season, seasonal_ds in sourced_ds.groupby("time.season"):
        IPython.display.display_html(f"<h2>{season}</h2>", raw=True)

        std = seasonal_ds["target_pr"].std(dim=["grid_longitude", "grid_latitude"])#/merged_ds.sel(source="CPM")["target_pr"].mean(dim=["grid_longitude", "grid_latitude"])
        std_sorted_time = std.sortby(-std)["time"].values
        mean = seasonal_ds["target_pr"].mean(dim=["grid_longitude", "grid_latitude"])
        mean_sorted_time = mean.sortby(-mean)["time"].values

        timestamp_chunks = {
            # "very wet": mean_sorted_time[20],
            "very varied": std_sorted_time[20],
            "quite wet": mean_sorted_time[math.ceil(len(mean_sorted_time)*0.20)],
            # "quiet varied": std_sorted_time[math.ceil(len(std_sorted_time)*0.20):math.ceil(len(std_sorted_time)*0.20)+1],
            "very dry": mean_sorted_time[-20],
        }

        for desc, timestamps in timestamp_chunks.items():
            IPython.display.display_html(f"<h3>{desc}</h3>", raw=True)
            plot_examples(seasonal_ds, [timestamps])