In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.stats as st
import shap
import torch
import warnings
from torch import nn
from utils import generate_data
from synthetic_dataset import SyntheticDataset
from models import BaselineNet, MDN_Linear, MDN_Conv, SSN
from metadatanorm import MetadataNorm

In [None]:
# Plot training history of SSN with one-dimensional orthogonalisation trained for 500 epochs repeated 10 times
SSN_200_train, SSN_200_val = [], []
SSN_1000_train, SSN_1000_val = [], []
SSN_2000_train, SSN_2000_val = [], []
for i in range(1, 11):
    SSN_200_train.append(np.load('experiments/same_test/cf1/SSN/batch_size200/run{}/d2000.npy'.format(i)))
    SSN_200_val.append(np.load('experiments/same_test/cf1/SSN/batch_size200/run{}/val_acc_d2000.npy'.format(i)))
    SSN_1000_train.append(np.load('experiments/same_test/cf1/SSN/batch_size1000/run{}/d2000.npy'.format(i)))
    SSN_1000_val.append(np.load('experiments/same_test/cf1/SSN/batch_size1000/run{}/val_acc_d2000.npy'.format(i)))
    SSN_2000_train.append(np.load('experiments/plot/SSN/batch_size2000/run{}/d2000.npy'.format(i)))
    SSN_2000_val.append(np.load('experiments/plot/SSN/batch_size2000/run{}/val_acc_d2000.npy'.format(i)))
df_SSN_200_train = pd.DataFrame(SSN_200_train).T
df_SSN_200_val = pd.DataFrame(SSN_200_val).T
df_SSN_1000_train = pd.DataFrame(SSN_1000_train).T
df_SSN_1000_val = pd.DataFrame(SSN_1000_val).T
df_SSN_2000_train = pd.DataFrame(SSN_2000_train).T
df_SSN_2000_val = pd.DataFrame(SSN_2000_val).T

# Plot training history of SSN with one-dimensional orthogonalisation trained for 500 epochs repeated 10 times
fig, ax = plt.subplots(2, 3, sharey=True, figsize=(15, 8))
sns.lineplot(ax=ax[0,0], data=df_SSN_200_train, palette="tab10", linewidth=0.5, legend=False)
sns.lineplot(ax=ax[1,0], data=df_SSN_200_val, palette="tab10", linewidth=0.5, legend=False)
sns.lineplot(ax=ax[0,1], data=df_SSN_1000_train, palette="tab10", linewidth=0.5, legend=False)
sns.lineplot(ax=ax[1,1], data=df_SSN_1000_val, palette="tab10", linewidth=0.5, legend=False)
sns.lineplot(ax=ax[0,2], data=df_SSN_2000_train, palette="tab10", linewidth=0.5, legend=False)
sns.lineplot(ax=ax[1,2], data=df_SSN_2000_val, palette="tab10", linewidth=0.5, legend=False)
ax[0,0].axhline(y=5/6*100, color='r', linestyle='--')
ax[0,1].axhline(y=5/6*100, color='r', linestyle='--')
ax[0,2].axhline(y=5/6*100, color='r', linestyle='--')
ax[1,0].axhline(y=5/6*100, color='r', linestyle='--')
ax[1,1].axhline(y=5/6*100, color='r', linestyle='--')
ax[1,2].axhline(y=5/6*100, color='r', linestyle='--')
ax[0,0].set_ylim(50, 100)
ax[0,1].set_ylim(50, 100)
ax[0,2].set_ylim(50, 100)
ax[1,0].set_ylim(50, 100)
ax[1,1].set_ylim(50, 100)
ax[1,2].set_ylim(50, 100)
ax[1,0].set_xlabel('Epochs')
ax[1,1].set_xlabel('Epochs')
ax[1,2].set_xlabel('Epochs')
ax[0,0].set_ylabel('Accuracy (%)')
ax[1,0].set_ylabel('Accuracy (%)')
ax[0,0].set_title('Training Accuracy (Batch Size 200)')
ax[0,1].set_title('Training Accuracy (Batch Size 1000)')
ax[0,2].set_title('Training Accuracy (Batch Size 2000)')
ax[1,0].set_title('Validation Accuracy (Batch Size 200)')
ax[1,1].set_title('Validation Accuracy (Batch Size 1000)')
ax[1,2].set_title('Validation Accuracy (Batch Size 2000)')
plt.savefig('SSN_histories.pdf', dpi=1200)

In [None]:
# Calculate mean and 95% confidence interval of test accuracies (accuracies for the individual runs can be copied from the skmetrics.txt and skmetrics_f.txt files))
indices = [np.arange(8, 90, 9), np.arange(7, 89, 9)]
files = ['experiments/final/Baseline/batch_size200/skmetrics.txt',
         'experiments/final/Baseline/batch_size1000/skmetrics.txt',
         'experiments/final/Baseline/batch_size2000/skmetrics.txt',
         'experiments/final/Linear/batch_size200/skmetrics.txt',
         'experiments/final/Linear/batch_size1000/skmetrics.txt',
         'experiments/final/Linear/batch_size2000/skmetrics.txt',
         'experiments/final/Conv/batch_size200/skmetrics.txt',
         'experiments/final/Conv/batch_size1000/skmetrics.txt',
         'experiments/final/Conv/batch_size2000/skmetrics.txt',
         'experiments/final/cf1/SSN/batch_size200/skmetrics.txt',
         'experiments/final/cf1/SSN/batch_size1000/skmetrics.txt',
         'experiments/final/cf1/SSN/batch_size2000/skmetrics.txt',
         'experiments/final/cf2/SSN/batch_size200/skmetrics.txt',
         'experiments/final/cf2/SSN/batch_size1000/skmetrics.txt',
         'experiments/final/cf2/SSN/batch_size2000/skmetrics.txt',
         'experiments/final/Baseline/batch_size200/skmetrics_f.txt',
         'experiments/final/Baseline/batch_size1000/skmetrics_f.txt',
         'experiments/final/Baseline/batch_size2000/skmetrics_f.txt',
         'experiments/final/Linear/batch_size200/skmetrics_f.txt',
         'experiments/final/Linear/batch_size1000/skmetrics_f.txt',
         'experiments/final/Linear/batch_size2000/skmetrics_f.txt',
         'experiments/final/Conv/batch_size200/skmetrics_f.txt',
         'experiments/final/Conv/batch_size1000/skmetrics_f.txt',
         'experiments/final/Conv/batch_size2000/skmetrics_f.txt',
         'experiments/final/cf1/SSN/batch_size200/skmetrics_f.txt',
         'experiments/final/cf1/SSN/batch_size1000/skmetrics_f.txt',
         'experiments/final/cf1/SSN/batch_size2000/skmetrics_f.txt',
         'experiments/final/cf2/SSN/batch_size200/skmetrics_f.txt',
         'experiments/final/cf2/SSN/batch_size1000/skmetrics_f.txt',
         'experiments/final/cf2/SSN/batch_size2000/skmetrics_f.txt']
files_new = []
for i in range(len(files)//2):
    files_new.append(files[i])
    files_new.append(files[i+len(files)//2])  
results = {'Baseline_200_A': [], 'Baseline_200_B': [], 'Baseline_1000_A': [], 'Baseline_1000_B': [], 'Baseline_2000_A': [], 'Baseline_2000_B': [],
           'Linear_200_A': [], 'Linear_200_B': [], 'Linear_1000_A': [], 'Linear_1000_B': [], 'Linear_2000_A': [], 'Linear_2000_B': [],
           'Conv_200_A': [], 'Conv_200_B': [], 'Conv_1000_A': [], 'Conv_1000_B': [], 'Conv_2000_A': [], 'Conv_2000_B': [],
           'SSN1_200_A': [], 'SSN1_200_B': [], 'SSN1_1000_A': [], 'SSN1_1000_B': [], 'SSN1_2000_A': [], 'SSN1_2000_B': [],
           'SSN2_200_A': [], 'SSN2_200_B': [], 'SSN2_1000_A': [], 'SSN2_1000_B': [], 'SSN2_2000_A': [], 'SSN2_2000_B': []}
for i, key in enumerate(results.keys()):
    file = open(files_new[i], 'r')
    j = 0
    while True:
        line = file.readline()
        if not line:
            break
        if i % 2 == 0:
            if j in indices[0]:
                results[key].append(float(line.split()[2]))
        else:
            if j in indices[1]:
                results[key].append(float(line.split()[2]))
        j += 1

In [None]:
# Add PMDN results (run PMDN train.py file from command line to obtain results)
results['PMDN_200_A'] = [0.8885, 0.814, 0.6105, 0.9065, 0.9215, 0.881, 0.827, 0.8335, 0.9095, 0.6025]
results['PMDN_200_B'] = [0.7205, 0.728, 0.553, 0.7085, 0.626, 0.7725, 0.6815, 0.7555, 0.687, 0.573]
results['PMDN_1000_A'] = [0.7945, 0.789, 0.945, 0.7855, 0.897, 0.8275, 0.9215, 0.8255, 0.756, 0.5845]
results['PMDN_1000_B'] = [0.624, 0.645, 0.488, 0.602, 0.6335, 0.6135, 0.6135, 0.6055, 0.5785, 0.571]
results['PMDN_2000_A'] = [0.8155, 0.8585, 0.9385, 0.9415, 0.6095, 0.937, 0.844, 0.782, 0.8225, 0.627]
results['PMDN_2000_B'] = [0.478, 0.5025, 0.427, 0.455, 0.5165, 0.401, 0.4595, 0.5255, 0.625, 0.51]

In [None]:
# Calculate mean and 95% confidence interval of test accuracies stored in results dictionary and store them in a pandas dataframe
df = pd.DataFrame.from_dict(results)
df_mean = df.mean()
df_std = df.std()
df_conf = df_std/np.sqrt(len(df))
df_conf = 1.96*df_conf
df_mean = df_mean.to_frame()
df_mean.columns = ['mean']
df_conf = df_conf.to_frame()
df_conf.columns = ['conf']
# Concatenate mean and 95% confidence interval in one dataframe
df_mean_conf = pd.concat([df_mean, df_conf], axis=1)
df_mean_conf

In [None]:
# Generate training and test data for estimation of SHAP values
device = torch.device("cuda:1")
seed = 1234
N = 1000
labels, cf, _, _, x, y = generate_data(N, seed=seed)
labels_val, cf_val, _, _, x_val, y_val = generate_data(N, seed=seed+1)
x = np.swapaxes(x, 1, 3) # move channels after batch so we have (N, channels, h, w)
x_val = np.swapaxes(x_val, 1, 3)
trainset_size = 2 * N
X = np.zeros((N*2,3))
X[:,0] = labels
X[:,1] = cf
X[:,2] = np.ones((N*2,))
XTX = np.transpose(X).dot(X)
kernel = np.linalg.inv(XTX)
cf_kernel = nn.Parameter(torch.tensor(kernel).float().to(device), requires_grad=False)

In [None]:
warnings.filterwarnings("ignore")
# Batch size 200
batch_size = 200
train_set = SyntheticDataset(x, labels, cf)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size, shuffle=True, pin_memory=True)
val_set = SyntheticDataset(x_val, labels_val, cf_val)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=batch_size, shuffle=True, pin_memory=True)
train_batch = next(iter(train_loader))
train_data = train_batch['image'].float().to(device)
train_target = train_batch['label'].float().to(device)
train_cf_batch = train_batch['cfs'].float().to(device)
test_batch = next(iter(val_loader))
test_data = test_batch['image'].float().to(device)
test_target = test_batch['label'].float().to(device)
test_cf_batch = test_batch['cfs'].float().to(device)

# Calculate SHAP values for Baseline
Baseline_200_SHAP = 0
for i in range(1, 6):
    model = BaselineNet().to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Baseline/batch_size200/run{i}/best_model.pth'))
    model.eval()
    explainer = shap.DeepExplainer(model, train_data)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Baseline_200_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()
# MDN Linear
Linear_200_SHAP = 0
for i in range(1, 6):
    model = MDN_Linear(2*N, batch_size, cf_kernel).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Linear/batch_size200/run{i}/best_model.pth'))
    model.eval()
    X_batch = np.zeros((batch_size,3))
    X_batch[:,0] = train_target.cpu().detach().numpy()
    X_batch[:,1] = train_cf_batch.cpu().detach().numpy()
    X_batch[:,2] = np.ones((batch_size,))
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    X_batch_test = np.zeros((batch_size,3))
    X_batch_test[:,0] = test_target.cpu().detach().numpy()
    X_batch_test[:,1] = test_cf_batch.cpu().detach().numpy()
    X_batch_test[:,2] = np.ones((batch_size,))
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(X_batch).to(device), torch.Tensor(X_batch_test).to(device))), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Linear_200_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()   
# MDN Conv
Conv_200_SHAP = 0
for i in range(1, 6):
    model = MDN_Conv(2*N, batch_size, cf_kernel).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Conv/batch_size200/run{i}/best_model.pth'))
    model.eval()
    X_batch = np.zeros((batch_size,3))
    X_batch[:,0] = train_target.cpu().detach().numpy()
    X_batch[:,1] = train_cf_batch.cpu().detach().numpy()
    X_batch[:,2] = np.ones((batch_size,))
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    X_batch_test = np.zeros((batch_size,3))
    X_batch_test[:,0] = test_target.cpu().detach().numpy()
    X_batch_test[:,1] = test_cf_batch.cpu().detach().numpy()
    X_batch_test[:,2] = np.ones((batch_size,))
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(X_batch).to(device), torch.Tensor(X_batch_test).to(device))), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Conv_200_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()       
# SSN1
SSN1_200_SHAP = 0
for i in range(1, 6):
    model = SSN(batch_size, cf_dim=1, num_features=32).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/cf1/SSN/batch_size200/run{i}/best_model.pth'))
    model.eval()
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(train_cf_batch.float())[:,None], requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(train_cf_batch.float()), torch.Tensor(test_cf_batch.float())))[:,None], requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        SSN1_200_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()
# SSN2
SSN2_200_SHAP = 0
for i in range(1, 6):
    model = SSN(batch_size, cf_dim=2, num_features=32).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/cf2/SSN/batch_size200/run{i}/best_model.pth'))
    model.eval()
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.cat((torch.ones(len(train_cf_batch), 1).to(device), torch.Tensor(train_cf_batch.float())[:,None].to(device)), dim=1).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.cat((torch.ones(len(train_cf_batch), 1).to(device), torch.Tensor(train_cf_batch.float())[:,None].to(device)), dim=1), 
                                                           torch.cat((torch.ones(len(test_cf_batch), 1).to(device), torch.Tensor(test_cf_batch.float())[:,None].to(device)), dim=1)), dim=0).to(device), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        SSN2_200_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()
        

In [None]:
print('Baseline SHAP: ', Baseline_200_SHAP/5)
print('Linear SHAP: ', Linear_200_SHAP/5)
print('Conv SHAP: ', Conv_200_SHAP/5)
print('SSN1 SHAP: ', SSN1_200_SHAP/5)
print('SSN2 SHAP: ', SSN2_200_SHAP/5)

In [None]:
warnings.filterwarnings("ignore")
# Batch size 1000
batch_size = 1000
train_set = SyntheticDataset(x, labels, cf)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size, shuffle=True, pin_memory=True)
val_set = SyntheticDataset(x_val, labels_val, cf_val)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=batch_size, shuffle=True, pin_memory=True)
train_batch = next(iter(train_loader))
train_data = train_batch['image'].float().to(device)
train_target = train_batch['label'].float().to(device)
train_cf_batch = train_batch['cfs'].float().to(device)
test_batch = next(iter(val_loader))
test_data = test_batch['image'].float().to(device)
test_target = test_batch['label'].float().to(device)
test_cf_batch = test_batch['cfs'].float().to(device)

# Calculate SHAP values for Baseline
Baseline_1000_SHAP = 0
for i in range(1, 6):
    model = BaselineNet().to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Baseline/batch_size1000/run{i}/best_model.pth'))
    model.eval()
    explainer = shap.DeepExplainer(model, train_data)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Baseline_1000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()
# MDN Linear
Linear_1000_SHAP = 0
for i in range(1, 6):
    model = MDN_Linear(2*N, batch_size, cf_kernel).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Linear/batch_size1000/run{i}/best_model.pth'))
    model.eval()
    X_batch = np.zeros((batch_size,3))
    X_batch[:,0] = train_target.cpu().detach().numpy()
    X_batch[:,1] = train_cf_batch.cpu().detach().numpy()
    X_batch[:,2] = np.ones((batch_size,))
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    X_batch_test = np.zeros((batch_size,3))
    X_batch_test[:,0] = test_target.cpu().detach().numpy()
    X_batch_test[:,1] = test_cf_batch.cpu().detach().numpy()
    X_batch_test[:,2] = np.ones((batch_size,))
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(X_batch).to(device), torch.Tensor(X_batch_test).to(device))), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Linear_1000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()   
# MDN Conv
Conv_1000_SHAP = 0
for i in range(1, 6):
    model = MDN_Conv(2*N, batch_size, cf_kernel).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Conv/batch_size1000/run{i}/best_model.pth'))
    model.eval()
    X_batch = np.zeros((batch_size,3))
    X_batch[:,0] = train_target.cpu().detach().numpy()
    X_batch[:,1] = train_cf_batch.cpu().detach().numpy()
    X_batch[:,2] = np.ones((batch_size,))
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    X_batch_test = np.zeros((batch_size,3))
    X_batch_test[:,0] = test_target.cpu().detach().numpy()
    X_batch_test[:,1] = test_cf_batch.cpu().detach().numpy()
    X_batch_test[:,2] = np.ones((batch_size,))
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(X_batch).to(device), torch.Tensor(X_batch_test).to(device))), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Conv_1000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()       
# SSN1
SSN1_1000_SHAP = 0
for i in range(1, 6):
    model = SSN(batch_size, cf_dim=1, num_features=32).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/cf1/SSN/batch_size1000/run{i}/best_model.pth'))
    model.eval()
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(train_cf_batch.float())[:,None], requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(train_cf_batch.float()), torch.Tensor(test_cf_batch.float())))[:,None], requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        SSN1_1000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()
# SSN2
SSN2_1000_SHAP = 0
for i in range(1, 6):
    model = SSN(batch_size, cf_dim=2, num_features=32).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/cf2/SSN/batch_size1000/run{i}/best_model.pth'))
    model.eval()
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.cat((torch.ones(len(train_cf_batch), 1).to(device), torch.Tensor(train_cf_batch.float())[:,None].to(device)), dim=1).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.cat((torch.ones(len(train_cf_batch), 1).to(device), torch.Tensor(train_cf_batch.float())[:,None].to(device)), dim=1), 
                                                           torch.cat((torch.ones(len(test_cf_batch), 1).to(device), torch.Tensor(test_cf_batch.float())[:,None].to(device)), dim=1)), dim=0).to(device), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        SSN2_1000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()

In [None]:
print('Baseline SHAP: ', Baseline_1000_SHAP/5/5)
print('Linear SHAP: ', Linear_1000_SHAP/5/5)
print('Conv SHAP: ', Conv_1000_SHAP/5/5)
print('SSN1 SHAP: ', SSN1_1000_SHAP/5/5)
print('SSN2 SHAP: ', SSN2_1000_SHAP/5/5)

In [None]:
warnings.filterwarnings("ignore")
# Batch size 2000
batch_size = 2000
train_set = SyntheticDataset(x, labels, cf)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size, shuffle=True, pin_memory=True)
val_set = SyntheticDataset(x_val, labels_val, cf_val)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=batch_size, shuffle=True, pin_memory=True)
train_batch = next(iter(train_loader))
train_data = train_batch['image'].float().to(device)
train_target = train_batch['label'].float().to(device)
train_cf_batch = train_batch['cfs'].float().to(device)
test_batch = next(iter(val_loader))
test_data = test_batch['image'].float().to(device)
test_target = test_batch['label'].float().to(device)
test_cf_batch = test_batch['cfs'].float().to(device)

# Calculate SHAP values for Baseline
Baseline_2000_SHAP = 0
for i in range(1, 6):
    model = BaselineNet().to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Baseline/batch_size2000/run{i}/best_model.pth'))
    model.eval()
    explainer = shap.DeepExplainer(model, train_data)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Baseline_2000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()
# MDN Linear
Linear_2000_SHAP = 0
for i in range(1, 6):
    model = MDN_Linear(2*N, batch_size, cf_kernel).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Linear/batch_size2000/run{i}/best_model.pth'))
    model.eval()
    X_batch = np.zeros((batch_size,3))
    X_batch[:,0] = train_target.cpu().detach().numpy()
    X_batch[:,1] = train_cf_batch.cpu().detach().numpy()
    X_batch[:,2] = np.ones((batch_size,))
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    X_batch_test = np.zeros((batch_size,3))
    X_batch_test[:,0] = test_target.cpu().detach().numpy()
    X_batch_test[:,1] = test_cf_batch.cpu().detach().numpy()
    X_batch_test[:,2] = np.ones((batch_size,))
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(X_batch).to(device), torch.Tensor(X_batch_test).to(device))), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Linear_2000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()   
# MDN Conv
Conv_2000_SHAP = 0
for i in range(1, 6):
    model = MDN_Conv(2*N, batch_size, cf_kernel).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/Conv/batch_size2000/run{i}/best_model.pth'))
    model.eval()
    X_batch = np.zeros((batch_size,3))
    X_batch[:,0] = train_target.cpu().detach().numpy()
    X_batch[:,1] = train_cf_batch.cpu().detach().numpy()
    X_batch[:,2] = np.ones((batch_size,))
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    X_batch_test = np.zeros((batch_size,3))
    X_batch_test[:,0] = test_target.cpu().detach().numpy()
    X_batch_test[:,1] = test_cf_batch.cpu().detach().numpy()
    X_batch_test[:,2] = np.ones((batch_size,))
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(X_batch).to(device), torch.Tensor(X_batch_test).to(device))), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        Conv_2000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()       
# SSN1
SSN1_2000_SHAP = 0
for i in range(1, 6):
    model = SSN(batch_size, cf_dim=1, num_features=32).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/cf1/SSN/batch_size2000/run{i}/best_model.pth'))
    model.eval()
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.Tensor(train_cf_batch.float())[:,None], requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.Tensor(train_cf_batch.float()), torch.Tensor(test_cf_batch.float())))[:,None], requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        SSN1_2000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()
# SSN2
SSN2_2000_SHAP = 0
for i in range(1, 6):
    model = SSN(batch_size, cf_dim=2, num_features=32).to(device)
    model.load_state_dict(torch.load(f'plot_data/experiments/final/cf2/SSN/batch_size2000/run{i}/best_model.pth'))
    model.eval()
    with torch.no_grad():
        model.cfs = nn.Parameter(torch.cat((torch.ones(len(train_cf_batch), 1).to(device), torch.Tensor(train_cf_batch.float())[:,None].to(device)), dim=1).to(device), requires_grad=False)
    explainer = shap.DeepExplainer(model, train_data)
    explainer.explainer.model.cfs = nn.Parameter(torch.cat((torch.cat((torch.ones(len(train_cf_batch), 1).to(device), torch.Tensor(train_cf_batch.float())[:,None].to(device)), dim=1), 
                                                           torch.cat((torch.ones(len(test_cf_batch), 1).to(device), torch.Tensor(test_cf_batch.float())[:,None].to(device)), dim=1)), dim=0).to(device), requires_grad=False)
    shap_values = explainer.shap_values(test_data)
    for i in range(shap_values.shape[0]):
        SSN2_2000_SHAP += np.abs(shap_values[i,:,:16,16:]).mean()

In [None]:
print('Baseline SHAP: ', Baseline_2000_SHAP/5/10)
print('Linear SHAP: ', Linear_2000_SHAP/5/10)
print('Conv SHAP: ', Conv_2000_SHAP/5/10)
print('SSN1 SHAP: ', SSN1_2000_SHAP/5/10)
print('SSN2 SHAP: ', SSN2_2000_SHAP/5/10)