# Global-Net and DAD-head

### *Run these cells only when in Google Colab*

In [None]:
# Clone the repository
!git clone https://github.com/beerger/mad_seminar_ws23.git
# Move all content to the current directory
!mv ./mad_seminar_ws23/* ./
# Remove the empty directory
!rm -rf mad_seminar_ws23/

In [None]:
# Install additional packages
!pip install pytorch_lightning --quiet
!pip install lpips

## Imports for Global-Net and DAD-head

In [None]:
import pytorch_lightning as pl
import yaml
import torch
import matplotlib.pyplot as plt
import json
from google.colab import drive

from model.local_net import LocalNet
from model.dad import DADHead
from model.iad import iad_head
from model.global_net import GlobalNet
from data_loader.joint_training_data_loader import JointTrainingDataModule
from model.joint_training_module import JointGlobalDADTrainingModule

# autoreload imported modules
%load_ext autoreload
%autoreload 2

## Load the config

In [None]:
with open('./configs/global_dad_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize data

Mount current Colab session to Google Drive (training/val images are stored here)

In [None]:
# Will provide you with an authentication link
drive.mount('/content/drive')

Copy zipped file of zipper dataset from Google Drive to current Colab session

In [None]:
!cp "/content/drive/MyDrive/AnomalyDetection/Datasets/MVTec/zipper.tar.xz" "/content/"
# Unzip it
!tar -xf /content/zipper.tar.xz -C /content/
# Remove zip file
!rm -rf zipper.tar.xz

Get paths to training images

In [None]:
import os

# Specify the directory you want to list
train_directory = '/content/zipper/train/good/'

# Get a list of all files in the train_directory
file_list = [train_directory + f for f in os.listdir(train_directory) if os.path.isfile(os.path.join(train_directory, f))]

assert len(file_list) == 240

Split into train/validation

In [None]:
from sklearn.model_selection import train_test_split

# Assuming file_paths is your list of 240 image file paths
train_image_paths, val_image_paths = train_test_split(file_list, test_size=0.1, random_state=config['seed'])  # Adjust test_size as needed

Create data module for joint training on MVTec AD

In [None]:
data_module = JointTrainingDataModule(
    train_image_paths, 
    val_image_paths, 
    batch_size=config['batch_size'], 
    num_workers=2, 
    caching_strategy='at-init'
)

print(f"Number of training images: {len(train_image_paths)}")
print(f"Number of validation images: {len(val_image_paths)}")

Visualize patches, images, binary masks

In [None]:
# make sure batch_size in data_module is equal to BATCH_SIZE

BATCH_SIZE=4

# Reverse the normalization process done by ImageNetDataModule
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]


# Retrieve one batch of images
I, patches, binary_masks, labels = next(iter(data_module.train_dataloader()))

# Denormalize the patches for visualization
patches = denormalize(patches)

fig1, ax1 = plt.subplots(1, BATCH_SIZE, figsize=(20, 8))

# Plotting patches
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patches[i].permute(1, 2, 0)

    # Display the image
    ax1[i].imshow(image.cpu().numpy())
    ax1[i].axis('on')

fig2, ax2 = plt.subplots(1, BATCH_SIZE, figsize=(20, 8))

# Plotting images
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = I[i].permute(1, 2, 0)

    # Display the image
    ax2[i].imshow(image.cpu().numpy())
    ax2[i].axis('on')

fig3, ax3 = plt.subplots(1, BATCH_SIZE, figsize=(20, 8))
for i in range(BATCH_SIZE):
    # Squeeze the tensor to 2D [H, W] if it's 3D [1, H, W]
    mask = binary_masks[i].squeeze()

    # Display the mask
    ax3[i].imshow(mask.cpu().numpy(), cmap='gray', interpolation='none')
    ax3[i].axis('on')

plt.show()

print(labels)

for i, mask in enumerate(binary_masks):
    print(f"Size of binary mask {i}: {mask.size()}")
    print(f"Mask {i} values: Unique: {torch.unique(mask)}")


## Set up tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Set up all models for training

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load fine-tuned local_net from Google Drive
local_net = LocalNet().to(device)
# Load the state dictionary from the saved file
state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Trained Models/local_net_finetuned.pth', map_location=device)
# Update the local_net model's state dictionary
local_net.load_state_dict(state_dict)


global_net = GlobalNet().to(device)
dad_head = DADHead().to(device)

joint_train_module = JointGlobalDADTrainingModule(
    config, 
    local_net=local_net, 
    global_net=global_net, 
    dad_head=dad_head,  
)

## Calculate number of epochs

In [None]:
# Given by paper is batch size of 64 for 50k iterations
# Need to calculate max_epochs
total_iterations = config['iterations']
batch_size = config['batch_size']
num_training_images = len(train_image_paths)
# Calculate max_epochs
max_epochs = total_iterations / (num_training_images / batch_size)
max_epochs = int(max_epochs) + (max_epochs % 1 > 0)  # round up if not an integer
print(f"Calculated max_epochs: {max_epochs}")

## Create callbacks for training

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Setup the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Checkpoints",  # Path where checkpoints will be saved
    filename="{epoch}-{val_loss:.2f}",  # Filename template
    monitor="val_loss",  # Metric to monitor for saving
    every_n_epochs=1,  # Save every epoch
    save_weights_only=True,  # If True, save only the model weights, not the full model
    save_top_k=3,  # Save the top 3 checkpoints based on val_loss
    save_last=True,  # Also save the last checkpoint to resume training later
    verbose=True  # If True, print a message to stdout for each save
)

## Setup new trainer

In [None]:

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Setup trainer from checkpoint

In [None]:
# Change <CHECK_POINT.ckpt> to a valid checkpoint file located in
# /content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Checkpoints

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    resume_from_checkpoint="/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Checkpoints/<CHECK_POINT.ckpt>",
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Run joint training (and save trained models)

In [None]:
trainer.fit(joint_train_module, datamodule=data_module)
torch.save(global_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Trained Models/global_net.pth')
torch.save(dad_head.state_dict(), '/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Trained Models/dad_head.pth')