# Pixel-Level Localization on MVTec AD

### *Run these cells only when in Google Colab*

In [None]:

# Clone the repository
!git clone https://github.com/beerger/mad_seminar_ws23.git
# Move all content to the current directory
!mv ./mad_seminar_ws23/* ./
# Remove the empty directory
!rm -rf mad_seminar_ws23/

In [None]:
# Install additional packages
!pip install pytorch_lightning --quiet
!pip install lpips

## Imports for Local-Net

In [None]:
import pytorch_lightning as pl
import yaml

from model.local_net import LocalNet
from model.model_utils import load_resnet_18_teacher_model
from model.student_training_module import StudentTrainingModule
from image_net_data_loader import ImageNetDataModule

# autoreload imported modules
%load_ext autoreload
%autoreload 2

## 1. Pre-training

Until next numbered step the following code blocks will be part of the training of Local-Net. This is refered to as the pre-training of the framework, since the Local-Nets parameters will be fixed during training of Global-Net and DAD-head. This consists of two major steps:

* **Distillation**: on ImageNet, where the teacher network is pretrained ResNet-18.
* **Fine-tuning**: on some certain category of MVTec AD

Pre-processing in accordance (*) to ResNet-18 documentation
@ https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html

(*) According to documentation it's first resized to 256x256 then
center cropped to 224x224. This step has been skipped, 
and it is instead resized directly to 224x224. 
The reason for this is because the input patch size to Local-Net is 33x33
and ResNet-18 has input size 224x224, 
meaning that resizing to 256x256 and then cropping to desired size (224x224)
would result in altering the original spatial relationships and scale of the features of the image.
This could potentially lead to a mismatch when comparing features extracted from the resized and 
cropped image by ResNet-18 with those extracted from the original 33x33 image by Local-Net.

Pre-processing the data for distillation consists of 4 major steps:

1. Load images from Image-Net
2. Resize images to 256 x 256 (done by most applications, and mentioned in Supplementary Material)
3. Extract 33 x 33 patches from each resized image
4. Create two separate transform pipelines
    * For Local-Net: Convert the 33x33 patches to PyTorch tensors and normalize them
    * For ResNet-18 (teacher model): Resize the 33x33 patches to 224x224, then convert to tensors and normalize

<span style="color:yellow"> **Note**: All pre-processing is done by *ImageNetDataModule* provided in *image_net_data_loader.py* </span>

## Load the config

In [None]:
with open('./configs/local_net_distillation.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize data

Huggingface imports

In [None]:
# Required imports to use ImageNet from Hugging Face
!pip install -U "huggingface_hub[cli]"
!pip install datasets
# To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens
!huggingface-cli login

Mount Google Drive in Colab

In [None]:
from google.colab import drive
# Will provide you with an authentication link
drive.mount('/content/drive')

Set Cache Directory

In [None]:
import os
from datasets import load_dataset

# Set a path on your Google Drive for the cache
cache_dir = '/content/drive/MyDrive/ImageNet'

# Make sure the cache directory exists
os.makedirs(cache_dir, exist_ok=True)

# Load the dataset with specified cache directory
dataset = load_dataset('imagenet-1k', cache_dir=cache_dir)


Visualize training data from ImageNet

In [None]:
import matplotlib.pyplot as plt

image_net_train_data_module = ImageNetDataModule(dataset, config['batch_size'])

batch = next(iter(image_net_train_data_module.train_dataloader()))

# Print statistics
print(f"Batch shape: {batch.shape}")
print(f"Batch min: {batch.min()}")
print(f"Batch max: {batch.max()}")

fig, ax = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    ax[i].imshow(batch[i].squeeze(), cmap='gray')
    ax[i].axis('off')
plt.show()


## Set up tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Load the Distillation config