In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import argparse
import random
import warnings
import numpy as np
import datasets.mvtec as mvtec
import torch
import torch.optim as optim
from cnn.efficientnet import EfficientNet as effnet
from cnn.resnet import resnet18 as resnet18
from cnn.resnet import wide_resnet50_2 as wide_resnet50_2
from cnn.vgg import vgg19_bn as vgg19_bn
from datasets.mvtec import MVTecDataset
from torch.utils.data import DataLoader
from utils.cfa import *
from utils.cfa_new import get_feature_extractor, CfaModel
from utils.metric import *
from utils.visualizer import *

warnings.filterwarnings("ignore", category=UserWarning)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


def parse_args():
    parser = argparse.ArgumentParser("CFA configuration")
    parser.add_argument("--data_path", type=str, default="../../../datasets/MVTec/")
    parser.add_argument("--save_path", type=str, default="./mvtec_result")
    parser.add_argument("--Rd", type=bool, default=False)
    parser.add_argument(
        "--cnn",
        type=str,
        choices=["resnet18", "wide_resnet50_2", "efficientnet_b5", "vgg19_bn"],
        default="wide_resnet50_2",
    )
    parser.add_argument("--resize", type=int, choices=[224, 256], default=224)
    parser.add_argument("--size", type=int, choices=[224, 256], default=224)
    parser.add_argument("--gamma_c", type=int, default=1)
    parser.add_argument("--gamma_d", type=int, default=1)

    parser.add_argument("--class_name", type=str, default="zipper")
    parser.add_argument(
        "--backbone",
        type=str,
        choices=["resnet18", "wide_resnet50_2", "efficientnet_b5", "vgg19_bn"],
        default="wide_resnet50_2",
    )
    return parser.parse_args(args=[])


seed = 1024
random.seed(seed)
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed_all(seed)

args = parse_args()
class_names = mvtec.CLASS_NAMES if args.class_name == "all" else [args.class_name]
class_name = args.class_name

train_dataset = MVTecDataset(
    dataset_path=args.data_path,
    class_name=class_name,
    resize=args.resize,
    cropsize=args.size,
    is_train=True,
    wild_ver=args.Rd,
)

test_dataset = MVTecDataset(
    dataset_path=args.data_path,
    class_name=class_name,
    resize=args.resize,
    cropsize=args.size,
    is_train=False,
    wild_ver=args.Rd,
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    pin_memory=True,
    shuffle=True,
    drop_last=True,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=4,
    pin_memory=True,
)

In [11]:
import pytest
import torch
from einops import rearrange
from torch import nn
from torch.utils.data import DataLoader
from utils.cfa import Descriptor as OldDescriptor

from anomalib.models.cfa.cnn.resnet import wide_resnet50_2
from anomalib.models.cfa.datasets.mvtec import MVTecDataset
from anomalib.models.cfa.torch_model import CfaModel
from anomalib.models.cfa.torch_model import CoordConv2d as NewCoordConv2d
from anomalib.models.cfa.torch_model import Descriptor as NewDescriptor
from anomalib.models.cfa.torch_model import get_feature_extractor
from anomalib.models.cfa.utils.cfa import DSVDD
from anomalib.models.cfa.utils.coordconv import CoordConv2d as OldCoordConv2d

"""Test Init Centroid should return the same memory bank."""
device = torch.device("cuda")
train_dataset = MVTecDataset("/home/sakcay/projects/anomalib/datasets/MVTec/", "zipper")
train_loader = DataLoader(train_dataset, 4)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Compare the feature extractor
_, (input, _, _) = next(enumerate(train_loader))
input = input.cuda()

old_feature_extractor = wide_resnet50_2(pretrained=True, progress=True).to(device)
old_feature_extractor.eval()

new_feature_extractor = get_feature_extractor("wide_resnet50_2", device=torch.device("cuda"))

old_features = old_feature_extractor(input)
new_features = new_feature_extractor(input)
new_features = [val for val in new_features.values()]

for old_feature, new_feature in zip(old_features, new_features):
    assert torch.allclose(old_feature, new_feature, atol=1e-1), "Old and new features should match."
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Compute the memory bank.
cfa_old = DSVDD(old_feature_extractor, train_loader, "wide_resnet50_2", 1, 1, device).to(device)
cfa_new = CfaModel(new_feature_extractor, train_loader, "wide_resnet50_2", 1, 1, device).to(device)
assert torch.allclose(cfa_old.C, cfa_new.memory_bank, atol=1e-1)
assert cfa_new.memory_bank.requires_grad is False
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Compute Forward-Pass.
old_loss, old_score = cfa_old(old_features)
new_loss, new_score = cfa_new(new_features)
assert (old_loss - new_loss).abs() / 1000 < 1e-1, "Old and new losses should match."
assert torch.allclose(old_score, new_score, atol=1e-1), "Old and new scores should match."

100%|██████████| 60/60 [00:07<00:00,  7.52it/s]
100%|██████████| 60/60 [00:08<00:00,  7.48it/s]


In [7]:
cfa_new.memory_bank.shape, cfa_old.C.shape

(torch.Size([1792, 4096]), torch.Size([1792, 4096]))

In [8]:
cfa_new.memory_bank.min(), cfa_new.memory_bank.max(), cfa_new.memory_bank.mean()

(tensor(-0.3419, device='cuda:0'),
 tensor(0.3341, device='cuda:0'),
 tensor(0.0016, device='cuda:0'))

In [9]:
cfa_old.C.min(), cfa_old.C.max(), cfa_old.C.mean()

(tensor(-0.7807, device='cuda:0'),
 tensor(0.9701, device='cuda:0'),
 tensor(0.0030, device='cuda:0'))

In [7]:
test_imgs = list()
gt_mask_list = list()
gt_list = list()
heatmaps = None

cfa_old.eval()
for x, y, mask in test_loader:
    test_imgs.extend(x.cpu().detach().numpy())
    gt_list.extend(y.cpu().detach().numpy())
    gt_mask_list.extend(mask.cpu().detach().numpy())

    p = old_feature_extractor(x.to(device))
    _, score = cfa_old(p)
    heatmap = score.cpu().detach()
    heatmap = torch.mean(heatmap, dim=1)
    heatmaps = torch.cat((heatmaps, heatmap), dim=0) if heatmaps != None else heatmap

heatmaps = upsample(heatmaps, size=x.size(2), mode="bilinear")
heatmaps = gaussian_smooth(heatmaps, sigma=4)

gt_mask = np.asarray(gt_mask_list)
scores = rescale(heatmaps)