# Evaluation spechum, temp vorticity GCMx -> 2.2km-4x over Birmingham

Compare model using spechum, temp and vort850 with one just vort850

## Diff model

No loc-spec params

64x64 pixel target

With stored transforms so CPM-based and GCM-based samples use the same set of transforms computed during training from CPM training set


Input transform: standardize (based on train set mean and std) then devide by standardized train set max

Target transform: sqrt then divide by sqrted train set max

Sample inverse transform: multiply by train set max, clip negative values to 0 then square

In [None]:
%reload_ext autoreload

%autoreload 2

import math
import os

import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from ml_downscaling_emulator import UKCPDatasetMetadata
from ml_downscaling_emulator.helpers import plot_over_ts
from ml_downscaling_emulator.utils import cp_model_rotated_pole, platecarree, plot_grid, prep_eval_data, show_samples, distribution_figure, plot_mean_bias, plot_std, plot_psd

In [None]:
split = "val"

datasets = {
    "CPM": "2.2km-coarsened-gcm-2.2km-coarsened-4x_birmingham_vorticity850_random",
    "GCM": "60km-2.2km-coarsened-4x_birmingham_vorticity850_random",
}

runs = [
    # ("bham-4x_V850_random-fixed-gcmx-vort-grid", 20, "V850 no-LS"),
    # ("bham-4x_V850_random-shared-xfms", 20, "V850 no-LS ST),
    ("bham-4x_V850_random-ls8", 20, "V850 LS8"),
]
vort_ds = prep_eval_data(datasets, runs, split)

In [None]:
datasets = {
    "CPM": "bham_gcmx-4x_spechum-temp-vorticity850_random",
    "GCM": "bham_60km-4x_spechum-temp-vorticity850_random",
}

runs = [
    # ("bham-4x_STV850_random-shared-xfms", 20, "STV no-LS ST"),
    # ("bham-4x_STV850_random", 20, "STV no-LS"),
    ("bham-4x_STV850_random-ls8-shared-xfms", 20, "STV850 LS8 ST"),
    ("bham-4x_STV850_random-ls8", 20, "STV850 LS8"),
]
spechum_temp_vort_ds = prep_eval_data(datasets, runs, split)

In [None]:
merged_ds = xr.merge([vort_ds.drop_vars("pressure"), spechum_temp_vort_ds.drop_vars("pressure")])

In [None]:
total_target_pr = merged_ds["target_pr"].sel(source="CPM").mean(dim=["grid_longitude", "grid_latitude"])

## Samples

In [None]:
sorted_time = total_target_pr.sortby(total_target_pr)["time"].values

In [None]:
timestamp_chunks = [
    sorted_time[t:t+2] for t in [
        -20, 
        # math.ceil(len(sorted_time)*0.9),
        math.ceil(len(sorted_time)*0.81),
        # math.ceil(len(sorted_time)*0.5), 
        # 17
    ]
]
# timestamps = np.random.choice(sorted_time, size=2*3, replace=False)

for i, timestamps in enumerate(timestamp_chunks):
    IPython.display.display_html(f"<h1>Timestep chunk {i}</h1>", raw=True)
    show_samples(merged_ds, timestamps)

## Frequency distribution

### Pixel

In [None]:
quantiles = np.concatenate([np.linspace(0.1,0.8,8), np.linspace(0.9,0.99,10), np.linspace(0.991,0.999,9)])

distribution_figure(merged_ds, quantiles, "Distribution of pixel values")

### Mean over space

In [None]:
quantiles = np.concatenate([np.linspace(0.1,0.8,8), np.linspace(0.9,0.99,10), np.linspace(0.991,0.999,9)])

distribution_figure(merged_ds[["target_pr", "pred_pr"]].mean(dim=["grid_longitude", "grid_latitude"]), quantiles, "Distribution of mean precip over space")

## Bias

In [None]:
plot_mean_bias(merged_ds)

## Standard deviation

In [None]:
plot_std(merged_ds)

## Scatter

In [None]:
# fig, axs = plt.subplots(1, 1, figsize=(20, 6))

# tr = min(merged_ds["pred_pr"].max(), merged_ds["target_pr"].max())


# ax = axs

# ax.scatter(x=merged_ds["pred_pr"], y=merged_ds["target_pr"].values[None, :].repeat(len(merged_ds["sample_id"]), 0), alpha=0.05)
# ax.plot([0, tr], [0, tr], linewidth=1, color='green')


In [None]:
# sample_ids = np.arange(9).reshape(3, 3)

# fig, axs = plt.subplots(sample_ids.shape[0], sample_ids.shape[1], figsize=(30, 12))

# tr = min(merged_ds["pred_pr"].max(), merged_ds["target_pr"].max())

# for i, sample_id in enumerate(sample_ids.flatten()):
#     ax = axs[i//3][i%3]

#     ax.scatter(x=merged_ds["pred_pr"].sel(sample_id=sample_id), y=merged_ds["target_pr"], alpha=0.1)
#     ax.plot([0, tr], [0, tr], linewidth=1, color='green')
#     ax.set_xlabel(f"Sample {sample_id} pr")
#     ax.set_ylabel(f"Target pr")
    
# plt.show()

In [None]:
# fig, axs = plt.subplots(1, 1, figsize=(20, 6))

# tr = min(total_pred_pr.max(), total_target_pr.max())

# ax = axs

# ax.scatter(x=total_pred_pr, y=total_target_pr.values[None, :].repeat(len(total_pred_pr["sample_id"]), 0), alpha=0.25)
# ax.plot([0, tr], [0, tr], linewidth=1, color='green')

# plt.show()

## PSD

## Correlation