In [1]:
from arguments import parser 
import torch 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from sklearn.manifold import TSNE
import os 
from datasets import create_dataset 
from torch.utils.data import DataLoader
from omegaconf import OmegaConf



os.environ['CUDA_VISIBLE_DEVICES'] = '1' 
cfg = OmegaConf.load('results/PatchCore/MVTecAD/screw/baseline-identity-sampling_ratio_0.1-anomaly_ratio_0.0/seed_0/configs.yaml')

# model  = __import__('models').__dict__[cfg.MODEL.method](
#         backbone = cfg.MODEL.backbone,
#         **cfg.MODEL.params
#         )
# # model.load_state_dict(
# #         torch.load('results/ProxyCoreBase/MVTecAD/screw/baseline-anomaly_ratio_0.0/seed_0/model_best.pt')
# # )
# model.to('cuda')

trainset, testset = create_dataset(
    dataset_name  = cfg.DATASET.dataset_name,
    datadir       = cfg.DATASET.datadir,
    class_name    = cfg.DATASET.class_name,
    img_size      = cfg.DATASET.img_size,
    mean          = cfg.DATASET.mean,
    std           = cfg.DATASET.std,
    aug_info      = cfg.DATASET.aug_info,
    **cfg.DATASET.get('params',{})
)

trainloader = DataLoader(
        dataset     = trainset,
        batch_size  = cfg.DATASET.batch_size,
        num_workers = cfg.DATASET.num_workers,
        shuffle     = True 
    )    

testloader = DataLoader(
        dataset     = testset,
        batch_size  = cfg.DATASET.batch_size,
        num_workers = cfg.DATASET.num_workers,
        shuffle     = False 
    )    

# train_featureloader = model.get_feature_loader(trainloader)
# test_featureloader = model.get_feature_loader(testloader)

# self = model.core

# import timm 
# import torch.nn.functional as F 
# vit = timm.create_model('vit_small_patch16_224_in21k',pretrained=True)
# vit.to('cuda')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from models import PatchCore
import sys 
sys.path.append('/Volume/VAD/UAADF/softpatch/src')
import logging
import os
import pickle
import tqdm

import torch
import torch.nn as nn 
import timm 
from models.softpatch.src import common, sampler, multi_variate_gaussian, backbones

from sklearn.neighbors import LocalOutlierFactor
from sklearn.ensemble import IsolationForest    
from skimage.filters import threshold_otsu, threshold_mean, threshold_li, threshold_yen, threshold_triangle
import torch.nn.functional as F
import numpy as np

def get_sampler(sampler_name, sampling_ratio, device):
    if sampler_name == "identity":
        return sampler.IdentitySampler()
    elif sampler_name == "greedy_coreset":
        return sampler.GreedyCoresetSampler(sampling_ratio, device)
    elif sampler_name == "approx_greedy_coreset":
        return sampler.ApproximateGreedyCoresetSampler(sampling_ratio, device)
    
class MaskGenerator:
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio
        
        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0
        
        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size
        
        self.token_count = self.rand_size ** 2
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
        
    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1
        
        mask = mask.reshape((self.rand_size, self.rand_size))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
        
        return mask

class MIMCore(PatchCore):
    def __init__(self, encoder, backbone, faiss_on_gpu, faiss_num_workers, 
                 sampling_ratio, device, input_shape, threshold='quant15', weight_method='identity'):
        super(MIMCore,self).__init__(backbone,faiss_on_gpu,faiss_num_workers,sampling_ratio,device,input_shape,
                                     threshold, weight_method)
        
        self.load(
            backbone = timm.create_model(backbone, pretrained=True),
            device         = device,
            input_shape    = input_shape,
            nn_method      = common.FaissNN(faiss_on_gpu,faiss_num_workers,int(device.strip('cuda:'))),
            featuresampler = get_sampler(sampler_name = 'approx_greedy_coreset',
                                              sampling_ratio = sampling_ratio,
                                              device = device),
            threshold = threshold,
            weight_method = weight_method
            )
        
        self.reducing_mapper = torch.nn.Linear(1024,384,bias=False)
        self.mask_generator = MaskGenerator(
                                                    input_size = 224,
                                                    mask_patch_size = 32,
                                                    mask_ratio = 0.6,
                                                    model_patch_size=16
                                                )
        self.encoder = encoder 
        
    @torch.no_grad()
    def _embed(self, images):
        with torch.no_grad():
            features = self.forward_modules['feature_aggregator'](images)

        features = [features[layer] for layer in self.layers_to_extract_from]
        features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features]

        patch_shapes = [x[1] for x in features] # [[28, 28], [14, 14]]
        features = [x[0] for x in features]
        ref_num_patches = patch_shapes[1] # [14,14]

        for i in range(0, len(features)):
            _features = features[i]
            patch_dims = patch_shapes[i]

            _features = _features.reshape(
                _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
            )
           
            _features = _features.permute(0, -3, -2, -1, 1, 2)
            perm_base_shape = _features.shape
            _features = _features.reshape(-1, *_features.shape[-2:])
          
            _features = F.interpolate(
                        _features.unsqueeze(1),
                        size=(ref_num_patches[0], ref_num_patches[1]),
                        mode="bilinear",
                        align_corners=False,
                    )

            _features = _features.squeeze(1)
            _features = _features.reshape(
                        *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
                    )
           
            _features = _features.permute(0, -2, -1, 1, 2, 3)
 
            _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
 
            
            features[i] = _features
        features = [x.reshape(-1, *x.shape[-3:]) for x in features]
        features = self.forward_modules["preprocessing"](features)
        features = self.forward_modules["preadapt_aggregator"](features)
        features = features.reshape(-1,196,1024)
        return features 
    
    def forward(self, images:torch.Tensor):
        features = self._embed(images)
        features = self.reducing_mapper(features)
        
        x = features 
        B, L, _ = features.shape

        mask = torch.Tensor(self.mask_generator().reshape(1,14,14,1)).to('cuda')

        mask_token = nn.Parameter(torch.zeros(1,1,384))
        mask_token = mask_token.expand(B, L, -1).to('cuda')
        w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
        x = x * (1 - w) + mask_token * w

        cls_tokens = self.encoder.cls_token.expand(B,-1,-1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.encoder.pos_drop(x)
        x = self.encoder.blocks(x)
        x = self.encoder.norm(x)

        output = x[:, 1:]
        output = output.reshape(-1,14,14,384)
        self.output = output 
        self.features = features 
        features = features.reshape(-1,14,14,384)
        loss = F.l1_loss(features,output,reduction='none')
        loss = (loss * mask).sum() / (mask.sum() + 1e-5) / 3
        return loss

In [3]:
import timm 
import torch.nn.functional as F 
encoder = timm.create_model('vit_small_patch16_224_in21k',pretrained=True)
encoder.to('cuda')

model  = MIMCore(
        backbone = cfg.MODEL.backbone,
        encoder = encoder,
        **cfg.MODEL.params
        )
model.to('cuda')
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [9]:
from accelerate import Accelerator 
acc = Accelerator() 
model, optimizer, trainloader,testloader = acc.prepare(model, optimizer, trainloader,testloader)

In [7]:
for e in range(100):
    for i,(imgs, labels, gts) in enumerate(trainloader):
        optimizer.zero_grad()
        loss = model(imgs)
        loss.backward()
        optimizer.step()
        
    if e%10 ==0:
        print(loss)

tensor(1475.8062, device='cuda:0', grad_fn=<DivBackward0>)
tensor(679.1708, device='cuda:0', grad_fn=<DivBackward0>)
tensor(476.9736, device='cuda:0', grad_fn=<DivBackward0>)
tensor(388.8877, device='cuda:0', grad_fn=<DivBackward0>)
tensor(336.1979, device='cuda:0', grad_fn=<DivBackward0>)
tensor(296.8904, device='cuda:0', grad_fn=<DivBackward0>)
tensor(268.5957, device='cuda:0', grad_fn=<DivBackward0>)
tensor(258.1698, device='cuda:0', grad_fn=<DivBackward0>)
tensor(264.4916, device='cuda:0', grad_fn=<DivBackward0>)
tensor(323.8585, device='cuda:0', grad_fn=<DivBackward0>)


In [30]:
features = [] 
for imgs, labels,gts in trainloader:
    with torch.no_grad():
        feats = model._embed(imgs)
        feats = model.reducing_mapper(feats)
        
        B, L, _ = feats.shape
        
        cls_tokens = model.encoder.cls_token.expand(B,-1,-1)
        x = torch.cat((cls_tokens, feats), dim=1)
        x = model.encoder.pos_drop(x)
        x = model.encoder.blocks(x)
        x = model.encoder.norm(x)
        output = x[:, 1:]
        
    features.append(output.reshape(-1,384).detach().cpu().numpy())
features = np.vstack(features)    
sample_features, _ = model.featuresampler.run(features)
model.anomaly_scorer.fit(detection_features=[sample_features])

In [51]:
from utils.metrics import MetricCalculator
img_level = MetricCalculator(metric_list = ['auroc','average_precision'])
pix_level = MetricCalculator(metric_list = ['auroc','average_precision'])

    
_ = model.forward_modules.eval()
for imgs, labels,gts in testloader:
    batchsize = imgs.shape[0]
    with torch.no_grad():
        feats = model._embed(imgs)
        feats = model.reducing_mapper(feats)
        
        B, L, _ = feats.shape
        
        cls_tokens = model.encoder.cls_token.expand(B,-1,-1)
        x = torch.cat((cls_tokens, feats), dim=1)
        x = model.encoder.pos_drop(x)
        x = model.encoder.blocks(x)
        x = model.encoder.norm(x)
        output = x[:, 1:]
        output = output.reshape(-1,384)
        
        image_scores, _, _ = model.anomaly_scorer.predict([output.detach().cpu().numpy()])      
        # get patch wise anomaly score using image score    
        patch_scores = model.patch_maker.unpatch_scores(
        image_scores, batchsize=batchsize
        ) # Unfold : (B)
                
        scales = [14,14]
        patch_scores = patch_scores.reshape(batchsize, scales[0], scales[1])
        masks = model.anomaly_segmentor.convert_to_segmentation(patch_scores) # interpolation : (B,pw,ph) -> (B,W,H)
                
        score_map = np.concatenate([np.expand_dims(sm,0) for sm in masks])
        score_map = np.expand_dims(score_map,1)

        # get image wise anomaly score 
        image_scores = model.patch_maker.unpatch_scores(
        image_scores, batchsize=batchsize
        )
        image_scores = image_scores.reshape(*image_scores.shape[:2], -1)
        image_scores = model.patch_maker.score(image_scores)      
    # result update 
    img_level.update(image_scores, labels.type(torch.int))
    pix_level.update(score_map, gts.type(torch.int))

In [52]:
img_level.compute(), pix_level.compute()

({'auroc': 0.6305986696230598, 'average_precision': 0.6627197968950415},
 {'auroc': 0.9653490618530766, 'average_precision': 0.12091722784310494})