
<div style="text-align: center;">
    <h1>ReConPatch: anomaly detection</h1>
    <h3>Authors:</h3>
    <p>Dario Loi 1940849, Elena Muia 1938610, Martina Doku 1938629</p>

</div>


<div>
    <h2>0 - Introduction</h2>
    <p>This project aims to reimplement and potentially advance the ReConPatch method proposed in the paper titled  <a href="https://arxiv.org/pdf/2305.16713v3">”ReConPatch:
Anomaly Detection by Linear Modulation of Pretrained Features.”</a> This method addresses the challenge of
anomaly detection by constructing discriminative features through a linear modulation of patch features extracted
from pre-trained models and employs contrastive representation learning to collect and distribute features in a way
that produces a target-oriented and easily separable representation of the data.</p>
</div>


In [9]:
#install package implementing EMA with pytorch
%pip install ema-pytorch 
%pip install lightning
%pip install wandb

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [10]:
import os
from typing import Optional
from lightning import LightningDataModule, LightningModule
import lightning.pytorch as pl
import torch.optim as optim
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
from ema_pytorch import EMA
import wandb
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint

<div>
    <h2>1.1 - Data</h2>
    <p>In this study, we used the <a href="https://arxiv.org/pdf/2305.16713v3">MVTec AD</a> dataset
and <a href="https://arxiv.org/pdf/2305.16713v3">BTAD</a> dataset for our experiments</p>
</div>


In [11]:
#data
#downloaded from https://www.kaggle.com/uciml/pima-indians-diabetes-database
#wget https://www.mvtec.com/company/research/datasets/mvtec-ad/downloads/mvtec_anomaly_detection.tar.xz


<div>
    <h2>1.2 - Data preprocessing</h2>
    <p>We create a unique data module to feed to the lightning module for both the datasets. The preprocessing will follow the ones specified in the sections 4.3 and 4.4 of the aforementioned paper.</p>
</div>


In [12]:

class MVTecDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        batch_size: int = 32,
        num_workers: int = 4,
        train_val_split: float = 0.8,
        shuffle: bool = True,
        pin_memory: bool = True,
        image_size: int = 256,
        normalize: Optional[transforms.Normalize] = None,
        **kwargs,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_val_split = train_val_split
        self.shuffle = shuffle
        self.pin_memory = pin_memory
        self.image_size = image_size
        self.normalize = normalize

    def setup(self, stage=None):
        # Define transformations
        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.CenterCrop(size=(224, 224)),
            self.normalize if self.normalize else transforms.Lambda(lambda x: x),
        ])

        # Load dataset
        tot_num=0
        for subclass in ["bottle","cable","capsule","carpet","grid","hazelnut","leather","metal_nut","pill","screw","tile","toothbrush","transistor","wood","zipper"]:
            dataset = ImageFolder(os.path.join(self.data_dir, subclass,"train"), transform=transform)
            # Split dataset into train and validation sets
            num_train = int(len(dataset) * self.train_val_split)
            num_val = len(dataset) - num_train
            tot_num+=len(dataset)
            train_dataset, val_dataset = torch.utils.data.random_split(
                dataset, [num_train, num_val])
            #add the subclass to the dataset
            if subclass == "bottle":
                self.train_dataset = train_dataset
                self.val_dataset = val_dataset
            else:
                self.train_dataset = torch.utils.data.ConcatDataset([self.train_dataset,train_dataset])
                self.val_dataset = torch.utils.data.ConcatDataset([self.val_dataset,val_dataset])

        #print the number of images in the dataset
        print("number of images in train dataset",len(self.train_dataset))
        print("number of images in val dataset",len(self.val_dataset))

        #create the dataloaders
        self.train_dataloader=self.get_train_dataloader()
        self.val_dataloader=self.get_val_dataloader()
        self.test_dataloader=self.get_test_dataloader()

    def get_train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

    def get_val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

    def get_test_dataloader(self):
        # Load test dataset without shuffling
        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            self.normalize if self.normalize else transforms.Lambda(lambda x: x),
        ])
        for subclass in ["bottle","cable","capsule","carpet","grid","hazelnut","leather","metal_nut","pill","screw","tile","toothbrush","transistor","wood","zipper"]:
            test_dataset = ImageFolder(os.path.join(self.data_dir, subclass,"test"), transform=transform)
            if subclass == "bottle":
                self.test_dataset = test_dataset
            else:
                self.test_dataset = torch.utils.data.ConcatDataset([self.test_dataset,test_dataset])

        print("number of images in test dataset",len(self.test_dataset))
        return DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

<p> setup the dataset </p>

In [13]:
data_dir=os.path.join('/kaggle/input/mvtec-ad')
datamod=MVTecDataModule(data_dir=data_dir,batch_size=8,num_workers=4,train_val_split=0.8,shuffle=True,pin_memory=True,image_size=256,normalize=None)
datamod.setup()

number of images in train dataset 2900
number of images in val dataset 729
number of images in test dataset 1725


<div>
    <h2>2 - Model</h2>
    <p>We start replicating the model as described in the paper</p>
</div>

In [20]:
def relaxedContrastiveLoss(w, delta, m): 
    return torch.mean(w * delta**2 + (1 - w) * torch.max(m - delta, torch.tensor(0.0))**2)


#We set \sigma = 1 based on the cited papers 18 and 19 in ReConPatch
class ContextualSimilarity(LightningModule):
    def __init__(self, k = 5, alpha = 0.5):
        super(ContextualSimilarity, self).__init__()
        self.k = k
        self.alpha = alpha
        
    def forward(self, z):
        distances = torch.cdist(z,z)  
        topk_dist = -torch.topk(-distances, self.k)[0][:,-1]
        filtering = (distances <= topk_dist.unsqueeze(-1)).float()
        similarity = torch.matmul(filtering, filtering.transpose(0, 1)) / torch.sum(filtering, dim=-1, keepdim=True)
        R = filtering * filtering.transpose(0, 1)
        similarity = torch.matmul(similarity, R.transpose(0, 1)) / torch.sum(R, dim=-1, keepdim=True)
        return self.alpha * (similarity + similarity.transpose(0, 1))
        
        
class PairwiseSimilarity(LightningModule):
    def __init__(self, sigma = 1.0):
        super(PairwiseSimilarity, self).__init__()
        self.sigma = sigma
    def forward(self,z):
        return torch.exp(-(torch.cdist(z,z)/self.sigma))


class ReConPatch(LightningModule):
    def __init__(
        self,
        input_dim,
        emb_dim = 512,
        proj_dim = 1024,
        alpha = 0.5,
        margin=0.1
    ):  
        super(ReConPatch, self).__init__() 
        self.margin=margin
        self.alpha=alpha
        self.wr_model = torch.hub.load('pytorch/vision:v0.13.0', 'wide_resnet50_2', pretrained=True)
        self.fmap = []
        def hook(module, input, output):
            self.fmap.append(output)
        self.wr_model.layer2.register_forward_hook(hook)            
        self.wr_model.layer3.register_forward_hook(hook)
        #setup network 1
        self.repr_layer = nn.Linear(input_dim, emb_dim)
        self.proj_layer = nn.Linear(emb_dim, proj_dim)
        
        #setup network 2
        self.repr_layer_2=nn.Linear(input_dim, emb_dim)
        self.proj_layer_2=nn.Linear(emb_dim, proj_dim)
        self.ema_repr = EMA(self.repr_layer_2)
        self.ema_proj = EMA(self.proj_layer_2) 
        #???
        with torch.no_grad():
            self.proj_layer.weight.copy_(torch.randn_like(self.proj_layer.weight))
            self.proj_layer.bias.copy_(torch.randn_like(self.proj_layer.bias))
            self.repr_layer.weight.copy_(torch.randn_like(self.repr_layer.weight))
            self.repr_layer.bias.copy_(torch.randn_like(self.repr_layer.bias))
        self.ema_repr.update()
        self.ema_proj.update()
        
        self.pairwise_sim=PairwiseSimilarity()
        self.contextual_sim=ContextualSimilarity()
    def forward(self, x):
        '''
        x has dimensions B x C x H x W (batch channels height width)
        '''
        #-----------FROM PRETRAINED MODEL TO FEATURE MAP
        y = self.wr_model(x)
        #NOTE: We are taking as dimensions the ones of the feature map with higher resolution as specified in
        #chapter 3.1 of the patchcore paper (https://arxiv.org/pdf/2106.08265)
        dimensions = (
                int(torch.Tensor([t.shape[-2] for t in self.fmap]).max().item()),
                int(torch.Tensor([t.shape[-1] for t in self.fmap]).max().item())
            )
        blur = nn.AvgPool2d(3, stride = 1)
        resizer = nn.AdaptiveAvgPool2d(dimensions)
        preprocess = lambda t : resizer(blur(t))
        feature_stacks = torch.cat([preprocess(m) for m in self.fmap], dim=1)
        
        #-----------RECONPATCH
        feature_stacks_reshaped = feature_stacks.reshape(feature_stacks.shape[1], -1).T

        #----------network1 pass
        h1=self.ema_repr(feature_stacks_reshaped)
        z1=self.ema_proj(h1)
        
        p_sim=self.pairwise_sim(z1)
        c_sim=self.contextual_sim(z1)
        
        w=self.alpha*p_sim+(1-self.alpha)*c_sim
        #----------network2 pass
        h2=self.repr_layer_2(feature_stacks_reshaped)
        z2=self.proj_layer_2(h2)
        
        
        pairwise_distances = torch.cdist(z2, z2, p=2)

        # Add a small epsilon to avoid division by zero when taking square root
        epsilon = 1e-9
        distances = torch.sqrt(pairwise_distances + epsilon)
        delta=distances/torch.mean(distances)
        
        return w ,delta
         
    def training_step(self, batch, batch_idx):
        x, y = batch
        w,delta = self(x)
        loss = relaxedContrastiveLoss(w, delta, self.margin)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def validation_step(self, batch, batch_idx):
        x, y = batch
        w,delta = self(x)
        loss = relaxedContrastiveLoss(w, delta, self.margin)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        w,delta = self(x)
        loss = relaxedContrastiveLoss(w, delta, self.margin)
        self.log('test_loss', loss)
        return loss

In [21]:
#visualize a sample and a target
sample,target= next(iter(datamod.train_dataloader))
print(sample.shape)
print(target.shape)

torch.Size([8, 3, 256, 256])
torch.Size([8])


In [24]:
#testing the forward method
sample,target= next(iter(datamod.train_dataloader))
print(sample.shape)
#1536 is the shape of the final features after the aggregation of the features
#of the different layers
model=ReConPatch(1536)
w,delta=model(sample)

torch.Size([8, 3, 256, 256])


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.13.0


dimesions of feature reshaped torch.Size([8192, 1536])


<div>
    <h2>3 - Training and Evaluation</h2>
</div>

In [25]:
#create a wandb training loop for the model


wandb.login()
wandb.init(project='ReConPatch')
wandb_logger = WandbLogger(project='ReConPatch')
checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min')
model = ReConPatch(1536)
trainer = pl.Trainer(max_epochs=10, logger=wandb_logger, callbacks=[checkpoint_callback])
trainer.fit(model, datamod.train_dataloader)

#close the wandb session
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mmartinadoku[0m ([33mpoi-dl-airo[0m). Use [1m`wandb login --relogin`[0m to force relogin


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.13.0
INFO: Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
INFO: LOCAL_RANK: 0 - CU

Training: |          | 0/? [00:00<?, ?it/s]

dimesions of feature reshaped torch.Size([8192, 1536])
dimesions of feature reshaped torch.Size([8192, 3072])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x3072 and 1536x512)

Versione esterna dal modello se si vuole implementare pezzetto pezzetto senza far partire tutto il training insieme

In [None]:
   
model = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet50_2', pretrained=True)
f_maps = []
def hook(module, input, output)-> None:
    f_maps.append(output)
model.layer2.register_forward_hook(hook)    
model.layer3.register_forward_hook(hook) 
#model.fc = nn.Linear(model.fc.in_features, 512)
model.eval()

In [None]:
sample, target = next(iter(datamod.train_dataloader()))
with torch.no_grad():
    y = model(sample)
    

In [None]:
dimensions = (
    int(torch.Tensor([t.shape[-2] for t in f_maps]).mean().item()),
    int(torch.Tensor([t.shape[-1] for t in f_maps]).mean().item())
)

In [None]:
blur = nn.AvgPool2d(3, stride = 1)
resizer = nn.AdaptiveAvgPool2d(dimensions)
preprocess = lambda t : resizer(blur(t))

In [None]:
feature_stacks = torch.cat([preprocess(m) for m in f_maps], dim=1)
feature_stacks_reshaped = feature_stacks.reshape(feature_stacks.shape[1], -1).T