In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import random
import warnings

import datasets.mvtec as mvtec
import torch
import torch.optim as optim
from cnn.efficientnet import EfficientNet as effnet
from cnn.resnet import resnet18 as resnet18
from cnn.resnet import wide_resnet50_2 as wide_resnet50_2
from cnn.vgg import vgg19_bn as vgg19_bn
from datasets.mvtec import MVTecDataset
from torch.utils.data import DataLoader
from utils.cfa import *
from utils.cfa_new import CfaModel, get_feature_extractor
from utils.metric import *
from utils.visualizer import *

warnings.filterwarnings("ignore", category=UserWarning)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


def parse_args():
    parser = argparse.ArgumentParser("CFA configuration")
    parser.add_argument("--data_path", type=str, default="../../../datasets/MVTec/")
    parser.add_argument("--save_path", type=str, default="./mvtec_result")
    parser.add_argument("--Rd", type=bool, default=False)
    parser.add_argument(
        "--cnn",
        type=str,
        choices=["resnet18", "wide_resnet50_2", "efficientnet_b5", "vgg19_bn"],
        default="wide_resnet50_2",
    )
    parser.add_argument("--resize", type=int, choices=[224, 256], default=224)
    parser.add_argument("--size", type=int, choices=[224, 256], default=224)
    parser.add_argument("--gamma_c", type=int, default=1)
    parser.add_argument("--gamma_d", type=int, default=1)

    parser.add_argument("--class_name", type=str, default="zipper")
    parser.add_argument(
        "--backbone",
        type=str,
        choices=["resnet18", "wide_resnet50_2", "efficientnet_b5", "vgg19_bn"],
        default="wide_resnet50_2",
    )
    return parser.parse_args(args=[])


seed = 1024
random.seed(seed)
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed_all(seed)

args = parse_args()
class_names = mvtec.CLASS_NAMES if args.class_name == "all" else [args.class_name]
class_name = args.class_name

train_dataset = MVTecDataset(
    dataset_path=args.data_path,
    class_name=class_name,
    resize=args.resize,
    cropsize=args.size,
    is_train=True,
    wild_ver=args.Rd,
)

test_dataset = MVTecDataset(
    dataset_path=args.data_path,
    class_name=class_name,
    resize=args.resize,
    cropsize=args.size,
    is_train=False,
    wild_ver=args.Rd,
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    pin_memory=True,
    shuffle=True,
    drop_last=True,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=4,
    pin_memory=True,
)

In [3]:
model = wide_resnet50_2(pretrained=True, progress=True)
model = model.to(device)
model.eval()

loss_fn = DSVDD(model, train_loader, args.cnn, args.gamma_c, args.gamma_d, device)
loss_fn = loss_fn.to(device)

100%|██████████| 60/60 [00:08<00:00,  6.76it/s]


In [4]:
i, (x, _, _) = next(enumerate(train_loader))
x.shape

torch.Size([4, 3, 224, 224])

In [5]:
features = model(x.to(device))
loss, _ = loss_fn(features)
print(loss)

tensor(1490.4032, device='cuda:0', grad_fn=<AddBackward0>)


In [6]:
for f in features:
    print(f.shape)

torch.Size([4, 256, 56, 56])
torch.Size([4, 512, 28, 28])
torch.Size([4, 1024, 14, 14])


In [7]:
import torch
from torch import nn
from utils.cfa import Descriptor as OldDescriptor

from anomalib.models.cfa.torch_model import Descriptor as NewDescriptor


def initialize_weights(m):
    torch.manual_seed(0)
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)


# Create a feature map from Wide-ReNet50-2
# input = [torch.rand(4, 256, 56, 56), torch.rand(4, 512, 28, 28), torch.rand(4, 1024, 14, 14)]
input = [torch.rand(4, 256, 56, 56).cuda(), torch.rand(4, 512, 28, 28).cuda(), torch.rand(4, 1024, 14, 14).cuda()]
use_cuda = True if input[0].device.type == "cuda" else False

d1 = OldDescriptor(gamma_d=1, cnn="wide_resnet50_2").to(input[0].device)
d2 = NewDescriptor(gamma_d=1, backbone="wide_resnet50_2").to(input[0].device)

d1.apply(initialize_weights)
d2.apply(initialize_weights)

o1 = d1(input)
o2 = d2(input)

torch.allclose(o1, o2)

True

In [9]:
oriented_features = torch.rand((1, 1792, 56, 56))
m1 = torch.mean(oriented_features, dim=0, keepdim=True)
m2 = oriented_features.mean(dim=0, keepdim=True)

torch.allclose(m1, m2)

True

In [10]:
x = torch.rand((4, 3, 224, 224)).cuda()
i = 0
memory_bank = torch.tensor(0, requires_grad=False)

with torch.no_grad():
    features = model(x)
    oriented_features = d1(features)
    memory_bank = ((memory_bank * i) + oriented_features.mean(dim=0, keepdim=True)) / (i + 1)

memory_bank = memory_bank.transpose(-1, -2).detach()

print(memory_bank.requires_grad)

False


In [None]:
memory_bank.shape

In [9]:
scale = features[0].size(2)
type(scale), scale

(int, 56)

In [11]:
from einops import rearrange

In [14]:
from anomalib.models.cfa.torch_model import CfaModel
from anomalib.models.cfa.utils.cfa import DSVDD

In [15]:
cfa_old = DSVDD(model, train_loader, "wide_resnet50_2", 1, 1, device).to(device)

100%|██████████| 60/60 [00:07<00:00,  7.56it/s]


In [16]:
cfa_new = CfaModel(model, train_loader, "wide_resnet50_2", 1, 1, device).to(device)

100%|██████████| 60/60 [00:07<00:00,  7.79it/s]


In [21]:
features = torch.rand((4, 1792, 56, 56)).cuda()
features = rearrange(features, "b c h w -> b (h w) c")
loss = cfa_old._soft_boundary(features)
print(loss)

tensor(611152.9375, device='cuda:0', grad_fn=<AddBackward0>)

In [31]:
import torch
from einops import rearrange
from torch import nn
from torch.utils.data import DataLoader
from utils.cfa import Descriptor as OldDescriptor

from anomalib.models.cfa.cnn.resnet import wide_resnet50_2
from anomalib.models.cfa.datasets.mvtec import MVTecDataset
from anomalib.models.cfa.torch_model import CfaModel
from anomalib.models.cfa.torch_model import CoordConv2d as NewCoordConv2d
from anomalib.models.cfa.torch_model import Descriptor as NewDescriptor
from anomalib.models.cfa.torch_model import get_feature_extractor
from anomalib.models.cfa.utils.cfa import DSVDD
from anomalib.models.cfa.utils.coordconv import CoordConv2d as OldCoordConv2d


def initialize_weights(m) -> None:
    torch.manual_seed(0)
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

        device = torch.device("cuda")
        train_dataset = MVTecDataset("/home/sakcay/projects/anomalib/datasets/MVTec/", "zipper")
        train_loader = DataLoader(train_dataset, 4)


# Compare the feature extractor
_, (input, _, _) = next(enumerate(train_loader))
input = input.cuda()

old_feature_extractor = wide_resnet50_2(pretrained=True, progress=True).to(device)
old_feature_extractor.eval()

new_feature_extractor = get_feature_extractor("wide_resnet50_2", device=torch.device("cuda"))

old_features = old_feature_extractor(input)
new_features = new_feature_extractor(input)
new_features = [val for val in new_features.values()]

for old_feature, new_feature in zip(old_features, new_features):
    assert torch.allclose(old_feature, new_feature, atol=1e-1), "Old and new features should match."


# Compute the memory bank.
cfa_old = DSVDD(old_feature_extractor, train_loader, "wide_resnet50_2", 1, 1, device).to(device)
cfa_new = CfaModel(new_feature_extractor, train_loader, "wide_resnet50_2", 1, 1, device).to(device)
assert torch.allclose(cfa_old.C, cfa_new.memory_bank, atol=1e-1)
assert cfa_new.memory_bank.requires_grad is False


# Compute Forward-Pass.
old_loss, _ = cfa_old(old_features)
new_loss, _ = cfa_new(new_features)
assert (old_loss - new_loss).abs() / 1000 < 1e-1, "Old and new losses should match."

  assert(torch.allclose(old_feature, new_feature, atol=1e-1), "Old and new features should match.")
100%|██████████| 60/60 [00:07<00:00,  7.77it/s]
100%|██████████| 60/60 [00:07<00:00,  7.76it/s]
  assert((old_loss - new_loss).abs()/1000 < 1e-1, "Old and new losses should match.")


In [30]:
new_loss

(tensor(1782.5446, device='cuda:0', grad_fn=<MulBackward0>),
 tensor([[[[0.2759, 0.2337, 0.3223,  ..., 0.3570, 0.2941, 0.3232],
           [0.2293, 0.3200, 0.3982,  ..., 0.4049, 0.3766, 0.2758],
           [0.2537, 0.2877, 0.4034,  ..., 0.3845, 0.3050, 0.2768],
           ...,
           [0.3576, 0.4133, 0.4989,  ..., 0.3087, 0.2425, 0.2177],
           [0.3463, 0.4520, 0.4932,  ..., 0.3214, 0.3161, 0.2245],
           [0.3794, 0.3341, 0.3979,  ..., 0.2614, 0.2258, 0.2417]]],
 
 
         [[[0.2157, 0.1812, 0.2434,  ..., 0.4345, 0.4311, 0.4588],
           [0.1773, 0.2465, 0.2855,  ..., 0.5210, 0.5528, 0.4045],
           [0.1785, 0.2006, 0.2577,  ..., 0.5137, 0.4807, 0.4416],
           ...,
           [0.1664, 0.2012, 0.2795,  ..., 0.5641, 0.5193, 0.4713],
           [0.1745, 0.2443, 0.2802,  ..., 0.5432, 0.6477, 0.4538],
           [0.1965, 0.1770, 0.2304,  ..., 0.4324, 0.4567, 0.4814]]],
 
 
         [[[0.1802, 0.1520, 0.1995,  ..., 0.3348, 0.2626, 0.2925],
           [0.1486, 0.20