# Local-Net

### *Run these cells only when in Google Colab*

In [None]:

# Clone the repository
!git clone https://github.com/beerger/mad_seminar_ws23.git
# Move all content to the current directory
!mv ./mad_seminar_ws23/* ./
# Remove the empty directory
!rm -rf mad_seminar_ws23/

In [None]:
# Install additional packages
!pip install pytorch_lightning --quiet
!pip install lpips

## Imports for Local-Net

In [None]:
import pytorch_lightning as pl
import yaml
import torch
import matplotlib.pyplot as plt
import json
from google.colab import drive

from model.local_net import LocalNet
from model.model_utils import load_resnet_18_teacher_model
from model.student_training_module import StudentTrainingModule
from data_loader.localnet_data_loader import LocalNetDataModule
from model.one_layer_decoder import OneLayerDecoder

# autoreload imported modules
%load_ext autoreload
%autoreload 2

# Pre-training

The following code blocks will be part of the training of Local-Net. This is refered to as the pre-training of the framework, since the Local-Nets parameters will be fixed during training of Global-Net and DAD-head. This consists of two major steps:

* **Distillation**: on ImageNet, where the teacher network is pretrained ResNet-18.
* **Fine-tuning**: on some certain category of MVTec AD

## 1. Distillation

## Load the config

In [None]:
with open('./configs/local_net_distillation.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize data

Mount current Colab session to Google Drive (training/val images are stored here)

In [None]:
# Will provide you with an authentication link
drive.mount('/content/drive')

Copy data from Google Drive to Colab VM's local storage

In [None]:
# Copy training and validation zips from Google Drive
# Takes ~ 7 - 15 min
!cp "/content/drive/MyDrive/AnomalyDetection/Datasets/ImageNet/train.zip" "/content/"
!cp "/content/drive/MyDrive/AnomalyDetection/Datasets/ImageNet/val.zip" "/content/"

# Unzip
# Takes ~ 7 min
!unzip "/content/train.zip" -d "/content/train"
!unzip "/content/val.zip" -d "/content/val"

# Delete the zip files to free up space
!rm "/content/train.zip"
!rm "/content/val.zip"

# Move files to correct places
!mv "/content/train/content/train/* /content/train/"
!rm -rf "/content/train/content"
!mv "/content/val/content/val/*" "/content/val/"
!rm -rf "/content/val/content"

Create data module for ImageNet

In [None]:
def load_image_paths(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

train_image_paths = load_image_paths('/content/drive/MyDrive/AnomalyDetection/Datasets/ImageNet/train_image_local_paths.json')
val_image_paths = load_image_paths('/content/drive/MyDrive/AnomalyDetection/Datasets/ImageNet/val_image_local_paths.json')

data_module = LocalNetDataModule(
    train_image_paths, 
    val_image_paths, 
    batch_size=config['batch_size'], 
    num_workers=4, 
    caching_strategy='none'
)

print(f"Number of training images: {len(train_image_paths)}")
print(f"Number of validation images: {len(val_image_paths)}")

Plot patches for Local-Net and ResNet-18

In [None]:
# make sure batch_size in data_module is equal to BATCH_SIZE

BATCH_SIZE=config['batch_size']

# Reverse the normalization process done by LocalNetDataModule
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]


# Retrieve one batch of images
patch_local, patch_resnet = next(iter(data_module.train_dataloader()))

# Denormalize the patches for visualization
patch_local = denormalize(patch_local)
patch_resnet = denormalize(patch_resnet)

fig, ax = plt.subplots(2, BATCH_SIZE, figsize=(20, 8))  # 2 rows, BATCH_SIZE columns

# Plotting patch_local images in the first row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_local[i].permute(1, 2, 0)

    # Display the image
    ax[0, i].imshow(image.cpu().numpy())
    ax[0, i].axis('off')

# Plotting patch_resnet images in the second row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_resnet[i].permute(1, 2, 0)

    # Display the image
    ax[1, i].imshow(image.cpu().numpy())
    ax[1, i].axis('off')

plt.show()

## Set up tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Set up all models for distillation

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

local_net = LocalNet(config).to(device)
resnet_18 = load_resnet_18_teacher_model('resnet18-5c106cde.pth', device)
decoder = OneLayerDecoder(config['local_net_output_dimensions'], 
                          config['resnet_output_dimensions']).to(device)

student_train_module = StudentTrainingModule(
    config, 
    student_model=local_net, 
    teacher_model=resnet_18, 
    decoder=decoder, 
    mode='distillation'
)

## Calculate number of epochs

In [None]:
# Given by paper is batch size of 64 for 50k iterations
# Need to calculate max_epochs
total_iterations = config['iterations']
batch_size = config['batch_size']
num_training_images = len(train_image_paths)
# Calculate max_epochs
max_epochs = total_iterations / (num_training_images / batch_size)
max_epochs = int(max_epochs) + (max_epochs % 1 > 0)  # round up if not an integer
print(f"Calculated max_epochs: {max_epochs}")

## Create callbacks for training

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Setup the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Checkpoints",  # Path where checkpoints will be saved
    filename="{epoch}-{val_loss:.2f}",  # Filename template
    monitor="val_loss",  # Metric to monitor for saving
    every_n_epochs=1,  # Save every epoch
    save_weights_only=True  # If True, save only the model weights, not the full model
)

## Setup new trainer

In [None]:

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Setup trainer from checkpoint


In [None]:
# Change <CHECK_POINT.ckpt> to a valid checkpoint file located in
# /content/drive/MyDrive/AnomalyDetection/LocalNet/Checkpoints/

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    resume_from_checkpoint="/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Checkpoints/<CHECK_POINT.ckpt>",
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Run distillation

In [None]:
trainer.fit(student_train_module, datamodule=data_module)

## Save distilled model

In [None]:
torch.save(local_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Trained Models/local_net_distilled.pth')

## 2. Fine-tuning

Local-Net has now been distilled from pre-trained ResNet-18 on ImageNet, and will in this part be fine-tuned. It is fine-tuned into a specific category in MVTec AD with the same loss as that in distillation. Here, the category ***zippers*** have been chosen, as it looks interesting, and it also got 0.99 in pixel-level AUROC, and 0.992 in per-region-overlap (PRO).

## Load the config

In [None]:
with open('./configs/local_net_fine_tune.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize data

In [None]:
# Will provide you with an authentication link
drive.mount('/content/drive')

Copy zipped file of zipper dataset from Google Drive to current Colab session

In [None]:
!cp "/content/drive/MyDrive/AnomalyDetection/Datasets/MVTec/zipper.tar.xz" "/content/"
# Unzip it
!tar -xf /content/zipper.tar.xz -C /content/
# Remove zip file
!rm -rf zipper.tar.xz

Get paths to training images

In [None]:
import os

# Specify the directory you want to list
train_directory = '/content/zipper/train/good/'

# Get a list of all files in the train_directory
file_list = [train_directory + f for f in os.listdir(train_directory) if os.path.isfile(os.path.join(train_directory, f))]

assert len(file_list) == 240

Split into train/validation

In [None]:
from sklearn.model_selection import train_test_split

# Assuming file_paths is your list of 240 image file paths
train_image_paths, val_image_paths = train_test_split(file_list, test_size=0.1, random_state=config['seed'])  # Adjust test_size as needed

Create data module for MVTec AD

In [None]:
data_module = LocalNetDataModule(
    train_image_paths, 
    val_image_paths, 
    batch_size=config['batch_size'], 
    num_workers=2, 
    caching_strategy='at-init'
)

print(f"Number of training images: {len(train_image_paths)}")
print(f"Number of validation images: {len(val_image_paths)}")

Plot patches for Local-Net and ResNet-18

In [None]:
# make sure batch_size in data_module is equal to BATCH_SIZE

BATCH_SIZE=config['batch_size']

# Reverse the normalization process done by LocalNetDataModule
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]


# Retrieve one batch of images
patch_local, patch_resnet = next(iter(data_module.train_dataloader()))

# Denormalize the patches for visualization
patch_local = denormalize(patch_local)
patch_resnet = denormalize(patch_resnet)

fig, ax = plt.subplots(2, BATCH_SIZE, figsize=(20, 8))  # 2 rows, BATCH_SIZE columns

# Plotting patch_local images in the first row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_local[i].permute(1, 2, 0)

    # Display the image
    ax[0, i].imshow(image.cpu().numpy())
    ax[0, i].axis('off')

# Plotting patch_resnet images in the second row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_resnet[i].permute(1, 2, 0)

    # Display the image
    ax[1, i].imshow(image.cpu().numpy())
    ax[1, i].axis('off')

plt.show()

## Set up tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Set up all models for fine-tuning

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load distilled local_net from Google Drive
local_net = LocalNet().to(device)
# Load the state dictionary from the saved file
local_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Trained Models/V2/local_net_distilled_v2.pth', map_location=device)
# Update the local_net model's state dictionary
local_net.load_state_dict(local_state_dict)

resnet_18 = load_resnet_18_teacher_model('resnet18-5c106cde.pth', device)
decoder = OneLayerDecoder(config['local_net_output_dimensions'], 
                          config['resnet_output_dimensions']).to(device)

decoder_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Trained Models/V2/decoder_v2.pth')

decoder.load_state_dict(decoder_state_dict)

student_train_module = StudentTrainingModule(
    config, 
    student_model=local_net, 
    teacher_model=resnet_18, 
    decoder=decoder, 
    mode='finetuning'
)

## Calculate number of epochs

In [None]:
# Given by paper is batch size of 64 for 50k iterations
# Need to calculate max_epochs
total_iterations = config['iterations']
batch_size = config['batch_size']
num_training_images = len(train_image_paths)
# Calculate max_epochs
max_epochs = total_iterations / (num_training_images / batch_size)
max_epochs = int(max_epochs) + (max_epochs % 1 > 0)  # round up if not an integer
print(f"Calculated max_epochs: {max_epochs}")

## Create callbacks for training

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Setup the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Checkpoints",  # Path where checkpoints will be saved
    filename="{epoch}-{val_loss:.2f}",  # Filename template
    monitor="val_loss",  # Metric to monitor for saving
    every_n_epochs=1,  # Save every epoch
    save_weights_only=True,  # If True, save only the model weights, not the full model
    save_top_k=3,  # Save the top 3 checkpoints based on val_loss
    save_last=True,  # Also save the last checkpoint to resume training later
    verbose=True  # If True, print a message to stdout for each save
)

## Setup new trainer

In [None]:

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Setup trainer from checkpoint


In [None]:
# Change <CHECK_POINT.ckpt> to a valid checkpoint file located in
# /content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Checkpoints

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    resume_from_checkpoint="/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Checkpoints/<CHECK_POINT.ckpt>",
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Run fine-tuning

In [None]:
trainer.fit(student_train_module, datamodule=data_module)

## Save fine-tuned model

In [None]:
torch.save(local_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Trained Models/local_net_finetuned.pth')

## Save model by first loading given checkpoint

In [None]:
local_net = LocalNet()
resnet_18 = load_resnet_18_teacher_model('resnet18-5c106cde.pth', device)
decoder = OneLayerDecoder(128, 512)

student_train_module = StudentTrainingModule(
    config, 
    student_model=local_net, 
    teacher_model=resnet_18, 
    decoder=decoder, 
    mode='finetuning'
)

# Replace with correct checkpoint path
checkpoint = torch.load("/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Checkpoints/V4/epoch=3535-val_loss=1890.43.ckpt")
student_train_module.load_state_dict(checkpoint['state_dict'])

local_net = student_train_module.student_model

# Save the state dictionaries of the individual models
torch.save(local_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Trained Models/V4/local_net_finetuned_v4.pth')
