# SHAP and Transfer Learning
##### authors: Elizabeth A. Barnes

## Python stuff

In [None]:
import sys, os, copy
import importlib as imp

import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

import scipy.stats as stats
import seaborn as sns
import pickle

import shap
from scipy.optimize import curve_fit
import gc
import plots
import regions

import regionmask
import experiment_settings
import file_methods, plots, data_processing, transfer_learning, xai

import matplotlib as mpl
import cartopy as ct
import cartopy.feature as cfeature

mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
# plt.style.use("seaborn-notebook")
plt.style.use("seaborn-v0_8-notebook")

import warnings

warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [None]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")
print(f"tensorflow version = {tf.__version__}")
print(f"tensorflow-probability version = {tfp.__version__}")
print(f"shap version = {shap.__version__}")

## User Choices

In [None]:
PARENT_EXP_NAME = "exp134"  # "exp134"
EVAL_THRESHOLD = 2.0
RESAVE = True

# -------------------------------------------------------

RNG_SEED = 66
settings = experiment_settings.get_settings(PARENT_EXP_NAME)
settings["target_temp"] = EVAL_THRESHOLD
settings["rng_seed"] = RNG_SEED

# -------------------------------------------------------

MODEL_DIRECTORY = "saved_models/"
PREDICTIONS_DIRECTORY = "saved_predictions/"
DATA_DIRECTORY = (
    "../../../2022/target_temp_detection/data/"  # point to where your data is sitting
)
GCM_DATA_DIRECTORY = "../data/"
OBS_DIRECTORY = "../data/"
DIAGNOSTICS_DIRECTORY = "model_diagnostics/"
FIGURE_DIRECTORY = "figures/"

## Load the data

In [None]:
# Observational data
settings_obs = settings.copy()
settings_obs["training_only"] = True

settings_obs["obs_training_only"] = True
__, x_obs, global_mean_obs = data_processing.get_observations(
    OBS_DIRECTORY, settings_obs, verbose=False
)

settings_obs["obs_training_only"] = False
settings_obs["cumulative_history"] = False
settings_obs["cumulative_history_only"] = False
da_obs, __, __ = data_processing.get_observations(
    OBS_DIRECTORY, settings_obs, verbose=False
)
mask = regionmask.defined_regions.ar6.land.mask(da_obs)
da_obs_shape = da_obs.shape
del da_obs

# CMIP testing data
tf.keras.utils.set_random_seed(settings["rng_seed"])

(
    __,
    __,
    x_test,
    __,
    __,
    __,
    __,
    __,
    onehot_test,
    __,
    __,
    y_yrs_test,
    __,
    __,
    target_temps_test,
    target_years_region,
    map_shape,
    __,
) = data_processing.create_data(DATA_DIRECTORY, settings.copy(), verbose=0)

# One member per GCM data
settings_gcm = settings.copy()
settings_gcm["target_region"] = None
settings_gcm["anomaly_yr_bounds"] = settings_gcm["baseline_yr_bounds"]
cmip_data = data_processing.get_one_model_one_vote_data(
    GCM_DATA_DIRECTORY, settings_gcm
)

In [None]:
# raise ValueError()

## Transfer Learning Main Analysis

In [None]:
import network
imp.reload(data_processing)
imp.reload(plots)
imp.reload(transfer_learning)
imp.reload(network)

IPCC_REGION_LIST = regionmask.defined_regions.ar6.land.abbrevs

obs_transfer_dict = {}
obs_transfer_dict[1.5] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_transfer_dict[2.0] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_transfer_dict[3.0] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_base_dict = copy.deepcopy(obs_transfer_dict)

for ireg, ipcc_region in enumerate(IPCC_REGION_LIST):
    # if ipcc_region not in ("WCE", "CNA", "SAH", "ESB", "NSA", "CAU"):
    #     continue
    # if ipcc_region not in ("WCE",):
    #     continue
    if ipcc_region not in ("CNA",):
        continue

    # ------------------------------------
    # SETTINGS AND MODEL SETUP

    # define particular settings
    settings["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region
    settings["target_region"] = "ipcc_" + ipcc_region

    settings_obs["target_region"] = "ipcc_" + ipcc_region
    settings_obs["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region

    settings_obs["obs_training_only"] = False
    settings_obs["cumulative_history"] = False
    settings_obs["cumulative_history_only"] = False
    da_obs, __, __ = data_processing.get_observations(
        OBS_DIRECTORY, settings_obs, verbose=False
    )
    da_obs, __, __ = regions.extract_region(settings_obs, da_obs)

    # load the models
    tf.keras.backend.clear_session()
    tf.keras.utils.set_random_seed(settings_obs["rng_seed"])

    try:
        model_name = file_methods.get_model_name(settings_obs)
        model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
        transfer_model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
    except:
        continue

    print("---" + str(ireg) + ": " + ipcc_region + "---")

    # ------------------------------------
    # FINE-TUNE WITH TRANSFER LEARNING

    (
        transfer_model,
        obs_timeseries,
        obs_years,
        obs_yearsvals_dict,
    ) = transfer_learning.perform_transfer_learning(
        transfer_model, da_obs, x_obs, settings_obs, plot=False
    )

    obs_base_dict = transfer_learning.compute_threshold_predictions(
        obs_base_dict, ireg, model, x_obs
    )
    obs_transfer_dict = transfer_learning.compute_threshold_predictions(
        obs_transfer_dict, ireg, transfer_model, x_obs
    )

    # plot the transfer learning results
    cmip_masked_region = xr.where(mask == ireg, cmip_data, np.nan).transpose(
        "gcm", "time", "lat", "lon"
    )
    plots.plot_transferlearning_timeseries(
        model,
        transfer_model,
        x_obs,
        obs_timeseries,
        obs_years,
        obs_yearsvals_dict,
        cmip_masked_region,
        settings,
        title=ipcc_region,
    )
    plt.tight_layout()
    plots.savefig(
        FIGURE_DIRECTORY + model_name + "_transfer_learning",
        dpi=savefig_dpi,
    )
    # plt.show()
    plt.close()
    # continue

    # ------------------------------------
    # CLEAR BIG THINGS
    try:
        del (da_obs,)
        _ = gc.collect()
    except:
        pass

    # ------------------------------------
    # SAVE THE TRANSFER LEARNING DATA
    if RESAVE:
        filename = (
            PREDICTIONS_DIRECTORY
            + PARENT_EXP_NAME
            + "_rng_seed"
            + str(settings_obs["rng_seed"])
            + "_observations_"
            + str(settings["final_year_of_obs"])
            + "_predictions_base.pickle"
        )
        with open(filename, "wb") as f:
            pickle.dump(
                [
                    obs_base_dict,
                ],
                f,
            )

        filename = (
            PREDICTIONS_DIRECTORY
            + PARENT_EXP_NAME
            + "_rng_seed"
            + str(settings_obs["rng_seed"])
            + "_observations_"
            + str(settings["final_year_of_obs"])
            + "_predictions_transfer.pickle"
        )
        with open(filename, "wb") as f:
            pickle.dump(
                [
                    obs_transfer_dict,
                ],
                f,
            )
        # print("data saved.")
print("done.")

In [None]:
raise ValueError()

## Transfer Learning Schematic

In [None]:
imp.reload(data_processing)
imp.reload(plots)
imp.reload(transfer_learning)

IPCC_REGION_LIST = regionmask.defined_regions.ar6.land.abbrevs

obs_transfer_dict = {}
obs_transfer_dict[1.5] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_transfer_dict[2.0] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_transfer_dict[3.0] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_base_dict = copy.deepcopy(obs_transfer_dict)

for ireg, ipcc_region in enumerate(IPCC_REGION_LIST):
    if ipcc_region not in ("WCE",):
        continue

    # ------------------------------------
    # SETTINGS AND MODEL SETUP

    # define particular settings
    settings["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region
    settings["target_region"] = "ipcc_" + ipcc_region

    settings_obs["target_region"] = "ipcc_" + ipcc_region
    settings_obs["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region

    settings_obs["obs_training_only"] = False
    settings_obs["cumulative_history"] = False
    settings_obs["cumulative_history_only"] = False
    da_obs, __, __ = data_processing.get_observations(
        OBS_DIRECTORY, settings_obs, verbose=False
    )
    da_obs, __, __ = regions.extract_region(settings_obs, da_obs)

    # load the models
    tf.keras.backend.clear_session()
    tf.keras.utils.set_random_seed(settings_obs["rng_seed"])

    try:
        model_name = file_methods.get_model_name(settings_obs)
        model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
        transfer_model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
    except:
        continue

    print("---" + str(ireg) + ": " + ipcc_region + "---")

    # ------------------------------------
    # FINE-TUNE WITH TRANSFER LEARNING

    (
        transfer_model,
        obs_timeseries,
        obs_years,
        obs_yearsvals_dict,
    ) = transfer_learning.perform_transfer_learning(
        transfer_model, da_obs, x_obs, settings_obs, plot=True
    )

print("done.")

In [None]:
raise ValueError()

## Transfer Learning on only 3 temperature thresholds

In [None]:
imp.reload(data_processing)
imp.reload(plots)
imp.reload(transfer_learning)

EXP_SUFFIX = "_3thresholds"
IPCC_REGION_LIST = regionmask.defined_regions.ar6.land.abbrevs

obs_transfer_dict = {}
obs_transfer_dict[1.5] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_transfer_dict[2.0] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_transfer_dict[3.0] = np.zeros((len(IPCC_REGION_LIST), 2)) * np.nan
obs_base_dict = copy.deepcopy(obs_transfer_dict)

for ireg, ipcc_region in enumerate(IPCC_REGION_LIST):
    # ------------------------------------
    # SETTINGS AND MODEL SETUP

    # define particular settings
    settings["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region
    settings["target_region"] = "ipcc_" + ipcc_region
    settings["transfer_temp_vec"] = (0.8, 1.0, 1.2)

    settings_obs["target_region"] = "ipcc_" + ipcc_region
    settings_obs["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region
    settings_obs["transfer_temp_vec"] = settings["transfer_temp_vec"]

    settings_obs["obs_training_only"] = False
    settings_obs["cumulative_history"] = False
    settings_obs["cumulative_history_only"] = False
    da_obs, __, __ = data_processing.get_observations(
        OBS_DIRECTORY, settings_obs, verbose=False
    )
    da_obs, __, __ = regions.extract_region(settings_obs, da_obs)

    # load the models
    tf.keras.backend.clear_session()
    tf.keras.utils.set_random_seed(settings_obs["rng_seed"])

    try:
        model_name = file_methods.get_model_name(settings_obs)
        model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
        transfer_model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
    except:
        continue

    print("---" + str(ireg) + ": " + ipcc_region + "---")

    # ------------------------------------
    # FINE-TUNE WITH TRANSFER LEARNING

    (
        transfer_model,
        obs_timeseries,
        obs_years,
        obs_yearsvals_dict,
    ) = transfer_learning.perform_transfer_learning(
        transfer_model, da_obs, x_obs, settings_obs, plot=False, suffix=EXP_SUFFIX
    )

    obs_base_dict = transfer_learning.compute_threshold_predictions(
        obs_base_dict, ireg, model, x_obs
    )
    obs_transfer_dict = transfer_learning.compute_threshold_predictions(
        obs_transfer_dict, ireg, transfer_model, x_obs
    )

    # plot the transfer learning results
    cmip_masked_region = xr.where(mask == ireg, cmip_data, np.nan).transpose(
        "gcm", "time", "lat", "lon"
    )
    plots.plot_transferlearning_timeseries(
        model,
        transfer_model,
        x_obs,
        obs_timeseries,
        obs_years,
        obs_yearsvals_dict,
        cmip_masked_region,
        settings,
        title=ipcc_region,
    )
    plt.tight_layout()
    plots.savefig(
        FIGURE_DIRECTORY + model_name + "_transfer_learning" + EXP_SUFFIX,
        dpi=savefig_dpi,
    )
    # plt.show()
    plt.close()

    # ------------------------------------
    # CLEAR BIG THINGS
    try:
        del (da_obs,)
        _ = gc.collect()
    except:
        pass

    # ------------------------------------
    # SAVE THE TRANSFER LEARNING DATA
    if RESAVE:
        filename = (
            PREDICTIONS_DIRECTORY
            + PARENT_EXP_NAME
            + "_rng_seed"
            + str(settings_obs["rng_seed"])
            + "_observations_"
            + str(settings["final_year_of_obs"])
            + "_predictions_base"
            + EXP_SUFFIX
            + ".pickle"
        )
        with open(filename, "wb") as f:
            pickle.dump(
                [
                    obs_base_dict,
                ],
                f,
            )

        filename = (
            PREDICTIONS_DIRECTORY
            + PARENT_EXP_NAME
            + "_rng_seed"
            + str(settings_obs["rng_seed"])
            + "_observations_"
            + str(settings["final_year_of_obs"])
            + "_predictions_transfer"
            + EXP_SUFFIX
            + ".pickle"
        )
        with open(filename, "wb") as f:
            pickle.dump(
                [
                    obs_transfer_dict,
                ],
                f,
            )
        # print("data saved.")
print("done.")

In [None]:
raise ValueError()

## XAI with SHAP

In [None]:
imp.reload(xai)
imp.reload(plots)

tf.config.set_visible_devices(
    [], "GPU"
)  # DO NOT TURN ON THE GPU! There are issues with DEEPSHAP!

# get the data put together for deep-shap, including the baseline
xai_settings = {
    "obs_start": 2000,
    "obs_end": settings["final_year_of_obs"],
    "baseline_factor": 0.0,
    "baseline_start": 2000,
    "baseline_end": 2000,
    "n_base_samples": 1,
    "n_cmip_samples": None,
    "target_temp": settings["target_temp"],
    "rng_seed_list": settings["rng_seed_list"],
    "rng_seed_list": (66,),
    # "rng_seed": None,
    "diff_scaling": 0.2,
}

IPCC_REGION_LIST = regionmask.defined_regions.ar6.land.abbrevs

for ireg, ipcc_region in enumerate(IPCC_REGION_LIST):
    if ipcc_region not in ("WCE", "CNA", "SAH", "ESB", "NSA", "CAU"):
        continue
    # if ipcc_region not in ("WCE",):
    #     continue

    print("---" + str(ireg) + ": " + ipcc_region + "---")

    original_shap_seeds = np.zeros(da_obs_shape[1:])
    transfer_shap_seeds = np.zeros(da_obs_shape[1:])
    cmip_shap_seeds = np.zeros(da_obs_shape[1:])

    for seed in xai_settings["rng_seed_list"]:
        xai_settings["rng_seed"] = seed

        # ------------------------------------
        # SETTINGS AND MODEL SETUP

        # define particular settings
        settings["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region
        settings["target_region"] = "ipcc_" + ipcc_region
        settings["rng_seed"] = xai_settings["rng_seed"]

        settings_obs["target_region"] = "ipcc_" + ipcc_region
        settings_obs["exp_name"] = PARENT_EXP_NAME + "_" + ipcc_region
        settings_obs["rng_seed"] = xai_settings["rng_seed"]

        settings_obs["obs_training_only"] = False
        settings_obs["cumulative_history"] = False
        settings_obs["cumulative_history_only"] = False
        da_obs, __, __ = data_processing.get_observations(
            OBS_DIRECTORY, settings_obs, verbose=False
        )
        x_obs_years = da_obs["time.year"].to_numpy()[-x_obs.shape[0] :]

        # load the models
        tf.keras.backend.clear_session()
        tf.keras.utils.set_random_seed(settings_obs["rng_seed"])

        model_name = file_methods.get_model_name(settings_obs)
        model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
        print(model_name)

        transfer_model_name = file_methods.get_model_name(settings) + "_transfer"
        transfer_model = file_methods.load_tf_model(
            transfer_model_name, MODEL_DIRECTORY
        )
        print(transfer_model_name)

        (
            original_shap,
            transfer_shap,
            cmip_shap,
            expected_value,
            expected_value_transfer,
        ) = xai.calculate_shap_values(
            model,
            transfer_model,
            x_obs,
            x_obs_years,
            x_test,
            y_yrs_test,
            xai_settings,
            baseline_factor=xai_settings["baseline_factor"],
        )

        # just add shap maps together
        original_shap_seeds = original_shap_seeds + original_shap[0][0][
            :, :, :, 0
        ].mean(axis=0)
        transfer_shap_seeds = transfer_shap_seeds + transfer_shap[0][0][
            :, :, :, 0
        ].mean(axis=0)
        cmip_shap_seeds = cmip_shap_seeds + cmip_shap[0][0][:, :, :, 0].mean(axis=0)

        # rescale by max value before adding together
        # original_shap_seeds = original_shap_seeds + original_shap[0][0][:, :, :, 0].mean(axis=0) / np.abs(np.min(original_shap[0][0][:, ilat, :, 0].mean(axis=0)))
        # transfer_shap_seeds = transfer_shap_seeds + transfer_shap[0][0][:, :, :, 0].mean(axis=0) / np.abs(np.min(transfer_shap[0][0][:, ilat, :, 0].mean(axis=0)))
        # cmip_shap_seeds = cmip_shap_seeds + cmip_shap[0][0][:, :, :, 0].mean(axis=0) / np.abs(np.min(cmip_shap[0][0][:, ilat, :, 0].mean(axis=0)))

    # CREATE THE PLOT
    ilat = np.where(np.abs(da_obs.lat) < 60)[0]  # used for scaling

    plots.plot_xai_heatmaps(
        original_shap_seeds,
        transfer_shap_seeds,
        cmip_shap_seeds,
        da_obs.lat,
        da_obs.lon,
        ipcc_region,
        subplots=2,
        scaling = 1e3 / np.abs(expected_value[0]),
        # scaling=1. / (np.abs(np.min(original_shap_seeds[ilat, :]))*.7),
        diff_scaling=xai_settings["diff_scaling"],
        # title=f"\n{model_name} {xai_settings['target_temp']}C\n{xai_settings['obs_start']}-{xai_settings['obs_end']}",
        title=None,
        colorbar=False,
    )
    plt.tight_layout()
    plots.savefig(
        FIGURE_DIRECTORY + model_name + "_shap_values_2subplots",
        dpi=savefig_dpi,
    )
    plt.show()
    # plt.close()

    # break
print("done.")