# Jupyter Notebook for the Medical Data experiment

Sources of data: \
Guangzhou: https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia/versions/1?resource=download&select=chest_xray \
RSNA: https://www.kaggle.com/competitions/rsna-pneumonia-detection-challenge/data

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2
import torch 
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pydicom
import os
import cv2
import seaborn as sns
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR, ConstantLR
from torchvision.datasets import ImageFolder
from torchvision import transforms
from collections import OrderedDict
from pathlib import Path
from torchview import draw_graph
from torchmetrics.classification import BinaryAUROC, MulticlassAUROC

from models import Baseline, SemiStructuredNet

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # GPU
downloads_path = str(Path.home() / "Downloads") # Downloads folder

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:734"

In [None]:
# Creating the data directory for the adult dataset
create_rsna = False # Set marker for creating the RSNA data directory
if create_rsna:
    # Read label file
    labels = pd.read_csv(downloads_path + "/rsna-pneumonia-detection-challenge/stage_2_train_labels.csv")[['patientId','Target']]
    # Define input path and output paths
    input_path = downloads_path + "/rsna-pneumonia-detection-challenge/stage_2_train_images/"
    if not os.path.exists(downloads_path + "/rsna-pneumonia-detection-challenge/train/"):
        os.mkdir(downloads_path + "/rsna-pneumonia-detection-challenge/train/") # Create the new train folder if it does not exist
    output_path_normal = downloads_path + "/rsna-pneumonia-detection-challenge/train/NORMAL/"
    output_path_pneumonia = downloads_path + "/rsna-pneumonia-detection-challenge/train/PNEUMONIA/"
    # Create the new normal and pneumonia folders if they do not exist
    if (not os.path.exists(output_path_normal)) and (not os.path.exists(output_path_pneumonia)):
        os.mkdir(output_path_normal) # Create the new normal folder if it does not exist
        os.mkdir(output_path_pneumonia) # Create the new pneumonia folder if it does not exist
        # Copy the images to the new folders according to their labels
        train_list = [f for f in os.listdir(input_path)] # List of all the images
        for f in train_list: # Iterate over all the images
            file = pydicom.read_file(input_path + f) # Read the image
            image = file.pixel_array # Get the image array
            id = os.path.splitext(f)[0] # Get the image id
            id_label = labels[labels['patientId'] == id]['Target'].values[0] # Get the image label
            if id_label == 0: # If the image is normal
                cv2.imwrite(output_path_normal + f.replace('.dcm','.jpg'), image) # Save the image in the normal folder
            else: # If the image is pneumonia
                cv2.imwrite(output_path_pneumonia + f.replace('.dcm','.jpg'), image) # Save the image in the pneumonia folder

In [None]:
# Check if data objects already exist in directory and load them, if not, create them
if not (os.path.exists("control_p.pt") and os.path.exists("disease_p.pt") and os.path.exists("control_a.pt") and os.path.exists("disease_a.pt")):
    # Loading the data for the pediatric dataset and resizing the images to 200x200
    dataset_pediatric = ImageFolder(root=downloads_path + "/archive/chest_xray/chest_xray/train", 
                                    transform=transforms.Compose([transforms.Resize((200, 200)), transforms.ToTensor()]))

    # Loading the data for the adult dataset and resizing the images to 200x200
    dataset_adult = ImageFolder(root=downloads_path + "/rsna-pneumonia-detection-challenge/train",
                                transform=transforms.Compose([transforms.Resize((200, 200)), transforms.ToTensor()]))

    # Check numbers of disease and control in each age group to identify the minority class
    disease_pediatric = dataset_pediatric.targets.count(1) # Count the number of disease images  
    control_pediatric = dataset_pediatric.targets.count(0) # Count the number of control images
    disease_adult = dataset_adult.targets.count(1) # Count the number of disease images
    control_adult = dataset_adult.targets.count(0) # Count the number of control images
    print(f'Pediatric: Control: {control_pediatric}, Disease: {disease_pediatric}\nAdult: Control: {control_adult}, Disease: {disease_adult}') # Print the numbers

    # Set minority class to size 1000 and sample down the majority class to match
    minority_size = 1000 # Set the number of control images to 1000
    
    # Remove control images and targets from the pediatric dataset to sample down to match number of control images of minority class
    for i in range(control_pediatric - minority_size): # Iterate over the number of disease images minus the number of control images
        dataset_pediatric.imgs.remove(dataset_pediatric.imgs[0]) # Remove the last image
        dataset_pediatric.targets.remove(dataset_pediatric.targets[0]) # Remove the last target
    
    # Remove disease images and targets from the pediatric dataset to sample down to match number of control images of minority class
    for i in range(disease_pediatric - minority_size): # Iterate over the number of disease images minus the number of control images
        dataset_pediatric.imgs.remove(dataset_pediatric.imgs[-1]) # Remove the last image
        dataset_pediatric.targets.remove(dataset_pediatric.targets[-1]) # Remove the last target

    # Remove control images and targets from the adult dataset to sample down to match number of control images of minority class
    for i in range(control_adult - minority_size): # Iterate over the number of control images minus the number of control images
        dataset_adult.imgs.remove(dataset_adult.imgs[0]) # Remove the first image
        dataset_adult.targets.remove(dataset_adult.targets[0]) # Remove the first target
        
    # Remove disease images and targets from the adult dataset to sample down to match number of control images of minority class
    for i in range(disease_adult - minority_size): # Iterate over the number of disease images minus the number of control images
        dataset_adult.imgs.remove(dataset_adult.imgs[-1]) # Remove the last image
        dataset_adult.targets.remove(dataset_adult.targets[-1]) # Remove the last target
        
    # Check numbers of disease and control in each age group to see if they match
    disease_pediatric = dataset_pediatric.targets.count(1) # Count the number of disease images
    control_pediatric = dataset_pediatric.targets.count(0) # Count the number of control images
    disease_adult = dataset_adult.targets.count(1) # Count the number of disease images
    control_adult = dataset_adult.targets.count(0) # Count the number of control images
    print(f'Pediatric: Control: {control_pediatric}, Disease: {disease_pediatric}\nAdult: Control: {control_adult}, Disease: {disease_adult}') # Print the numbers

    # Split the pediatric and adult datasets into control and disease yielding 4 datasets for the different classes
    control_p, disease_p, control_a, disease_a = [], [], [], [] # Initialize the lists
    for i in range(len(dataset_pediatric)): # Iterate over the pediatric dataset
        if dataset_pediatric[i][1] == 0: # If the image is control
            control_p.append(dataset_pediatric[i]) # Append the image to the control list
        else: # If the image is disease
            disease_p.append(dataset_pediatric[i]) # Append the image to the disease list
    for i in range(len(dataset_adult)): # Iterate over the adult dataset
        if dataset_adult[i][1] == 0: # If the image is control
            control_a.append(dataset_adult[i]) # Append the image to the control list
        else: # If the image is disease
            disease_a.append(dataset_adult[i]) # Append the image to the disease list
    
    # Save the datasets
    torch.save(control_p, "control_p.pt") # Save the control pediatric dataset
    torch.save(disease_p, "disease_p.pt") # Save the disease pediatric dataset
    torch.save(control_a, "control_a.pt") # Save the control adult dataset
    torch.save(disease_a, "disease_a.pt") # Save the disease adult dataset
else:
    # Load the datasets
    control_p = torch.load("control_p.pt") # Load the control pediatric dataset
    disease_p = torch.load("disease_p.pt") # Load the disease pediatric dataset
    control_a = torch.load("control_a.pt") # Load the control adult dataset
    disease_a = torch.load("disease_a.pt") # Load the disease adult dataset

# Add confounder label to the datasets
for i in range(len(control_p)): # Iterate over the control pediatric dataset
    control_p[i] = control_p[i] + (0,) # Add the confounder label
for i in range(len(disease_p)): # Iterate over the disease pediatric dataset
    disease_p[i] = disease_p[i] + (0,) # Add the confounder label
for i in range(len(control_a)): # Iterate over the control adult dataset
    control_a[i] = control_a[i] + (1,) # Add the confounder label
for i in range(len(disease_a)): # Iterate over the disease adult dataset
    disease_a[i] = disease_a[i] + (1,) # Add the confounder label
    
# Create the test dataset by sampling 10% of the images from each class
testdata_size = 100 # Calculate the number of images to sample
sampling_indices = np.random.choice(range(1000), testdata_size, replace=False) # Sample the indices
testdata = [] # Initialize the list
for i in range(testdata_size): # Iterate over the number of images to sample
    testdata.append(control_p[sampling_indices[i]]) # Append the image to the test dataset
    testdata.append(disease_p[sampling_indices[i]]) # Append the image to the test dataset
    testdata.append(control_a[sampling_indices[i]]) # Append the image to the test dataset
    testdata.append(disease_a[sampling_indices[i]]) # Append the image to the test dataset
for i in sorted(sampling_indices, reverse=True): # Iterate over the indices in reverse order
    del control_p[i] # Delete the image from the control list
    del control_a[i] # Delete the image from the control list
    del disease_p[i] # Delete the image from the disease list
    del disease_a[i] # Delete the image from the disease list
    
# Create the balanced training dataset by sampling the same number of images from each class
length = 500 # Calculate the number of images to sample
sampling_indices = np.random.choice(range(900), length, replace=False) #np.random.choice(range(len(control_p)), length, replace=False) # Sample the indices
balanced = [] # Initialize the list
for i in range(length): # Iterate over the number of images to sample
    balanced.append(control_p[sampling_indices[i]]) # Append the image to the balanced dataset
    balanced.append(disease_p[sampling_indices[i]]) # Append the image to the balanced dataset
    balanced.append(control_a[sampling_indices[i]]) # Append the image to the balanced dataset
    balanced.append(disease_a[sampling_indices[i]]) # Append the image to the balanced dataset
    
# Create confounder datasets for total confounding
# Total confounding: Create traindata from half pediatric disease and half adult control or vice versa
# 1) Pediatric disease data and adult control data
# 2) Pediatric control data and adult disease data

length_total = 500 # Length of half the training dataset
sampling_indices = np.random.choice(range(900), length_total, replace=False) # Sample the indices
total_1 = [] # Initialize the list
for i in range(length_total): # Iterate over the number of images to sample
    total_1.append(disease_p[sampling_indices[i]]) # Append the image to the total dataset
    total_1.append(control_a[sampling_indices[i]]) # Append the image to the total dataset
sampling_indices = np.random.choice(range(900), length_total, replace=False) # Sample the indices
total_2 = [] # Initialize the list
for i in range(length_total): # Iterate over the number of images to sample
    total_2.append(disease_a[sampling_indices[i]]) # Append the image to the total dataset
    total_2.append(control_p[sampling_indices[i]]) # Append the image to the total dataset
    
# Generate light confounding datasets in steps of 5 percentage points
light_confounded_data = {} # Initialize the dictionary
for step in range(1,10):
    length = int(500 - 50*step) # Calculate the number of images to sample
    sampling_indices_1 = np.random.choice(range(900), length, replace=False) # Sample the indices
    sampling_indices_2 = np.random.choice(range(900), 500-length, replace=False) # Sample the indices
    light_confounded_data[f'{int(50-5*step)}-{int(5*step)}'] = [] # Initialize the list
    for i in range(length):
        light_confounded_data[f'{int(50-5*step)}-{int(5*step)}'].append(disease_p[sampling_indices_1[i]])
        light_confounded_data[f'{int(50-5*step)}-{int(5*step)}'].append(control_a[sampling_indices_1[i]])
    for j in range(500-length):
        light_confounded_data[f'{int(50-5*step)}-{int(5*step)}'].append(disease_a[sampling_indices_2[j]])
        light_confounded_data[f'{int(50-5*step)}-{int(5*step)}'].append(control_p[sampling_indices_2[j]])
    
# Split traindata in training and validation (internal testing) data
total_1_train, total_1_val = torch.utils.data.random_split(total_1, [800, 200]) # Split the dataset
total_2_train, total_2_val = torch.utils.data.random_split(total_2, [800, 200]) # Split the dataset
balanced_train, balanced_val = torch.utils.data.random_split(balanced, [1600, 400]) # Split the dataset
light_confounded_train_val = {} # Initialize the dictionary
for key in light_confounded_data.keys():
    light_confounded_train_val[key + '_train'], light_confounded_train_val[key + '_val'] = torch.utils.data.random_split(light_confounded_data[key], [800, 200]) # Split the dataset
torch.save(light_confounded_data, "light_confounded_data.pt") # Save the light confounded datasets
torch.save(light_confounded_train_val, "light_confounded_train_val.pt") # Save the light confounded training and validation datasets

In [None]:
# Plot the first 4 images of the first balanced train dataset
plt.figure(figsize = (20, 20)) # Set the figure size
for i, j in enumerate(range(5, 9)): # Iterate over the first 5 images
    image = transforms.ToPILImage()(balanced_train[j][0]) # Convert the image to PIL format
    plt.subplot(1, 4, i + 1) # Set the subplot
    plt.axis('off')
    age = 'Child' if balanced_train[j][2] == 0 else 'Adult'
    disease = 'Control' if balanced_train[j][1] == 0 else 'Disease'
    plt.title(age + ' - ' + disease) # Set the title
    plt.imshow(image) # Show the image
plt.savefig('cxr_sample.pdf') # Save the plot
plt.show() # Show the plot

Replication of results from Garcia Santa Cruz, Husch, Hertel (2022):

In [None]:
def evaluate(model, valloader, device):
    model.eval()
    if model.num_classes > 1:
        metric = MulticlassAUROC(num_classes=model.num_classes, average='macro', thresholds=None) # Initialize AUROC metric
    else: 
        metric = BinaryAUROC() # Initialize AUROC metric
    outputs = [model.validation_step(batch, metric, device) for batch in valloader]
    return model.validation_epoch_end(outputs)

def fit(model, optimizer, scheduler, trainloader, valloader, epochs, device, print_results=True):
    torch.cuda.empty_cache() # Empty GPU cache
    train_history = []
    for epoch in range(epochs):
        model.train()
        train_losses = []
        for batch in trainloader:
            optimizer.zero_grad()
            _, loss = model.training_step(batch, device)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            scheduler.step()
            if model.name == 'SSN':
                model.set_delta(trainloader, device)
            del batch
            torch.cuda.empty_cache()
        model.eval()
        result = evaluate(model, valloader, device)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        if print_results:
            model.epoch_end(epoch, result)
        train_history.append(result)
    return train_history

In [None]:
'''
train_results = {'Baseline': {}, 'SSN_1': {}, 'SSN_2': {}} # Initialize the dictionary
test_balanced_results = {'Baseline': {}, 'SSN_1': {}, 'SSN_2': {}} # Initialize the dictionary
test_inverse_results = {'Baseline': {}, 'SSN_1': {}, 'SSN_2': {}} # Initialize the dictionary
keys = [f'{int(50-5*step)}-{int(5*step)}' for step in range(1,10)] # Initialize the keys
for key in keys:
    train_results['Baseline'][key] = {} # Initialize the dictionary
    train_results['SSN_1'][key] = {} # Initialize the dictionary
    train_results['SSN_2'][key] = {} # Initialize the dictionary
    test_balanced_results['Baseline'][key] = {} # Initialize the dictionary
    test_balanced_results['SSN_1'][key] = {} # Initialize the dictionary
    test_balanced_results['SSN_2'][key] = {} # Initialize the dictionary
    test_inverse_results['Baseline'][key] = {} # Initialize the dictionary
    test_inverse_results['SSN_1'][key] = {} # Initialize the dictionary
    test_inverse_results['SSN_2'][key] = {} # Initialize the dictionary
    for run in range(10):
        trainloader = DataLoader(light_confounded_train_val[key + '_train'], batch_size=50, shuffle=True) # Get the trainloader
        valloader = DataLoader(light_confounded_train_val[key + '_val'], batch_size=50, shuffle=True) # Get the valloader
        Model_Baseline = Baseline(num_classes=2).to(device)
        optimizer_Baseline = optim.Adam(Model_Baseline.parameters(), lr=0.001, weight_decay=0.0001) # Create optimizer
        scheduler_Baseline = torch.optim.lr_scheduler.OneCycleLR(optimizer_Baseline, max_lr=0.01, epochs=15, steps_per_epoch=len(trainloader))
        Model_SSN_1 = SemiStructuredNet(batch_size=50, cf_dim=1, num_classes=2, num_features=128).to(device)
        optimizer_SSN_1 = optim.Adam(Model_SSN_1.parameters(), lr=0.001, weight_decay=0.0001) # Create optimizer
        scheduler_SSN_1 = torch.optim.lr_scheduler.OneCycleLR(optimizer_SSN_1, max_lr=0.01, epochs=15, steps_per_epoch=len(trainloader))
        Model_SSN_2 = SemiStructuredNet(batch_size=50, cf_dim=2, num_classes=2, num_features=128).to(device)
        optimizer_SSN_2 = optim.Adam(Model_SSN_2.parameters(), lr=0.001, weight_decay=0.0001) # Create optimizer
        scheduler_SSN_2 = torch.optim.lr_scheduler.OneCycleLR(optimizer_SSN_2, max_lr=0.01, epochs=15, steps_per_epoch=len(trainloader))
        print(f'Light confounding: {key}')
        # Train the baseline model
        print('Baseline')
        train_results['Baseline'][key][run] = fit(Model_Baseline, optimizer_Baseline, scheduler_Baseline, trainloader, valloader, epochs = 15, device=device, print_results=False)
        # Train the SSN models
        print('SSN_1')
        train_results['SSN_1'][key][run] = fit(Model_SSN_1, optimizer_SSN_1, scheduler_SSN_1, trainloader, valloader, epochs = 15, device=device, print_results=False)
        print('SSN_2')
        train_results['SSN_2'][key][run] = fit(Model_SSN_2, optimizer_SSN_2, scheduler_SSN_2, trainloader, valloader, epochs = 15, device=device, print_results=False)
        # Test the models
        print('Test Balanced')
        testloader = DataLoader(testdata, batch_size=50, shuffle=True) # Get the testloader
        test_balanced_results['Baseline'][key][run] = evaluate(Model_Baseline, testloader, device)['val_auc']
        test_balanced_results['SSN_1'][key][run] = evaluate(Model_SSN_1, testloader, device)['val_auc']
        test_balanced_results['SSN_2'][key][run] = evaluate(Model_SSN_2, testloader, device)['val_auc']
        print('Test Inverse')
        testloader = DataLoader(light_confounded_train_val[key.split('-')[1] + '-' + key.split('-')[0] + '_val'], batch_size=50, shuffle=True) # Get the testloader
        test_inverse_results['Baseline'][key][run] = evaluate(Model_Baseline, testloader, device)['val_auc']
        test_inverse_results['SSN_1'][key][run] = evaluate(Model_SSN_1, testloader, device)['val_auc']
        test_inverse_results['SSN_2'][key][run] = evaluate(Model_SSN_2, testloader, device)['val_auc']
        # Save fitted models
        torch.save(Model_Baseline.state_dict(), f"fitted_models/Baseline_{key}_{run}.pt")
        torch.save(Model_SSN_1.state_dict(), f"fitted_models/SSN_1_{key}_{run}.pt")
        torch.save(Model_SSN_2.state_dict(), f"fitted_models/SSN_2_{key}_{run}.pt")
torch.save(train_results, "train_results.pt") # Save the training results
torch.save(test_balanced_results, "test_balanced_results.pt") # Save the test results
torch.save(test_inverse_results, "test_inverse_results.pt") # Save the test results
'''

In [None]:
# Run experiment
os.system('python run_experiment.py')

In [None]:
train_results = torch.load("train_results.pt") # Load the training results
test_balanced_results = torch.load("test_balanced_results.pt") # Load the test results
test_inverse_results = torch.load("test_inverse_results.pt") # Load the test results

In [None]:
# Ignore warnings for plotting
import warnings
warnings.filterwarnings("ignore")
# Make boxplots of results for each confounding level
fig, ax = plt.subplots(3, 1, figsize=(8, 15)) # Initialize the figure
plot_train_results = pd.DataFrame(columns=['Model', 'Confounding Degree', 'AUC']) # Initialize the dataframe
Models = ['Baseline', 'SSN_1', 'SSN_2'] # Initialize the models vector
Cf_degrees = ['25-25', '30-20', '20-30', '35-15', '15-35', '40-10', '10-40', '45-5', '5-45'] # Initialize the confounding degrees vector
for model in Models:
    for cf_degree in Cf_degrees:
        aucs = []
        for i in range(10):
            aucs.append(train_results[model][cf_degree][i][-1]['val_auc'])
        plot_train_results = plot_train_results.append(pd.DataFrame({'Model': [model] * 10, 'Confounding Degree': [cf_degree] * 10, 'AUC': aucs}), ignore_index=True)
sns.boxplot(ax=ax[0], x='Confounding Degree', y='AUC', hue='Model', data=plot_train_results, palette='tab10')
ax[0].set_title('AUC on (Internal) Validation Set')
plot_balanced_results = pd.DataFrame(columns=['Model', 'Confounding Degree', 'AUC']) # Initialize the dataframe
for model in Models:
    for cf_degree in Cf_degrees:
        aucs = []
        for i in range(10):
            aucs.append(test_balanced_results[model][cf_degree][i])
        plot_balanced_results = plot_balanced_results.append(pd.DataFrame({'Model': [model] * 10, 'Confounding Degree': [cf_degree] * 10, 'AUC': aucs}), ignore_index=True)
sns.boxplot(ax=ax[1], x='Confounding Degree', y='AUC', hue='Model', data=plot_balanced_results, palette='tab10')
ax[1].set_title('AUC on (External) Balanced Test Set')
plot_inverse_results = pd.DataFrame(columns=['Model', 'Confounding Degree', 'AUC']) # Initialize the dataframe
for model in Models:
    for cf_degree in Cf_degrees:
        aucs = []
        for i in range(10):
            aucs.append(test_inverse_results[model][cf_degree][i])
        plot_inverse_results = plot_inverse_results.append(pd.DataFrame({'Model': [model] * 10, 'Confounding Degree': [cf_degree] * 10, 'AUC': aucs}), ignore_index=True)
sns.boxplot(ax=ax[2], x='Confounding Degree', y='AUC', hue='Model', data=plot_inverse_results, palette='tab10')
ax[2].set_title('AUC on "Inverse" Validation Set')
fig.tight_layout(pad=3.0)
plt.savefig('results_AUC_1.pdf') # Save the plot

In [None]:
# Make boxplots for the difference between the AUC on the balanced test set and the AUC on the inverse validation set
fig, ax = plt.subplots(1, 2, figsize=(15, 5)) # Initialize the figure
plot_results = pd.DataFrame(columns=['Model', 'Confounding Degree', 'AUC']) # Initialize the dataframe
for model in Models:
    for cf_degree in Cf_degrees:
        aucs = []
        for i in range(10):
            aucs.append(np.abs(test_balanced_results[model][cf_degree][i] - test_inverse_results[model][cf_degree][i]))
        plot_results = plot_results.append(pd.DataFrame({'Model': [model] * 10, 'Confounding Degree': [cf_degree] * 10, 'AUC': aucs}), ignore_index=True)
sns.boxplot(ax=ax[1], x='Confounding Degree', y='AUC', hue='Model', data=plot_results, palette='tab10')
ax[1].set_title('Balanced Test Set vs. "Inverse" Validation Set')
plot_results = pd.DataFrame(columns=['Model', 'Confounding Degree', 'AUC']) # Initialize the dataframe
for model in Models:
    for cf_degree in Cf_degrees:
        aucs = []
        for i in range(10):
            aucs.append(np.abs(test_balanced_results[model][cf_degree][i] - train_results[model][cf_degree][i][-1]['val_auc']))
        plot_results = plot_results.append(pd.DataFrame({'Model': [model] * 10, 'Confounding Degree': [cf_degree] * 10, 'AUC': aucs}), ignore_index=True)
sns.boxplot(ax=ax[0], x='Confounding Degree', y='AUC', hue='Model', data=plot_results, palette='tab10')
ax[0].set_title('Balanced Test Set vs. Validation Set')
plt.savefig('results_AUC_2.pdf') # Save the plot

Appendix (old versions of plots)

In [None]:
# Prepare AUCs for plotting
internal_aucs = {'Baseline': [], 'SSN_1': [], 'SSN_2': []}
external_balanced_aucs = {'Baseline': [], 'SSN_1': [], 'SSN_2': []}
external_inverse_aucs = {'Baseline': [], 'SSN_1': [], 'SSN_2': []}
for step in range(1,10):
    A, B, C, D, E, F, G, H, I = [], [], [], [], [], [], [], [], []
    for run in range(10):
        A += [train_results['Baseline'][f'{int(50-5*step)}-{int(5*step)}'][run][-1]['val_auc']]
        B += [train_results['SSN_1'][f'{int(50-5*step)}-{int(5*step)}'][run][-1]['val_auc']]
        C += [train_results['SSN_2'][f'{int(50-5*step)}-{int(5*step)}'][run][-1]['val_auc']]
        D += [test_balanced_results['Baseline'][f'{int(50-5*step)}-{int(5*step)}'][run]]
        E += [test_balanced_results['SSN_1'][f'{int(50-5*step)}-{int(5*step)}'][run]]
        F += [test_balanced_results['SSN_2'][f'{int(50-5*step)}-{int(5*step)}'][run]]
        G += [test_inverse_results['Baseline'][f'{int(50-5*step)}-{int(5*step)}'][run]]
        H += [test_inverse_results['SSN_1'][f'{int(50-5*step)}-{int(5*step)}'][run]]
        I += [test_inverse_results['SSN_2'][f'{int(50-5*step)}-{int(5*step)}'][run]]
    internal_aucs['Baseline'].append(np.mean(A))
    internal_aucs['SSN_1'].append(np.mean(B))
    internal_aucs['SSN_2'].append(np.mean(C))
    external_balanced_aucs['Baseline'].append(np.mean(D))
    external_balanced_aucs['SSN_1'].append(np.mean(E))
    external_balanced_aucs['SSN_2'].append(np.mean(F))
    external_inverse_aucs['Baseline'].append(np.mean(G))
    external_inverse_aucs['SSN_1'].append(np.mean(H))
    external_inverse_aucs['SSN_2'].append(np.mean(I))
internal_auc_Baseline_1 = internal_aucs['Baseline'][:4]
internal_auc_Baseline_2 = internal_aucs['Baseline'][5:]
external_balanced_auc_Baseline_1 = external_balanced_aucs['Baseline'][:4]
external_balanced_auc_Baseline_2 = external_balanced_aucs['Baseline'][5:]
external_inverse_auc_Baseline_1 = external_inverse_aucs['Baseline'][:4]
external_inverse_auc_Baseline_2 = external_inverse_aucs['Baseline'][5:]
internal_auc_SSN_1_1 = internal_aucs['SSN_1'][:4]
internal_auc_SSN_1_2 = internal_aucs['SSN_1'][5:]
external_balanced_auc_SSN_1_1 = external_balanced_aucs['SSN_1'][:4]
external_balanced_auc_SSN_1_2 = external_balanced_aucs['SSN_1'][5:]
external_inverse_auc_SSN_1_1 = external_inverse_aucs['SSN_1'][:4]
external_inverse_auc_SSN_1_2 = external_inverse_aucs['SSN_1'][5:]
internal_auc_SSN_2_1 = internal_aucs['SSN_2'][:4]
internal_auc_SSN_2_2 = internal_aucs['SSN_2'][5:]
external_balanced_auc_SSN_2_1 = external_balanced_aucs['SSN_2'][:4]
external_balanced_auc_SSN_2_2 = external_balanced_aucs['SSN_2'][5:]
external_inverse_auc_SSN_2_1 = external_inverse_aucs['SSN_2'][:4]
external_inverse_auc_SSN_2_2 = external_inverse_aucs['SSN_2'][5:]
internal_auc_Baseline_1.reverse()
internal_auc_SSN_1_1.reverse()
internal_auc_SSN_2_1.reverse()
external_balanced_auc_Baseline_1.reverse()
external_balanced_auc_SSN_1_1.reverse()
external_balanced_auc_SSN_2_1.reverse()
external_inverse_auc_Baseline_1.reverse()
external_inverse_auc_SSN_1_1.reverse()
external_inverse_auc_SSN_2_1.reverse()

In [None]:
iB, iSSN1, iSSN2 = [], [], []
ebB, ebSSN1, ebSSN2 = [], [], []
eiB, eiSSN1, eiSSN2 = [], [], []
for run in range(10):
    iB += [train_results['Baseline']['25-25'][run][-1]['val_auc']]
    iSSN1 += [train_results['SSN_1']['25-25'][run][-1]['val_auc']]
    iSSN2 += [train_results['SSN_2']['25-25'][run][-1]['val_auc']]
    ebB += [test_balanced_results['Baseline']['25-25'][run]]
    ebSSN1 += [test_balanced_results['SSN_1']['25-25'][run]]
    ebSSN2 += [test_balanced_results['SSN_2']['25-25'][run]]
    eiB += [test_inverse_results['Baseline']['25-25'][run]]
    eiSSN1 += [test_inverse_results['SSN_1']['25-25'][run]]
    eiSSN2 += [test_inverse_results['SSN_2']['25-25'][run]]
internal_auc_Baseline, internal_auc_SSN_1, internal_auc_SSN_2 = [np.mean(iB)], [np.mean(iSSN1)], [np.mean(iSSN2)]
external_balanced_auc_Baseline, external_balanced_auc_SSN_1, external_balanced_auc_SSN_2 = [np.mean(ebB)], [np.mean(ebSSN1)], [np.mean(ebSSN2)]
external_inverse_auc_Baseline, external_inverse_auc_SSN_1, external_inverse_auc_SSN_2 = [np.mean(eiB)], [np.mean(eiSSN1)], [np.mean(eiSSN2)]
for i in range(4):
    internal_auc_Baseline.append(internal_auc_Baseline_1[i])
    internal_auc_Baseline.append(internal_auc_Baseline_2[i])
    internal_auc_SSN_1.append(internal_auc_SSN_1_1[i])
    internal_auc_SSN_1.append(internal_auc_SSN_1_2[i])
    internal_auc_SSN_2.append(internal_auc_SSN_2_1[i])
    internal_auc_SSN_2.append(internal_auc_SSN_2_2[i])
    external_balanced_auc_Baseline.append(external_balanced_auc_Baseline_1[i])
    external_balanced_auc_Baseline.append(external_balanced_auc_Baseline_2[i])
    external_balanced_auc_SSN_1.append(external_balanced_auc_SSN_1_1[i])
    external_balanced_auc_SSN_1.append(external_balanced_auc_SSN_1_2[i])
    external_balanced_auc_SSN_2.append(external_balanced_auc_SSN_2_1[i])
    external_balanced_auc_SSN_2.append(external_balanced_auc_SSN_2_2[i])
    external_inverse_auc_Baseline.append(external_inverse_auc_Baseline_1[i])
    external_inverse_auc_Baseline.append(external_inverse_auc_Baseline_2[i])
    external_inverse_auc_SSN_1.append(external_inverse_auc_SSN_1_1[i])
    external_inverse_auc_SSN_1.append(external_inverse_auc_SSN_1_2[i])
    external_inverse_auc_SSN_2.append(external_inverse_auc_SSN_2_1[i])
    external_inverse_auc_SSN_2.append(external_inverse_auc_SSN_2_2[i])

In [None]:
# Plot the AUCs
fig, ax = plt.subplots(1,3, figsize=(20,5))
ax[0].plot(range(1,10), internal_auc_Baseline, '-b', label='Baseline')
ax[0].plot(range(1,10), internal_auc_SSN_1, '-r', label='SSN (no constant)')
ax[0].plot(range(1,10), internal_auc_SSN_2, '-g', label='SSN (with constant)')
ax[0].set_title('AUC on (Internal) Validation Set')
ax[1].plot(range(1,10), external_balanced_auc_Baseline, '-b', label='Baseline')
ax[1].plot(range(1,10), external_balanced_auc_SSN_1, '-r', label='SSN (no constant)')
ax[1].plot(range(1,10), external_balanced_auc_SSN_2, '-g', label='SSN (with constant)')
ax[1].set_title('AUC on (External) Balanced Test Set')
ax[2].plot(range(1,10), external_inverse_auc_Baseline, '-b', label='Baseline')
ax[2].plot(range(1,10), external_inverse_auc_SSN_1, '-r', label='SSN (no constant)')
ax[2].plot(range(1,10), external_inverse_auc_SSN_2, '-g', label='SSN (with constant)')
ax[2].set_title('AUC on "inverted" Validation Set')
for i in range(3): 
    ax[i].set_xticks(range(1,10), ['25-25', '30-20', '20-30', '35-15', '15-35', '40-10', '10-40', '45-5', '5-45'])
    ax[i].set_ylim(0.5, 1)
    ax[i].set_xlabel('Degree of confounding')
    ax[i].set_ylabel('AUC')
    ax[i].legend()
plt.savefig('AUCs.pdf')

In [None]:
# Plot AUC differences
fig, ax = plt.subplots(1,2, figsize=(20,7))
ax[0].plot(range(1,10), np.abs(np.array(internal_auc_Baseline) - np.array(external_balanced_auc_Baseline)), '-b', label='Baseline')
ax[0].plot(range(1,10), np.abs(np.array(internal_auc_SSN_1) - np.array(external_balanced_auc_SSN_1)), '-r', label='SSN (no constant)')
ax[0].plot(range(1,10), np.abs(np.array(internal_auc_SSN_2) - np.array(external_balanced_auc_SSN_2)), '-g', label='SSN (with constant)')
ax[1].plot(range(1,10), np.abs(np.array(internal_auc_Baseline) - np.array(external_inverse_auc_Baseline)), '-b', label='Baseline')
ax[1].plot(range(1,10), np.abs(np.array(internal_auc_SSN_1) - np.array(external_inverse_auc_SSN_1)), '-r', label='SSN (no constant)')
ax[1].plot(range(1,10), np.abs(np.array(internal_auc_SSN_2) - np.array(external_inverse_auc_SSN_2)), '-g', label='SSN (with constant)')
for i in range(2):
    ax[i].set_xticks(range(1,10), ['25-25', '30-20', '20-30', '35-15', '15-35', '40-10', '10-40', '45-5', '5-45'])
    ax[i].legend()
    ax[i].set_xlabel('Degree of confounding')
    ax[i].set_ylabel('Absolute difference')
    ax[i].set_title('Absolute difference between AUC on internal validation set and external (balanced) test set')
plt.savefig('AUC_difference_balanced.pdf')