In [1]:
import os
import torch
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from utils import tools, metrics, criterion
from modules.vae_base_module import VAEBaseModule
from models import supported_models
from datasets import supported_datamodules

plt.style.use('seaborn')
config_file = '../configs/vae/vae_simple_mnist.yaml'
log_path = '../logs/NoveltyMNISTDataModule/SimpleVAE/version_0'
model_path = log_path + '/checkpoints/last.ckpt'
recons_path = '../logs/NoveltyMNISTDataModule/SimpleVAE/version_0/train_recons.pt'

In [2]:
config = tools.load_config(config_file)
exp_params = config['experiment-parameters']
data_params = config['data-parameters']
module_params = config['module-parameters']

datamodule = supported_datamodules[exp_params['datamodule']](**data_params)
datamodule.setup('test')

model = supported_models[exp_params['model']](datamodule.data_shape, **module_params)

module = VAEBaseModule(model, **module_params)

checkpoint = torch.load(model_path)
module.load_state_dict(checkpoint['state_dict'])

Experimental parameters
------
{'data-parameters': {'batch_size': 128,
                     'root_data_path': '/home/sleipnir/Documents/Datasets/NoveltyMNIST',
                     'train_fraction': 0.9},
 'experiment-parameters': {'datamodule': 'NoveltyMNISTDataModule',
                           'log_dir': 'logs',
                           'model': 'SimpleVAE',
                           'patience': None},
 'module-parameters': {'latent_nodes': 10, 'learning_rate': 0.01}}


<All keys matched successfully>

In [3]:
test_novelty_scores = []
test_novelty_labels = []
score_criterion = criterion.BestLoss(recons_path)

Testing with the Best Loss criterion
torch.Size([337, 128, 1, 28, 28])


In [None]:
module.model.eval()
with torch.no_grad():
    for batch_nb, batch_tuple in enumerate(datamodule.test_dataloader()):
        result = module.test_step(batch_tuple, batch_nb, score_criterion);
        test_novelty_scores.extend(result['scores'].numpy())
        test_novelty_labels.extend(result['labels'].numpy())

In [None]:
fpr, tpr, thresholds, auc = metrics.roc(test_novelty_scores, test_novelty_labels)

plt.plot(fpr, tpr)
plt.plot([0., 1.], [0., 1.])
plt.show()
print('Model ROC AUC: ', auc)
print('Random ROC AUC: 0.5')

In [None]:
pak = metrics.precision_at_k(test_novelty_scores, test_novelty_labels)

uniques, counts = np.unique(test_novelty_labels, return_counts=True)
random = counts.min() / counts.sum()

plt.plot(pak)
plt.plot([0, 10000], [random, random])
plt.ylim([0., 1.])
plt.show()

In [None]:
print(len(pak))