# TgV Bham-64 8.8km datasets

In [None]:
%reload_ext autoreload

%autoreload 2

from collections import defaultdict
import os

import IPython
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

from ml_downscaling_emulator.plotting import plot_map, create_map_fig, qq_plot

In [None]:
datasets = {
    "cpm": "bham_gcmx-4x_tempgrad-vort_random",
    "gcm": "bham_60km-4x_tempgrad-vort_random"
}

split = "train"

ds = xr.concat(
    [ xr.open_dataset(os.path.join(os.getenv("MOOSE_DERIVED_DATA"), "nc-datasets", dataset_name, f"{split}.nc")) for dataset_name in datasets.values() ],
    pd.Index(datasets.keys(), name='source')
)
ds

In [None]:
VAR_TYPES = ["vorticity"]
THETAS = [250, 500, 700, 850, 925]

n_examples = 5
random_timestamps = np.random.choice(ds["time"], size=n_examples, replace=False)

std = ds.sel(source="cpm")["target_pr"].std(dim=["grid_longitude", "grid_latitude"])
std_sorted_time = std.sortby(-std)["time"].values
mean = ds.sel(source="cpm")["target_pr"].mean(dim=["grid_longitude", "grid_latitude"])
mean_sorted_time = mean.sortby(-mean)["time"].values

In [None]:
timestamps = mean_sorted_time[0:5]

In [None]:
def summarize_distribution(da, mean_center=False):
    gcm_da = da.sel(source="gcm")
    cpm_da = da.sel(source="cpm")
    
    IPython.display.display_html(f"<h2>Distribution</h2>", raw=True)
    fig, axes = plt.subplot_mosaic([["Density", "QQ"]], figsize=(16.5, 5.5), constrained_layout=True)
    
    ax = axes["Density"]
    cpm_da.plot.hist(ax=ax, label="CPM", density=True, alpha=0.5, bins=100)#, histtype="step")
    gcm_da.plot.hist(ax=ax, label="GCM", density=True, alpha=1, bins=100, histtype="step", linewidth=2)
    ax.set_title("Frequency density")
    ax.legend()
    
    ax = axes["QQ"]
    if mean_center == 0:
        quantiles = np.concatenate([np.linspace(0.001,0.009,9), np.linspace(0.01,0.09,9), np.linspace(0.1,0.9,9), np.linspace(0.91,0.99,9), np.linspace(0.991,0.999,9)])
    else:
        quantiles = np.concatenate([np.linspace(0.1,0.8,8), np.linspace(0.9,0.99,10), np.linspace(0.991,0.999,9)])
    qq_plot(ax, cpm_da, [("GCM vs CPM", gcm_da)], quantiles, title="GCM vs CPM quantiles", xlabel="CPM", ylabel="GCM")
    plt.show()

def summarize_variable(da, etimestamps=list(), mean_center=False):
    gcm_da = da.sel(source="gcm")
    cpm_da = da.sel(source="cpm")
    
    if len(etimestamps) > 0:
        IPython.display.display_html(f"<h2>Examples</h2>", raw=True)
        fig, axes = create_map_fig([[f"CPM Example {i}" for i in range(n_examples)], [f"GCM Example {i}" for i in range(n_examples)]])
        emin = da.sel(time=etimestamps).min()
        emax = max(-emin, da.sel(time=timestamps).max())
        for i, timestamp in enumerate(timestamps):
            ax = axes[f"CPM Example {i}"]
            plot_map(cpm_da.sel(time=timestamp), ax=ax, add_colorbar=True, style=None, center=mean_center, vmax=emax)

            ax = axes[f"GCM Example {i}"]
            plot_map(gcm_da.sel(time=timestamp), ax=ax, add_colorbar=True, style=None, center=mean_center, vmax=emax)
        plt.show()
    
    summarize_distribution(da, mean_center=mean_center)
    
    IPython.display.display_html(f"<h2>Mean and standard deviation</h2>", raw=True)
    data = {
        "$\mu_{CPM}$": cpm_da.mean(dim=["time"]),
        "$\mu_{GCM}$": gcm_da.mean(dim=["time"])
    }
    data["$\mu_{GCM}$ - $\mu_{CPM}$"] = data["$\mu_{GCM}$"] - data["$\mu_{CPM}$"]
    data["$\\frac{\mu_{GCM} - \mu_{CPM}}{\sigma_{CPM}}$"] = data["$\mu_{GCM}$ - $\mu_{CPM}$"]/cpm_da.std(dim=["time"])

    mean_min = min(data["$\mu_{CPM}$"].min(), data["$\mu_{GCM}$"].min())
    mean_max = max(-mean_min, data["$\mu_{CPM}$"].max(), data["$\mu_{GCM}$"].max())
    plotkwargs = defaultdict(dict, **{
        "$\mu_{CPM}$": dict(center=mean_center, vmax=mean_max),
        "$\mu_{GCM}$": dict(center=mean_center, vmax=mean_max),
        "$\mu_{GCM}$ - $\mu_{CPM}$": dict(center=0),
        "$\\frac{\mu_{GCM} - \mu_{CPM}}{\sigma_{CPM}}$": dict(center=0)
    })
    
    _, axd = create_map_fig([data.keys()])

    for key, data_summary in data.items():
        plot_map(data_summary, axd[key], title=key, style=None, add_colorbar=True, **plotkwargs[key])

    plt.show()
        
    data = {
        "$\sigma_{CPM}$": cpm_da.std(dim=["time"]),
        "$\sigma_{GCM}$": gcm_da.std(dim=["time"])
    }
    data["$\sigma_{GCM}$ - $\sigma_{CPM}$"] = data["$\sigma_{GCM}$"] - data["$\sigma_{CPM}$"]
    data["$\\frac{\sigma_{GCM} - \sigma_{CPM}}{\sigma_{CPM}}$"] = data["$\sigma_{GCM}$ - $\sigma_{CPM}$"]/data["$\sigma_{CPM}$"]
    
    sigma_min = min(data["$\sigma_{GCM}$"].min(), data["$\sigma_{CPM}$"].min())
    sigma_max = max(data["$\sigma_{GCM}$"].max(), data["$\sigma_{CPM}$"].max())
    plotkwargs = defaultdict(dict, **{
        "$\sigma_{CPM}$": dict(vmin=sigma_min, vmax=sigma_max),
        "$\sigma_{GCM}$": dict(vmin=sigma_min, vmax=sigma_max),
        "$\sigma_{GCM}$ - $\sigma_{CPM}$": dict(center=0),
        "$\\frac{\sigma_{GCM} - \sigma_{CPM}}{\sigma_{CPM}}$": dict(center=0),
    })
    
    _, axd = create_map_fig([data.keys()])

    for key, data_summary in data.items():
        plot_map(data_summary, axd[key], title=key, style=None, add_colorbar=True, **plotkwargs[key])
    plt.show()

def seasonal_summarize_variable(da, mean_center=False):
    IPython.display.display_html("<h2>Seasonal $\\frac{\mu_{GCM} - \mu_{CPM}}{\sigma_{CPM}}$</h2>", raw=True)
    _, axd = create_map_fig([["DJF", "MAM", "JJA", "SON"]])
    for season, seasonal_da in da.groupby("time.season"):
        gcm_da = seasonal_da.sel(source="gcm")
        cpm_da = seasonal_da.sel(source="cpm")
        data = {}
        plotkwargs = defaultdict(dict)

        data[season] = (gcm_da.mean(dim=["time"]) - cpm_da.mean(dim=["time"]))/cpm_da.std(dim=["time"])
        plotkwargs[season] =  dict(center=0)

        for key, data_summary in data.items():
            plot_map(data_summary, axd[key], title=key, style=None, add_colorbar=True, **plotkwargs[key])
    plt.show()

    IPython.display.display_html("<h2>Seasonal $\sigma_{GCM}$ - $\sigma_{CPM}$</h2>", raw=True)
    _, axd = create_map_fig([["DJF", "MAM", "JJA", "SON"]])
    for season, seasonal_da in da.groupby("time.season"):
        gcm_da = seasonal_da.sel(source="gcm")
        cpm_da = seasonal_da.sel(source="cpm")
        data = {}
        plotkwargs = defaultdict(dict)

        data[season] = gcm_da.std(dim=["time"]) - cpm_da.std(dim=["time"])
        plotkwargs[season] =  dict(center=0)

        for key, data_summary in data.items():
            plot_map(data_summary, axd[key], title=key, style=None, add_colorbar=True, **plotkwargs[key])
    plt.show()
        
    IPython.display.display_html(f"<h2>Seasonal distribution", raw=True)
    fig, axd = plt.subplot_mosaic([["DJF", "MAM", "JJA", "SON"]], figsize=(22, 5.5), constrained_layout=True)
    for season, seasonal_da in da.groupby("time.season"):
        gcm_da = seasonal_da.sel(source="gcm")
        cpm_da = seasonal_da.sel(source="cpm")
    
        ax = axd[season]
        cpm_da.plot.hist(ax=ax, label="CPM", density=True, alpha=0.5, bins=100)#, histtype="step")
        gcm_da.plot.hist(ax=ax, label="GCM", density=True, alpha=1, bins=100, histtype="step", linewidth=2)
        ax.set_title(season)
        ax.legend()
    plt.show()
    
    fig, axd = plt.subplot_mosaic([["DJF", "MAM", "JJA", "SON"]], figsize=(22, 5.5), constrained_layout=True)
    for season, seasonal_da in da.groupby("time.season"):
        ax = axd[season]
        if mean_center == 0:
            quantiles = np.concatenate([np.linspace(0.001,0.009,9), np.linspace(0.01,0.09,9), np.linspace(0.1,0.9,9), np.linspace(0.91,0.99,9), np.linspace(0.991,0.999,9)])
        else:
            quantiles = np.concatenate([np.linspace(0.1,0.8,8), np.linspace(0.9,0.99,10), np.linspace(0.991,0.999,9)])
        qq_plot(ax, cpm_da, [("GCM vs CPM", gcm_da)], quantiles, title=f"{season} GCM vs CPM quantiles", xlabel="CPM", ylabel="GCM")
    plt.show()
    

In [None]:
for variable in [ f"vorticity{theta}" for theta in THETAS]:
    IPython.display.display_html(f"<h1>{variable}</h1>", raw=True)
    summarize_variable(ds[variable], timestamps, mean_center=0)
    # seasonal_summarize_variable(ds[variable], mean_center=0)

In [None]:
for variable in [ f"tempgrad{theta}" for theta in ["700500", "850700"]]:
    IPython.display.display_html(f"<h1>{variable}</h1>", raw=True)
    summarize_variable(ds[variable], timestamps)
    # seasonal_summarize_variable(ds[variable], mean_center=0)

In [None]:
IPython.display.display_html(f"<h1>precip</h1>", raw=True)
summarize_distribution(ds["target_pr"])

In [None]:
IPython.display.display_html("<h1>$\sqrt{precip}$</h1>", raw=True)
summarize_distribution(np.sqrt(ds["target_pr"]))

In [None]:
IPython.display.display_html("<h1>$\log(1+precip)$</h1>", raw=True)
summarize_distribution(np.log(1+ds["target_pr"]))

In [None]:
IPython.display.display_html("<h1>$e^{(\log(1+precip))}-1$</h1>", raw=True)
summarize_distribution(np.expm1(np.log1p(ds["target_pr"])))

In [None]:
IPython.display.display_html("<h1>$\sqrt[3]{precip}$</h1>", raw=True)
summarize_distribution(np.power(ds["target_pr"], 1/3))

In [None]:
IPython.display.display_html("<h1>$\sqrt[4]{precip}$</h1>", raw=True)
summarize_distribution(np.power(ds["target_pr"], 1/4))

## Correlation