In [None]:
import os
import sys
sys.path.append('/home/edshkim98/synapse/one_class/anomalib/')
import math
import glob
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import config as c
from model import get_cs_flow_model, save_model, FeatureExtractor, nf_forward
from utils import *
import time
import glob
from sklearn import metrics
from PIL import Image
import cv2
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Union
from kornia.filters import gaussian_blur2d
from torch import Tensor


from anomalib.pre_processing.transforms import Denormalize
from anomalib.utils.metrics import (
    AdaptiveThreshold,
    AnomalibMetricCollection,
    AnomalyScoreDistribution,
    MinMax,
)
from efficientnet_pytorch import EfficientNet
from albumentations.pytorch import ToTensorV2
import timm 
import time
from collections import OrderedDict
import shutil

In [None]:
import numpy as np
import torch
from tqdm import tqdm
import config as c
from model import FeatureExtractor
from utils import *
import os


# AUROC Curve

In [None]:
fpr = np.load('fpr.npy')
tpr = np.load('tpr.npy')

roc = metrics.auc(fpr, tpr)
plt.plot(fpr,tpr)
idx = np.where(tpr==1)[0][0]
plt.scatter(fpr[idx],tpr[idx],s=50)
plt.title('Eff B5: Coupling block 4 Warm Restart')

print("Image AUC: {} FPR: {}".format(roc, fpr[idx]))

# Load model

In [None]:
model = get_cs_flow_model()
pretrained = torch.load(os.getcwd()+"/models/smt/cs_flow_couplingx4_restart.pth")
model.load_state_dict(pretrained)

In [None]:
fe = FeatureExtractor(c)

# Load data

In [None]:
trainset = SynapseData(c, c.dataset_path+'/'+c.class_name+'/train', train=True)
testset = SynapseData(c, c.dataset_path+'/'+c.class_name+'/test', train=False, ret_path=True)
#trainset, testset = load_datasets(c.dataset_path, c.class_name)
train_loader, test_loader = make_dataloaders(trainset, testset)

In [None]:
data = next(iter(test_loader))
len(data)

In [None]:
model.to(c.device)
fe.to(c.device)
model.eval()
fe.eval();

# Calculate optimal threshold

In [None]:
def g_means(tpr, fpr):
    return math.sqrt(tpr * (1-fpr))

times = []
test_z = []
test_labels = list()
with torch.no_grad():
    for i, data in enumerate(tqdm(test_loader)):
        inputs, labels, _ = data
        inputs = inputs.cuda()
        start = time.time()
        if not c.pre_extracted:
            inputs = fe(inputs)
        z, jac = nf_forward(model, inputs)
        loss = get_loss(z, jac)
        times.append(time.time()-start)

        z_concat = t2np(concat_maps(z))
        score = np.std(z_concat ** 2, axis=(1, 2))
        test_z.append(score)
        test_labels.append(t2np(labels))
            
    print("Inf TIme: {}".format(np.mean(np.array(times))))
    
    test_labels = np.concatenate(test_labels)
    is_anomaly = np.array([0 if l == 0 else 1 for l in test_labels])

    anomaly_score = np.concatenate(test_z, axis=0)
    
    print(anomaly_score.shape, is_anomaly.shape)
    
    fpr, tpr, thresholds = metrics.roc_curve(is_anomaly, anomaly_score)

    roc = roc_auc_score(is_anomaly, anomaly_score)
    
    #Equation to calculate optimal threshold
    lst = []
    for i in range(len(tpr)):
        lst.append(g_means(tpr[i],fpr[i]))
    optimal_idx = np.argmax(np.array(lst))
    optimal_thresh = thresholds[optimal_idx]
        
print("ROC: {}".format(roc))

plt.plot(fpr,tpr)
idx = np.where(tpr==1)[0][0]
plt.scatter(fpr[idx],tpr[idx],s=50)
idx2 = optimal_idx
plt.scatter(fpr[idx2],tpr[idx2],s=50)
plt.title('EffB5')

print("Image AUC: {} FPR: {}".format(roc, fpr[idx]))

# Calculate heatmap

In [None]:
def gen_anomaly_map(z, jac, optimal_thresh):
    flow_maps: List[Tensor] = []
    flow_maps2: List[Tensor] = []  
    log_like: List[Tensor] = []   
    sigma=10
    cnt = 0
    for (hidden_variable, jacobian) in zip(z, jac):
        cnt+=1
        if (cnt == 1):
            continue
        log_prob = -torch.mean(hidden_variable**2, dim=1, keepdim=True) * 0.5
        prob = torch.exp(log_prob)
        flow_map = F.interpolate(
            input=-prob,
            size=512,
            mode="bilinear",
            align_corners=False,
        )
        flow_maps.append(flow_map)
        
    
    flow_maps = torch.stack(flow_maps, dim=-1) #torch.Size([1, 1, 256, 256, 3])
    anomaly_map = torch.mean(flow_maps, dim=-1) - optimal_thresh
    #anomaly_map = (anomaly_map - anomaly_map.min())/ (anomaly_map.max() - anomaly_map.min())
    
    z_concat = t2np(concat_maps(z))
    image_anomaly_map = torch.max(anomaly_map)#np.mean(z_concat ** 2, axis=(1, 2))

    kernel_size = 2 * int(4.0 * sigma + 0.5) + 1
    anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(sigma, sigma)).cpu()
    
    return anomaly_map.cpu(), image_anomaly_map

def eval(test_loader, model, threshold):
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    outputs = []
    cnt = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            output = {}
            inputs, labels, _ = data
            inputs = inputs.cuda()
            if not c.pre_extracted:
                inputs2 = fe(inputs)
                
            z, jac = nf_forward(model, inputs2)
            anomaly_maps, _ = gen_anomaly_map(z, jac, threshold)
            z_concat = t2np(concat_maps(z))
            anomaly_score = np.std(z_concat ** 2, axis=(1, 2))
            
            output["pred_scores"] = anomaly_score
            output["label"] = labels[0].cpu()
            
            if output['pred_scores'] > threshold:
                output['pred'] = torch.tensor(1)
            else:
                output['pred'] = torch.tensor(0)
            
            
            if output['pred'] == output['label']:
                if output['label'] == 1:
                    fname = str(cnt)
                    cnt+=1
                    inputs = inputs.cpu()
                    img = inputs[0].permute(1,2,0)
                    fig = plt.figure()
                    plt.imshow(img, cmap='gray')
                    plt.imshow(anomaly_maps[0][0], alpha=0.8)
                    plt.show()
    #                 fig.savefig('/home/edshkim98/synapse/one_class/cs-flow2/cs-flow/figures/tmfp/'+fname+'_heatmap.png', dpi=fig.dpi)
                    fig = plt.figure()
                    plt.imshow(Denormalize()(inputs[0]))
                    plt.show()
    #                 fig.savefig('/home/edshkim98/synapse/one_class/cs-flow2/cs-flow/figures/tmfp/'+fname+'.png', dpi=fig.dpi)
            else:
                continue
                
eval(test_loader, model, optimal_thresh)