# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

import math
import os

import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pysteps
import seaborn as sns
import xarray as xr

from mlde_utils.utils import cp_model_rotated_pole, plot_grid, prep_eval_data, show_samples, distribution_figure, plot_mean_bias, plot_std_bias, plot_psd, scatter_plots, seasonal_distribution_figure
from mlde_utils.plotting import qq_plot

In [None]:
split = "val"
samples_per_run = 3
data_config = [
    {
        "datasets": {
            "CPM": "bham_gcmx-4x_pr_random",
            "GCM": "bham_60km-4x_pr_random",
        },
        "runs": [
            ("id-pr", "epoch-0", "LR precip"),
        ],
    }
    
]
desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds = xr.merge([ prep_eval_data(c["datasets"], c["runs"], split, samples_per_run=samples_per_run) for c in data_config ])
merged_ds

## Frequency distribution

### Pixel

In [None]:
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -8, -1)])

target_pr = merged_ds.sel(source="CPM")["target_pr"]

In [None]:
fig, axes = plt.subplot_mosaic(
    [[f"{source} Quantiles" for source in merged_ds["source"].values]], figsize=(16.5, 5.5), constrained_layout=True
)

target_quantiles = target_pr.quantile(quantiles)

for source in merged_ds["source"].values:
    qq_plot(axes[f"{source} Quantiles"], target_pr, merged_ds.sel(source=source), quantiles)

fig, axes = plt.subplot_mosaic(
    [[f"{source} Quantiles" for source in merged_ds["source"].values]], figsize=(16.5, 5.5), constrained_layout=True
)

ymin=0
ymax=0

for source in merged_ds["source"].values:
    qdiff = merged_ds["pred_pr"].sel(source=source).quantile(quantiles, dim=["grid_longitude", "grid_latitude", "time", "sample_id"]) - target_quantiles

    ymin=min(ymin, qdiff.min()-5)
    ymax=max(ymax, qdiff.max()+5)

    axes[f"{source} Quantiles"].plot(target_quantiles, qdiff)

for source in merged_ds["source"].values:
    axes[f"{source} Quantiles"].set_ylim(ymin, ymax)


#### Quantile error

In [None]:
fig, axes = plt.subplot_mosaic(
    [["Quantile Diffs"]], figsize=(16.5, 5.5), constrained_layout=True
)

sample_quantiles = merged_ds["pred_pr"].quantile(quantiles, dim=["grid_longitude", "grid_latitude", "time", "sample_id"])
qdiff = sample_quantiles - target_quantiles

ax = axes["Quantile Diffs"]

ax.axhline(0, c="black", linestyle="--")

for source in merged_ds["source"].values:

    data = qdiff.sel(source=source).to_pandas().reset_index().melt(id_vars='quantile', value_vars=list(qdiff.model.values)).merge(target_quantiles.to_pandas().rename('cpm_quantile').reset_index())
    sns.lineplot(data=data, x="cpm_quantile", y="value", errorbar="sd", ax=ax)
    # sns.lineplot(data=data, x="cpm_quantile", y="value", hue="model", alpha=0.25, marker="X", ax=axes["CPM Quantiles"])

#### Quantile error std dev

In [None]:
# fig, axes = plt.subplot_mosaic(
#     [["Quantile stds"]], figsize=(5.5, 5.5), constrained_layout=True
# )

# ax = axes["Quantile stds"]
quantile_std = sample_quantiles.std(dim=["model"]).to_pandas().reset_index().merge(target_quantiles.to_pandas().rename('CPM pr').reset_index()).melt(id_vars='CPM pr', value_vars=merged_ds["source"].values, value_name="Model quantile std", var_name="source")
ax = sns.lineplot(data=quantile_std, x='CPM pr', y="Model quantile std", hue="source")
ax.set(ylabel="Model quantile $\sigma$")
# sns.barplot(data=data, x="cpm_quantile", y="GCM")
for source in merged_ds["source"].values:
    ax.fill_between(quantile_std[quantile_std["source"] == source]["CPM pr"], quantile_std[quantile_std["source"] == source]["Model quantile std"], alpha=0.5)

In [None]:
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    seasonal_distribution_figure(merged_ds.sel(source=source), target_pr, quantiles)

## Bias $\frac{\mu_{sample}-\mu_{CPM}}{\mu_{CPM}}$

### All

In [None]:
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h3>{source}</h3>", raw=True)
    plot_mean_bias(merged_ds.sel(source=source), target_pr)

### Seasonal

In [None]:
for season, seasonal_ds in merged_ds.groupby("time.season"):
    IPython.display.display_html(f"<h3>{season}</h3>", raw=True)
    seasonal_target_pr = target_pr.sel(time=(target_pr["time.season"] == season))
    for source in merged_ds["source"].values:
        IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
        plot_mean_bias(seasonal_ds.sel(source=source), seasonal_target_pr)

## Standard deviation $\sigma_{sample}$/$\sigma_{CPM}$

### All

In [None]:
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h3>{source}</h3>", raw=True)
    plot_std_bias(merged_ds.sel(source=source), target_pr)

### Seasonal

In [None]:
for season, seasonal_ds in merged_ds.groupby("time.season"):
    IPython.display.display_html(f"<h3>{season}</h3>", raw=True)
    seasonal_target_pr = target_pr.sel(time=(target_pr["time.season"] == season))
    for source in merged_ds["source"].values:
        IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
        plot_std_bias(seasonal_ds.sel(source=source), seasonal_target_pr)

## PSD

In [None]:
gcm_lr_lin_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("MOOSE_DERIVED_DATA"), "nc-datasets", "bham_60km-4x_linpr_random", "val.nc"
    )
)["linpr"]*3600*24).assign_attrs({"units": "mm day-1"})

cpm_hr_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("MOOSE_DERIVED_DATA"), "nc-datasets", "bham_gcmx-4x_linpr_random", "val.nc"
    )
)["target_pr"]*3600*24).assign_attrs({"units": "mm day-1"})

In [None]:
simulation_data = {"CPM pr": cpm_hr_pr, "GCM pr": gcm_lr_lin_pr}
for source in merged_ds["source"].values:
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    ml_data = { f"{model} Sample": merged_ds.sel(source=source, model=model)["pred_pr"] for model in merged_ds["model"].values }
    plot_psd(ml_data | simulation_data)