# PslTV Bham-64 8.8km 2 ensemble members datasets

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

from collections import defaultdict
import os

import IPython
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

from mlde_analysis import plot_map, create_map_fig, qq_plot
from mlde_utils import cp_model_rotated_pole, dataset_split_path

In [None]:
ENSEMBLE_MEMBERS=["01", "04"]
DATASET_NAMES = {
    "CPM": "bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
    # "GCM": "bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season"
}

SPLIT = "train"

VAR_TYPES = ["vorticity", "spechum", "temp"]
THETAS = [250, 500, 700, 850]

VARIABLES = [ f"{var_type}{theta}" for var_type in ["vorticity", "temp"] for theta in THETAS ] + ["psl"]

N_EXAMPLES = 5

In [None]:
datasets = { source: xr.open_dataset(dataset_split_path(dataset_name, SPLIT)).sel(ensemble_member=ENSEMBLE_MEMBERS) for source, dataset_name in DATASET_NAMES.items() }
datasets

In [None]:
example_timestamps = np.random.choice(datasets["CPM"]["time"], size=N_EXAMPLES, replace=False)

In [None]:
def plot_examples(da, timestamps, **kwargs):
    for em, em_da in da.groupby("ensemble_member"):
        fig, axes = create_map_fig([[f"Example {i}" for i in range(len(timestamps))]])
        emax = np.amax(em_da.sel(time=timestamps))
        for i, timestamp in enumerate(timestamps):
            ax = axes[f"Example {i}"]
            plot_map(em_da.sel(time=timestamp), ax=ax, add_colorbar=True, title=em, **kwargs)

def plot_means(da):
    p = da.mean(dim=["time"]).plot(col="ensemble_member", subplot_kws={"projection": cp_model_rotated_pole})
    for ax in p.axs.flat:
        ax.coastlines()

def plot_std(da):
    p = da.std(dim=["time"]).plot(col="ensemble_member", subplot_kws={"projection": cp_model_rotated_pole}, vmin=0)
    for ax in p.axs.flat:
        ax.coastlines()
        
def plot_histogram(da, **kwargs):
    fig, axes = plt.subplot_mosaic([["Density"]], figsize=(16.5, 5.5), constrained_layout=True)

    xrange = (da.min().values, da.max().values)
    
    ax = axes["Density"]
    bins = kwargs.get("bins", None)
    for em, em_da in da.groupby("ensemble_member"):
        _, comp_bins, _ = em_da.plot.hist(label=em, density=True, alpha=0.25, bins=100, range=xrange, **kwargs)
        if bins is None:
            bins = comp_bins
        ax.set_title("Frequency density")
        ax.legend()

def plot_quantiles(da, quantiles, **kwargs):
    da.quantile(quantiles, dim=["time", "grid_longitude", "grid_latitude"]).plot(hue="ensemble_member", **kwargs)

## Target PR

In [None]:
variable = f"target_pr"

da = datasets["CPM"][variable]*3600*24

IPython.display.display_html(f"<h3>{variable}</h3>", raw=True)
IPython.display.display_html("<h4>Upper quantiles</h4>", raw=True)
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-5, -8, -1)])
plot_quantiles(da, quantiles)
plt.yscale('log')
plt.xscale('log')
plt.show()

IPython.display.display_html("<h4>Histogram</h4>", raw=True)
plot_histogram(da, log=True)
plt.show()


IPython.display.display_html("<h4>Mean</h4>", raw=True)
plot_means(da)
plt.show()

IPython.display.display_html("<h4>Std Dev</h4>", raw=True)
plot_std(da)
plt.show()

IPython.display.display_html("<h4>Examples</h4>", raw=True)
plot_examples(da, example_timestamps, style="precip")
plt.show()

## Distribution

In [None]:
for variable in VARIABLES:
    IPython.display.display_html(f"<h3>{variable}</h3>", raw=True)
    da = datasets["CPM"][variable]    
    plot_histogram(da)
    plt.show()

    IPython.display.display_html("<h4>Mean</h4>", raw=True)
    plot_means(da)
    plt.show()

    IPython.display.display_html("<h4>Std Dev</h4>", raw=True)
    plot_std(da)
    plt.show()

## Examples

In [None]:
for variable in VARIABLES:
    IPython.display.display_html(f"<h3>{variable}</h3>", raw=True)

    da = datasets["CPM"][variable]

    plot_examples(da, example_timestamps, style=None, center=0, vmax=np.amax(da))
    plt.show()