# Evaluation of temp, vorticity 60km -> 2.2km-4x over Birmingham

Compare:

* original: input xfm - standardize with global mean and variance, unit range (based on active set training split max), 2x - 1; target xfm - sqrt, unit range (training domain training split max), 2x - 1 (this one is implemented on forward pass in network, uninverted); step-based training scheme

with

* Input are always standardized, target is always sqrt
* Then try different version of applying unit range to inputs and target.

NO PIXELMMS

## Diff model

8-channels loc-spec params

Inputs: 5 levels of spechum, temp and vorticity

Target domain and resolution: 64x64 2.2km-4x England and Wales

Input resolution: 60km/gcmx

Input transforms are fitted on dataset in use (ie separate GCM and CPM versions) while target transform is fitted only at training on the CPM dataset

In [None]:
%reload_ext autoreload

%autoreload 2

import math
import os

import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from ml_downscaling_emulator.helpers import plot_over_ts
from ml_downscaling_emulator.utils import cp_model_rotated_pole, plot_grid, prep_eval_data, show_samples, distribution_figure, plot_mean_bias, plot_std, plot_psd

In [None]:
split = "val"

In [None]:
datasets = {
    "CPM": "bham_gcmx-4x_spechum-temp-vort_random",
    "GCM": "bham_60km-4x_spechum-temp-vort_random",
}

runs = [
    # ("bham-4x_STV_random-Iv1Tv1", "epoch-100", "STV v1;v1"),
    ("score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_STV_random-ls8-IPTS", "checkpoint-20", "STV original (Stan,UR,2x-1;Sqrt,UR)"),
    ("score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_STV_random-IStanTSqrt", "epoch-100", "STV Stan;Sqrt"),
]
stv_ds = prep_eval_data(datasets, runs, split)

In [None]:
datasets = {
    "CPM": "bham_gcmx-4x_temp-vort_random",
    "GCM": "bham_60km-4x_temp-vort_random",
}

runs = [
    ("score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_TV_random-IstanTsqrt", "epoch-100", "TV Stan;Sqrt"),
]
tv_ds = prep_eval_data(datasets, runs, split)

In [None]:
# datasets = {
#     "CPM": "bham_gcmx-4x_tempgrad-vort_random",
#     "GCM": "bham_60km-4x_tempgrad-vort_random",
# }

# runs = [
#     ("score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_TgV_random-IstanTsqrt", "epoch-100", "TgV Stan;Sqrt"),
# ]
# tgv_ds = prep_eval_data(datasets, runs, split)

In [None]:
merged_ds = xr.merge([stv_ds, tv_ds])#, tgv_ds])
merged_ds

In [None]:
total_target_pr = merged_ds["target_pr"].sel(source="CPM").mean(dim=["grid_longitude", "grid_latitude"])

## Samples

In [None]:
for source, sourced_ds in merged_ds.groupby("source"):
    IPython.display.display_html(f"<h1>{source}</h1>", raw=True)
    for season, seasonal_ds in sourced_ds.groupby("time.season"):
        IPython.display.display_html(f"<h2>{season}</h2>", raw=True)

        std = seasonal_ds["target_pr"].std(dim=["grid_longitude", "grid_latitude"])#/merged_ds.sel(source="CPM")["target_pr"].mean(dim=["grid_longitude", "grid_latitude"])
        std_sorted_time = std.sortby(-std)["time"].values
        mean = seasonal_ds["target_pr"].mean(dim=["grid_longitude", "grid_latitude"])
        mean_sorted_time = mean.sortby(-mean)["time"].values

        timestamp_chunks = {
            # "very wet": mean_sorted_time[20],
            "very varied": std_sorted_time[20],
            "quiet wet": mean_sorted_time[math.ceil(len(mean_sorted_time)*0.20)],
            # "quiet varied": std_sorted_time[math.ceil(len(std_sorted_time)*0.20):math.ceil(len(std_sorted_time)*0.20)+1],
            "very dry": mean_sorted_time[-20],
        }

        for desc, timestamps in timestamp_chunks.items():
            IPython.display.display_html(f"<h3>{desc}</h3>", raw=True)
            show_samples(seasonal_ds, [timestamps])

## Frequency distribution

### Pixel

In [None]:
# quantiles = np.concatenate([np.linspace(0.1,0.9,9), np.linspace(0.91,0.99,9), np.linspace(0.991,0.999,9), 0.9999999])
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -8, -1)])

distribution_figure(merged_ds, quantiles, "Distribution of pixel values")

### Mean over space

In [None]:
quantiles = np.concatenate([np.linspace(0.1,0.8,8), np.linspace(0.9,0.99,10), np.linspace(0.991,0.999,9)])

distribution_figure(merged_ds[["target_pr", "pred_pr"]].mean(dim=["grid_longitude", "grid_latitude"]), quantiles, "Distribution of mean precip over space")

## Bias

In [None]:
plot_mean_bias(merged_ds)

## Standard deviation

In [None]:
plot_std(merged_ds)

## PSD

In [None]:
simulation_data = {"CPM Target": merged_ds.sel(source="CPM")["target_pr"], "GCM Target": merged_ds.sel(source="GCM")["target_pr"]}
ml_data = { f"{source} {model} Sample": merged_ds.sel(source=source, model=model)["pred_pr"] for model in merged_ds["model"].values for source in ["CPM"] }
plot_psd(ml_data | simulation_data)

simulation_data = {"CPM Target": merged_ds.sel(source="CPM")["target_pr"], "GCM Target": merged_ds.sel(source="GCM")["target_pr"]}
ml_data = { f"{source} {model} Sample": merged_ds.sel(source=source, model=model)["pred_pr"] for model in merged_ds["model"].values for source in ["GCM"] }
plot_psd(ml_data | simulation_data)

## Correlation