# Probability Integral Transform
##### authors: Elizabeth A. Barnes

## Python stuff

In [None]:
import sys, os, copy
import importlib as imp

import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

import scipy.stats as stats
import seaborn as sns
import custom_metrics

import shap
from scipy.optimize import curve_fit
import gc
import plots
import regions

import regionmask
import experiment_settings
import file_methods, plots, data_processing, transfer_learning, xai

import matplotlib as mpl
import cartopy as ct
import cartopy.feature as cfeature

mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
# plt.style.use("seaborn-notebook")
plt.style.use("seaborn-v0_8-notebook")

import warnings

warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [None]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")
print(f"tensorflow version = {tf.__version__}")
print(f"tensorflow-probability version = {tfp.__version__}")
print(f"shap version = {shap.__version__}")

## User Choices

In [None]:
PARENT_EXP_NAME = "exp134"
# PARENT_EXP_NAME = "exp082"
RNG_SEED = 66
ipcc_region = "SAH"

# -------------------------------------------------------

settings = experiment_settings.get_settings(PARENT_EXP_NAME)
settings["target_temp"] = None
settings["rng_seed"] = RNG_SEED
settings["target_region"] = "ipcc_" + ipcc_region
settings["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region

# -------------------------------------------------------

MODEL_DIRECTORY = "saved_models/"
PREDICTIONS_DIRECTORY = "saved_predictions/"
DATA_DIRECTORY = (
    "../../../2022/target_temp_detection/data/"  # point to where your data is sitting
)
GCM_DATA_DIRECTORY = "../data/"
OBS_DIRECTORY = "../data/"
DIAGNOSTICS_DIRECTORY = "model_diagnostics/"
FIGURE_DIRECTORY = "figures/"

In [None]:
# raise ValueError()

In [None]:

(
    __,
    x_val,
    x_test,
    __,
    __,
    __,
    __,
    onehot_val,
    onehot_test,
    __,
    __,
    y_yrs_test,
    __,
    target_temps_val,
    target_temps_test,
    target_years_region,
    map_shape,
    __,
) = data_processing.create_data(DATA_DIRECTORY, settings.copy(), verbose=0)

tf.keras.backend.clear_session()
tf.keras.utils.set_random_seed(settings["rng_seed"])

model_name = file_methods.get_model_name(settings)
model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
print(model_name)

In [None]:
# concatenate validation and testing data for PIT calculation
x_data = np.concatenate((x_val, x_test), axis=0)
target_temps_data = np.concatenate((target_temps_val, target_temps_test), axis=0)
onehot_data = np.concatenate((onehot_val, onehot_test), axis=0)
print(f"{onehot_data.shape = }")

# make model predictions on the validation and testing data
p = model.predict((x_data, target_temps_data), verbose=None)
__ = gc.collect()

# compute PIT
bins, hist_shash, D_shash, EDp_shash = custom_metrics.compute_pit(onehot_data, p)

In [None]:
plt.figure(figsize=(8,5))
ax = plt.subplot(1,1,1)
clr_shash = "tab:purple"
bins_inc = bins[1] - bins[0]

bin_add = bins_inc / 2
bin_width = bins_inc * 0.98
ax.bar(
    hist_shash[1][:-1] + bin_add,
    hist_shash[0],
    width=bin_width,
    color=clr_shash,
    label="SHASH",
)

# make the figure pretty
ax.axhline(
    y=0.1,
    linestyle="--",
    color="k",
    linewidth=2.0,
)
# ax = plt.gca()
yticks = np.around(np.arange(0, 0.55, 0.05), 2)
plt.yticks(yticks, yticks)
ax.set_ylim(0, 0.25)
ax.set_xticks(bins, np.around(bins, 1))

plt.text(
    0.0,
    np.max(ax.get_ylim()) * 0.99,
    "D statistic: "
    + str(np.round(D_shash, 4))
    + " ("
    + str(np.round(EDp_shash, 3))
    + ")",
    color=clr_shash,
    verticalalignment="top",
    fontsize=12,
)

ax.set_xlabel("probability integral transform")
ax.set_ylabel("probability")
plt.title(ipcc_region, fontsize=12, color='k')

plots.savefig(
    FIGURE_DIRECTORY
    + model_name
    + "_pit",
    dpi=savefig_dpi,
)

plt.show()

In [None]:
raise ValueError()

## Loop through experiments to compute PIT

In [None]:
IPCC_REGION_LIST = regionmask.defined_regions.ar6.land.abbrevs
EXP_NAME_LIST = ("exp134", "exp082")

pitd_dict = {}
for exp in EXP_NAME_LIST:
    pitd_dict[exp] = np.zeros(len(IPCC_REGION_LIST)) * np.nan

for ireg, ipcc_region in enumerate(IPCC_REGION_LIST):
    # if ipcc_region not in ("WCE", "CNA", "SAH", "ESB", "NSA", "CAU"):
    #     continue

    settings["target_region"] = "ipcc_" + ipcc_region

    (
        __,
        x_val,
        x_test,
        __,
        __,
        __,
        __,
        onehot_val,
        onehot_test,
        __,
        __,
        y_yrs_test,
        __,
        target_temps_val,
        target_temps_test,
        target_years_region,
        map_shape,
        __,
    ) = data_processing.create_data(DATA_DIRECTORY, settings.copy(), verbose=0)

    for PARENT_EXP_NAME in EXP_NAME_LIST:
        settings["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region

        tf.keras.backend.clear_session()
        tf.keras.utils.set_random_seed(settings["rng_seed"])

        model_name = file_methods.get_model_name(settings)
        model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
        print(model_name)

        # concatenate validation and testing data for PIT calculation
        x_data = np.concatenate((x_val, x_test), axis=0)
        target_temps_data = np.concatenate((target_temps_val, target_temps_test), axis=0)
        onehot_data = np.concatenate((onehot_val, onehot_test), axis=0)
        print(f"{onehot_data.shape = }")

        # make model predictions on the validation and testing data
        p = model.predict((x_data, target_temps_data), verbose=None)
        __ = gc.collect()

        # compute PIT
        bins, hist_shash, D_shash, EDp_shash = custom_metrics.compute_pit(onehot_data, p)

        # put value in dictionary
        pitd_dict[PARENT_EXP_NAME][ireg] = D_shash

print("done.")

In [None]:
bins = np.arange(0,0.04, .0025)

for exp in EXP_NAME_LIST:
    plt.hist(pitd_dict[exp], bins, histtype="step", linewidth=3, alpha=.5, label=exp, density=False)
plt.legend()
plt.yticks(range(19))
plt.xlabel("PIT D statistic")
plt.ylabel("count")

plots.savefig(
    FIGURE_DIRECTORY
    + "compare_pitd_across_experiments"
    + "_rng_seed" + str(settings["rng_seed"]),
    dpi=savefig_dpi,
)
plt.show()