# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

import math
import os

import cftime
import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pysteps
import xarray as xr

from mlde_utils import cp_model_rotated_pole
from mlde_utils.utils import plot_grid, prep_eval_data, show_samples, distribution_figure, plot_mean_bias, plot_std_bias, plot_psd, scatter_plots, seasonal_distribution_figure, compute_gridspec
from mlde_utils.plotting import create_map_fig, qq_plot

In [None]:
time_slices = {
    "TS1": (cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2000, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "TS2": (cftime.Datetime360Day(2020, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2040, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "TS3": (cftime.Datetime360Day(2060, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2080, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
}

In [None]:
split = "val"
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-season-IstanTsqrturrecen-shuffle-fix",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "PslTV shuffle",
            "dataset": "bham_gcmx-4x_psl-temp-vort_random-season",
            "deterministic": False,
        },
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "",
            "label": "Lin. interp. of LR precip",
            "deterministic": True,
            "dataset": "bham_gcmx-4x_linpr_random-season",
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-season-IstanTsqrturrecen-shuffle-fix",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "PslTV shuffle",
            "dataset": "bham_60km-4x_psl-temp-vort_random-season",
            "deterministic": False,
        },
    ],
}
highlighted_cpm_models = ["PslTV shuffle"]
desc = """
Describe in more detail the models being compared
"""
# the datasets to use for comparisons like PSD which need default datasets with CPM-based hi-res precip and GCM-based lo-res precip respectively
simulation_pr_datasets = {
    "GCM": "bham_60km-4x_linpr_random",
    "CPM": "bham_60km-4x_linpr_random"
}
gcm_lr_lin_pr_dataset = "bham_60km-4x_linpr_random"
cpm_hr_pr_dataset = "bham_gcmx-4x_linpr_random"

In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds = { source: prep_eval_data(data_config, split, samples_per_run=samples_per_run) for source, data_config in data_configs.items() }
merged_ds

## Pixel distribution

In [None]:
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -8, -1)])
cpm_pr = merged_ds["CPM"]["target_pr"]

### CPM

In [None]:
distribution_figure(merged_ds["CPM"], cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

### GCM

In [None]:
gcm_hl_cpm_ds = xr.merge([merged_ds["GCM"][["target_pr", "pred_pr"]], merged_ds["CPM"].sel(model=highlighted_cpm_models).update({"model": ("model", [f"CPM {m}" for m in highlighted_cpm_models])})[["pred_pr"]]])
distribution_figure(gcm_hl_cpm_ds, cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

### Seasonal

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    seasonal_distribution_figure(merged_ds[source], cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

### Subregions
#### CPM

In [None]:
subregions = {
    "SE": dict(grid_latitude=slice(10, 26), grid_longitude=slice(38, 54)),
    "NW": dict(grid_latitude=slice(44, 60), grid_longitude=slice(18, 34)),
}

fig, axd = create_map_fig([["subregions"]])

plot_grid(10*xr.ones_like(merged_ds["CPM"]["pred_pr"].isel(sample_id=0, model=0, time=0)).isel(**subregions["NW"]), ax=axd["subregions"], style="precip")
plot_grid(1*xr.ones_like(merged_ds["CPM"]["pred_pr"].isel(sample_id=0, model=0, time=0)).isel(**subregions["SE"]), ax=axd["subregions"], style="precip")
    
axd["subregions"].set_extent([-2, 3, -2.5, 2.5], crs=cp_model_rotated_pole)

quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -7, -1)])

source="CPM"

##### NW

In [None]:
srname = "NW"
srbnds = subregions[srname]

IPython.display.display_html(f"<h6>{source}</h6>", raw=True)
ds = merged_ds[source].isel(**srbnds).sel(model=highlighted_cpm_models)
distribution_figure(ds, cpm_pr.isel(**srbnds), quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

##### NW Winter

In [None]:
srname="NW"
srseason="DJF"

IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
srseason_mask = merged_ds[source].isel(**subregions[srname])["time.season"] == srseason
srseason_mask_sample_ds = merged_ds[source].isel(**subregions[srname]).sel(time=srseason_mask).sel(model=highlighted_cpm_models)
srseason_mask_cpm_pr = cpm_pr.isel(srbnds).sel(time=srseason_mask)
distribution_figure(srseason_mask_sample_ds, srseason_mask_cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

##### SE

In [None]:
srname = "SE"
srbnds = subregions[srname]

IPython.display.display_html(f"<h6>{source}</h6>", raw=True)
ds = merged_ds[source].isel(**srbnds).sel(model=highlighted_cpm_models)
distribution_figure(ds, cpm_pr.isel(**srbnds), quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

##### SE Summer

In [None]:
srname = "SE"
srseason = "JJA"

IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
srseason_mask = merged_ds[source].isel(**subregions[srname])["time.season"] == srseason
srseason_mask_sample_ds = merged_ds[source].isel(**subregions[srname]).sel(time=srseason_mask).sel(model=highlighted_cpm_models)
srseason_mask_cpm_pr = cpm_pr.isel(srbnds).sel(time=srseason_mask)
distribution_figure(srseason_mask_sample_ds, srseason_mask_cpm_pr, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

## Bias $\frac{\mu_{sample}-\mu_{CPM}}{\mu_{CPM}}$

### All

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h3>{source}</h3>", raw=True)
    plot_mean_bias(merged_ds[source], cpm_pr)

## Samples

In [None]:
for source, sourced_ds in merged_ds.items():
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    for season, seasonal_ds in sourced_ds.groupby("time.season"):
        IPython.display.display_html(f"<h2>{season}</h2>", raw=True)

        std = seasonal_ds["target_pr"].std(dim=["grid_longitude", "grid_latitude"])#/merged_ds.sel(source="CPM")["target_pr"].mean(dim=["grid_longitude", "grid_latitude"])
        std_sorted_time = std.sortby(-std)["time"].values
        mean = seasonal_ds["target_pr"].mean(dim=["grid_longitude", "grid_latitude"])
        mean_sorted_time = mean.sortby(-mean)["time"].values

        timestamp_chunks = {
            # "very wet": mean_sorted_time[20],
            "very varied": std_sorted_time[20],
            "quite wet": mean_sorted_time[math.ceil(len(mean_sorted_time)*0.20)],
            # "quiet varied": std_sorted_time[math.ceil(len(std_sorted_time)*0.20):math.ceil(len(std_sorted_time)*0.20)+1],
            "very dry": mean_sorted_time[-20],
        }

        for desc, timestamps in timestamp_chunks.items():
            IPython.display.display_html(f"<h3>{desc}</h3>", raw=True)
            show_samples(seasonal_ds, [timestamps])