In [1]:
import argparse
import numpy as np
import os
import pickle
from tqdm import tqdm
from collections import OrderedDict
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.models import wide_resnet50_2, Wide_ResNet50_2_Weights
import datasets.mvtec as mvtec


def parse_args():
    parser = argparse.ArgumentParser('SPADE')
    parser.add_argument("--top_k", type=int, default=5)
    parser.add_argument("--save_path", type=str, default="./result")
    return parser.parse_args()


def main(args):

    #args = parse_args()

    # device setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # load model
    model = wide_resnet50_2(weights= Wide_ResNet50_2_Weights.IMAGENET1K_V1, progress=True)
    model.to(device)
    model.eval()

    # set model's intermediate outputs
    outputs = []
    def hook(module, input, output):
        outputs.append(output)
    model.layer1[-1].register_forward_hook(hook)
    model.layer2[-1].register_forward_hook(hook)
    model.layer3[-1].register_forward_hook(hook)
    model.avgpool.register_forward_hook(hook)

    os.makedirs(os.path.join(args.save_path, 'temp'), exist_ok=True)

    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    fig_img_rocauc = ax[0]
    fig_pixel_rocauc = ax[1]

    total_roc_auc = []
    total_pixel_roc_auc = []

    for class_name in mvtec.CLASS_NAMES:

        train_dataset = mvtec.MVTecDataset(class_name=class_name, is_train=True)
        train_dataloader = DataLoader(train_dataset, batch_size=32, pin_memory=True)
        test_dataset = mvtec.MVTecDataset(class_name=class_name, is_train=False)
        test_dataloader = DataLoader(test_dataset, batch_size=32, pin_memory=True)

        train_outputs = OrderedDict([('layer1', []), ('layer2', []), ('layer3', []), ('avgpool', [])])
        test_outputs = OrderedDict([('layer1', []), ('layer2', []), ('layer3', []), ('avgpool', [])])

        # extract train set features
        train_feature_filepath = os.path.join(args.save_path, 'temp', 'train_%s.pkl' % class_name)
        if not os.path.exists(train_feature_filepath):
            for (x, y, mask) in tqdm(train_dataloader, '| feature extraction | train | %s |' % class_name):
                # model prediction
                with torch.no_grad():
                    pred = model(x.to(device))
                # get intermediate layer outputs
                for k, v in zip(train_outputs.keys(), outputs):
                    train_outputs[k].append(v)
                # initialize hook outputs
                outputs = []
            for k, v in train_outputs.items():
                train_outputs[k] = torch.cat(v, 0)
            # save extracted feature
            with open(train_feature_filepath, 'wb') as f:
                pickle.dump(train_outputs, f)
        else:
            print('load train set feature from: %s' % train_feature_filepath)
            with open(train_feature_filepath, 'rb') as f:
                train_outputs = pickle.load(f)
            
        gt_list = []
        gt_mask_list = []
        test_imgs = []

        # extract test set features
        for (x, y, mask) in tqdm(test_dataloader, '| feature extraction | test | %s |' % class_name):
            test_imgs.extend(x.cpu().detach().numpy())
            gt_list.extend(y.cpu().detach().numpy())
            gt_mask_list.extend(mask.cpu().detach().numpy())
            # model prediction
            with torch.no_grad():
                pred = model(x.to(device))
            # get intermediate layer outputs
            for k, v in zip(test_outputs.keys(), outputs):
                test_outputs[k].append(v)
            # initialize hook outputs
            outputs = []
        for k, v in test_outputs.items():
            test_outputs[k] = torch.cat(v, 0)

        # calculate distance matrix
        dist_matrix = calc_dist_matrix(torch.flatten(test_outputs['avgpool'], 1),
                                       torch.flatten(train_outputs['avgpool'], 1))
        

        # select K nearest neighbor and take average
        topk_values, topk_indexes = torch.topk(dist_matrix, k=args.top_k, dim=1, largest=False)
        scores = torch.mean(topk_values, 1).cpu().detach().numpy()

        # calculate image-level ROC AUC score
        fpr, tpr, _ = roc_curve(gt_list, scores)
        roc_auc = roc_auc_score(gt_list, scores)
        total_roc_auc.append(roc_auc)
        print('%s ROCAUC: %.3f' % (class_name, roc_auc))
        fig_img_rocauc.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (class_name, roc_auc))

        score_map_list = []
        for t_idx in tqdm(range(test_outputs['avgpool'].shape[0]), '| localization | test | %s |' % class_name):
            score_maps = []
            for layer_name in ['layer1', 'layer2', 'layer3']:  # for each layer

                # construct a gallery of features at all pixel locations of the K nearest neighbors
                topk_feat_map = train_outputs[layer_name][topk_indexes[t_idx]]
                test_feat_map = test_outputs[layer_name][t_idx:t_idx + 1]
                feat_gallery = topk_feat_map.transpose(3, 1).flatten(0, 2).unsqueeze(-1).unsqueeze(-1)

                # calculate distance matrix
                '''
                dist_matrix_list = []
                for d_idx in range(feat_gallery.shape[0] // 100):
                    dist_matrix = torch.pairwise_distance(feat_gallery[d_idx * 100:d_idx * 100 + 100], test_feat_map)
                    dist_matrix_list.append(dist_matrix)
                dist_matrix = torch.cat(dist_matrix_list, 0)
                '''
                score_map = (topk_feat_map - test_feat_map)
                # k nearest features from the gallery (k=1)
                #score_map = torch.min(dist_matrix, dim=0)[0]
                score_map = F.interpolate(score_map.unsqueeze(0).unsqueeze(0), size=224,
                                          mode='bilinear', align_corners=False)
                score_maps.append(score_map)
            import pdb;pdb.set_trace()
            # average distance between the features
            score_map = torch.mean(torch.cat(score_maps, 0), dim=0)

            # apply gaussian smoothing on the score map
            score_map = gaussian_filter(score_map.squeeze().cpu().detach().numpy(), sigma=4)
            score_map_list.append(score_map)

        flatten_gt_mask_list = np.concatenate(gt_mask_list).ravel()
        flatten_score_map_list = np.concatenate(score_map_list).ravel()

        # calculate per-pixel level ROCAUC
        fpr, tpr, _ = roc_curve(flatten_gt_mask_list, flatten_score_map_list)
        per_pixel_rocauc = roc_auc_score(flatten_gt_mask_list, flatten_score_map_list)
        total_pixel_roc_auc.append(per_pixel_rocauc)
        print('%s pixel ROCAUC: %.3f' % (class_name, per_pixel_rocauc))
        fig_pixel_rocauc.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (class_name, per_pixel_rocauc))

        # get optimal threshold
        precision, recall, thresholds = precision_recall_curve(flatten_gt_mask_list, flatten_score_map_list)
        a = 2 * precision * recall
        b = precision + recall
        f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
        threshold = thresholds[np.argmax(f1)]
        print(threshold)
        threshold = 0.70
        # visualize localization result
        visualize_loc_result(test_imgs, gt_mask_list, score_map_list, threshold, args.save_path, class_name, vis_num=5)

    print('Average ROCAUC: %.3f' % np.mean(total_roc_auc))
    fig_img_rocauc.title.set_text('Average image ROCAUC: %.3f' % np.mean(total_roc_auc))
    fig_img_rocauc.legend(loc="lower right")

    print('Average pixel ROCUAC: %.3f' % np.mean(total_pixel_roc_auc))
    fig_pixel_rocauc.title.set_text('Average pixel ROCAUC: %.3f' % np.mean(total_pixel_roc_auc))
    fig_pixel_rocauc.legend(loc="lower right")

    fig.tight_layout()
    fig.savefig(os.path.join(args.save_path, 'roc_curve.png'), dpi=100)


def calc_dist_matrix(x, y):
    """Calculate Euclidean distance matrix with torch.tensor"""
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    dist_matrix = torch.sqrt(torch.pow(x - y, 2).sum(2))
    return dist_matrix


def visualize_loc_result(test_imgs, gt_mask_list, score_map_list, threshold,
                         save_path, class_name, vis_num=5):

    for t_idx in range(vis_num):
        test_img = test_imgs[t_idx]
        test_img = denormalization(test_img)
        test_gt = gt_mask_list[t_idx].transpose(1, 2, 0).squeeze()
        test_pred = score_map_list[t_idx]
        #test_pred[test_pred <= threshold] = 0
        #test_pred[test_pred > threshold] = 1
        test_pred_img = test_img.copy()
        test_pred_img[test_pred == 0] = 0

        fig_img, ax_img = plt.subplots(1, 4, figsize=(12, 4))
        fig_img.subplots_adjust(left=0, right=1, bottom=0, top=1)

        for ax_i in ax_img:
            ax_i.axes.xaxis.set_visible(False)
            ax_i.axes.yaxis.set_visible(False)

        ax_img[0].imshow(test_img)
        ax_img[0].title.set_text('Image')
        ax_img[1].imshow(test_gt, cmap='gray')
        ax_img[1].title.set_text('GroundTruth')
        ax_img[2].imshow(test_pred, cmap='gray')
        ax_img[2].title.set_text('Predicted mask')
        ax_img[3].imshow(test_pred_img)
        ax_img[3].title.set_text('Predicted anomalous image')

        os.makedirs(os.path.join(save_path, 'images'), exist_ok=True)
        fig_img.savefig(os.path.join(save_path, 'images', '%s_%03d.png' % (class_name, t_idx)), dpi=100)
        fig_img.clf()
        plt.close(fig_img)


def denormalization(x):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)
    return x




In [1]:
import argparse
import numpy as np
import os
import pickle
from tqdm import tqdm
from collections import OrderedDict
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.models import wide_resnet50_2

import datasets.mvtec as mvtec

In [2]:
device='cuda'
train_dataset = mvtec.MVTecDataset(class_name='bottle', is_train=True)
train_dataloader = DataLoader(train_dataset, batch_size=32, pin_memory=True)
model = wide_resnet50_2(pretrained=True, progress=True)
model.to(device)
model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

In [None]:
class params():
    def __init__(self):
        self.save_path = 'temp'
        self.top_k     = 3
args = params()
main(args)



load train set feature from: temp/temp/train_bottle.pkl


| feature extraction | test | bottle |: 100%|█████| 3/3 [00:05<00:00,  1.95s/it]


bottle ROCAUC: 0.971


| localization | test | bottle |:   0%|                  | 0/83 [00:00<?, ?it/s]

> [0;32m/tmp/ipykernel_3233951/2068658623.py[0m(149)[0;36mmain[0;34m()[0m
[0;32m    147 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    148 [0;31m            [0;31m# average distance between the features[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 149 [0;31m            [0mscore_map[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0mscore_maps[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    150 [0;31m[0;34m[0m[0m
[0m[0;32m    151 [0;31m            [0;31m# apply gaussian smoothing on the score map[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  l


[1;32m    144 [0m                score_map = F.interpolate(score_map.unsqueeze(0).unsqueeze(0), size=224,
[1;32m    145 [0m                                          mode='bilinear', align_corners=False)
[1;32m    146 [0m                [0mscore_maps[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mscore_map[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    147 [0m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    148 [0m            [0;31m# average distance between the features[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0;32m--> 149 [0;31m            [0mscore_map[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0mscore_maps[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m    150 [0m[0;34m[0m[0m
[1;32m    151 [0m  

ipdb>  ll


[1;32m     27 [0m[0;32mdef[0m [0mmain[0m[0;34m([0m[0margs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m     28 [0m[0;34m[0m[0m
[1;32m     29 [0m    [0;31m#args = parse_args()[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[1;32m     30 [0m[0;34m[0m[0m
[1;32m     31 [0m    [0;31m# device setup[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[1;32m     32 [0m    [0mdevice[0m [0;34m=[0m [0;34m'cuda'[0m [0;32mif[0m [0mtorch[0m[0;34m.[0m[0mcuda[0m[0;34m.[0m[0mis_available[0m[0;34m([0m[0;34m)[0m [0;32melse[0m [0;34m'cpu'[0m[0;34m[0m[0;34m[0m[0m
[1;32m     33 [0m[0;34m[0m[0m
[1;32m     34 [0m    [0;31m# load model[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[1;32m     35 [0m    [0mmodel[0m [0;34m=[0m [0mwide_resnet50_2[0m[0;34m([0m[0mweights[0m[0;34m=[0m [0mWide_ResNet50_2_Weights[0m[0;34m.[0m[0mIMAGENET1K_V1[0m[0;34m,[0m [0mprogress[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;

ipdb>  test_feat_map.shape


torch.Size([1, 1024, 14, 14])


ipdb>  topk_feat_map.shape


torch.Size([3, 1024, 14, 14])


ipdb>  dist_matrix.shape


torch.Size([500, 1024, 14])


ipdb>  dmt = torch.pairwise_distance(topk_feat_map.mean(axis=0), test_feat_map)
ipdb>  dmt.shape


torch.Size([1, 1024, 14])


ipdb>  feat_gallery.shape


torch.Size([588, 1024, 1, 1])


In [None]:
!ls /home/irfan/Desktop/data/

In [None]:
plt.imshow(topk_feat_map[0,0].to('cpu'));plt.show()