# Emergence of Latent Binary Encoding in Deep Neural Network Classifiers
---

### Supplementary material to reproduce results

All plots shown in the manuscript can be made using this notebook. Results are loaded from two 'best_results.pkl' files which contain data relative to the CIFAR10 and CIFAR100 dataset, correspondingly. 
Each experiment consists of running multiple trainings for different learning rates, and the training which has the best accuracy on the training set at the last epoch is picked. Experiments can be run on a GPU cluster using a Slurm script that we provide in 'emergence_binary_encoding/scripts.sh'. Best results for each experiment can then be selected using the Python script that we provide in 'emergence_binary_encoding/find_best_results.py'

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib
import pickle
import yaml
import glob
import os

In [None]:
results_dir = '../results'
with open(results_dir + '/svhn/best_results.pkl', 'rb') as file:
    results_svhn = pickle.load(file)
    
with open(results_dir + '/cifar10/best_results.pkl', 'rb') as file:
    results_10 = pickle.load(file)

with open(results_dir + '/cifar100/best_results.pkl', 'rb') as file:
    results_100 = pickle.load(file)


In [None]:
epochs = results_10['ib']['training_hypers']['epochs'][0]
logging = results_10['ib']['training_hypers']['logging'][0]

In [None]:
sns.set(style="whitegrid")
alpha=0.3

In [None]:
models = [ 'ib', 'wide_ib', 'narrow_ib', 'no_pen', 'dropout_no_pen', 'lin_pen', 'dropout_lin_pen', 'nonlin_pen']

In [None]:
color_ib = '#1f77b4'
color_ib_wide = '#8c564b' 
color_ib_narrow = '#7f7f7f' 
color_no = '#ff7f0e'
color_lin = '#2ca02c'
color_nonlin = '#d62728'
color_dropout_no = '#9467bd'
color_dropout_lin = '#e377c2'

label_size = 18
legend_size = 12

### Table 1 - Sigma_W

In [None]:
print('CIFAR100')  
for model in models:
    print(model, ': ', np.mean(results_100[model]['collapse_train']['sigma_w'][-1]), np.std(results_100[model]['collapse_train']['sigma_w'][-1]))

print('\nCIFAR10')
for model in models:
    print(model, ': ', np.mean(results_10[model]['collapse_train']['sigma_w'][-1]), np.std(results_10[model]['collapse_train']['sigma_w'][-1]))

print('\nSVHN')
for model in models:
    print(model, ': ', np.mean(results_svhn[model]['collapse_train']['sigma_w'][-1]), np.std(results_svhn[model]['collapse_train']['sigma_w'][-1]))


### Table 1 - Coeff. of Var.

In [None]:
print('CIFAR100')  
for model in models:
    print(model, ': ', np.around(np.mean(results_100[model]['coeff_var_train_tpt']),3), np.around(np.std(results_100[model]['coeff_var_train_tpt']),3))

print('\nCIFAR10')
for model in models:
    print(model, ': ', np.around(np.mean(results_10[model]['coeff_var_train_tpt']),3), np.around(np.std(results_10[model]['coeff_var_train_tpt']),3))

print('\nSVHN')
for model in models:
    print(model, ': ', np.around(np.mean(results_svhn[model]['coeff_var_train_tpt']),3), np.around(np.std(results_svhn[model]['coeff_var_train_tpt']),3))


### Table 1 - Entropy

In [None]:
print('CIFAR100')  
for model in models:
    print(model, ': ', np.around(np.mean(results_100[model]['entropy_train_tpt']),3), np.around(np.std(results_100[model]['entropy_train_tpt']),3))

print('\nCIFAR10')
for model in models:
    print(model, ': ', np.around(np.mean(results_10[model]['entropy_train_tpt']),3), np.around(np.std(results_10[model]['entropy_train_tpt']),3))

print('\nSVHN')
for model in models:
    print(model, ': ', np.around(np.mean(results_svhn[model]['entropy_train_tpt']),3), np.around(np.std(results_svhn[model]['entropy_train_tpt']),3))


### Table 2 - Robustness

In [None]:
print('CIFAR100')      
for model in models:
    print(model, ': ', np.around(np.mean(results_100[model]['deepfool_score_tpt']),3), np.around(np.std(results_100[model]['deepfool_score_tpt']),3))

print('\nCIFAR10')
for model in models:
    print(model, ': ', np.around(np.mean(results_10[model]['deepfool_score_tpt']),3), np.around(np.std(results_10[model]['deepfool_score_tpt']),3))

print('\nSVHN')
for model in models:
    print(model, ': ', np.around(np.mean(results_svhn[model]['deepfool_score_tpt']),3), np.around(np.std(results_svhn[model]['deepfool_score_tpt']),3))

### Table 2 - Accuracy

In [None]:
print('CIFAR100')      
for model in models:
    print(model, ':', np.around(100*np.mean(results_100[model]['accuracy_test'], axis=1)[-1],2), np.around(100*np.std(results_100[model]['accuracy_test'], axis=1)[-1],2),)
       
print('\nCIFAR10')
for model in models:
    print(model, ':', np.around(100*np.mean(results_10[model]['accuracy_test'], axis=1)[-1],2), np.around(100*np.std(results_10[model]['accuracy_test'], axis=1)[-1],2))

print('\nSVHN')
for model in models:
    print(model, ':', np.around(100*np.mean(results_svhn[model]['accuracy_test'], axis=1)[-1],2), np.around(100*np.std(results_svhn[model]['accuracy_test'], axis=1)[-1],2))

### Table 2 - Odin AUROC

In [None]:
print('CIFAR100')      
for model in models:
    print(model, ': ', np.around(np.mean(results_100[model]['odin_score_tpt']['auroc']),2), np.around(np.std(results_100[model]['odin_score_tpt']['auroc']),2))

print('\nCIFAR10')
for model in models:
    print(model, ': ', np.around(np.mean(results_10[model]['odin_score_tpt']['auroc']),2), np.around(np.std(results_10[model]['odin_score_tpt']['auroc']),2))

print('\nSVHN')      
for model in models:
    print(model, ': ', np.around(np.mean(results_svhn[model]['odin_score_tpt']['auroc']),2), np.around(np.std(results_svhn[model]['odin_score_tpt']['auroc']),2))

### Table 2 - Mahalanobis AUROC

In [None]:
print('CIFAR100')      
for model in models:
    print(model, ': ', np.around(np.mean(results_100[model]['mahalanobis_score_tpt']['auroc']),2), np.around(np.std(results_100[model]['mahalanobis_score_tpt']['auroc']),2))

print('\nCIFAR10')
for model in models:
    print(model, ': ', np.around(np.mean(results_10[model]['mahalanobis_score_tpt']['auroc']),2), np.around(np.std(results_10[model]['mahalanobis_score_tpt']['auroc']),2))

print('\nSVHN')      
for model in models:
    print(model, ': ', np.around(np.mean(results_svhn[model]['mahalanobis_score_tpt']['auroc']),2), np.around(np.std(results_svhn[model]['mahalanobis_score_tpt']['auroc']),2))

### Figure 2

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(10, 5), sharex=True)

color_0 = color_ib
color_1 = color_lin
color_2 = color_ib_wide
color_3 = color_dropout_lin
color_4 = color_ib_narrow


x = np.arange(logging, epochs+logging,logging)

y = np.mean(results_svhn['ib']['binarity_train']['score'], axis=1)
std = np.std(results_svhn['ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][0], color=color_0, label='IB')
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_svhn['wide_ib']['binarity_train']['score'], axis=1)
std = np.std(results_svhn['wide_ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][0], color=color_2, label='WideIB')
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_svhn['narrow_ib']['binarity_train']['score'], axis=1)
std = np.std(results_svhn['narrow_ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][0], color=color_4, label='NarrowIB')
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_svhn['lin_pen']['binarity_train']['score'], axis=1)
std = np.std(results_svhn['lin_pen']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][0], color=color_1, label='LinPen')
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_svhn['dropout_lin_pen']['binarity_train']['score'], axis=1)
std = np.std(results_svhn['dropout_lin_pen']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[0][0], color=color_3, label='DropoutLinPen')
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)

y = np.mean(results_svhn['ib']['binarity_train']['stds'], axis=1)
std = np.std(results_svhn['ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][0], color=color_0)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_svhn['lin_pen']['binarity_train']['stds'], axis=1)
std = np.std(results_svhn['lin_pen']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][0], color=color_1)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_svhn['wide_ib']['binarity_train']['stds'], axis=1)
std = np.std(results_svhn['wide_ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][0], color=color_2, label='WideIB')
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_svhn['narrow_ib']['binarity_train']['stds'], axis=1)
std = np.std(results_svhn['narrow_ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][0], color=color_4, label='WideIB')
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_svhn['dropout_lin_pen']['binarity_train']['stds'], axis=1)
std = np.std(results_svhn['dropout_lin_pen']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[1][0], color=color_3, label='DropoutLinPen')
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)

y = np.mean(results_svhn['ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_svhn['ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][0], color=color_0, label='IB')
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_svhn['lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_svhn['lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][0], color=color_1, label='LinPen')
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_svhn['wide_ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_svhn['wide_ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][0], color=color_2, label='WideIB')
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_svhn['narrow_ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_svhn['narrow_ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][0], color=color_4, label='WideIB')
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_svhn['dropout_lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_svhn['dropout_lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[2][0], color=color_3, label='DropoutLinPen')
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)



y = np.mean(results_10['ib']['binarity_train']['score'], axis=1)
std = np.std(results_10['ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][1], color=color_0, label='IB')
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_10['lin_pen']['binarity_train']['score'], axis=1)
std = np.std(results_10['lin_pen']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][1], color=color_1, label='LinPen')
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_10['wide_ib']['binarity_train']['score'], axis=1)
std = np.std(results_10['wide_ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][1], color=color_2, label='WideIB')
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_10['narrow_ib']['binarity_train']['score'], axis=1)
std = np.std(results_10['narrow_ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][1], color=color_4, label='NarrowIB')
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_10['dropout_lin_pen']['binarity_train']['score'], axis=1)
std = np.std(results_10['dropout_lin_pen']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[0][1], color=color_3, label='DropoutLinPen')
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)

y = np.mean(results_10['ib']['binarity_train']['stds'], axis=1)
std = np.std(results_10['ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][1], color=color_0)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_10['lin_pen']['binarity_train']['stds'], axis=1)
std = np.std(results_10['lin_pen']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][1], color=color_1)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_10['wide_ib']['binarity_train']['stds'], axis=1)
std = np.std(results_10['wide_ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][1], color=color_2, label='WideIB')
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_10['narrow_ib']['binarity_train']['stds'], axis=1)
std = np.std(results_10['narrow_ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][1], color=color_4, label='NarrowIB')
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_10['dropout_lin_pen']['binarity_train']['stds'], axis=1)
std = np.std(results_10['dropout_lin_pen']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[1][1], color=color_3, label='DropoutLinPen')
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)

y = np.mean(results_10['ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_10['ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][1], color=color_0, label="IB")
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_10['lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_10['lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][1], color=color_1, label="LinPen")
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_10['wide_ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_10['wide_ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][1], color=color_2, label='WideIB')
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_10['narrow_ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_10['narrow_ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][1], color=color_4, label='NarrowIB')
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_10['dropout_lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_10['dropout_lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[2][1], color=color_3, label='DropoutLinPen')
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)



y = np.mean(results_100['ib']['binarity_train']['score'], axis=1)
std = np.std(results_100['ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][2], color=color_0, label='IB')
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_100['lin_pen']['binarity_train']['score'], axis=1)
std = np.std(results_100['lin_pen']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][2], color=color_1, label='LinPen')
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_100['wide_ib']['binarity_train']['score'], axis=1)
std = np.std(results_100['wide_ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][2], color=color_2, label='WideIB')
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_100['narrow_ib']['binarity_train']['score'], axis=1)
std = np.std(results_100['narrow_ib']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[0][2], color=color_4, label='NarrowIB')
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_100['dropout_lin_pen']['binarity_train']['score'], axis=1)
std = np.std(results_100['dropout_lin_pen']['binarity_train']['score'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[0][2], color=color_3, label='DropoutLinPen')
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)


y = np.mean(results_100['ib']['binarity_train']['stds'], axis=1)
std = np.std(results_100['ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][2], color=color_0)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_100['lin_pen']['binarity_train']['stds'], axis=1)
std = np.std(results_100['lin_pen']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][2], color=color_1)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_100['wide_ib']['binarity_train']['stds'], axis=1)
std = np.std(results_100['wide_ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][2], color=color_2, label='WideIB')
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_100['narrow_ib']['binarity_train']['stds'], axis=1)
std = np.std(results_100['narrow_ib']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[1][2], color=color_4, label='NarrowIB')
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_100['dropout_lin_pen']['binarity_train']['stds'], axis=1)
std = np.std(results_100['dropout_lin_pen']['binarity_train']['stds'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[1][2], color=color_3, label='DropoutLinPen')
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)

y = np.mean(results_100['ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_100['ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][2], color=color_0)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_0)

y = np.mean(results_100['lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_100['lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][2], color=color_1, )
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_1)

y = np.mean(results_100['wide_ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_100['wide_ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][2], color=color_2, label='WideIB')
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_2)

y = np.mean(results_100['narrow_ib']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_100['narrow_ib']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o', ax=axes[2][2], color=color_4, label='NarrowIB')
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_4)

y = np.mean(results_100['dropout_lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
std = np.std(results_100['dropout_lin_pen']['binarity_train']['peaks_distance_mean'], axis=1)
sns.lineplot(x=x, y=y, marker='o',  linestyle='--', ax=axes[2][2], color=color_3, label='DropoutLinPen')
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha, color=color_3)



axes[0][0].set_title('SVHN', size=label_size)
axes[0][1].set_title('CIFAR 10', size=label_size)
axes[0][2].set_title('CIFAR 100', size=label_size)

axes[2][0].set_xlabel('Epochs', size=label_size)
axes[2][1].set_xlabel('Epochs', size=label_size)
axes[2][2].set_xlabel('Epochs', size=label_size)
axes[0][0].set_ylabel('Score', size=label_size)
axes[1][0].set_ylabel('Std', size=label_size)
axes[2][0].set_ylabel('Peaks distance', size=label_size)


axes[0][0].set_ylim(-1.5,2.5)
axes[0][1].set_ylim(-1.5,1.5)
axes[0][2].set_ylim(-1.5,1.5)
axes[1][0].set_ylim(0,1)
axes[1][1].set_ylim(0,1)
axes[1][2].set_ylim(0,1)
axes[2][0].set_ylim(0,15)
axes[2][1].set_ylim(0,60)
axes[2][2].set_ylim(0,60)

for ax in axes.flat:
    ax.get_legend().remove()
    
plt.tight_layout()

axes[2][1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.45), shadow=True, ncol=5)

plt.savefig('./score.png')

### Figure 3

In [None]:
convergence_svhn = {}
convergence_10 = {}
convergence_100 = {}

for model in models:
    convergence_svhn [model] = np.mean(results_svhn[model]['convergence_epoch']) 
    convergence_10 [model] = np.mean(results_10[model]['convergence_epoch'])    
    convergence_100 [model] = np.mean(results_100[model]['convergence_epoch'])

In [None]:
x = np.arange(logging, epochs+logging,logging)

fig, axes = plt.subplots(4, 3, figsize=(14, 10), sharex=True)



y = np.mean(results_svhn['ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[0][0],color=color_ib)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_svhn['wide_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['wide_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[0][0],color=color_ib_wide)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_svhn['narrow_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['narrow_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[0][0],color=color_ib_narrow)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_svhn['no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[0][0],color=color_no)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_svhn['dropout_no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['dropout_no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[0][0],color=color_dropout_no)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_svhn['lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[0][0],color=color_lin)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_svhn['dropout_lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['dropout_lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[0][0],color=color_dropout_lin)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_svhn['nonlin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_svhn['nonlin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[0][0],color=color_nonlin)
axes[0][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)



y = np.mean(results_10['ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[0][1],color=color_ib)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_10['wide_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['wide_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[0][1],color=color_ib_wide)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_10['narrow_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['narrow_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[0][1],color=color_ib_narrow)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_10['no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[0][1],color=color_no)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_10['dropout_no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['dropout_no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[0][1],color=color_dropout_no)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_10['lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[0][1],color=color_lin)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_10['dropout_lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['dropout_lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[0][1],color=color_dropout_lin)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_10['nonlin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_10['nonlin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[0][1],color=color_nonlin)
axes[0][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)


y = np.mean(results_100['ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[0][2],color=color_ib)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_100['wide_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['wide_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[0][2],color=color_ib_wide)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_100['narrow_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['narrow_ib']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[0][2],color=color_ib_narrow)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_100['no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[0][2],color=color_no)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_100['dropout_no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['dropout_no_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[0][2],color=color_dropout_no)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_100['lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[0][2],color=color_lin)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_100['dropout_lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['dropout_lin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[0][2],color=color_dropout_lin)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_100['nonlin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
std = np.std(results_100['nonlin_pen']['collapse_train']['within_class_variation_weighted'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[0][2],color=color_nonlin)
axes[0][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)



y = np.mean(results_svhn['ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[1][0],color=color_ib)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_svhn['wide_ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['wide_ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[1][0],color=color_ib_wide)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_svhn['narrow_ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['narrow_ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[1][0],color=color_ib_narrow)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_svhn['no_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['no_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[1][0],color=color_no)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_svhn['dropout_no_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['dropout_no_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[1][0],color=color_dropout_no)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_svhn['lin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['lin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[1][0],color=color_lin)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_svhn['dropout_lin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['dropout_lin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[1][0],color=color_dropout_lin)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_svhn['nonlin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_svhn['nonlin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[1][0],color=color_nonlin)
axes[1][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)


y = np.mean(results_10['ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[1][1],color=color_ib)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_10['wide_ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['wide_ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[1][1],color=color_ib_wide)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_10['narrow_ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['narrow_ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[1][1],color=color_ib_narrow)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_10['no_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['no_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[1][1],color=color_no)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_10['dropout_no_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['dropout_no_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[1][1],color=color_dropout_no)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_10['lin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['lin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[1][1],color=color_lin)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_10['dropout_lin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['dropout_lin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[1][1],color=color_dropout_lin)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_10['nonlin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_10['nonlin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[1][1],color=color_nonlin)
axes[1][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)


y = np.mean(results_100['ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[1][2],color=color_ib)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_100['wide_ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['wide_ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[1][2],color=color_ib_wide)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_100['narrow_ib']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['narrow_ib']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[1][2],color=color_ib_narrow)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_100['no_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['no_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[1][2],color=color_no)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_100['dropout_no_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['dropout_no_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[1][2],color=color_dropout_no)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_100['lin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['lin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[1][2],color=color_lin)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_100['dropout_lin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['dropout_lin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[1][2],color=color_dropout_lin)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_100['nonlin_pen']['collapse_train']['equiangular'], axis=1)
std = np.std(results_100['nonlin_pen']['collapse_train']['equiangular'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[1][2],color=color_nonlin)
axes[1][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)




y = np.mean(results_svhn['ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[2][0],color=color_ib)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_svhn['wide_ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['wide_ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[2][0],color=color_ib_wide)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_svhn['narrow_ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['narrow_ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[2][0],color=color_ib_narrow)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_svhn['no_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['no_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[2][0],color=color_no)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_svhn['dropout_no_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['dropout_no_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[2][0],color=color_dropout_no)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_svhn['lin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['lin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[2][0],color=color_lin)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_svhn['dropout_lin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['dropout_lin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[2][0],color=color_dropout_lin)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_svhn['nonlin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_svhn['nonlin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[2][0],color=color_nonlin)
axes[2][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)



y = np.mean(results_10['ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[2][1],color=color_ib)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_10['wide_ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['wide_ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[2][1],color=color_ib_wide)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_10['narrow_ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['narrow_ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[2][1],color=color_ib_narrow)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_10['no_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['no_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[2][1],color=color_no)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_10['dropout_no_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['dropout_no_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[2][1],color=color_dropout_no)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_10['lin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['lin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[2][1],color=color_lin)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_10['dropout_lin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['dropout_lin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[2][1],color=color_dropout_lin)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_10['nonlin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_10['nonlin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[2][1],color=color_nonlin)
axes[2][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)



y = np.mean(results_100['ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[2][2],color=color_ib)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_100['wide_ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['wide_ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[2][2],color=color_ib_wide)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_100['narrow_ib']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['narrow_ib']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[2][2],color=color_ib_narrow)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_100['no_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['no_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[2][2],color=color_no)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_100['dropout_no_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['dropout_no_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[2][2],color=color_dropout_no)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_100['lin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['lin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[2][2],color=color_lin)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_100['dropout_lin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['dropout_lin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[2][2],color=color_dropout_lin)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_100['nonlin_pen']['collapse_train']['maxangle'], axis=1)
std = np.std(results_100['nonlin_pen']['collapse_train']['maxangle'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[2][2],color=color_nonlin)
axes[2][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)


y = np.mean(results_svhn['ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[3][0],color=color_ib)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_svhn['wide_ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['wide_ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[3][0],color=color_ib_wide)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_svhn['narrow_ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['narrow_ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[3][0],color=color_ib_narrow)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_svhn['no_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['no_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[3][0],color=color_no)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_svhn['dropout_no_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['dropout_no_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[3][0],color=color_dropout_no)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_svhn['lin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['lin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[3][0],color=color_lin)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_svhn['dropout_lin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['dropout_lin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[3][0],color=color_dropout_lin)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_svhn['nonlin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_svhn['nonlin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[3][0],color=color_nonlin)
axes[3][0].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)


y = np.mean(results_10['ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[3][1],color=color_ib)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_10['wide_ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['wide_ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[3][1],color=color_ib_wide)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_10['narrow_ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['narrow_ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[3][1],color=color_ib_narrow)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_10['no_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['no_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[3][1],color=color_no)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_10['dropout_no_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['dropout_no_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[3][1],color=color_dropout_no)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_10['lin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['lin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[3][1],color=color_lin)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_10['dropout_lin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['dropout_lin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[3][1],color=color_dropout_lin)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_10['nonlin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_10['nonlin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[3][1],color=color_nonlin)
axes[3][1].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)



y = np.mean(results_100['ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IB', ax=axes[3][2],color=color_ib)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib)

y = np.mean(results_100['wide_ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['wide_ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBWide', ax=axes[3][2],color=color_ib_wide)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_wide)

y = np.mean(results_100['narrow_ib']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['narrow_ib']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='IBNarrow', ax=axes[3][2],color=color_ib_narrow)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_ib_narrow)

y = np.mean(results_100['no_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['no_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPen', ax=axes[3][2],color=color_no)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_no)

y = np.mean(results_100['dropout_no_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['dropout_no_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NoPenDropout', ax=axes[3][2],color=color_dropout_no)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_no)

y = np.mean(results_100['lin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['lin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPen', ax=axes[3][2],color=color_lin)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_lin)

y = np.mean(results_100['dropout_lin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['dropout_lin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='LinPenDropout', ax=axes[3][2],color=color_dropout_lin)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_dropout_lin)

y = np.mean(results_100['nonlin_pen']['collapse_train']['equinorm'], axis=1)
std = np.std(results_100['nonlin_pen']['collapse_train']['equinorm'], axis=1)
sns.lineplot(x=x, y=y, marker='o',label='NonlinPen', ax=axes[3][2],color=color_nonlin)
axes[3][2].fill_between(x, y - std, y + std, alpha=alpha,color=color_nonlin)



axes[0][0].set_yscale('log')
axes[0][1].set_yscale('log')
axes[0][2].set_yscale('log')

axes[0][0].set_title('SVHN', size=label_size)
axes[0][1].set_title('CIFAR 10', size=label_size)
axes[0][2].set_title('CIFAR 100', size=label_size)

axes[3][0].set_xlabel('Epochs', size=label_size)
axes[3][1].set_xlabel('Epochs', size=label_size)
axes[3][2].set_xlabel('Epochs', size=label_size)


axes[0][0].set_ylabel('NC1', size=label_size)
axes[1][0].set_ylabel('Equiangular', size=label_size)
axes[2][0].set_ylabel('Max angle', size=label_size)
axes[3][0].set_ylabel('Equinorm', size=label_size)

axes[0][0].axvline(x=convergence_svhn['lin_pen'], color=color_lin, linestyle='--',)
axes[0][0].axvline(x=convergence_svhn['nonlin_pen'], color=color_nonlin, linestyle='-.',)

axes[0][1].axvline(x=convergence_10['lin_pen'], color=color_lin, linestyle='--',)
axes[0][1].axvline(x=convergence_10['nonlin_pen'], color=color_nonlin, linestyle='-.',)

axes[0][2].axvline(x=convergence_100['lin_pen'], color=color_lin, linestyle='--',)
axes[0][2].axvline(x=convergence_100['nonlin_pen'], color=color_nonlin, linestyle='-.',)

axes[0][0].axvline(x=convergence_svhn['no_pen'], color=color_no, linestyle='--',)
axes[0][0].axvline(x=convergence_svhn['dropout_no_pen'], color=color_dropout_no, linestyle='-.',)

axes[0][1].axvline(x=convergence_10['no_pen'], color=color_no, linestyle='--',)
axes[0][1].axvline(x=convergence_10['dropout_no_pen'], color=color_dropout_no, linestyle='-.',)

axes[0][2].axvline(x=convergence_100['no_pen'], color=color_no, linestyle='--',)
axes[0][2].axvline(x=convergence_100['dropout_no_pen'], color=color_dropout_no, linestyle='-.',)

axes[0][0].axvline(x=convergence_svhn['ib'], color=color_ib, linestyle='--',)
axes[0][1].axvline(x=convergence_10['ib'], color=color_ib, linestyle='--',)
axes[0][2].axvline(x=convergence_100['ib'], color=color_ib, linestyle='--',)

axes[0][0].axvline(x=convergence_svhn['wide_ib'], color=color_ib_wide, linestyle='--',)
axes[0][1].axvline(x=convergence_10['wide_ib'], color=color_ib_wide, linestyle='--',)
axes[0][2].axvline(x=convergence_100['wide_ib'], color=color_ib_wide, linestyle='--',)

axes[0][0].axvline(x=convergence_svhn['narrow_ib'], color=color_ib_narrow, linestyle='--',)
axes[0][1].axvline(x=convergence_10['narrow_ib'], color=color_ib_narrow, linestyle='--',)
axes[0][2].axvline(x=convergence_100['narrow_ib'], color=color_ib_narrow, linestyle='--',)


for ax in axes.flat:
    ax.get_legend().remove()
    
axes[3][1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.35), shadow=True, ncol=4)

plt.savefig('./nc_metrics.png')
