In [1]:
import os
import torch
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from utils import tools, metrics, criterion
from modules.vae_base_module import VAEBaseModule
from models import supported_models
from datasets import supported_datamodules

plt.style.use('seaborn')
config_file = '../configs/vae/vae_simple_mnist.yaml'
log_path = '../logs/NoveltyMNISTDataModule/SimpleVAE/version_8'
model_path = log_path + '/checkpoints/val_elbo_loss=-0.40-epoch=18.ckpt'

In [2]:
config = tools.load_config(config_file)
exp_params = config['experiment-parameters']
data_params = config['data-parameters']
module_params = config['module-parameters']

datamodule = supported_datamodules[exp_params['datamodule']](**data_params)
datamodule.setup('test')

model = supported_models[exp_params['model']](datamodule.data_shape, **module_params)

module = VAEBaseModule(model, **module_params)

checkpoint = torch.load(model_path)
module.load_state_dict(checkpoint['state_dict'])

Experimental parameters
------
{'data-parameters': {'batch_size': 128,
                     'root_data_path': '/home/fenrir/Documents/Datasets/NoveltyMNIST',
                     'train_fraction': 0.9},
 'experiment-parameters': {'datamodule': 'NoveltyMNISTDataModule',
                           'log_dir': 'logs',
                           'model': 'SimpleVAE',
                           'patience': None},
 'module-parameters': {'latent_nodes': 10, 'learning_rate': 0.01}}


<All keys matched successfully>

In [5]:
test_novelty_scores = []
test_novelty_labels = []
score_criterion = criterion.MixedLoss()

module.model.eval()
with torch.no_grad():
    for batch_nb, batch_tuple in enumerate(datamodule.test_dataloader()):
        result = module.test_step(batch_tuple, batch_nb, score_criterion);
        test_novelty_scores.extend(result['scores'].numpy())
        test_novelty_labels.extend(result['labels'].numpy())

Testing with Mixed Loss criterion
Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-6.6804e-03, -5.9665e-03, -5.9994e-03,  ..., -2.1997e-04,
            4.6569e-04, -2.4187e-03],
          [-5.4450e-03, -5.0222e-03, -4.8801e-03,  ...,  2.4731e-03,
            1.9311e-03, -2.5088e-03],
          [-5.3076e-03, -4.6290e-03, -4.4439e-03,  ...,  3.3873e-03,
            2.3496e-03, -2.3721e-03],
          ...,
          [-1.5365e-03,  1.4714e-03,  2.4027e-03,  ..., -1.7450e-03,
           -1.3154e-03, -4.1069e-03],
          [-1.9834e-03,  2.5540e-04,  1.2502e-03,  ..., -2.1978e-03,
           -3.6332e-03, -9.0692e-03],
          [-5.3176e-03, -4.7670e-03, -4.3755e-03,  ..., -6.0276e-03,
           -1.6184e-02, -1.6611e-02]]],


        [[[-6.8606e-03, -5.6877e-03, -5.6103e-03,  ..., -3.0153e-05,
            2.4354e-04, -2.6921e-03],
          [-5.6439e-03, -4.6693e-03, -4.9201e-03,  ...,  1.0338e-03,
            4.3637e-04, -3.3397e-03],
          [-5.

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.1292e-03, -6.4617e-03, -6.6528e-03,  ..., -3.1295e-04,
           -3.8598e-04, -3.3755e-03],
          [-5.9141e-03, -5.0348e-03, -5.3601e-03,  ...,  2.3336e-03,
            1.2182e-03, -3.1626e-03],
          [-5.6568e-03, -4.2480e-03, -4.4643e-03,  ...,  3.6412e-03,
            2.1731e-03, -2.5991e-03],
          ...,
          [-2.2965e-03,  3.8239e-05, -1.4673e-04,  ..., -9.8362e-05,
           -1.8390e-04, -3.5479e-03],
          [-2.5851e-03, -8.8002e-04, -7.4733e-04,  ..., -1.4606e-03,
           -3.7839e-03, -1.1006e-02],
          [-5.5662e-03, -5.2380e-03, -5.1758e-03,  ..., -5.8547e-03,
           -1.7141e-02, -1.9667e-02]]],


        [[[-7.1726e-03, -6.4318e-03, -5.2945e-03,  ..., -3.5218e-03,
           -2.1650e-03, -3.7474e-03],
          [-6.1037e-03, -5.3254e-03, -3.1508e-03,  ..., -1.6519e-03,
           -1.3001e-03, -4.0953e-03],
          [-5.8601e-03, -4.9223e-03, -3.4826e-03

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.8265e-03, -6.7539e-03, -6.7002e-03,  ..., -3.6256e-03,
           -3.0822e-03, -4.7970e-03],
          [-7.1967e-03, -6.0342e-03, -6.0344e-03,  ..., -3.5414e-03,
           -4.5446e-03, -8.5546e-03],
          [-7.2245e-03, -5.5161e-03, -4.8683e-03,  ..., -3.0267e-02,
           -8.5863e-03, -1.1124e-02],
          ...,
          [-3.0645e-03, -9.8685e-04, -2.3350e-03,  ..., -1.3866e-03,
           -1.0140e-03, -3.9025e-03],
          [-3.2503e-03, -1.8767e-03, -2.4917e-03,  ..., -2.0982e-03,
           -3.6999e-03, -9.6147e-03],
          [-5.9770e-03, -5.9396e-03, -6.1635e-03,  ..., -6.0296e-03,
           -1.6790e-02, -1.7644e-02]]],


        [[[-6.9214e-03, -5.7635e-03, -5.8127e-03,  ..., -5.6992e-04,
           -2.4237e-04, -2.9510e-03],
          [-5.6920e-03, -4.6270e-03, -5.2492e-03,  ..., -1.0415e-04,
           -7.2911e-04, -4.0798e-03],
          [-5.4800e-03, -4.1280e-03, -4.8046e-03

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.6532e-03, -6.7887e-03, -6.2763e-03,  ..., -1.6595e-03,
           -1.5823e-03, -4.0571e-03],
          [-6.7918e-03, -5.7892e-03, -5.1949e-03,  ..., -1.2157e-03,
           -1.9855e-03, -5.0094e-03],
          [-6.6405e-03, -5.3914e-03, -5.1851e-03,  ..., -4.2410e-03,
           -4.0672e-03, -5.6779e-03],
          ...,
          [-3.7066e-03,  3.3097e-03,  8.5658e-03,  ..., -1.3175e-03,
           -8.3355e-04, -3.6715e-03],
          [-3.4602e-03, -2.3203e-03, -2.7681e-03,  ..., -2.1550e-03,
           -3.6216e-03, -9.4628e-03],
          [-5.9410e-03, -5.8871e-03, -6.0399e-03,  ..., -6.0777e-03,
           -1.7115e-02, -1.7715e-02]]],


        [[[-6.5507e-03, -5.1020e-03, -4.9227e-03,  ..., -5.1923e-04,
            3.0766e-04, -2.3759e-03],
          [-5.6461e-03, -4.8357e-03, -5.0151e-03,  ...,  1.0690e-03,
            9.7048e-04, -2.8025e-03],
          [-6.0509e-03, -5.6275e-03, -6.0059e-03

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.0960e-03, -6.0628e-03, -6.0293e-03,  ..., -1.4947e-03,
           -4.0839e-04, -2.7625e-03],
          [-6.1223e-03, -5.1900e-03, -5.0749e-03,  ...,  8.0336e-04,
            8.0972e-04, -2.9374e-03],
          [-6.1014e-03, -4.8159e-03, -4.7053e-03,  ...,  1.6746e-03,
            1.1141e-03, -2.8516e-03],
          ...,
          [-1.5825e-03,  1.9003e-03,  1.1228e-03,  ..., -1.5707e-03,
           -9.5036e-04, -3.7150e-03],
          [-2.4001e-03, -4.9856e-04, -1.0209e-03,  ..., -2.2203e-03,
           -3.5646e-03, -9.1633e-03],
          [-5.7225e-03, -5.6206e-03, -5.8309e-03,  ..., -6.1072e-03,
           -1.7493e-02, -1.7503e-02]]],


        [[[-6.3842e-03, -4.8517e-03, -4.6355e-03,  ..., -8.0420e-04,
            1.2052e-05, -2.6029e-03],
          [-5.1219e-03, -3.8463e-03, -3.8095e-03,  ...,  2.1627e-03,
            1.7184e-03, -2.5081e-03],
          [-5.2102e-03, -3.9337e-03, -3.8628e-03

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.6568e-03, -5.4317e-03, -4.8136e-03,  ..., -2.2647e-03,
           -2.0634e-03, -4.2532e-03],
          [-7.0965e-03, -5.0222e-03, -4.8055e-03,  ..., -2.5641e-03,
           -3.0392e-03, -5.5314e-03],
          [-7.0844e-03, -5.0608e-03, -4.5561e-03,  ..., -1.4499e-02,
           -5.2735e-03, -6.2671e-03],
          ...,
          [-2.0379e-03,  1.1583e-03,  1.6581e-03,  ..., -6.7778e-04,
           -4.2963e-04, -3.5353e-03],
          [-2.4550e-03, -2.2406e-04,  3.8745e-04,  ..., -1.7178e-03,
           -3.3233e-03, -9.0941e-03],
          [-5.5860e-03, -5.0772e-03, -4.8588e-03,  ..., -5.9087e-03,
           -1.5810e-02, -1.6748e-02]]],


        [[[-6.7837e-03, -5.2524e-03, -4.7063e-03,  ..., -6.6190e-05,
            3.3144e-04, -2.5613e-03],
          [-5.5668e-03, -4.3298e-03, -3.9257e-03,  ...,  2.2504e-03,
            1.6315e-03, -2.5414e-03],
          [-5.3307e-03, -4.0720e-03, -3.9704e-03

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.5791e-03, -6.8745e-03, -6.4455e-03,  ..., -2.1750e-03,
           -1.9385e-03, -4.2389e-03],
          [-6.9201e-03, -6.2742e-03, -5.9007e-03,  ..., -2.0144e-03,
           -2.6107e-03, -5.4391e-03],
          [-7.0959e-03, -6.4068e-03, -5.8383e-03,  ..., -4.5220e-03,
           -4.5082e-03, -6.1848e-03],
          ...,
          [-2.1639e-03,  2.0036e-04,  7.0902e-03,  ..., -1.3660e-03,
           -1.4666e-03, -4.5094e-03],
          [-2.5152e-03, -7.0611e-04, -3.3158e-04,  ..., -2.1259e-03,
           -5.0859e-03, -1.3830e-02],
          [-5.5937e-03, -5.2674e-03, -5.1718e-03,  ..., -6.0726e-03,
           -2.0489e-02, -2.4147e-02]]],


        [[[-6.8487e-03, -6.0411e-03, -5.7435e-03,  ..., -2.5096e-03,
           -1.6327e-03, -3.6093e-03],
          [-5.6829e-03, -5.1618e-03, -4.8180e-03,  ..., -1.0655e-03,
           -9.8115e-04, -3.9358e-03],
          [-5.5761e-03, -4.9643e-03, -4.7641e-03

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.4404e-03, -6.9277e-03, -6.8532e-03,  ..., -4.7945e-04,
           -6.3349e-04, -3.6010e-03],
          [-6.2334e-03, -5.4584e-03, -5.2736e-03,  ...,  2.2003e-03,
            9.7975e-04, -3.3598e-03],
          [-5.7368e-03, -4.4466e-03, -4.5135e-03,  ...,  3.2693e-03,
            1.7874e-03, -2.8503e-03],
          ...,
          [-2.2185e-03,  6.3451e-04,  1.7040e-03,  ..., -5.7808e-04,
           -3.6963e-04, -3.5142e-03],
          [-2.4331e-03, -2.0315e-04,  8.3126e-04,  ..., -1.6670e-03,
           -3.3314e-03, -9.2067e-03],
          [-5.4684e-03, -4.8274e-03, -4.4116e-03,  ..., -5.8927e-03,
           -1.5803e-02, -1.6893e-02]]],


        [[[-7.8209e-03, -6.4537e-03, -5.8905e-03,  ..., -1.3521e-03,
           -1.4596e-03, -4.0674e-03],
          [-7.1553e-03, -5.9274e-03, -5.5538e-03,  ..., -1.5841e-03,
           -2.5903e-03, -5.5756e-03],
          [-6.9861e-03, -5.7608e-03, -5.3800e-03

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-0.0066, -0.0054, -0.0057,  ..., -0.0002,  0.0005, -0.0024],
          [-0.0055, -0.0042, -0.0056,  ...,  0.0020,  0.0016, -0.0027],
          [-0.0058, -0.0040, -0.0052,  ...,  0.0027,  0.0019, -0.0026],
          ...,
          [-0.0014,  0.0013,  0.0006,  ..., -0.0015, -0.0012, -0.0042],
          [-0.0021, -0.0004, -0.0006,  ..., -0.0021, -0.0042, -0.0110],
          [-0.0054, -0.0053, -0.0054,  ..., -0.0061, -0.0182, -0.0199]]],


        [[[-0.0074, -0.0061, -0.0055,  ..., -0.0020, -0.0018, -0.0041],
          [-0.0065, -0.0048, -0.0042,  ..., -0.0023, -0.0030, -0.0057],
          [-0.0063, -0.0040, -0.0030,  ..., -0.0247, -0.0056, -0.0066],
          ...,
          [-0.0036,  0.0023, -0.0065,  ..., -0.0007, -0.0007, -0.0039],
          [-0.0034, -0.0028, -0.0035,  ..., -0.0019, -0.0043, -0.0121],
          [-0.0059, -0.0061, -0.0067,  ..., -0.0060, -0.0190, -0.0217]]],


        [[[-0.0067, -

Testing with Reconstruction Probability criterion
Testing with KLD criterion
tensor([[[[-7.5380e-03, -7.1880e-03, -6.3253e-03,  ..., -3.2648e-03,
           -2.6961e-03, -4.6118e-03],
          [-6.5328e-03, -5.9776e-03, -4.2133e-03,  ..., -1.2245e-03,
           -1.5643e-03, -4.6086e-03],
          [-6.2452e-03, -5.4738e-03, -4.3914e-03,  ..., -9.6944e-04,
           -1.3022e-03, -4.2403e-03],
          ...,
          [-3.0063e-03, -4.2936e-04,  6.8444e-04,  ..., -1.0220e-03,
           -7.3537e-04, -3.7378e-03],
          [-2.8972e-03, -8.2545e-04,  1.8904e-04,  ..., -1.8966e-03,
           -3.5586e-03, -9.5058e-03],
          [-5.6645e-03, -5.0893e-03, -4.7054e-03,  ..., -5.9636e-03,
           -1.6402e-02, -1.7419e-02]]],


        [[[-7.2754e-03, -6.7052e-03, -6.7141e-03,  ..., -1.7971e-03,
           -1.1966e-03, -3.5806e-03],
          [-6.3698e-03, -5.7138e-03, -5.5097e-03,  ...,  7.4666e-04,
            3.4369e-04, -3.4743e-03],
          [-6.4436e-03, -5.5527e-03, -5.5619e-03

In [6]:
fpr, tpr, thresholds, auc = metrics.roc(test_novelty_scores, test_novelty_labels)

plt.plot(fpr, tpr)
plt.plot([0., 1.], [0., 1.])
plt.show()
print('Model ROC AUC: ', auc)
print('Random ROC AUC: 0.5')

ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

In [None]:
pak = metrics.precision_at_k(test_novelty_scores, test_novelty_labels)

uniques, counts = np.unique(test_novelty_labels, return_counts=True)
random = counts.min() / counts.sum()

plt.plot(pak)
plt.plot([0, 10000], [random, random])
plt.ylim([0., 1.])
plt.show()

In [None]:
print(len(pak))