# Analyze trained models

Rader et al. 2023

In [None]:
%matplotlib inline
%load_ext autotime

import palettable.scientific.sequential
import importlib as imp
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import experiments
import base_directories
import tensorflow as tf
import build_model
import build_data
import plots
import metrics
import pickle
import save_load_model_run
import silence_tensorflow.auto
from silence_tensorflow import silence_tensorflow
silence_tensorflow()
from shapely.errors import ShapelyDeprecationWarning
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)

dir_settings = base_directories.get_directories()

In [None]:
imp.reload(experiments)
exp_name = "exp000"
settings = experiments.get_experiment(exp_name)
settings["rng_seed"] = 34  # 23, 34, 45
settings["model_type"] = "interp_model"

In [None]:
# get training / validation / testing data
(
    analog_input,
    analog_output,
    soi_train_input,
    soi_train_output,
    soi_val_input,
    soi_val_output,
    soi_test_input,
    soi_test_output,
    input_standard_dict,
    output_standard_dict,
    lat,
    lon,
) = build_data.build_data(settings, dir_settings["data_directory"])

print(soi_test_input.shape)
print(analog_input.shape)

## Explore masks

In [None]:
# LOAD THE TRAINED MODEL
tf.keras.backend.clear_session()
savename_prefix = (
        exp_name
        + "_" + settings["model_type"] + "_"
        + f"rng_seed_{settings['rng_seed']}"
)
settings["savename_prefix"] = savename_prefix

model = save_load_model_run.load_model(settings, settings["savename_prefix"], [soi_train_input, analog_input])

mask_model = model.get_layer('mask_model')
dissimilarity_model = model.get_layer('dissimilarity_model')
prediction_model = model.get_layer('prediction_model')

In [None]:
gen = build_data.batch_generator(settings, soi_train_input, soi_train_output,
                                  analog_input, analog_output, batch_size=1_000, rng_seed=settings["rng_seed"])
x_input_train, target_train = next(gen)
gen.close()

In [None]:
imp.reload(plots)

(weights_train,
 dissimilarities_train,
 prediction_train,
 ) = build_model.parse_model([x_input_train[0], x_input_train[1]], mask_model, dissimilarity_model, prediction_model)

# PLOT THE MASKS AND THEIR VARIANCE

# Training mean mask
fig = plt.figure(figsize=(15, 5*2))
ax1, climits = plots.plot_interp_masks(fig, settings, weights_train[:, :, :, 0].mean(axis=0), lat=lat, lon=lon, central_longitude=215., title_text="(a) Mask for Channel 0", subplot=(2, 2, 1))
ax2 = plots.plot_interp_masks(fig, settings, weights_train[:, :, :, 1].mean(axis=0), lat=lat, lon=lon, central_longitude=215., climits=climits, title_text="(b) Mask for Channel 1", subplot=(2, 2, 2))

# Validation mean mask
ax3, climits = plots.plot_interp_masks(fig, settings, weights_train[:, :, :, 0].var(axis=0), lat=lat, lon=lon,
                                       central_longitude=215., title_text="(c) Variance of Mask for Channel 0",
                                       subplot=(2, 2, 3), )
ax4, climits = plots.plot_interp_masks(fig, settings, weights_train[:, :, :, 1].var(axis=0), lat=lat, lon=lon,
                                       central_longitude=215., climits=climits, title_text="(d) Variance of Mask for Channel 1",
                                       subplot=(2, 2, 4), )
plt.tight_layout()
plt.show()

In [None]:
print('UNIQUE MASK SUMS = ' + str(np.unique(np.sum(weights_train[:,:,:,:],axis=(1,2,3))).round(1)))
print('UNIQUE MASK NORMALIZED SUMS = ' + str(np.unique(np.sum(weights_train[:,:,:,:],axis=(1,2,3))/(len(lat)*len(lon))).round(4)))

## Explore predictions

In [None]:
import gc
y_predict = model.predict([x_input_train[0], x_input_train[1]])
_ = gc.collect()

plt.figure()
plt.hist(y_predict)
plt.title('histogram of training predictions')
plt.show()

In [None]:
synthetic_similarities = np.arange(0,np.max(dissimilarities_train)*1.1,.01)
prediction = prediction_model([synthetic_similarities])

plt.figure(figsize=(6,3))
plt.plot(synthetic_similarities,prediction,'.')
plt.xlabel('input dissimilarity values')
plt.ylabel('output prediction')
plt.title('training range: ' + str((np.min(dissimilarities_train).round(3), np.max(dissimilarities_train).round(3))))
plt.show()

print((np.min(dissimilarities_train), np.max(dissimilarities_train)))

## Metrics

In [None]:
# error('here')

In [None]:
# # MAKE SUMMARY PLOT ACROSS ALL MODEL TYPES
# rng_string = settings["savename_prefix"][settings["savename_prefix"].find('rng'):]
#
# plt.figure(figsize=(8, 4 * 3))
# for i_rng, rng_string in enumerate(("rng_seed_" + str(settings["rng_seed_list"][0]),
#                                     "rng_seed_" + str(settings["rng_seed_list"][1]),
#                                     "rng_seed_" + str(settings["rng_seed_list"][2]),
#                                     )):
#     # GET THE METRICS DATA
#     with open(dir_settings["metrics_directory"] + settings[
#         "exp_name"] + "_interp_model_" + rng_string + '_metrics.pickle', 'rb') as f:
#         plot_metrics = pickle.load(f)
#     with open(dir_settings["metrics_directory"] + settings[
#         "exp_name"] + '_ann_model_' + rng_string + '_metrics.pickle', 'rb') as f:
#         ann_metrics = pickle.load(f)
#     with open(dir_settings["metrics_directory"] + settings[
#         "exp_name"] + '_ann_analog_model_' + rng_string + '_metrics.pickle', 'rb') as f:
#         ann_analog_metrics = pickle.load(f)
#
#     # PLOT THE METRICS
#     plt.subplot(3, 1, i_rng + 1)
#
#     plots.summarize_skill_score(plot_metrics)
#
#     plot_ann_metrics = ann_metrics
#     y_plot = 1. - metrics.eval_function(plot_ann_metrics["error_network"]) / metrics.eval_function(
#         plot_ann_metrics["error_climo"])
#     plt.axhline(y=y_plot, linestyle='--', color="teal", alpha=.8, label="vanilla ann")
#
#     plot_ann_metrics = ann_analog_metrics
#     y_plot = 1. - metrics.eval_function(plot_ann_metrics["error_network"]) / metrics.eval_function(
#         plot_ann_metrics["error_climo"])
#     plt.plot(plot_ann_metrics["analogue_vector"], y_plot, '-', color="teal", alpha=.8, label="ann analogue")
#
#     plt.text(0.0, .99, ' ' + settings["exp_name"] + "_interp_model_" + rng_string + '\n smooth_time: ['
#              + str(settings["smooth_len_input"]) + ', ' + str(settings["smooth_len_output"]) + '], leadtime: '
#              + str(settings["lead_time"]),
#              fontsize=6, color="gray", va="top", ha="left", fontfamily="monospace",
#              transform=plt.gca().transAxes)
#     plt.grid(False)
#     plt.ylim(-.4, .4)
#     plt.legend(fontsize=6, loc=4)
#
#     plt.tight_layout()
#     plt.savefig(dir_settings["figure_directory"] + 'metric_summaries/' + settings["exp_name"]
#                 + "multiple_rng" + '_skill_score_vs_nanalogues.png',
#                 dpi=300, bbox_inches='tight')
# plt.show()

## Explore Case Studies

In [None]:
error('here')

In [None]:
imp.reload(plots)
import gc

CMAP = "RdBu_r"

n_analogues = 15
n_rows = 5 + 2#n_analogues + 2
n_testing_soi = soi_test_input.shape[0]
n_testing_analogs = analog_input.shape[0]
rng_eval = np.random.default_rng(settings["rng_seed"])
i_soi = rng_eval.choice(np.arange(0, soi_test_input.shape[0]), n_testing_soi, replace=False)
i_analog = rng_eval.choice(np.arange(0, analog_input.shape[0]), n_testing_analogs, replace=False)

x_input_test = [soi_test_input[i_soi,:,:,:], analog_input[i_analog,:,:,:]]
x_output_test = [soi_test_output[i_soi], analog_output[i_analog]]

# PLOT THE MASKS AND THEIR VARIANCE

# Training mean mask
for sample in (13, 44, 201):
    prediction_test = model.predict(
        [np.broadcast_to(x_input_test[0][sample:sample+1],
                         (x_input_test[1].shape[0],
                          x_input_test[1].shape[1],
                          x_input_test[1].shape[2],
                          x_input_test[1].shape[3])), x_input_test[1]
         ], batch_size=10_000,)
    __ = gc.collect()
    min_index = np.concatenate(np.argsort(prediction_test, axis=0))[:n_analogues]
    y_truth = str(np.round(x_output_test[0][sample],3))
    y_predict = str(np.round(x_output_test[1][min_index].mean(),3))
    print(y_truth, y_predict)

    for masked in (True, False):
        fig = plt.figure(figsize=(15, 4.5*n_rows), dpi=100)
        if masked:
            mask = weights_train.mean(axis=0)
            mask = np.where(mask<np.max(mask[:])*.15, 0., 1.)
            mask_channel0 = mask[:,:,0]
            mask_channel1 = mask[:,:,1]
        else:
            mask_channel0 = np.ones((weights_train.shape[1],weights_train.shape[2]))
            mask_channel1 = np.ones((weights_train.shape[1],weights_train.shape[2]))

        __, climits = plots.plot_interp_masks(fig, settings, mask_channel0*weights_train[:,:,:,0].mean(axis=0), lat=lat, lon=lon, central_longitude=215.,
                                              title_text=f"(a) Mask Channel 0", subplot=(n_rows, 2, 1))
        __, climits = plots.plot_interp_masks(fig, settings, mask_channel1*weights_train[:,:,:,1].mean(axis=0), lat=lat, lon=lon, central_longitude=215.,
                                              climits=climits, title_text=f"(b) Mask Channel 1", subplot=(n_rows, 2, 2))

        __, climits = plots.plot_interp_masks(fig, settings, mask_channel0*x_input_test[0][sample,:,:,0], lat=lat, lon=lon, central_longitude=215.,
                                              climits=(-3,3), title_text=f"(a) SOI #{sample}; Channel 0; Truth={y_truth}, Predicted={y_predict}", subplot=(n_rows, 2, 3), cmap=CMAP)
        __, climits = plots.plot_interp_masks(fig, settings, mask_channel1*x_input_test[0][sample,:,:,1], lat=lat, lon=lon, central_longitude=215.,
                                              climits=climits, title_text=f"(b) SOI #{sample}; Channel 1", subplot=(n_rows, 2, 4), cmap=CMAP)

        for i_analog in np.arange(0,n_rows-2):
            __, climits = plots.plot_interp_masks(fig, settings, mask_channel0*x_input_test[1][min_index[i_analog],:,:,0], lat=lat, lon=lon, central_longitude=215.,
                                                  climits=climits, title_text=f"(a) Analog #{min_index[i_analog]}; Channel 0", subplot=(n_rows, 2, 5+i_analog*2), cmap=CMAP)
            __, climits = plots.plot_interp_masks(fig, settings, mask_channel1*x_input_test[1][min_index[i_analog],:,:,1], lat=lat, lon=lon, central_longitude=215.,
                                                  climits=climits, title_text=f"(b) Analog #{min_index[i_analog]}; Channel 1", subplot=(n_rows, 2, 6+i_analog*2), cmap=CMAP)

        plt.tight_layout()
        plt.savefig(dir_settings["figure_directory"] + 'case_studies/' + settings["savename_prefix"]
                    + '_casestudymaps_mask' + str(masked) + '_sample' + str(sample) + '.png', dpi=300, bbox_inches='tight')
        # plt.show()
        plt.close()


## XAI

In [None]:
# LOAD THE TRAINED MODEL
tf.keras.backend.clear_session()
savename_prefix = (
        exp_name
        + "_" + "ann_model" + "_"
        + f"rng_seed_{settings['rng_seed']}"
)
settings["savename_prefix"] = savename_prefix

model = save_load_model_run.load_model(settings, settings["savename_prefix"], [soi_train_input, analog_input])

In [None]:
import xai
imp.reload(xai)
#---------------------------------------
# Gradient x Input
#---------------------------------------
# compute the multiplication of gradient * inputs
# and reshape into a map of latitude x longitude
top_pred_idx = 0
soi_input = soi_test_input
soi_output = soi_test_output

grads = xai.get_gradients(model,soi_input,top_pred_idx).numpy()
grad_x_input = grads * soi_input
# grad_x_input = grad_x_input.reshape((soi_input.shape[0],soi_input.shape[1],soi_input.shape[2]))
print(np.shape(grad_x_input))

grad_x_input = np.abs(grad_x_input)

fig = plt.figure(figsize=(15, 4.5*n_rows), dpi=100)
__, climits = plots.plot_interp_masks(fig, settings, grad_x_input[:,:,:,0].mean(axis=0), lat=lat, lon=lon, central_longitude=215., title_text=f"(a) XAI Channel 0", subplot=(n_rows, 2, 1),climits=(0, .0075))

__, climits = plots.plot_interp_masks(fig, settings, grad_x_input[:,:,:,1].mean(axis=0), lat=lat, lon=lon, central_longitude=215., title_text=f"(a) XAI Channel 1", subplot=(n_rows, 2, 2),climits=climits)

plt.tight_layout()
plt.savefig(dir_settings["figure_directory"] + 'case_studies/' + settings["savename_prefix"]
            + '_xai_mean.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# import model_diagnostics
# import metrics
#
# imp.reload(model_diagnostics)
# imp.reload(plots)
# imp.reload(metrics)
# rng_eval = np.random.default_rng(settings["rng_seed"])
#
# for n_testing_analogs in (1000,):
#     for n_testing_soi in (250,):
#         print('---' + str((n_testing_soi, n_testing_analogs)) + '---')
#
#         i_soi = rng_eval.choice(np.arange(0, soi_test_input.shape[0]), n_testing_soi, replace=False)
#         i_analog = rng_eval.choice(np.arange(0, analog_train_input.shape[0]), n_testing_analogs, replace=False)
#
#         metrics_dict = model_diagnostics.assess_metrics(settings, model,
#                                                         soi_test_input[i_soi, :, :, :],
#                                                         soi_test_output[i_soi],
#                                                         analog_train_input[i_analog, :, :, :],
#                                                         analog_train_output[i_analog],
#                                                         lat, lon,
#                                                         mask=np.mean(weights_train, axis=0)[np.newaxis, :, :, :],
#                                                         analogue_vector=[1, 2, 5, 10, 15, 20, 30, 50, 75],
#                                                         show_figure=True,
#                                                         save_figure=False,
#                                                         )
#
#         with open(dir_settings["metrics_directory"]+settings["savename_prefix"]+'_metrics_testing.pickle', 'wb') as f:
#             pickle.dump(metrics_dict, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import model_diagnostics_orig
import metrics
import importlib as imp
from timebudget import timebudget

imp.reload(model_diagnostics_orig)
imp.reload(plots)
imp.reload(metrics)

# metrics_dict = model_diagnostics.assess_metrics(settings, model,
#                                                 soi_train_input[:200,:,:,:],
#                                                 soi_train_output[:200],
#                                                 analog_input[:1000,:,:,:],
#                                                 analog_output[:1000],
#                                                 lat, lon,
#                                                 mask=np.mean(weights_train, axis=0)[np.newaxis, :, :, :],
#                                                 analogue_vector=[15,],
#                                                 show_figure=False,
#                                                 save_figure=False,
#                                                 )
metrics_dict = model_diagnostics_orig.assess_metrics(settings, model,
                                                soi_train_input[:500,:,:,:],
                                                soi_train_output[:500],
                                                analog_input[:500,:,:,:],
                                                analog_output[:500],
                                                lat, lon,
                                                mask=np.mean(weights_train, axis=0)[np.newaxis, :, :, :],
                                                analogue_vector=[15,],
                                                show_figure=False,
                                                save_figure=False,
                                                )
# metrics_dict = model_diagnostics.assess_metrics(settings, model,
#                                                 soi_train_input[:200,:,:,0:1],
#                                                 soi_train_output[:200],
#                                                 analog_input[:1000,:,:,0:1],
#                                                 analog_output[:1000],
#                                                 lat, lon,
#                                                 mask=np.ones(shape=(1,96,192,2)),
#                                                 analogue_vector=[15,],
#                                                 show_figure=False,
#                                                 save_figure=False,
#                                                 )

        # with open(dir_settings["metrics_directory"]+settings["savename_prefix"]+'_metrics_testing.pickle', 'wb') as f:
        #     pickle.dump(metrics_dict, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import model_diagnostics
import metrics
import importlib as imp
from timebudget import timebudget

imp.reload(model_diagnostics)
imp.reload(plots)
imp.reload(metrics)

metrics_dict = model_diagnostics.assess_metrics(settings, model,
                                                soi_train_input[:500,:,:,:],
                                                soi_train_output[:500],
                                                analog_input[:500,:,:,:],
                                                analog_output[:500],
                                                lat, lon,
                                                mask=np.mean(weights_train, axis=0)[np.newaxis, :, :, :],
                                                analogue_vector=[15,],
                                                show_figure=False,
                                                save_figure=False,
                                                )