This notebook is used to 
- compute the SHAP value change before and after transfer learning 
- plot the SHAP value maps
- notice that the prob in the title is that f(x) becomes committor.

In [124]:
import xarray as xr
import numpy as np
import cartopy
from cartopy import crs as ccrs
import matplotlib 
matplotlib.rcParams["font.size"] = 12
from matplotlib import pyplot as plt
from os.path import join, exists
from os import mkdir
import scipy
import netCDF4
import sklearn
import sys
import matplotlib.ticker as mticker
from cartopy.mpl.ticker import (LongitudeFormatter, LatitudeFormatter,
                                LatitudeLocator, LongitudeLocator)
%matplotlib inline
import matplotlib.path as mpath
import importlib.util
import MM_util_AI
import MM_utilplot
import warnings
import pickle

warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
import shap
import tensorflow as tf
from sklearn.metrics import confusion_matrix,recall_score,precision_score
import matplotlib.colors as colors
import os

spec = importlib.util.spec_from_file_location("MM_dataprepare", \
                        "/scratch/hz1994/blocking/MMmodel/MMmodel/notebooks/MM_dataprepare.py")
MM_dataprepare = importlib.util.module_from_spec(spec)
sys.modules["module.name"] = MM_dataprepare
spec.loader.exec_module(MM_dataprepare)

spec = importlib.util.spec_from_file_location("MM_utilblocking", \
                        "/scratch/hz1994/blocking/MMmodel/MMmodel/notebooks/MM_utilblocking.py")
MM_utilblocking = importlib.util.module_from_spec(spec)
sys.modules["module.name"] = MM_utilblocking
spec.loader.exec_module(MM_utilblocking)

 
            
with open("/scratch/hz1994/blocking/data_MMmodel/filepath.txt","r") as fi:
    for ln in fi:
        if ln.startswith("dimensionalized_filepath"):
            dim_path=ln.strip().split('\t')[1]
        if ln.startswith("nondimensionalized_filepath"):
            nondim_path=ln.strip().split('\t')[1]
        if ln.startswith("code_filepath"):
            code_path=ln.strip().split('\t')[1]            
        if ln.startswith("DGindex_filepath"):
            DGindex_path=ln.strip().split('\t')[1]  
        if ln.startswith("conditionedT_filepath" ):
            train_path=ln.strip().split('\t')[1]
        if ln.startswith("model_filepath" ):
            models_path=ln.strip().split('\t')[1]
        if ln.startswith("fig_filepath" ):
            fig_path=ln.strip().split('\t')[1] 
print(dim_path)
print(nondim_path)
print(code_path)
print(DGindex_path)
print(train_path)
print(models_path)
print(fig_path)
train_path_setA=train_path+'T/'
models_path_setA=models_path+'T/'
import tensorflow as tf

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

font = {'family' : 'sans-serif',
        'weight' : 'regular',
        'size'   : 16}
plt.rc('font', **font)
plt.rcParams['axes.linewidth'] = 1.5

from sklearn.metrics import confusion_matrix,recall_score,precision_score
def test_model(model,weightpath,test_data,test_classlabels,test_labels):
    model.load_weights(weightpath)
    predictions = model.predict(test_data)
    pred=(predictions[:,0]<predictions[:,1])
    bce=tf.keras.losses.BinaryCrossentropy(from_logits=True)
    loss=bce(test_classlabels, predictions).numpy()
    TN,FP,FN,TP = confusion_matrix(test_labels, pred).flatten()
    recall=recall_score(test_labels, pred , labels=1 )
    precision=precision_score(test_labels, pred , labels=1 )
    return np.array([TN,FP,FN,TP]),recall,precision,loss

def polorplot(ax,data_xr,max_abs,levels,iv=0.02):
    norm = colors.TwoSlopeNorm(vmin=-max_abs, vcenter=0, vmax=max_abs)
    im=xr.plot.contourf( 
        data_xr,
        x="longitude", y="latitude", ax=ax,transform=ccrs.PlateCarree(),cmap='coolwarm',\
        levels=levels, add_colorbar=False,norm=norm
    )
    gl=ax.gridlines(draw_labels=False)
    gl.ylocator = mticker.FixedLocator([20,50,60,70])  
    ax.coastlines()
    
    return ax,im

def polorplot_levels(plotmap,latitudes,longitudes ,minval,maxval,label="SHAP values", number_levels=30,iv=0.02):
    fig,ax = plt.subplots(figsize=(9,3), 
            subplot_kw={'projection': ccrs.NorthPolarStereo()},ncols=3)
    titles=["Z200","Z500","Z800"]
    max_abs=max(abs(minval),abs(maxval))
    print("min plotmap=", plotmap.min() ,"max plotmap=", plotmap.max() ,)
    for i in range(3):
        y=plotmap[:,:,i]
        a = xr.DataArray(y, 
            coords={'latitude':latitudes,'longitude': longitudes,}, 
            dims=["latitude","longitude",])
        ax[i],im =polorplot( ax[i],a, max_abs=max_abs,\
                            levels=np.linspace(minval,maxval,10 ), iv=iv)
        ax[i].set_title(titles[i])
        
    cbar_ax = fig.add_axes([0.05, -0.1, .9, .05]) #left, bottom, width, height
    cbar = fig.colorbar(im, cax=cbar_ax, orientation="horizontal",cmap='coolwarm',\
                        ticks= np.arange(iv*int(minval/iv),iv*int(maxval/iv)+iv,iv),\
                        label = label, shrink = 1,)
    print( np.linspace(minval,maxval,10 ))
#     plt.subplots_adjust(wspace=0.1,width_ratios=[1,1,1])
    fig.tight_layout()
    return fig,ax

/scratch/hz1994/blocking/data_MMmodel/dim/
/scratch/hz1994/blocking/data_MMmodel/nondim/
/scratch/hz1994/blocking/MMmodel/MMmodel/code_Lucarini/
/scratch/hz1994/blocking/data_MMmodel/DGindex/
/scratch/hz1994/blocking/data_MMmodel/conditionT/
/scratch/hz1994/blocking/data_MMmodel/CNNmodels/
/scratch/hz1994/blocking/data_MMmodel/fig_MMmodel/


# Prepare data

In [5]:
Duration=5
cnnsize="normal"
data_amount="1000.0k"
random_seed=30
epsilon=0
trainable_layer_number=7
learning_rate=0.0001  
epoch=2
epoch_TL=4  


X=np.load("/scratch/hz1994/blocking/data_era5/"+"test_data_1940-2022.npy")  
Ysparse=np.load("/scratch/hz1994/blocking/data_era5/"+"test_labels_1940-2022_T%d.npy"%Duration)
Y=np.zeros((Ysparse.size,2)).astype(bool)
Y[:,1][Ysparse==1]=True  #blocking
Y[:,0][Ysparse==0]=True
latitudes = np.load(dim_path+'dataX_lat.npy')
longitudes = np.load(dim_path+'dataX_lon.npy')
X_dim=X[0].size

# Load model before transfer learning

In [6]:
tf.config.experimental_run_functions_eagerly(False)
if cnnsize=="smaller":
    base_model = MM_util_AI.make_s_model((18,90,3))
    TL_model = MM_util_AI.make_s_model((18,90,3))
elif cnnsize=="smaller_smaller":
    base_model = MM_util_AI.make_ss_model((18,90,3))
    TL_model = MM_util_AI.make_ss_model((18,90,3))
elif cnnsize=="normal":
    base_model = MM_util_AI.make_model((18,90,3))
    TL_model = MM_util_AI.make_model((18,90,3))


# Pick the data we want to make the plot: TP  

In [11]:
TP_consistent=np.ones(Y.shape[0]).astype(bool)
for random_seed in range(30,40):
    path="/scratch/hz1994/blocking/data_MMmodel/CNNmodels/T/era5_retrainCNN/extreme_%ddaysblocking/trained_layer_%s/learning_rate_%.4f/"\
            %( Duration,trainable_layer_number,learning_rate)+    "data_"+data_amount+"_"+cnnsize\
            +"_cnn_"+"regularize"+"_%.3e"%epsilon+"_rs_%d"%random_seed+"epoch_%d/"%epoch
    
    for num in range(10): 
        weightpath=path+"%d/"%num +"cp-%04d.ckpt"%epoch_TL 
        TL_model.load_weights(weightpath)
        predictions = TL_model.predict(X,verbose=None)
        pred_TL=(predictions[:,0]<predictions[:,1])
        TP_consistent=np.logical_and(TP_consistent,np.logical_and(pred_TL, Ysparse))

        weightpath=path+"%d/"%num +"cp-%04d.ckpt"%0 
        TL_model.load_weights(weightpath)
        predictions = TL_model.predict(X,verbose=None)
        pred_TL=(predictions[:,0]<predictions[:,1])
        TP_consistent=np.logical_and(TP_consistent,np.logical_and(pred_TL, Ysparse))

In [None]:
plot1=[]
plot2=[]
for random_seed in range(30,40):
    path="/scratch/hz1994/blocking/data_MMmodel/CNNmodels/T/era5_retrainCNN/extreme_%ddaysblocking/trained_layer_%s/learning_rate_%.4f/"\
            %( Duration,trainable_layer_number,learning_rate)+    "data_"+data_amount+"_"+cnnsize\
            +"_cnn_"+"regularize"+"_%.3e"%epsilon+"_rs_%d"%random_seed+"epoch_%d/"%epoch
    
    for num in range(10): 
        init_shap=np.load(path+"%d/"%num+"shapvalue-%04d_prob.npy"%0)[TP_consistent,1,0]
        TL_shap=np.load(path+"%d/"%num+"shapvalue-%04d_prob.npy"%epoch_TL )[TP_consistent,1,0]
        plot1.append(init_shap)
        plot2.append(TL_shap)   
plot1_mean=plot1.mean(axis=(0,1))
fig1,ax=polorplot_levels(plot1,latitudes,longitudes   ,minval=-0.001,maxval=0.003,iv=0.0005)

# Choose the non-overfitting ones

In [86]:
a=[]
b=[]
for i in range(100):
    a.append(np.array(plot2)[i].mean(axis=0).max())
    b.append(np.array(plot2)[i].mean(axis=0).min())
a=np.array(a)
b=np.array(b)
non_overfit=np.logical_and((abs(a)<0.003) , (abs(b)<0.0016)) # 0.003 and 0.0016 are 80% percentile of absolute value of max and mins. 

# Save data

In [132]:
plot1_mean=np.array(plot1)[non_overfit].mean(axis=(0,1))
plot1_mean_normalize=plot1_mean/np.sum(plot1_mean)
plot2_mean=np.array(plot2)[non_overfit].mean(axis=(0,1))
plot2_mean_normalize=plot2_mean/np.sum(plot2_mean) 
plot3=(np.maximum(plot2_mean_normalize,0)-np.maximum(plot1_mean_normalize,0) )*X_dim
plot3_2=( plot2_mean_normalize - plot1_mean_normalize  )*X_dim
np.save("shap_before_Fine_tuning",plot1_mean)
np.save("shap_after_Fine_tuning",plot2_mean)
np.save("normalized_shap_difference_max",plot3)
np.save("normalized_shap_difference",plot3_2)

