# MNIST-C Evaluation

In [None]:
import os
import sys
import numpy
sys.path.insert(1, os.path.abspath('') + '/../../')
import common.datasets
import common.test
import common.numpy
import common.eval
from experiments.eval import misc
import torch
torch.cuda.set_device(0)
from IPython.display import display, Markdown

In [None]:
config_module = 'config.mnist'
config_training_variables = [
    'normal_training_check',
    'adversarial_training_lr005_i40_half_momentum_backtrack_check',
    'confidence_calibrated_adversarial_training_lr001_ce_f7p_i40_random_momentum_backtrack_power2_10',
    'multi_steepest_descent',
]

In [None]:
config_training_names = [
    '\\Normal',
    '\\AdvTrain',
    '\\ConfTrain',
    '\\Wong',
]
ood_names = [
    'all',
    'brightness',
    'canny_edges',
    'dotted_line',
    'fog',
    'glass_blur',
    #'identity',
    'impulse_noise',
    'motion_blur',
    'rotate',
    'scale',
    'shear',
    'shot_noise',
    'spatter',
    'stripe',
    'translate',
    'zigzag',
    #'all_after',
]

In [None]:
config, training_configs, attack_configs = misc.module(config_module, config_training_variables, [])

In [None]:
model_files, model_epochs, perturbations_files, perturbations_epochs, new_config_training_variables, new_config_training_names = misc.load(training_configs, config_training_names, attack_configs)

In [None]:
display(Markdown(misc.epoch_table(model_epochs, perturbations_files, perturbations_epochs, config_training_names, attack_configs)))

In [None]:
clean_probabilities = misc.compute_clean_probabilities(model_files, config.testloader)

In [None]:
clean_evaluations = misc.compute_clean_evaluations(config, clean_probabilities)

In [None]:
oodloaders = []
indices = list(range(9000))
#indices = list(range(100))
print('all')
oodloaders.append(torch.utils.data.DataLoader(common.datasets.MNISTCTestSet(corruptions=None, indices=indices), batch_size=config.batch_size, shuffle=False, num_workers=0))
corruptions = [
    'brightness',
    'canny_edges',
    'dotted_line',
    'fog',
    'glass_blur',
    #'identity',
    'impulse_noise',
    'motion_blur',
    'rotate',
    'scale',
    'shear',
    'shot_noise',
    'spatter',
    'stripe',
    'translate',
    'zigzag'
]
for i in range(len(corruptions)):
    print(corruptions[i])
    oodloaders.append(torch.utils.data.DataLoader(common.datasets.MNISTCTestSet(corruptions=[corruptions[i]], indices=indices), batch_size=config.batch_size, shuffle=False, num_workers=0))
assert len(oodloaders) == len(corruptions) + 1
for i in range(1, len(oodloaders)):
    assert len(oodloaders[i].dataset.corruptions) == 1
assert len(oodloaders) == len(ood_names), (len(oodloaders), len(ood_names))

In [None]:
ood_evaluations = []
all_probabilities = [None]*len(model_files)
all_probabilities_p = [None]*len(model_files)
for i in range(len(model_files)):
    model, model_epochs[i] = misc.load_model(model_files[i])
    model.eval()
    
    ood_evaluations_ = []
    for j in range(len(oodloaders)):
        oodloader = oodloaders[j]
    
        ood_probabilities = common.test.test(model, oodloader, cuda=True)
        if j > 0:
            all_probabilities[i] = common.numpy.concatenate(all_probabilities[i], numpy.expand_dims(ood_probabilities, axis=0))
        ood_probabilities = ood_probabilities.reshape(len(oodloader.dataset.corruptions), -1, ood_probabilities.shape[1])
        if j == 0:
            all_probabilities_p[i] = ood_probabilities
        ood_evaluations_.append(common.eval.CorruptedEvaluation(clean_probabilities[i], ood_probabilities, config.testset.labels))
    ood_evaluations.append(ood_evaluations_)

In [None]:
for i in range(len(model_files)):
    print(all_probabilities[i].shape, all_probabilities_p[i].shape)
    assert numpy.allclose(all_probabilities[i], all_probabilities_p[i])

In [None]:
print(len(ood_names), len(ood_evaluations))
display(Markdown(misc.corrupted_markdown_table(ood_evaluations, config_training_names, ood_names, '99')))

In [None]:
misc.corrupted_main_latex_table(config, config_training_names, ood_evaluations, '98')

In [None]:
misc.corrupted_main_latex_table(config, config_training_names, ood_evaluations, '99')

In [None]:
misc.corrupted_supp_latex_table(config, config_training_names, ood_names, ood_evaluations, '99')