### *Run these cells only when in Google Colab*

In [None]:
# Clone the repository
!git clone https://github.com/beerger/mad_seminar_ws23.git
# Move all content to the current directory
!mv ./mad_seminar_ws23/* ./
# Remove the empty directory
!rm -rf mad_seminar_ws23/

In [None]:
# Install additional packages
!pip install pytorch_lightning --quiet
!pip install lpips

## Imports for inference

In [None]:
import pytorch_lightning as pl
import yaml
import torch
import matplotlib.pyplot as plt
import json
from google.colab import drive

from model.local_net import LocalNet
from model.dad import DADHead
from model.iad import iad_head
from model.global_net import GlobalNet
from data_loader.mvtec_inference_data_loader import MVTecInferenceDataModule
from model.anomaly_detector import AnomalyDetector

# autoreload imported modules
%load_ext autoreload
%autoreload 2

## Load and visualize data

Mount current Colab session to Google Drive (training/val images are stored here)

In [None]:
# Will provide you with an authentication link
drive.mount('/content/drive')

Copy zipped file of zipper dataset from Google Drive to current Colab session

In [None]:
!cp "/content/drive/MyDrive/AnomalyDetection/Datasets/MVTec/zipper.tar.xz" "/content/"
# Unzip it
!tar -xf /content/zipper.tar.xz -C /content/
# Remove zip file
!rm -rf zipper.tar.xz

Get paths to test images

In [None]:
import os

# Specify the directory you want to list
train_directory = '/content/zipper/test/'

test_file_paths = []
for root, dirs, files in os.walk(train_directory, topdown=False):
   for name in files:
      test_file_paths.append(os.path.join(root, name))

assert len(test_file_paths) == 151
print(test_file_paths)


Create dataloader

In [None]:
data_module = MVTecInferenceDataModule(
    test_file_paths, 
    batch_size=1, 
    num_workers=0, 
    caching_strategy='none'
)

Display inference images

In [None]:
# make sure batch_size in data_module is equal to BATCH_SIZE

BATCH_SIZE=4

images = next(iter(data_module.test_dataloader()))

fig, ax = plt.subplots(1, 4, figsize=(10, 10))

for i in range(4):
  ax[i].imshow(images[i].permute(1, 2, 0).cpu().numpy())

plt.show()


Check that AnomalyDetector creates correct patches and binary masks

In [None]:
anomaly_detector = AnomalyDetector(None, None, None)
patches, binary_masks = anomaly_detector.create_patches_and_masks(images[0])

# Reverse the normalization process done by data module
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]

# Denormalize
for i, patch in enumerate(patches):
  patches[i] = denormalize(patch)

# Plot image
image = images[0].squeeze(0)
plt.imshow(image.permute(1, 2, 0).cpu().numpy())
plt.show()
  
# Plot all (overlapping) patches
fig1, ax1 = plt.subplots(20, 20, figsize=(20, 20))
for j in range(20):
  for i in range(20):
    ax1[j][i].imshow(patches[i+j*20].squeeze(0).permute(1, 2, 0).cpu().numpy())
    ax1[j][i].axis('off')

# Plot all (overlappin) binary masks
fig2, ax2 = plt.subplots(20, 20, figsize=(20, 20))
for j in range(20):
  for i in range(20):
    ax2[j][i].imshow(binary_masks[i+j*20].squeeze(0).cpu().numpy(), cmap='gray', interpolation='none')
    #ax2[j][i].axis('off')

plt.show()


## Create the Anomaly Detector

Load all trained models

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load all state dictionary from Google Drive
local_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Trained Models/local_net_finetuned.pth', map_location=device)
global_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Trained Models/V1/global_net.pth', map_location=device)
dad_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Trained Models/V1/dad_head.pth', map_location=device)

# Initialise all networks
local_net = LocalNet().to(device)
global_net = GlobalNet().to(device)
dad_head = DADHead().to(device)

# Update all network's state dictionaries
local_net.load_state_dict(local_state_dict)
global_net.load_state_dict(global_state_dict)
dad_head.load_state_dict(dad_state_dict)

anomaly_detector = AnomalyDetector(local_net, global_net, dad_head).to(device)



In [None]:
anomaly_detector.detect_anomalies(images[0].unsqueeze(0))