# SHAP and Transfer Learning
##### authors: Elizabeth A. Barnes

## Python stuff

In [None]:
# %matplotlib inline
# %load_ext autotime

import sys, os, copy
import importlib as imp

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from scipy.optimize import curve_fit
import custom_metrics
import gc
import regionmask
import pandas as pd
import experiment_settings
import file_methods, plots, data_processing
import matplotlib as mpl
import transfer_learning
import regions
import cartopy as ct
import pickle

mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
plt.style.use("seaborn-v0_8-notebook")

import warnings

warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# tf.config.set_visible_devices([], "GPU")  # turn-off tensorflow-metal if it is on

## User Choices

In [None]:
MODEL_DIRECTORY = "saved_models/"
PREDICTIONS_DIRECTORY = "saved_predictions/"
DATA_DIRECTORY = (
    "../../../2022/target_temp_detection/data/"  # point to where your data is sitting
)
GCM_DATA_DIRECTORY = "../data/"
DIAGNOSTICS_DIRECTORY = "model_diagnostics/"
FIGURE_DIRECTORY = "figures/"
OBS_DIRECTORY = "../data/"

## Compare hyperparameters

In [None]:
imp.reload(data_processing)
IPCC_REGION_LIST = regionmask.defined_regions.ar6.land.abbrevs

exp_list = (
    "exp134",
    "exp061",
    "exp062",
    "exp063",
    "exp064",
    "exp065",
    "exp066",
    "exp067",
    "exp068",
)
ipcc_region = "WCE"
ireg = IPCC_REGION_LIST.index(ipcc_region)

# load the models

# ------------------------------------
df_metrics = pd.DataFrame()
for parent_exp_name in exp_list:
    try:
        settings = experiment_settings.get_settings(parent_exp_name)
    except:
        continue

    for rng_seed in settings["rng_seed_list"]:
        settings["exp_name"] = parent_exp_name + "_" + ipcc_region
        settings["rng_seed"] = rng_seed
        model_name = file_methods.get_model_name(settings)

        # get base model metrics
        try:
            df = pd.read_pickle(PREDICTIONS_DIRECTORY + model_name + "_metrics.pickle")
        except:
            continue

        dict_df = {}
        dict_df["exp_name"] = parent_exp_name
        dict_df["rng_seed"] = rng_seed

        dict_df["loss_test"] = df["loss_test"][0]
        dict_df["error_test"] = df["error_test"][0]
        dict_df["d_test"] = df["d_test"][0]

        dict_df["loss_val"] = df["loss_val"][0]
        dict_df["error_val"] = df["error_val"][0]
        dict_df["d_val"] = df["d_val"][0]

        dict_df["d_valtest"] = df["d_valtest"][0]

        # -----------------------------------------------
        # GET OBS METRICS AND PREDICTIONS
        settings["target_temp"] = 3.0
        settings["rng_seed"] = rng_seed

        # LOAD BASE MODEL PREDICTIONS
        filename = (
            PREDICTIONS_DIRECTORY
            + parent_exp_name
            + "_rng_seed"
            + str(settings["rng_seed"])
            + "_observations_"
            + str(settings["final_year_of_obs"])
            + "_predictions_base"
            + ".pickle"
        )
        try:
            with open(filename, "rb") as f:
                (obs_base_dict,) = pickle.load(f)
            obs_base = obs_base_dict[settings["target_temp"]][ireg, :]
            dict_df["obs_base_3C_mu"] = obs_base[0] + settings["final_year_of_obs"]

            dict_df["obs_base_3C_sigma"] = obs_base[1]
        except:
            pass

        # LOAD TRANSFER MODEL PREDICTIONS
        filename = (
            PREDICTIONS_DIRECTORY
            + parent_exp_name
            + "_rng_seed"
            + str(settings["rng_seed"])
            + "_observations_"
            + str(settings["final_year_of_obs"])
            + "_predictions_transfer"
            + ".pickle"
        )
        try:
            with open(filename, "rb") as f:
                (obs_transfer_dict,) = pickle.load(f)
            obs_transfer = obs_transfer_dict[settings["target_temp"]][ireg, :]
            dict_df["obs_transfer_3C_mu"] = obs_transfer[0] + settings["final_year_of_obs"]
            dict_df["obs_transfer_3C_sigma"] = obs_transfer[1]
        except:
            pass

        # -----------------------------------------------
        # LOAD EVERYTHING INTO THE MAIN DATAFRAME
        df_metrics = pd.concat([df_metrics, pd.DataFrame(dict_df, index=[0])])

df_metrics = df_metrics.reset_index(drop=True,)
df_metrics

In [None]:
import seaborn as sns

labels = ["default", "hp1", "hp2", "hp3", "hp4", "hp5", "hp6", "hp7", "hp8"]

palette = (
    "tab:gray",
    "tab:purple",
    "tab:orange",
    "tab:blue",
    "tab:red",
    "tab:green",
    "tab:pink",
    "tab:brown",
    "tab:olive",
)
clr_order = [2, 0, 0, 0, 0, 0, 0, 0, 0,]
size = 3.0
FS = 10


fig, axs = plt.subplots(2, 3, figsize=(8.4, 5.0))
axs = axs.flatten()

ax = axs[0]
plots.format_spines(ax)
sns.swarmplot(
    x="exp_name",
    y="error_val",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
    marker="$\circ$",
)
sns.swarmplot(
    x="exp_name",
    y="error_test",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
)
ax.set_title("[a] Mean Absolute Error\n(years)", fontsize=FS)
ax.set_ylim(2.0, None)
ax.tick_params(axis="x", labelrotation=45, labelsize=FS * 0.8)
ax.tick_params(axis="y", labelsize=FS * 0.8)
ax.set_xlabel('', fontsize=FS)
ax.set_ylabel('', fontsize=FS)
ax.set_xticklabels(labels)

ax = axs[1]
plots.format_spines(ax)
sns.swarmplot(
    x="exp_name",
    y="loss_val",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
    marker="$\circ$",
)
sns.swarmplot(
    x="exp_name",
    y="loss_test",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
)
ax.set_title("[b] Loss", fontsize=FS)
ax.set_ylim(2.5, None)
ax.tick_params(axis="x", labelrotation=45, labelsize=FS * 0.8)
ax.tick_params(axis="y", labelsize=FS * 0.8)
ax.set_xlabel('', fontsize=FS)
ax.set_ylabel('', fontsize=FS)
ax.set_xticklabels(labels)

ax = axs[2]
plots.format_spines(ax)
sns.swarmplot(
    x="exp_name",
    y="d_valtest",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
    # marker="$\circ$",
)
# sns.swarmplot(
#     x="exp_name",
#     y="d_test",
#     palette=np.array(palette)[clr_order],
#     data=df_metrics,
#     size=size,
#     ax=ax,
# )
ax.set_title("[c] PIT D Metric", fontsize=FS)
ax.set_ylim(0, None)
# ax.set_xticks(np.arange(0,5), fontsize=FS*0.8, rotation=45)
ax.tick_params(axis="x", labelrotation=45, labelsize=FS * 0.8)
ax.tick_params(axis="y", labelsize=FS * 0.8)
ax.set_xlabel('', fontsize=FS)
ax.set_ylabel('', fontsize=FS)
ax.set_xticklabels(labels)

ax = axs[3]
plots.format_spines(ax)
sns.swarmplot(
    x="exp_name",
    y="obs_base_3C_mu",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
    marker="$\circ$",
)
sns.swarmplot(
    x="exp_name",
    y="obs_transfer_3C_mu",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
)
ax.set_title("[d] 3.0C Obs. Prediction\nInitialized w/ 2023", fontsize=FS)
ax.set_ylim(2020, 2070)
ax.tick_params(axis="x", labelrotation=45, labelsize=FS * 0.8)
ax.tick_params(axis="y", labelsize=FS * 0.8)
ax.set_ylabel('', fontsize=FS)
ax.set_xlabel('', fontsize=FS)
ax.set_xticklabels(labels)

ax = axs[4]
plots.format_spines(ax)
sns.swarmplot(
    x="exp_name",
    y="obs_base_3C_sigma",
    palette=np.array(palette)[clr_order],
    data=df_metrics,
    size=size,
    ax=ax,
    marker="$\circ$",
)
ax.set_title("[e] 3.0C Obs. Uncertainty Prediction\nInitialized w/ 2023 (years)", fontsize=FS)
ax.set_ylim(0., 20.)
ax.tick_params(axis="x", labelrotation=45, labelsize=FS * 0.8)
ax.tick_params(axis="y", labelsize=FS * 0.8)
ax.set_ylabel('', fontsize=FS)
ax.set_xlabel('', fontsize=FS)
ax.set_xticklabels(labels)

axs[5].remove()

plt.tight_layout()

plots.savefig(
        FIGURE_DIRECTORY + model_name + "_hyperparameter_sweep_panels",
        dpi=savefig_dpi,
    )

plt.show()