In [1]:
import sys
sys.path.append('../..')  # Expose top level program access

import time
import pickle
import logging
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from functools import reduce
from pathlib import Path

from modules.cae_base_module import CAEBaseModule
from modules.vae_base_module import VAEBaseModule
from modules.aae_base_module import AAEBaseModule
from modules.pca_base_module import PCABaseModule
from models import supported_models
from datasets.curiosity import CuriosityDataset
from utils import tools, metrics, supported_preprocessing_transforms
from utils.dtypes import *

logging.getLogger("matplotlib").setLevel(logging.WARNING)  # Suppress verbose font warnings from matplotlib

In [14]:
# Import configurations and paths to logged models
root = Path.cwd().parents[1]
log_path = root / 'logs' / 'CuriosityDataModule'
paths_to_archived_models = list(Path(log_path).glob('**/*Baseline*/archive*'))

print('Found archived models:\n------')
print('\n'.join([f'{p.parent.name}/{p.name}' for p in paths_to_archived_models]))

Found archived models:
------
BaselineVAE/archive_v1_2021-04-23
BaselineVAE/archive-lowperf_v2_2021-04-23
BaselineCAE/archive_v2_2021-04-21
BaselineCAE/archive_v1_2021-04-12
BaselineCAE/archive-redundant_v4_2021-04-22
BaselineCAE/archive-redundant_v3_2021-04-22
BaselineAAE/archive-lowperf_v1_2021-05-07
BaselineAAE/archive_v2_2021-05-08


In [16]:
model = supported_models['BaselineAAE'](in_shape=(6, 64, 64), latent_nodes=10)
module = AAEBaseModule(model, **{'learning_rate': 0.007, 'weight_decay_coefficient': 0.001})

# Load the state_dict into the module architecture
checkpoint = torch.load('/home/brahste/Projects/novelty-detection/logs/CuriosityDataModule/BaselineAAE/archive_v2_2021-05-08/checkpoints/val_r_loss=0.58-epoch=29.ckpt')
module.load_state_dict(checkpoint['state_dict'])

preprocessing_transforms = supported_preprocessing_transforms['CuriosityPreprocessing']

novelty_classes = [f'test_novel/{c}/' for c in ('all', 'bedrock', 'broken-rock', 'drill-hole', 'drt', 'dump-pile', 'float', 'meteorite', 'veins')]

test_sets = {}

for nov_class in novelty_classes:
    test_sets[nov_class] = CuriosityDataset(
        '/home/brahste/Datasets/MartianCuriosity',
        train=False,
        data_transforms=preprocessing_transforms,
        novel_class_specifier=nov_class
    )

In [18]:
module.encoder.eval()
module.decoder.eval()

with torch.no_grad():
    for nov_class in novelty_classes:
        test_novelty_scores = []
        test_novelty_labels = []
        for batch_nb, (image, label) in enumerate(test_sets[nov_class]):
            image = image[None]

            result = module.test_step((image, label), batch_nb)

            test_novelty_scores.append(result['scores'])
            test_novelty_labels.append(result['labels'])
#             print(f'[BATCH {batch_nb}] Mean score: {np.mean(result["scores"])}')
        fpr, tpr, thresholds, auc = metrics.roc(test_novelty_scores, test_novelty_labels)
        print(batch_nb, nov_class)
        print(f'AUC: {auc}')

855 test_novel/all/
AUC: 0.6712304836772574
436 test_novel/bedrock/
AUC: 0.5516431924882629
501 test_novel/broken-rock/
AUC: 0.7246108228317272
487 test_novel/drill-hole/
AUC: 0.6083976980160533
536 test_novel/drt/
AUC: 0.7253309647675844
518 test_novel/dump-pile/
AUC: 0.6391286788833358
443 test_novel/float/
AUC: 0.6264997391757956
459 test_novel/meteorite/
AUC: 0.5935515051090858
455 test_novel/veins/
AUC: 0.6863067292644758


0.5134976525821596
