# X-ray

### *Run these cells only when in Google Colab*

In [None]:

# Clone the repository
!git clone https://github.com/beerger/mad_seminar_ws23.git
# Move all content to the current directory
!mv ./mad_seminar_ws23/* ./
# Remove the empty directory
!rm -rf mad_seminar_ws23/

In [None]:
# # Download the data
!wget https://syncandshare.lrz.de/dl/fiH6r4B6WyzAaxZXTEAYCE/data.zip
# # Extract the data
!unzip -q ./data.zip

In [None]:
# Install additional packages
!pip install pytorch_lightning --quiet
!pip install lpips

In [None]:
import pytorch_lightning as pl
import yaml
import torch
import matplotlib.pyplot as plt
import json
from google.colab import drive
import os
import pandas as pd

from model.local_net import LocalNet
from model.model_utils import load_resnet_18_teacher_model
from model.student_training_module import StudentTrainingModule
from data_loader.mvtec_data_loader import MVTecDataModule
from model.one_layer_decoder import OneLayerDecoder

# autoreload imported modules
%load_ext autoreload
%autoreload 2

## Load the config

In [None]:
with open('./configs/local_net_fine_tune.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize data

Mount current Colab session to Google Drive (training/val images are stored here)

In [None]:
# Will provide you with an authentication link
drive.mount('/content/drive')

Get file paths to train/val images

In [None]:
split_dir = "./data/splits"

train_csv_ixi = os.path.join(split_dir, 'ixi_normal_train.csv')
train_csv_fastMRI = os.path.join(split_dir, 'normal_train.csv')
val_csv = os.path.join(split_dir, 'normal_val.csv')
# Load csv files
train_files_ixi = pd.read_csv(train_csv_ixi)['filename'].tolist()
train_files_fastMRI = pd.read_csv(train_csv_fastMRI)['filename'].tolist()
val_files = pd.read_csv(val_csv)['filename'].tolist()
# Combine files
train_image_paths = train_files_ixi + train_files_fastMRI
val_image_paths = val_files

print(f"Using {len(train_files_ixi)} IXI images "
      f"and {len(train_files_fastMRI)} fastMRI images for training. "
      f"Using {len(val_files)} images for validation.")

# Ensure that it's file paths
print(train_image_paths)
print(val_image_paths)

Create data loader

In [None]:
data_module = MVTecDataModule(
    train_image_paths, 
    val_image_paths, 
    batch_size=config['batch_size'], 
    num_workers=4, 
    caching_strategy='none'
)

In [None]:
# make sure batch_size in data_module is equal to BATCH_SIZE

BATCH_SIZE=config['batch_size']

# Reverse the normalization process done by ImageNetDataModule
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]


# Retrieve one batch of images
patch_local, patch_resnet = next(iter(data_module.train_dataloader()))

# Denormalize the patches for visualization
patch_local = denormalize(patch_local)
patch_resnet = denormalize(patch_resnet)

fig, ax = plt.subplots(2, BATCH_SIZE, figsize=(20, 8))  # 2 rows, BATCH_SIZE columns

# Plotting patch_local images in the first row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_local[i].permute(1, 2, 0)

    # Display the image
    ax[0, i].imshow(image.cpu().numpy())
    ax[0, i].axis('off')

# Plotting patch_resnet images in the second row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_resnet[i].permute(1, 2, 0)

    # Display the image
    ax[1, i].imshow(image.cpu().numpy())
    ax[1, i].axis('off')

plt.show()

## Set up tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Set up all models for fine-tuning

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load distilled local_net from Google Drive
local_net = LocalNet().to(device)
# Load the state dictionary from the saved file
local_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Trained Models/V2/local_net_distilled_v2.pth', map_location=device)
# Update the local_net model's state dictionary
local_net.load_state_dict(local_state_dict)

resnet_18 = load_resnet_18_teacher_model('resnet18-5c106cde.pth', device)
decoder = OneLayerDecoder(config['local_net_output_dimensions'], 
                          config['resnet_output_dimensions']).to(device)

decoder_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Trained Models/V2/decoder_v2.pth')

decoder.load_state_dict(decoder_state_dict)

student_train_module = StudentTrainingModule(
    config, 
    student_model=local_net, 
    teacher_model=resnet_18, 
    decoder=decoder, 
    mode='finetuning'
)

## Calculate number of epochs

In [None]:
# Given by paper is batch size of 64 for 50k iterations
# Need to calculate max_epochs
total_iterations = config['iterations']
batch_size = config['batch_size']
num_training_images = len(train_image_paths)
# Calculate max_epochs
max_epochs = total_iterations / (num_training_images / batch_size)
max_epochs = int(max_epochs) + (max_epochs % 1 > 0)  # round up if not an integer
print(f"Calculated max_epochs: {max_epochs}")

## Create callbacks for training

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Setup the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/X-ray/Checkpoints/V1",  # Path where checkpoints will be saved
    filename="{epoch}-{val_loss:.2f}",  # Filename template
    monitor="val_loss",  # Metric to monitor for saving
    every_n_epochs=1,  # Save every epoch
    save_weights_only=True,  # If True, save only the model weights, not the full model
    save_top_k=3,  # Save the top 3 checkpoints based on val_loss
    save_last=True,  # Also save the last checkpoint to resume training later
    verbose=True  # If True, print a message to stdout for each save
)

## Setup new trainer

In [None]:

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Run fine-tuning

In [None]:
trainer.fit(student_train_module, datamodule=data_module)

## Save model by first loading given checkpoint

In [None]:
local_net = LocalNet()
resnet_18 = load_resnet_18_teacher_model('resnet18-5c106cde.pth', device)
decoder = OneLayerDecoder(128, 512)

student_train_module = StudentTrainingModule(
    config, 
    student_model=local_net, 
    teacher_model=resnet_18, 
    decoder=decoder, 
    mode='finetuning'
)

# Replace with correct checkpoint path
checkpoint = torch.load("/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Checkpoints/V4/epoch=3535-val_loss=1890.43.ckpt")
student_train_module.load_state_dict(checkpoint['state_dict'])

local_net = student_train_module.student_model

# Save the state dictionaries of the individual models
torch.save(local_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Trained Models/V4/local_net_finetuned_v4.pth')
