# Effect of quantile mapping on GCM-driven samples

In [None]:
%reload_ext autoreload

%autoreload 2

%load_ext dotenv
%dotenv

import glob
import math
import os

import cartopy.crs as ccrs
import cftime
from cmethods import CMethods
import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import metpy.plots.ctables
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
import xarray as xr

In [None]:
time_period = "present"
split = "val"
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-season-IstanTsqrturrecen-shuffle-fix",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "Diffusion",
            "dataset": f"bham_gcmx-4x_psl-temp-vort_random-season-{time_period}",
            "deterministic": False,
        },
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "",
            "label": "Coarsened CPM precip (interp)",
            "deterministic": True,
            "dataset": f"bham_gcmx-4x_linpr_random-season-{time_period}",
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-season-IstanTsqrturrecen-shuffle-fix",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "Diffusion",
            "dataset": f"bham_60km-4x_psl-temp-vort_random-season-{time_period}",
            "deterministic": False,
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-season-IstanTsqrturrecen-shuffle-fix",
            "checkpoint": "epoch-100",
            "input_xfm": "pixelmmsstan",
            "label": "Diffusion bc inputs",
            "dataset": f"bham_60km-4x_psl-temp-vort_random-season-{time_period}",
            "deterministic": False,
        },
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "",
            "label": "GCM precip (interp)",
            "deterministic": True,
            "dataset": f"bham_60km-4x_linpr_random-season-{time_period}",
        },
    ],
}
highlighted_cpm_models = ["Diffusion"]
desc = """
Compare diffusion model (PslTV) with and without quantile mapping (and simple interpolation of LR precip as baseline)
"""
# the datasets to use for comparisons like PSD which need default datasets with CPM-based hi-res precip and GCM-based lo-res precip respectively
simulation_pr_datasets = {
    "GCM": f"bham_60km-4x_linpr_random-season-{time_period}",
    "CPM": f"bham_gcmx-4x_linpr_random-season-{time_period}"
}
gcm_lr_lin_pr_dataset = f"bham_60km-4x_linpr_random-season-{time_period}"
cpm_hr_pr_dataset = f"bham_gcmx-4x_psl-temp-vort_random-season-{time_period}"

In [None]:
IPython.display.Markdown(desc)

## Load data

In [None]:
def si_to_mmday(ds, varname):
    # convert from kg m-2 s-1 (i.e. mm s-1) to mm day-1
    return (ds[varname] * 3600 * 24).assign_attrs({"units": "mm day-1"})


def open_samples_ds(
    run_name,
    checkpoint_id,
    dataset_name,
    input_xfm_key,
    split,
    num_samples,
    deterministic,
):
    samples_filepath_pattern = os.path.join(
        os.getenv("DERIVED_DATA"),
        "workdirs",
        run_name,
        f"samples/{checkpoint_id}",
        dataset_name,
        input_xfm_key,
        split,
        "predictions-*.nc",
    )
    sample_ds_list = [
        xr.open_dataset(sample_filepath)
        for sample_filepath in glob.glob(samples_filepath_pattern)[:num_samples]
    ]
    if len(sample_ds_list) == 0:
        raise RuntimeError(f"{samples_filepath_pattern} has no sample files")
    if not deterministic:
        if len(sample_ds_list) < num_samples:
            raise RuntimeError(
                f"{samples_filepath_pattern} does not have {num_samples} sample files"
            )

        ds = xr.concat(sample_ds_list, dim="sample_id")
        ds = ds.isel(sample_id=range(num_samples))
    else:
        ds = sample_ds_list[0]

    ds["pred_pr"] = si_to_mmday(ds, "pred_pr")

    return ds


def open_split_ds(dataset_name, split):
    ds = xr.open_dataset(
        os.path.join(
            os.getenv("DERIVED_DATA"),
            "moose",
            "nc-datasets",
            dataset_name,
            f"{split}.nc",
        )
    )
    ds["target_pr"] = si_to_mmday(ds, "target_pr")

    return ds


def open_merged_split_datasets(sample_runs, split):
    return xr.merge(
        [
            open_split_ds(dataset_name, split)
            for dataset_name in set(
                [sample_run["dataset"] for sample_run in sample_runs]
            )
        ],
        compat="override",
    )


def open_concat_sample_datasets(sample_runs, split, samples_per_run):
    samples_das = [
        open_samples_ds(
            run_name=sample_run["fq_model_id"],
            checkpoint_id=sample_run["checkpoint"],
            dataset_name=sample_run["dataset"],
            input_xfm_key=sample_run["input_xfm"],
            split=split,
            num_samples=samples_per_run,
            deterministic=sample_run["deterministic"],
        )["pred_pr"]
        for sample_run in sample_runs
    ]

    samples_ds = xr.concat(
        samples_das, pd.Index([sr["label"] for sr in sample_runs], name="model")
    )

    return samples_ds


def prep_eval_data(sample_runs, split, samples_per_run=3):
    samples_ds = open_concat_sample_datasets(sample_runs, split, samples_per_run)

    eval_ds = open_merged_split_datasets(sample_runs, split)

    return xr.merge([samples_ds, eval_ds], join="inner", compat="override")


In [None]:
merged_ds = { source: prep_eval_data(data_config, split, samples_per_run=samples_per_run) for source, data_config in data_configs.items() }
merged_ds

In [None]:
# time period train and eval split truth (CPM precip) - train split used to compute the quantile mapping, eval split used for evaluation
cpm_sim_pr_train = open_split_ds(cpm_hr_pr_dataset, "train")["target_pr"]
cpm_sim_pr_split = open_split_ds(cpm_hr_pr_dataset, split)["target_pr"]


# time period train split samples (Diffusion precip) - used to compute the quantile mapping
configs_for_qm = list(filter(lambda x: x["label"] in ["Diffusion", "Diffusion bc inputs"], data_configs["GCM"]))
gcm_ml_pred_pr_train = open_concat_sample_datasets(configs_for_qm, "train", 1).isel(sample_id=0)

# time period evaluation split samples (Diffusion precip) - samples to which quantile mapping applied 
gcm_ml_pred_pr_split = merged_ds["GCM"]["pred_pr"].sel(model=list(map( lambda x: x["label"], configs_for_qm)))

## Quatile map eval data

In [None]:
def qm_dom_aware(obs, simh, simp, n_quantiles=250, kind="+"):
    obs, simh, simp = np.array(obs), np.array(simh), np.array(simp)

    global_max = max(np.amax(obs), np.amax(simh))
    global_min = min(np.amin(obs), np.amin(simh))
    
    
    obs_min = np.amin(obs)
    obs_max = np.amax(obs)
    wide = abs(obs_max - obs_min) / n_quantiles
    xbins_obs = np.arange(obs_min, obs_max + wide, wide)
    
    simh_min = np.amin(simh)
    simh_max = np.amax(simh)
    wide = abs(simh_max - simh_min) / n_quantiles
    xbins_simh = np.arange(simh_min, simh_max + wide, wide)
    def get_cdf(x, xbins):
        pdf, _ = np.histogram(x, xbins)
        return np.insert(np.cumsum(pdf), 0, 0.0)
    
    cdf_obs = get_cdf(obs, xbins_obs)
    cdf_simh = get_cdf(simh, xbins_simh)
    
    epsilon = np.interp(simp, xbins_simh, cdf_simh)

    return np.interp(epsilon, cdf_obs, xbins_obs)

def qm_vec(sim_train_da, ml_train_da, ml_eval_da, n_quantiles=250, qm_1d_comp_func=qm_dom_aware):
    return (
        xr.apply_ufunc(
            qm_1d_comp_func,  # first the function
            sim_train_da,  # now arguments in the order expected by the function
            ml_train_da,
            ml_eval_da,
            kwargs=dict(n_quantiles=n_quantiles, kind="+"),
            input_core_dims=[
                ["time"],
                ["time"],
                ["time"],
            ],  # list with one entry per arg
            output_core_dims=[["time"]],
            exclude_dims=set(
                ("time",)
            ),  # dimensions allowed to change size. Must be set!
            vectorize=True,
        )
        .transpose("time", "grid_latitude", "grid_longitude")
        .assign_coords(time=ml_eval_da["time"])
    )



In [None]:
qm_output = []
for model, model_da in gcm_ml_pred_pr_split.groupby("model"):
    def qm(gp):
        return qm_vec(cpm_sim_pr_train, gcm_ml_pred_pr_train.sel(model=model), gp, qm_1d_comp_func=qm_dom_aware)
        
    qm_output.append(model_da.groupby("sample_id").map(qm).rename("pred_pr").assign_coords({'model': model+" qm"}))

qm_gcm_ml_pred_pr_split = xr.concat(qm_output, dim="model")
# qm_adj_pred_pr = qm_adj_pred_pr.assign_coords({'model':('model', qm_adj_pred_pr['model'].astype(np.dtype(object)).data+" qm",qm_adj_pred_pr['model'].attrs)})
# qm_adj_pred_pr = qm_adj_pred_pr.assign_coords({'model':('model',qm_adj_pred_pr['model'].astype(np.dtype(object)).data+" qm",qm_adj_pred_pr['model'].attrs)})


merged_ds["GCM"] = xr.merge([merged_ds["GCM"], qm_gcm_ml_pred_pr_split.to_dataset()])
merged_ds["GCM"]

## QQ and histogram

In [None]:
def freq_density_plot(ax, ds, target_pr, grouping_key="model"):
    pred_pr = ds["pred_pr"]

    hrange = (
        min(pred_pr.min().values, target_pr.min().values),
        max(pred_pr.max().values, target_pr.max().values),
    )
    _, bins, _ = target_pr.plot.hist(
        ax=ax,
        bins=50,
        density=True,
        color="black",
        alpha=0.2,
        label="CPM",
        log=True,
        range=hrange,
    )
    for group_value in pred_pr[grouping_key].values:
        pred_pr.sel({grouping_key: group_value}).plot.hist(
            ax=ax,
            bins=bins,
            density=True,
            alpha=0.75,
            histtype="step",
            label=f"{group_value}",
            log=True,
            range=hrange,
            linewidth=2,
            linestyle="-",
        )

    ax.set_title("Log density of samples and CPM precip")
    ax.set_xlabel("Precip (mm day-1)")
    ax.tick_params(axis="both", which="major")
    ax.legend()
    # ax.set_aspect(aspect=1)

def one_minus_cdf_plot(ax, ds, target_pr):
    pred_pr = ds["pred_pr"]
    
    hrange = (
        min(pred_pr.min().values, target_pr.min().values),
        max(pred_pr.max().values, target_pr.max().values),
    )
    _, bins, _ = target_pr.plot.hist(
        ax=ax,
        bins=50,
        density=True,
        cumulative=-1,
        color="black",
        alpha=0.2,
        label="CPM",
        log=True,
        range=hrange,
    )
    for group_name, group in pred_pr.groupby("model"):
        group.plot.hist(
            ax=ax,
            bins=bins,
            density=True,
            cumulative=-1,
            alpha=0.75,
            histtype="step",
            label=f"{group_name}",
            log=True,
            range=hrange,
            linewidth=2,
            linestyle="-",
        )

    ax.set_title("1 - CDF (log scale) of samples and CPM precip")
    ax.set_xlabel("Precip (mm day-1)")
    ax.tick_params(axis="both", which="major")
    ax.legend()

def qq_plot(
    ax,
    target_quantiles,
    sample_quantiles,
    grouping_key="model",
    title="Sample vs CPM quantiles",
    xlabel="CPM precip (mm/day)",
    ylabel="Sample precip (mm/day)",
    tr=200,
    bl=0,
    guide_label="Ideal",
    show_legend=True,
    **lineplot_args,
):
    ax.plot(
        [bl, tr],
        [bl, tr],
        color="black",
        linestyle="--",
        label=guide_label,
        alpha=0.2,
    )

    for label, group_quantiles in sample_quantiles.groupby(grouping_key):
        data = (
            group_quantiles.squeeze()
            .to_pandas()
            .dropna()  # bit of a hack while have some models just for GCM and others just for CPM
            .reset_index()
        )
        if grouping_key != "sample_id":
            data = data.melt(
                id_vars="quantile", value_vars=list(group_quantiles["sample_id"].values)
            )
        else:
            data = data.melt(id_vars="quantile", value_vars=[0])
        data = data.merge(
            target_quantiles.to_pandas().rename("cpm_quantile").reset_index()
        )

        kwargs = (
            dict(
                errorbar=None,
                marker="X",
                alpha=0.75,
            )
            | lineplot_args
        )
        sns.lineplot(
            data=data,
            x="cpm_quantile",
            y="value",
            ax=ax,
            label=label,
            **kwargs,
        )

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    legend = ax.legend()
    if not show_legend:
        legend.remove()
    ax.set_aspect(aspect=1)
    
def distribution_figure(
    ds,
    target_pr,
    quantiles,
    quantile_dims,
    grouping_key="model",
    density_kwargs=dict(),
    qq_kwargs=dict(),
):
    fig, axes = plt.subplot_mosaic(
        [["Density"]], figsize=(7, 3.5), constrained_layout=True
    )

    ax = axes["Density"]
    freq_density_plot(ax, ds, target_pr, grouping_key=grouping_key, **density_kwargs)
    plt.show()

    fig, axes = plt.subplot_mosaic(
        [["Quantiles"]], figsize=(3.5, 3.5), constrained_layout=True
    )

    ax = axes["Quantiles"]

    cpm_quantiles = target_pr.quantile(quantiles, dim=quantile_dims)

    sample_quantiles = ds["pred_pr"].quantile(quantiles, dim=quantile_dims)
    qq_plot(
        ax,
        cpm_quantiles,
        sample_quantiles,
        grouping_key=grouping_key,
        **({"title": None} | qq_kwargs),
    )
    plt.show()

    fig, axes = plt.subplot_mosaic(
        [ds["model"].values], figsize=(10.5, 3.5), constrained_layout=True
    )
    for model, model_quantiles in sample_quantiles.groupby("model"):
        qq_plot(
            axes[model],
            cpm_quantiles,
            model_quantiles,
            title=model,
            grouping_key="sample_id",
            alpha=0.5,
            show_legend=False,
        )
    plt.show()

### Training set

In [None]:
for nq in [250]:
    qm_gcm_ml_pred_pr_train = xr.concat([ 
        qm_vec(cpm_sim_pr_train, model_da, model_da, n_quantiles=nq).rename("pred_pr").assign_coords({'model': model+" qm"})
        for model, model_da in gcm_ml_pred_pr_train.groupby("model") 
    ], dim="model")
    fig, axes = plt.subplot_mosaic(
            [["Density"], ["CDF"]], figsize=(7, 7), constrained_layout=True
        )

    ax = axes["Density"]
    freq_density_plot(
        ax, 
        xr.merge([
            gcm_ml_pred_pr_train, 
            qm_gcm_ml_pred_pr_train, 
        ]), 
        cpm_sim_pr_train,
        grouping_key="model")

    ax = axes["CDF"]
    one_minus_cdf_plot(
        ax, 
        xr.merge([
            gcm_ml_pred_pr_train, 
            qm_gcm_ml_pred_pr_train, 
        ]), 
        cpm_sim_pr_train,
    )
    plt.show()

### Eval set

In [None]:
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -8, -1)] + [[1.0]])

distribution_figure(merged_ds["GCM"], cpm_sim_pr_split, quantiles, quantile_dims=["time", "grid_latitude", "grid_longitude"])

## PSD

In [None]:
def psd(batch):
    # npix = batch.shape[1]
    fourier = np.fft.fftshift(np.fft.fftn(batch, axes=(1, 2)), axes=(1, 2))
    amps = np.abs(fourier) ** 2  # / npix**2
    return amps

def raspd(precip_da):
    npix = precip_da["grid_latitude"].size
    fourier_amplitudes = psd(precip_da.values.reshape(-1, npix, npix))

    kfreq = np.fft.fftshift(np.fft.fftfreq(npix)) * npix
    kfreq2D = np.meshgrid(kfreq, kfreq)
    knrm = np.sqrt(kfreq2D[0] ** 2 + kfreq2D[1] ** 2)
    kbins = np.arange(-0.5, npix // 2 + 1, 1.0)
    kvals = 0.5 * (kbins[1:] + kbins[:-1])

    # radially average the means for each example
    # take mean of amplitudes of each example once grouped into tori
    # kbins defined equal-width (though not equal area) tori of Fourier plane with radius between start and end of each bin point
    # knrm is size of k at each point in the plane (basically Euclidean distance in Fourier plane from centre point of the 64x64 array) so can determine which torus each member of fourier amplitudes belongs
    Abins, _, _ = scipy.stats.binned_statistic(
        knrm.flatten(),
        fourier_amplitudes.reshape(-1, npix * npix),
        statistic="mean",
        bins=kbins,
    )
    # take mean over all the examples
    mean_Abins = np.mean(Abins, axis=0)

    return kvals, mean_Abins

def raspd_pysteps(precip_da):
    npix = precip_da["grid_latitude"].size
    fourier_amplitudes = psd(precip_da.values.reshape(-1, npix, npix))

    s1 = np.s_[-int(npix / 2) : int(npix / 2)]
    s2 = np.s_[-int(npix / 2) : int(npix / 2)]
    yc, xc = np.ogrid[s1, s2]

    r_grid = np.sqrt(xc * xc + yc * yc).round()

    r_range = np.arange(0, int(npix / 2))
    freq = np.fft.fftfreq(npix) * npix
    freq = freq[r_range]

    pys_result = []
    for r in r_range:
        mask = r_grid == r
        psd_vals = fourier_amplitudes[:, mask]
        pys_result.append(np.mean(psd_vals))

    mean_Abins = np.array(pys_result)

    return freq, mean_Abins

def plot_psd(cpm_hr_pr, gcm_lr_lin_pr, pred_pr):
    fig, axd = plt.subplot_mosaic(
        [["PSD"]], tight_layout=True#, figsize=(12, 12)
    )
    ax = axd["PSD"]

    ax.loglog(*raspd(cpm_hr_pr), label="CPM pr", color="black", linewidth=5, linestyle="-.", alpha=0.5)
    ax.loglog(*raspd(gcm_lr_lin_pr), label="GCM interp. pr", color="black", linewidth=5, linestyle="-.", alpha=0.5)

    for model, precip_da in pred_pr.groupby("model"):
        ax.loglog(*raspd(precip_da), label=model)

    ax.set_xlabel("$k$")
    ax.set_ylabel("$P(k)$")
    ax.legend(ncols=3)
    ax.set_title("RAPSD")

    return fig, axd


In [None]:
gcm_lr_lin_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("DERIVED_DATA"), "moose", "nc-datasets", gcm_lr_lin_pr_dataset, f"{split}.nc"
    )
)["linpr"]*3600*24).assign_attrs({"units": "mm day-1"})

cpm_hr_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("DERIVED_DATA"), "moose", "nc-datasets", cpm_hr_pr_dataset, f"{split}.nc"
    )
)["target_pr"]*3600*24).assign_attrs({"units": "mm day-1"})

plot_psd(cpm_hr_pr, gcm_lr_lin_pr, merged_ds["GCM"]["pred_pr"])

## Samples

In [None]:
cp_model_rotated_pole = ccrs.RotatedPole(pole_longitude=177.5, pole_latitude=37.5)

precip_clevs = [0, 0.1, 1, 2.5, 5, 7.5, 10, 15, 20, 30, 40, 50, 70, 100, 150, 200]
precip_cmap = matplotlib.colors.ListedColormap(
    metpy.plots.ctables.colortables["precipitation"][: len(precip_clevs) - 1],
    "precipitation",
)
precip_norm = matplotlib.colors.BoundaryNorm(precip_clevs, precip_cmap.N)

STYLES = {
    "precip": {"cmap": precip_cmap, "norm": precip_norm},
    "logBlues": {"cmap": "Blues", "norm": matplotlib.colors.LogNorm()},
}

def plot_map(da, ax, title="", style="logBlues", add_colorbar=False, **kwargs):
    if style is not None:
        kwargs = STYLES[style] | kwargs
    pcm = da.plot.pcolormesh(ax=ax, add_colorbar=add_colorbar, **kwargs)
    ax.set_title(title)
    ax.coastlines()
    return pcm

    
def plot_examples(ds, timestamps):
    thetas = [925, 850, 700, 500, 250]
    input_variables = ["vorticity850", "psl"]
    for ts in timestamps:
        grid_spec = [input_variables]
        fig, axd = plt.subplot_mosaic(
            grid_spec,
            figsize=(12, 2.5),
            constrained_layout=True,
            subplot_kw={"projection": cp_model_rotated_pole},
        )
        fig.suptitle(f"Inputs {ts}")
        for i, var in enumerate(input_variables):
            plot_map(ds.sel(time=ts)[var], ax=axd[var], style=None, title=var, add_colorbar=False)

        plt.show()

        grid_spec = [
            ["Target"]
            + [f"{model} Name"]
            + [
                f"{model} Sample {sample_idx}"
                for sample_idx in range(len(ds["sample_id"]))
            ]
            for model in ds["model"].values
        ]
        fig, axd = plt.subplot_mosaic(
            grid_spec,
            figsize=(12, 10),
            constrained_layout=True,
            subplot_kw={"projection": cp_model_rotated_pole},
        )
        fig.suptitle(f"Precip {ts}")

        ax = axd[f"Target"]
        plot_map(
            ds.sel(time=ts).isel(model=0)["target_pr"],
            ax,
            title=f"Simulation",
            cmap=precip_cmap,
            norm=precip_norm,
            add_colorbar=False,
        )

        for model in ds["model"].values:
            ax = axd[f"{model} Name"]
            ax.text(x=0, y=0, s=model)
            ax.set_axis_off()
            for sample_idx in range(len(ds["sample_id"].values)):
                ax = axd[f"{model} Sample {sample_idx}"]
                plot_map(
                    ds.sel(model=model, time=ts).isel(sample_id=sample_idx)["pred_pr"],
                    ax,
                    cmap=precip_cmap,
                    norm=precip_norm,
                    add_colorbar=False,
                    title=f"Sample",
                )

        ax = fig.add_axes([1.05, 0.0, 0.05, 0.95])
        cb = matplotlib.colorbar.Colorbar(ax, cmap=precip_cmap, norm=precip_norm)
        cb.ax.set_yticks(precip_clevs)
        cb.ax.set_yticklabels(precip_clevs)
        cb.ax.tick_params(axis="both", which="major")
        cb.ax.set_ylabel("Precip (mm day-1)")

        plt.show()

### Coarsened-CPM-driven

In [None]:
# CPM - 2034-12-03

cpm_samples_ds = merged_ds["CPM"].sel(time=[cftime._cftime.Datetime360Day(2034, 12, 3, 12, 0, 0)], method="nearest")
plot_examples(cpm_samples_ds, cpm_samples_ds["time"].values)

### GCM-driven

In [None]:
# GCM - 2022-10-21
gcm_samples_ds = merged_ds["GCM"].sel(time=[cftime._cftime.Datetime360Day(2022, 10, 21, 12, 0, 0)], method="nearest")
plot_examples(gcm_samples_ds, gcm_samples_ds["time"].values)