In [None]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import shap 
import os
import matplotlib.gridspec as gridspec
import pickle # for saving to file
import scipy.stats as stats # for 95% CI
from matplotlib.lines import Line2D
import matplotlib as mlp
from time import time
from torchvision import datasets, transforms
from torch import nn, optim
from torchview import draw_graph
from sklearn.metrics import matthews_corrcoef
from torchmetrics.regression import SpearmanCorrCoef
from torch.utils.data import DataLoader
import seaborn as sns

from models_SSN import Baseline, SemiStructuredNet, IRM
from datasets_SSN import create_dataloader, prepare_data, plot_colored_mnist, make_environment, ColoredMNIST

# Set random seed for reproducibility
torch.manual_seed(0)

# Set default device
#torch.set_default_device("cpu")
# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda:1")
#elif torch.backends.mps.is_available():
#    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
# https://medium.com/analytics-vidhya/multiclass-image-classification-with-pytorch-af7578e10ee6

In [None]:
# Define functions for model fitting and evaluation as well as plotting
def evaluate(model, valloader, model_type, device, cf_dim=2):
    model.eval()
    outputs = [model.validation_step(batch, model_type, cf_dim, device) for batch in valloader]
    return model.validation_epoch_end(outputs)

def fit(model, optimizer, trainloader, valloader, testloader, model_type, batch_size=100, cf_dim=2, num_features=32, epochs=50, device=torch.device("cpu")):
    train_history = []
    best_val_loss = float('inf')
    patience = 5
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        for batch in trainloader:
            if model_type == 'SSN':
                cfs = batch['colors'].to(device)
                model.set_confounders(cfs, cf_dim, device)
            optimizer.zero_grad()
            pred, loss = model.training_step(batch, model_type, cf_dim, device)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
        if model_type == 'SSN':
            model.set_delta(trainloader, cf_dim, num_features, device)
        # Internal Validation
        model.eval()
        result = evaluate(model, valloader, model_type, device, cf_dim)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        train_history.append(result)
        # External Validation
        test = evaluate(model, testloader, model_type, device, cf_dim)
        train_history[-1]['test_acc'] = test['val_acc']
        # Early stopping
        if result['val_loss'] < best_val_loss:
            torch.save(model.state_dict(), f'fitted_models/ColoredMNIST/best-model-{model_type}-{batch_size}.pth')
            print(f"Model saved: {result['val_loss']}")
            best_val_loss = result['val_loss']
            patience = 5
        else:
            patience -= 1
            if patience == 0:
                model.load_state_dict(torch.load(f'fitted_models/ColoredMNIST/best-model-{model_type}-{batch_size}.pth'))
                break
    return train_history

def plot_train_history(history, model, figsize):
    acc_val = [x['val_acc'] for x in history] 
    acc_test = [x['test_acc'] for x in history]
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x['val_loss'] for x in history]
    mlp.style.use('default')
    if model == 'Baseline':
        fig = plt.figure(figsize=figsize)
        gs = gridspec.GridSpec(nrows=2, ncols=2)
        ax0 = fig.add_subplot(gs[0, 0]) 
        ax0.plot(acc_val, '-r')
        ax0.set_ylabel('Accuracy')
        ax0.set_title('Validation and Test Accuracy')
        ax0.set_ylim(min(acc_val)*.99, max(acc_val)*1.01)
        ax0.spines['bottom'].set_visible(False)
        ax0.xaxis.tick_top()
        ax0.xaxis.set_major_locator(plt.NullLocator())
        ax0.tick_params(labeltop=False)
        ax1 = fig.add_subplot(gs[1, 0])
        ax1.plot(acc_test, '-g')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Accuracy')
        ax1.set_ylim(min(acc_test)*.99, max(acc_test)*1.01)
        ax1.spines['top'].set_visible(False)
        ax1.xaxis.tick_bottom()
        custom_lines = [Line2D([0], [0], color='r', lw=4),
                        Line2D([0], [0], color='g', lw=4)]
        ax0.legend(custom_lines, ['Validation', 'Test'], frameon=False)
        ax2 = fig.add_subplot(gs[:, 1])
        ax2.plot(train_losses, '-b')
        ax2.plot(val_losses, '-r')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.yaxis.set_label_position("right")
        ax2.yaxis.tick_right()
        ax2.set_title('Training and Validation Loss')
        custom_lines = [Line2D([0], [0], color='b', lw=4),
                        Line2D([0], [0], color='r', lw=4)]
        ax2.legend(custom_lines, ['Training', 'Validation'], frameon=False)
        d = .015
        kwargs = dict(transform=ax0.transAxes, color='k', clip_on=False)
        ax0.plot((-d, +d), (-d, +d), **kwargs)
        ax0.plot((1 - d, 1 + d), (-d, +d), **kwargs)
        kwargs.update(transform=ax1.transAxes)
        ax1.plot((-d, +d), (1 - d, 1 + d), **kwargs)
        ax1.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
    else:
        fig, axs = plt.subplots(1, 2, figsize=figsize)
        axs[0].plot(acc_val, '-r')
        axs[0].plot(acc_test, '-g')
        axs[0].set_xlabel('Epoch')
        axs[0].set_ylabel('Accuracy')
        custom_lines = [Line2D([0], [0], color='r', lw=4),
                        Line2D([0], [0], color='g', lw=4)]
        axs[0].legend(custom_lines, ['Validation', 'Test'], frameon=False)
        axs[0].set_title('Validation and Test Accuracy')
        axs[1].plot(train_losses, '-b')
        axs[1].plot(val_losses, '-r')
        axs[1].set_xlabel('Epoch')
        axs[1].set_ylabel('Loss')
        axs[1].yaxis.set_label_position("right")
        axs[1].yaxis.tick_right()
        axs[1].set_title('Training and Validation Loss')
        custom_lines = [Line2D([0], [0], color='b', lw=4),
                        Line2D([0], [0], color='r', lw=4)]
        axs[1].legend(custom_lines, ['Training', 'Validation'], frameon=False)
    plt.savefig(f'plots/ColoredMNIST/training_history_{model}.pdf', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
# Create dataloaders
preparation = prepare_data()
dataloader = create_dataloader(**preparation, batch_size=100)

In [None]:
# Generate plots for ColouredMNIST samples
plot_colored_mnist(next(iter(dataloader['train'])), 0.2)
plot_colored_mnist(next(iter(dataloader['test'])), 0.9)

In [None]:
# Calculate correlation between true labels, flipped labels and colours for one batch of data
# Training set
labels = torch.tensor([])
true_labels = torch.tensor([])
colors = torch.tensor([])
for batch in dataloader['train']:
    labels = torch.cat((labels, batch['labels']), 0)
    true_labels = torch.cat((true_labels, batch['true_labels']), 0)
    colors = torch.cat((colors, batch['colors']), 0)
corrdata = pd.DataFrame({'Flipped Labels': labels.T.int().numpy().squeeze(), 'True Labels': true_labels.T.int().numpy().squeeze(), 'Colours': colors.T.int().numpy().squeeze()})
# Test set
labels_test = torch.tensor([])
true_labels_test = torch.tensor([])
colors_test = torch.tensor([])
for batch in dataloader['test']:
    labels_test = torch.cat((labels_test, batch['labels']), 0)
    true_labels_test = torch.cat((true_labels_test, batch['true_labels']), 0)
    colors_test = torch.cat((colors_test, batch['colors']), 0)
corrdata_test = pd.DataFrame({'Flipped Labels': labels_test.T.int().numpy().squeeze(), 'True Labels': true_labels_test.T.int().numpy().squeeze(), 'Colors': colors_test.T.int().numpy().squeeze()})

In [None]:
# Plot correlation heatmaps
mask = np.triu(np.ones_like(corrdata.corr(), dtype=np.bool_))
training_corr = sns.heatmap(corrdata.corr(), mask=mask, vmin=-1, vmax=1, annot=True, cmap='BrBG')
fig = training_corr.get_figure()
fig.savefig('plots/ColoredMNIST/training_corr.pdf', bbox_inches='tight')

In [None]:
# Plot correlation heatmaps
mask = np.triu(np.ones_like(corrdata_test.corr(), dtype=np.bool_))
testing_corr = sns.heatmap(corrdata_test.corr(), mask=mask, vmin=-1,vmax=1, annot=True, cmap='BrBG')
fig = testing_corr.get_figure()
fig.savefig('plots/ColoredMNIST/testing_corr.pdf', bbox_inches='tight')

In [None]:
# Train and evaluate Baseline model
device = torch.device("cpu")
BaselineModel = Baseline().to(device)
BaselineOptimizer = optim.Adam(BaselineModel.parameters(), lr=0.001, weight_decay=0.0001)
print('Training Baseline model...')
Baseline_train_history = fit(BaselineModel, BaselineOptimizer, 
                             dataloader['train'], dataloader['val'], dataloader['test'], 
                             batch_size=100, model_type='Baseline', epochs=50)
print('\nEvaluating Baseline model...')
Baseline_test = evaluate(BaselineModel, dataloader['test'], model_type='Baseline', device=device)
print(f'Testing Loss: {Baseline_test["val_loss"]:.4f}, Testing Accuracy: {Baseline_test["val_acc"]:.2%}')

In [None]:
# Visualize the model architecture
model_graph = draw_graph(BaselineModel, input_size=(100, 2*14*14), device ='meta', graph_dir='TB', hide_module_functions=True, 
                         save_graph=True, directory='plots/ColoredMNIST', filename='Baseline_model_graph.pdf')
model_graph.visual_graph

In [None]:
# Plot results of training Baseline model
plot_train_history(Baseline_train_history, model='Baseline', figsize=(8, 4))

In [None]:
# Train and evaluate Semi-Structured model
device = torch.device("cpu")
SemiStructuredModel = SemiStructuredNet(batch_size=100, cf_dim=2, num_features=32).to(device)
SemiStructuredModelOptimizer = optim.Adam(SemiStructuredModel.parameters(), lr=0.001, weight_decay=0.0001)
print('Training Semi-Structured model...')
SemiStructredModel_train_history = fit(SemiStructuredModel, SemiStructuredModelOptimizer, 
                                       dataloader['train'], dataloader['val'], dataloader['test'], 
                                       batch_size=100, model_type='SSN', cf_dim=2, num_features=32, epochs=50, device=device)
print('\nEvaluating Semi-Structured model...')
SemiStructured_test = evaluate(SemiStructuredModel, dataloader['test'], model_type='SSN', device=device)
print(f'Testing Loss: {SemiStructured_test["val_loss"]:.4f}, Testing Accuracy: {SemiStructured_test["val_acc"]:.2%}')

In [None]:
# Plot results of training Semi-Structured model
plot_train_history(SemiStructredModel_train_history, model='SSN', figsize=(8, 4))

In [None]:
# Get IRM results as benchmark
os.system(f'python IRM.py --hidden_dim=390 --l2_regularizer_weight=0.00110794568 --lr=0.0004898536566546834 --penalty_anneal_iters=190 --penalty_weight=91257.18613115903 --steps=501')

In [None]:
# Investigate model performance for different batch sizes
device = torch.device("cpu")
batch_sizes = [50, 100, 1000, 10000]
train_history = {} # Dictionary for storing training history for different batch sizes
test_history = {} # Dictionary for storing test results for different batch sizes
for run in range(10):
    train_history[run], test_history[run] = {}, {} # Initialize dictionaries for current run
    train_history[run]['Baseline'], test_history[run]['Baseline'] = [], [] # Initialize lists for storing training history and test results for Baseline model
    train_history[run]['SSN'], test_history[run]['SSN'] = [], [] # Initialize lists for storing training history and test results for Semi-Structured model
    print(f'Run {run+1}...')
    preparation = prepare_data() # Prepare data for given batch size
    for batch in batch_sizes:
        dataloader = create_dataloader(**preparation, batch_size=batch) # Create dataloaders for given batch size
        # Create and train models
        BaselineModel = Baseline().to(device)
        BaselineOptimizer = optim.Adam(BaselineModel.parameters(), lr=0.001, weight_decay=0.0001)
        print(f'Training Baseline model with batch size {batch}...')
        train_history[run]['Baseline'].append(fit(BaselineModel, BaselineOptimizer, 
                                                  dataloader['train'], dataloader['val'], dataloader['test'], 
                                                  batch_size=batch, model_type='Baseline', epochs=50, device=device))
        test_history[run]['Baseline'].append(evaluate(BaselineModel, dataloader['test'], model_type='Baseline', device=device))
        SemiStructuredModel = SemiStructuredNet(batch_size=batch, cf_dim=2, num_features=32).to(device)
        SemiStructuredModelOptimizer = optim.Adam(SemiStructuredModel.parameters(), lr=0.001, weight_decay=0.0001)
        print(f'Training Semi-Structured model with batch size {batch}...')
        train_history[run]['SSN'].append(fit(SemiStructuredModel, SemiStructuredModelOptimizer, 
                                             dataloader['train'], dataloader['val'], dataloader['test'], 
                                             batch_size=batch, model_type='SSN', cf_dim=2, num_features=32, epochs=50, device=device))
        test_history[run]['SSN'].append(evaluate(SemiStructuredModel, dataloader['test'], model_type='SSN', device=device))
print('Done!')

# Save results to file
with open('results/ColoredMNIST/train_history.pkl', 'wb') as f:
    pickle.dump(train_history, f)
    print('Training history saved to file.')
with open('results/ColoredMNIST/test_history.pkl', 'wb') as f:
    pickle.dump(test_history, f)
    print('Test history saved to file.')

In [None]:
# Load results from file
with open('results/ColoredMNIST/train_history.pkl', 'rb') as f:
    train_history = pickle.load(f)
with open('results/ColoredMNIST/test_history.pkl', 'rb') as f:
    test_history = pickle.load(f)
batch_sizes = [50, 100, 1000, 10000]

In [None]:
# Display results of training for different batch sizes in a table
batch_results = pd.DataFrame(columns=['Model', 'Batch Size', 'Final Train Loss', 'Final Train Accuracy', 'Test Loss', 'Test Accuracy'])
for model in ['Baseline', 'SSN']:
    for i, batch in enumerate(batch_sizes):
        train_losses = []
        test_losses = []
        train_accs = []
        test_accs = []
        for run in range(10):
            train_losses.append(train_history[run][model][i][-6]['train_loss'])
            test_losses.append(test_history[run][model][i]['val_loss'])
            train_accs.append(train_history[run][model][i][-6]['val_acc'])
            test_accs.append(test_history[run][model][i]['val_acc'])
        # Calculate mean and 95% confidence interval for train and test losses and accuracies
        train_loss_mean = np.mean(train_losses)
        train_loss_ci = stats.t.interval(0.95, len(train_losses)-1, loc=train_loss_mean, scale=stats.sem(train_losses))
        test_loss_mean = np.mean(test_losses)
        test_loss_ci = stats.t.interval(0.95, len(test_losses)-1, loc=test_loss_mean, scale=stats.sem(test_losses))
        train_acc_mean = np.mean(train_accs)
        train_acc_ci = stats.t.interval(0.95, len(train_accs)-1, loc=train_acc_mean, scale=stats.sem(train_accs))
        test_acc_mean = np.mean(test_accs)
        test_acc_ci = stats.t.interval(0.95, len(test_accs)-1, loc=test_acc_mean, scale=stats.sem(test_accs))
        # Add results to table
        batch_results = batch_results.append({'Model': model, 'Batch Size': batch, 
                                              'Final Train Loss':  f'{train_loss_mean:.4f}' + u"\u00B1" + f'{train_loss_ci[1]-train_loss_mean:.2f}', 
                                              'Final Train Accuracy': f'{train_acc_mean:.2%}' + u"\u00B1" + f'{train_acc_ci[1]-train_acc_mean:.2%}', 
                                              'Test Loss': f'{test_loss_mean:.4f}' + u"\u00B1" + f'{test_loss_ci[1]-test_loss_mean:.2f}', 
                                              'Test Accuracy': f'{test_acc_mean:.2%}' + u"\u00B1" + f'{test_acc_ci[1]-test_acc_mean:.2%}'}, ignore_index=True)
batch_results

In [None]:
# Exact orthogonalization with full batch size
data = prepare_data()
trainset = data['trainset']
testset = data['testset']
mnist_train = (trainset.data, trainset.targets) # (images, labels)
mnist_test = (testset.data, testset.targets) # (images, labels)
train_val_split = 40000 # Split the training set into training and validation set
envs = [
    make_environment(mnist_train[0], mnist_train[1], 0.1, 2, True), # Environment for training and validation with 20% of the colors flipped
    make_environment(mnist_test[0], mnist_test[1], 0.9, 2, True) # Environment for testing with 90% of the colors flipped
]
traindata = ColoredMNIST(envs[0]['images'][:train_val_split], 
                        envs[0]['labels'][:train_val_split], 
                        envs[0]['colors'][:train_val_split],
                        envs[0]['true_labels'][:train_val_split]) # Create ColoredMNIST dataset training set
valdata = ColoredMNIST(envs[0]['images'][train_val_split:], 
                    envs[0]['labels'][train_val_split:], 
                    envs[0]['colors'][train_val_split:],
                    envs[0]['true_labels'][train_val_split:]) # Create ColoredMNIST dataset validation set
testdata = ColoredMNIST(envs[1]['images'], 
                        envs[1]['labels'], 
                        envs[1]['colors'],
                        envs[1]['true_labels']) # Create ColoredMNIST dataset test set
trainloader = DataLoader(traindata, batch_size=40000, shuffle=True) # Create DataLoader for training set
valloader = DataLoader(valdata, batch_size=10000, shuffle=True) # Create DataLoader for validation set
testloader = DataLoader(testdata, batch_size=10000, shuffle=True) # Create DataLoader for test set

device = torch.device("cpu")
train_history = []
test_history = []
'''
for i in range(10):
    BaselineModel = Baseline().to(device)
    BaselineOptimizer = optim.Adam(BaselineModel.parameters(), lr=0.001, weight_decay=0.0001)
    Baseline_train_history = fit(BaselineModel, BaselineOptimizer, 
                                 dataloader['train'], dataloader['val'], dataloader['test'], 
                                 batch_size=40000, model_type='Baseline', epochs=50, device=device)
    Baseline_test_history = evaluate(BaselineModel, dataloader['test'], model_type='Baseline', device=device)
    train_history.append(Baseline_train_history[-6]['val_acc'])
    test_history.append(Baseline_test_history['val_acc'])
'''
for i in range(10):
    SemiStructuredModel = SemiStructuredNet(batch_size=40000, cf_dim=2, num_features=32).to(device)
    SemiStructuredModelOptimizer = optim.Adam(SemiStructuredModel.parameters(), lr=0.001, weight_decay=0.0001)
    SemiStructredModel_train_history = fit(SemiStructuredModel, SemiStructuredModelOptimizer, 
                                        dataloader['train'], dataloader['val'], dataloader['test'], 
                                        batch_size=40000, model_type='SSN', cf_dim=2, num_features=32, epochs=50, device=device)
    SemiStructured_test = evaluate(SemiStructuredModel, dataloader['test'], model_type='SSN', device=device)
    print(f'Testing Loss: {SemiStructured_test["val_loss"]:.4f}, Testing Accuracy: {SemiStructured_test["val_acc"]:.2%}')
    train_history.append(SemiStructredModel_train_history[-6]['val_acc'])
    test_history.append(SemiStructured_test['val_acc'])

In [None]:
train_acc_mean = np.mean(train_history)
train_acc_ci = stats.t.interval(0.95, len(train_history)-1, loc=train_acc_mean, scale=stats.sem(train_history))
test_acc_mean = np.mean(test_history)
test_acc_ci = stats.t.interval(0.95, len(test_history)-1, loc=test_acc_mean, scale=stats.sem(test_history))
print(f'Validation accuracy: {train_acc_mean:.4f} ± {train_acc_ci[1]-train_acc_mean:.4f}')
print(f'Test accuracy: {test_acc_mean:.4f} ± {test_acc_ci[1]-test_acc_mean:.4f}')

In [None]:
# Performance of SSN on different test environments
device = torch.device("cpu")
p_e = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
train_history = {} # Dictionary for storing training history for different batch sizes
test_history = {} # Dictionary for storing test results for different batch sizes
preparation = prepare_data() # Prepare data for given batch size
for run in range(5):
    train_history[run], test_history[run] = {}, {} # Initialize dictionaries for current run
    train_history[run]['Baseline'], test_history[run]['Baseline'] = [], [] # Initialize lists for storing training history and test results for Baseline model
    train_history[run]['SSN'], test_history[run]['SSN'] = [], [] # Initialize lists for storing training history and test results for Semi-Structured model
    print(f'Run {run+1}...')
    for p in p_e:
        dataloader = create_dataloader(**preparation, batch_size=100, p_test=p) # Create dataloaders for given batch size
        # Create and train models
        SemiStructuredModel = SemiStructuredNet(batch_size=100, cf_dim=2, num_features=32).to(device)
        SemiStructuredModelOptimizer = optim.Adam(SemiStructuredModel.parameters(), lr=0.001, weight_decay=0.0001)
        print(f'Training Semi-Structured model with p = {p}...')
        train_history[run]['SSN'].append(fit(SemiStructuredModel, SemiStructuredModelOptimizer, 
                                             dataloader['train'], dataloader['val'], dataloader['test'], 
                                             batch_size=100, model_type='SSN', cf_dim=2, num_features=32, epochs=50, device=device))
        test_history[run]['SSN'].append(evaluate(SemiStructuredModel, dataloader['test'], model_type='SSN', device=device))
print('Done!')

In [None]:
# Prepare results for different environments for plotting
p_results = pd.DataFrame(columns=['p', 'Final Train Accuracy', 'Test Accuracy'])
p_train_mean = []
p_train_ci = []
p_test_mean = []
p_test_ci = []
for i, p in enumerate(p_e):
    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []
    for run in range(5):
        train_losses.append(train_history[run]['SSN'][i][-6]['train_loss'])
        test_losses.append(test_history[run]['SSN'][i]['val_loss'])
        train_accs.append(train_history[run]['SSN'][i][-6]['val_acc'])
        test_accs.append(test_history[run]['SSN'][i]['val_acc'])
    # Calculate mean and 95% confidence interval for train and test losses and accuracies
    train_loss_mean = np.mean(train_losses)
    train_loss_ci = stats.t.interval(0.95, len(train_losses)-1, loc=train_loss_mean, scale=stats.sem(train_losses))
    test_loss_mean = np.mean(test_losses)
    test_loss_ci = stats.t.interval(0.95, len(test_losses)-1, loc=test_loss_mean, scale=stats.sem(test_losses))
    train_acc_mean = np.mean(train_accs)
    train_acc_ci = stats.t.interval(0.95, len(train_accs)-1, loc=train_acc_mean, scale=stats.sem(train_accs))
    test_acc_mean = np.mean(test_accs)
    test_acc_ci = stats.t.interval(0.95, len(test_accs)-1, loc=test_acc_mean, scale=stats.sem(test_accs))
    # Add results to table
    p_results = p_results.append({'p': p, 
                                          'Final Train Accuracy': f'{train_acc_mean:.2%}' + u"\u00B1" + f'{train_acc_ci[1]-train_acc_mean:.2%}', 
                                          'Test Accuracy': f'{test_acc_mean:.2%}' + u"\u00B1" + f'{test_acc_ci[1]-test_acc_mean:.2%}'}, ignore_index=True)
    p_train_mean.append(train_acc_mean)
    p_train_ci.append(train_acc_ci[1]-train_acc_mean)
    p_test_mean.append(test_acc_mean)
    p_test_ci.append(test_acc_ci[1]-test_acc_mean)
p_results

In [None]:
# Plot results for
fig = plt.figure(figsize=(8, 4))
plt.plot(p_e, p_train_mean, '-r')
plt.plot(p_e, p_test_mean, '-g')
plt.vlines(np.array(p_e), np.array(p_train_mean)-np.array(p_train_ci), np.array(p_train_mean)+np.array(p_train_ci), colors='r', linestyles='dashed')
plt.vlines(np.array(p_e), np.array(p_test_mean)-np.array(p_test_ci), np.array(p_test_mean)+np.array(p_test_ci), colors='g', linestyles='dashed')
plt.legend(['Validation', 'Test'], frameon=False)
plt.xlabel(r'$p^{e}$')
plt.ylabel('Accuracy')
plt.savefig(f'plots/ColoredMNIST/p_e.pdf', bbox_inches='tight')

In [None]:
###############################################################################################################
# Compute SHAP values for Baseline and Semi-Structured Net model:
###############################################################################################################
device = torch.device("cpu")
batch_size = 100 # Set batch size for SHAP computation
model_Baseline = Baseline().to(device) # Create model instance
model_Baseline.load_state_dict(torch.load('fitted_models/ColoredMNIST/best-model-Baseline-100.pth')) # Load fitted model
model_Baseline.eval() # Set model to evaluation mode
model_SemiStructured = SemiStructuredNet(batch_size=100, cf_dim=2, num_features=32).to(device) # Create model instance
model_SemiStructured.load_state_dict(torch.load('fitted_models/ColoredMNIST/best-model-SSN-100.pth')) # Load fitted model
model_SemiStructured.eval() # Set model to evaluation mode
# Create background sample
dataloader = create_dataloader(**preparation, batch_size=100)
background_sample = next(iter(dataloader['test']))
background_images = background_sample['images'].to(device)
background_labels = background_sample['labels'].to(device)
background_colors = background_sample['colors'].to(device)
# Set confounders for Semi-Structured Net model
with torch.no_grad():
    model_SemiStructured.set_confounders(background_colors, cf_dim=2, device=device)
background_images = background_images.view(background_images.shape[0], -1)
explainer_Baseline = shap.DeepExplainer(model_Baseline, background_images) # Create explainer instance for Baseline model
explainer_SemiStructured = shap.DeepExplainer(model_SemiStructured, background_images) # Create explainer instance for Semi-Structured Net model
# Create test sample
test_sample = next(iter(dataloader['test']))
test_images = test_sample['images'].to(device)
test_labels = test_sample['labels'].to(device)
test_colors = test_sample['colors'].to(device)
# Set confounders for Semi-Structured Net model by using the confounder of the background and concatenating it with the confounder of the test sample
with torch.no_grad():
    model_SemiStructured.cfs = nn.Parameter(torch.cat((
        torch.cat((torch.ones((len(background_colors), 1)), torch.Tensor(background_colors)), dim=1).to(device),
        torch.cat((torch.ones((len(test_colors), 1)), torch.Tensor(test_colors)), dim=1).to(device)
    )), requires_grad=False)
test_images = test_images.view(background_images.shape[0], -1)
shap_values_Baseline = explainer_Baseline.shap_values(test_images) # Compute SHAP values for Baseline model
shap_values_SemiStructured = explainer_SemiStructured.shap_values(test_images) # Compute SHAP values for Semi-Structured Net model

In [None]:
# Create plots of SHAP values for Baseline model and save to file
plt.ioff() # Turn interactive mode off to prevent plots from being displayed
shap.image_plot(np.reshape(shap_values_Baseline[:,:196], (len(test_images), 14, 14, 1))[:10], 
                -np.reshape(test_images.numpy()[:,:196], (len(test_images), 14, 14, 1))[:10], show=False)
plt.savefig(f'plots/ColoredMNIST/shap_values_Baseline-{batch_size}_channel0.png', dpi=1200, bbox_inches='tight')
shap.image_plot(np.reshape(shap_values_Baseline[:,196:], (len(test_images), 14, 14, 1))[:10], 
                -np.reshape(test_images.numpy()[:,196:], (len(test_images), 14, 14, 1))[:10], show=False)
plt.savefig(f'plots/ColoredMNIST/shap_values_Baseline-{batch_size}_channel1.png', dpi=1200, bbox_inches='tight')
shap.image_plot(np.reshape(shap_values_Baseline[:,:196]+shap_values_Baseline[:,196:], (len(test_images), 14, 14, 1))[:10], 
                -np.reshape(test_images.numpy()[:,:196]+test_images.numpy()[:,196:], (len(test_images), 14, 14, 1))[:10], 
                labels=test_labels[:10].numpy().astype(int), 
                show=False)
plt.savefig(f'plots/ColoredMNIST/shap_values_Baseline-{batch_size}_both_channels.png', dpi=1200, bbox_inches='tight')

In [None]:
# Display plots of SHAP values for Baseline model in a grid
fig = plt.figure(figsize=(12, 12))
gs = gridspec.GridSpec(nrows=1, ncols=3)
ax0 = fig.add_subplot(gs[0, 1])
ax0.imshow(plt.imread(f'plots/ColoredMNIST/shap_values_Baseline-{batch_size}_channel0.png'))
ax1 = fig.add_subplot(gs[0, 2])
ax1.imshow(plt.imread(f'plots/ColoredMNIST/shap_values_Baseline-{batch_size}_channel1.png'))
ax2 = fig.add_subplot(gs[0, 0])
ax2.imshow(plt.imread(f'plots/ColoredMNIST/shap_values_Baseline-{batch_size}_both_channels.png'))
ax0.axis('off')
ax1.axis('off')
ax2.axis('off')
plt.savefig(f'plots/ColoredMNIST/shap_values_Baseline-{batch_size}_full.pdf', dpi=1200, bbox_inches='tight')

In [None]:
# Create plots of SHAP values for Semi-Structured Net model and save to file
plt.ioff() # Turn interactive mode off to prevent plots from being displayed
shap.image_plot(np.reshape(shap_values_SemiStructured[:,:196], (len(test_images), 14, 14, 1))[:10], 
                -np.reshape(test_images.numpy()[:,:196], (len(test_images), 14, 14, 1))[:10], show=False)
plt.savefig(f'plots/ColoredMNIST/shap_values_SemiStructured-{batch_size}_channel0.pdf', dpi=1200, bbox_inches='tight')
shap.image_plot(np.reshape(shap_values_SemiStructured[:,196:], (len(test_images), 14, 14, 1))[:10], 
                -np.reshape(test_images.numpy()[:,196:], (len(test_images), 14, 14, 1))[:10], show=False)
plt.savefig(f'plots/ColoredMNIST/shap_values_SemiStructured-{batch_size}_channel1.pdf', dpi=1200, bbox_inches='tight')
shap.image_plot(np.reshape(shap_values_SemiStructured[:,:196]+shap_values_SemiStructured[:,196:], (len(test_images), 14, 14, 1))[:10], 
                -np.reshape(test_images.numpy()[:,:196]+test_images.numpy()[:,196:], (len(test_images), 14, 14, 1))[:10], 
                labels=test_labels[:10].numpy().astype(int), 
                show=False)
plt.savefig(f'plots/ColoredMNIST/shap_values_SemiStructured-{batch_size}_both_channels.pdf', dpi=1200, bbox_inches='tight')
plt.close()

In [None]:
# Display plots of SHAP values for Semi-Structured Net model in a grid
fig = plt.figure(figsize=(12, 16))
gs = gridspec.GridSpec(nrows=1, ncols=3)
ax0 = fig.add_subplot(gs[0, 1])
ax0.imshow(plt.imread(f'plots/ColoredMNIST/shap_values_SemiStructured-{batch_size}_channel0.png'))
ax1 = fig.add_subplot(gs[0, 2])
ax1.imshow(plt.imread(f'plots/ColoredMNIST/shap_values_SemiStructured-{batch_size}_channel1.png'))
ax2 = fig.add_subplot(gs[0, 0])
ax2.imshow(plt.imread(f'plots/ColoredMNIST/shap_values_SemiStructured-{batch_size}_both_channels.png'))
ax0.axis('off')
ax1.axis('off')
ax2.axis('off')
plt.savefig(f'plots/ColoredMNIST/shap_values_SemiStructured-{batch_size}_full.png', dpi=1200, bbox_inches='tight')