In [1]:
import os
import time
import datetime
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import MVTecDataset, VisADataset, MVTEC_CLASS_NAMES, VISA_CLASS_NAMES
from models.extractors import build_extractor


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import default as c
c.dataset = 'visa'
if c.dataset == 'mvtec':
    c.data_path = '/mnt/disk2/zhouyixuan/datasets/anomaly-detection/image/MVTec'
    c.class_names = MVTEC_CLASS_NAMES
    
if c.dataset == 'visa':
    c.data_path = '/mnt/disk2/zhouyixuan/datasets/anomaly-detection/image/VisA_pytorch/1cls/'
    c.class_names = VISA_CLASS_NAMES

In [3]:
def extract(c):
    Dataset = MVTecDataset if c.dataset == 'mvtec' else VisADataset
    train_dataset = Dataset(c, is_train=True)
    test_dataset  = Dataset(c, is_train=False)

    extractor, output_channels = build_extractor(c)
    extractor = extractor.to(c.device).eval()

    for dataset in [train_dataset, test_dataset]:
        for idx, (image, y, mask, cls) in enumerate(dataset):
            image = image.unsqueeze(0).to(c.device)
            feature_list = extractor(image)
            image_path = dataset.x[idx]
            if c.dataset == 'mvtec':
                feature_path = image_path.replace('png', 'npy').replace('MVTec', 'MVTec_features/{}/{}'.format(c.extractor, c.input_size[0]))
            if c.dataset == 'visa':
                feature_path = image_path.replace('JPG', 'npy').replace('1cls', '1cls_features/{}/{}'.format(c.extractor, c.input_size[0]))
            feature_dir = os.path.dirname(feature_path)
            if not os.path.exists(feature_dir):
                os.makedirs(feature_dir)
            feature_list = {idx: feature.squeeze(0).detach().cpu().numpy() for idx, feature in enumerate(feature_list)}
            np.save(feature_path, feature_list)

In [4]:
for class_name in c.class_names:
    for input_size in [256, 512]:
        c.class_name = class_name
        c.input_size = (input_size, input_size)
        print(c.class_name, c.input_size)
        extract(c)

bottle (256, 256)
Channels of extracted features: [256, 512, 1024]
bottle (512, 512)
Channels of extracted features: [256, 512, 1024]
cable (256, 256)
Channels of extracted features: [256, 512, 1024]
cable (512, 512)
Channels of extracted features: [256, 512, 1024]
capsule (256, 256)
Channels of extracted features: [256, 512, 1024]
capsule (512, 512)
Channels of extracted features: [256, 512, 1024]
carpet (256, 256)
Channels of extracted features: [256, 512, 1024]
carpet (512, 512)
Channels of extracted features: [256, 512, 1024]
grid (256, 256)
Channels of extracted features: [256, 512, 1024]
grid (512, 512)
Channels of extracted features: [256, 512, 1024]
hazelnut (256, 256)
Channels of extracted features: [256, 512, 1024]
hazelnut (512, 512)
Channels of extracted features: [256, 512, 1024]
leather (256, 256)
Channels of extracted features: [256, 512, 1024]
leather (512, 512)
Channels of extracted features: [256, 512, 1024]
metal_nut (256, 256)
Channels of extracted features: [256, 5

In [6]:
# MVTec
image_root = '/mnt/disk2/zhouyixuan/datasets/anomaly-detection/image/MVTec'
feature_root = '/mnt/disk2/zhouyixuan/datasets/anomaly-detection/image/MVTec_features/wide_resnet50_2/512'

for sub_dir in os.listdir(image_root):
    image_dir = os.path.join(image_root, sub_dir, 'ground_truth')
    if os.path.isdir(image_dir):
        feature_dir = os.path.join(feature_root, sub_dir, 'ground_truth')
        if not os.path.exists(feature_dir):
            os.symlink(image_dir, feature_dir)