In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/data/tim/heronWorkspace/src")
sys.path.append("/data/tim/heronWorkspace/0_preProcessing")
sys.path.append("/data/tim/heronWorkspace/1_AE")
sys.path.append("/data/tim/heronWorkspace/2_postProcessing")
sys.path.append("/data/tim/heronWorkspace/")


from AEHeronModel import CAEHeron
from lightning.pytorch.callbacks import ModelSummary
from torchsummary import summary
import HeronImageLoader
from torch.utils.data import DataLoader, BatchSampler
from matplotlib import pyplot as plt
import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner
import pandas as pd
from lightning.pytorch.loggers import CSVLogger
from models import MLPBasic, CAEBigBottleneck, CAESmallBottleneckWithLinear, MLPBasicHeatMap, CAEV1
import numpy as np
import torch.nn.functional as F
import torch
from torchvision.transforms import GaussianBlur
from PIL import Image, ImageFilter
import random
from scipy.stats import loguniform
from ClassifierDatasets import DatasetThreeConsecutive, UnNormalize
# from torchmetrics.image import StructuralSimilarityIndexMeasure
from skimage.metrics import structural_similarity as ssim
import seaborn as sns
from sklearn.model_selection import ParameterSampler
from scipy.stats import loguniform
import functorch
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, roc_curve, precision_recall_curve
from PostProcessingHelper import MinFilter, PostProcess
import glob
from pathlib import Path
import os


colors = [
    "#32829C",
    "#E38538",
    "#51AC8C",
    "#D94841",
    "#7A5C96"
]
sns.set_palette(sns.color_palette(colors))
sns.set_style("whitegrid")

# Global vs Camera Specific Model

In [7]:
# wanted checkpoints
path = '/data/tim/heronWorkspace/logs/BasicCAE1/' # use your path
ckptList = []
for i in range(2, 6):
    ckptList.append(glob.glob(os.path.join(path, f"version_{i}", "checkpoints", "*.ckpt"))[-1])


loaderParams = dict(
    lblValidationMode = "Manual",
    balanced = True,
    anomalyObviousness = "all",
    distinctCAETraining = False,
    colorMode = "RGB",
    random_state = 1,
    set = "all"
)
startState = dict(
    cameras = ["SBU4"],#["NEN1", "SBU3"], #["GBU1", "GBU4", "KBU2", "PSU1", "PSU2", "PSU3", "SBU3", "SGN1"],
    balanced = True,
    distinctCAETraining = False,
    gaussianFilterSize = 5,
    gaussianFilterSigma = 5,
    filter = "MinFilter", #["MinFilter", "GaussianFilter"]
    minFilterKernelSize = 4,
    zeroThreshold = 0.2, # uniform dist on loc, loc+scale -> uniform(loc, scale) #threshold for zeroing out the image
    # sumThreshold = 50.9,
    lossFn = "L1"
)

checkPointGlobal = "/data/tim/heronWorkspace/logs/BasicCAE1/version_3/checkpoints/epoch=14-step=40770.ckpt"
checkPointSBU3 = "/data/tim/heronWorkspace/logs/BasicCAE1/version_12/checkpoints/epoch=24-step=4825.ckpt"

 # uniform dist on loc, loc+scale -> uniform(loc, scale) #threshold for zeroing out the image


sumValsListAll = []
lblValsListAll = []

for ckpt in [checkPointGlobal, checkPointSBU3]:
    sumVals, lblVals = PostProcess.computeSum(params=startState, loaderParams=loaderParams, checkPoint = ckpt)
    sumValsListAll.append(sumVals)
    lblValsListAll.append(lblVals)

loaderParams["anomalyObviousness"] = "obvious"
sumValsListObv = []
lblValsListObv= []
for ckpt in [checkPointGlobal, checkPointSBU3]:
    sumVals, lblVals = PostProcess.computeSum(params=startState, loaderParams=loaderParams, checkPoint = ckpt)
    sumValsListObv.append(sumVals)
    lblValsListObv.append(lblVals)
    
print("successfully computed sumVals and lblVals")


Length of dataset: 740
{'cameras': ['SBU4'], 'balanced': True, 'distinctCAETraining': False, 'gaussianFilterSize': 5, 'gaussianFilterSigma': 5, 'filter': 'MinFilter', 'minFilterKernelSize': 4, 'zeroThreshold': 0.2, 'lossFn': 'L1'}
Length of dataset: 740
{'cameras': ['SBU4'], 'balanced': True, 'distinctCAETraining': False, 'gaussianFilterSize': 5, 'gaussianFilterSigma': 5, 'filter': 'MinFilter', 'minFilterKernelSize': 4, 'zeroThreshold': 0.2, 'lossFn': 'L1'}


In [None]:
fig, ax = plt.subplots(1, 2, figsize = (16,11)) 
for i, (sumVals, lblVals) in enumerate(zip(sumValsListAll, lblValsListAll)):
    fpr, tpr, thresholds = roc_curve(lblVals, sumVals)
    roc_auc = roc_auc_score(lblVals, sumVals)
    if i == 0:
        ax[0].plot(fpr, tpr, label=f'Global Model, area = {roc_auc:0.2f}')
    else:
        ax[0].plot(fpr, tpr, label=f'SBU3 Model, area = {roc_auc:0.2f}')

    ax[0].plot([0, 1], [0, 1],'r--') 
    ax[0].set_xlabel('False Positive Rate') 
    ax[0].set_ylabel('True Positive Rate') 
    ax[0].set_title('ROC: All Anomalies') 
    ax[0].legend(loc="lower right")



for i, (sumVals, lblVals) in enumerate(zip(sumValsListObv, lblValsListObv)):
    fpr, tpr, thresholds = roc_curve(lblVals, sumVals)
    roc_auc = roc_auc_score(lblVals, sumVals)
    if i == 0:
        ax[0].plot(fpr, tpr, label=f'Global Model, area = {roc_auc:0.2f}')
    else:
        ax[0].plot(fpr, tpr, label=f'SBU3 Model, area = {roc_auc:0.2f}')

    ax[1].plot([0, 1], [0, 1],'r--') 
    ax[1].set_xlabel('False Positive Rate') 
    ax[1].set_ylabel('True Positive Rate') 
    ax[1].set_title('ROC: Obvious Anomalies') 
    ax[1].legend(loc="lower right")

plt.legend(loc="lower right") 
plt.show()