# Brain MRI

### *Run these cells only when in Google Colab*

In [None]:

# Clone the repository
!git clone https://github.com/beerger/mad_seminar_ws23.git
# Move all content to the current directory
!mv ./mad_seminar_ws23/* ./
# Remove the empty directory
!rm -rf mad_seminar_ws23/

In [None]:
# # Download the data
!wget https://syncandshare.lrz.de/dl/fiH6r4B6WyzAaxZXTEAYCE/data.zip
# # Extract the data
!unzip -q ./data.zip

In [None]:
# Install additional packages
!pip install pytorch_lightning --quiet
!pip install lpips

In [None]:
import pytorch_lightning as pl
import yaml
import torch
import matplotlib.pyplot as plt
from google.colab import drive
import os
import pandas as pd

from model.local_net import LocalNet
from utils.model_utils import load_resnet_18_teacher_model
from training_module.student_training_module import StudentTrainingModule
from data_loader.localnet_data_loader import LocalNetDataModule
from model.one_layer_decoder import OneLayerDecoder
from model.dad import DADHead
from model.global_net import GlobalNet
from data_loader.joint_training_data_loader import JointTrainingDataModule
from training_module.joint_training_module import JointGlobalDADTrainingModule
from anomaly_detector import AnomalyDetector
from data_loader.test_data_loader import TestDataModule


# autoreload imported modules
%load_ext autoreload
%autoreload 2

Load train/val image paths

In [None]:
split_dir = "./data/splits"

train_csv_ixi = os.path.join(split_dir, 'ixi_normal_train.csv')
train_csv_fastMRI = os.path.join(split_dir, 'normal_train.csv')
val_csv = os.path.join(split_dir, 'normal_val.csv')
# Load csv files
train_files_ixi = pd.read_csv(train_csv_ixi)['filename'].tolist()
train_files_fastMRI = pd.read_csv(train_csv_fastMRI)['filename'].tolist()
val_files = pd.read_csv(val_csv)['filename'].tolist()
# Combine files
train_file_paths = train_files_ixi + train_files_fastMRI
val_file_paths = val_files

print(f"Using {len(train_files_ixi)} IXI images "
      f"and {len(train_files_fastMRI)} fastMRI images for training. "
      f"Using {len(val_files)} images for validation.")

# Ensure that it's file paths
print(train_file_paths)
print(val_file_paths)

Load test image paths

In [None]:
# Specify the directory you want to list
train_directory = '/content/zipper/test/'

test_file_paths = []
for root, dirs, files in os.walk(train_directory, topdown=False):
   for name in files:
      test_file_paths.append(os.path.join(root, name))

assert len(test_file_paths) == 151
print(test_file_paths)


# 1. Fine-tune Local-Net

## Load the config

In [None]:
with open('./configs/local_net_fine_tune.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize data

Mount current Colab session to Google Drive

In [None]:
# Will provide you with an authentication link
drive.mount('/content/drive')

Create data loader

In [None]:
data_module = LocalNetDataModule(
    train_file_paths, 
    val_file_paths, 
    batch_size=config['batch_size'], 
    num_workers=4, 
    caching_strategy='none'
)

In [None]:
# make sure batch_size in data_module is equal to BATCH_SIZE

BATCH_SIZE=config['batch_size']

# Reverse the normalization process done by LocalNetDataModule
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]


# Retrieve one batch of images
patch_local, patch_resnet = next(iter(data_module.train_dataloader()))

# Denormalize the patches for visualization
patch_local = denormalize(patch_local)
patch_resnet = denormalize(patch_resnet)

fig, ax = plt.subplots(2, BATCH_SIZE, figsize=(20, 8))  # 2 rows, BATCH_SIZE columns

# Plotting patch_local images in the first row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_local[i].permute(1, 2, 0)

    # Display the image
    ax[0, i].imshow(image.cpu().numpy())
    ax[0, i].axis('off')

# Plotting patch_resnet images in the second row
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patch_resnet[i].permute(1, 2, 0)

    # Display the image
    ax[1, i].imshow(image.cpu().numpy())
    ax[1, i].axis('off')

plt.show()

## Set up tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Set up all models for fine-tuning

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load distilled local_net from Google Drive
local_net = LocalNet().to(device)
# Load the state dictionary from the saved file
local_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Trained Models/V2/local_net_distilled_v2.pth', map_location=device)
# Update the local_net model's state dictionary
local_net.load_state_dict(local_state_dict)

resnet_18 = load_resnet_18_teacher_model('resnet18-5c106cde.pth', device)
decoder = OneLayerDecoder(config['local_net_output_dimensions'], 
                          config['resnet_output_dimensions']).to(device)

decoder_state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Distillation/Trained Models/V2/decoder_v2.pth')

decoder.load_state_dict(decoder_state_dict)

student_train_module = StudentTrainingModule(
    config, 
    student_model=local_net, 
    teacher_model=resnet_18, 
    decoder=decoder, 
    mode='finetuning'
)

## Calculate number of epochs

In [None]:
# Given by paper is batch size of 64 for 50k iterations
# Need to calculate max_epochs
total_iterations = config['iterations']
batch_size = config['batch_size']
num_training_images = len(train_file_paths)
# Calculate max_epochs
max_epochs = total_iterations / (num_training_images / batch_size)
max_epochs = int(max_epochs) + (max_epochs % 1 > 0)  # round up if not an integer
print(f"Calculated max_epochs: {max_epochs}")

## Create callbacks for training

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Setup the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/X-ray/Checkpoints/V1",  # Path where checkpoints will be saved
    filename="{epoch}-{val_loss:.2f}",  # Filename template
    monitor="val_loss",  # Metric to monitor for saving
    every_n_epochs=1,  # Save every epoch
    save_weights_only=True,  # If True, save only the model weights, not the full model
    save_top_k=3,  # Save the top 3 checkpoints based on val_loss
    save_last=True,  # Also save the last checkpoint to resume training later
    verbose=True  # If True, print a message to stdout for each save
)

## Setup new trainer

In [None]:

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Run fine-tuning

In [None]:
trainer.fit(student_train_module, datamodule=data_module)

## Save model by first loading given checkpoint

In [None]:
local_net = LocalNet()
resnet_18 = load_resnet_18_teacher_model('resnet18-5c106cde.pth', device)
decoder = OneLayerDecoder(128, 512)

student_train_module = StudentTrainingModule(
    config, 
    student_model=local_net, 
    teacher_model=resnet_18, 
    decoder=decoder, 
    mode='finetuning'
)

# Replace with correct checkpoint path
checkpoint = torch.load("/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Checkpoints/V4/epoch=3535-val_loss=1890.43.ckpt")
student_train_module.load_state_dict(checkpoint['state_dict'])

local_net = student_train_module.student_model

# Save the state dictionaries of the individual models
torch.save(local_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/Trained Models/V4/local_net_finetuned_v4.pth')


# 2. Train Global-Net and DAD-head

## Load the config

In [None]:
with open('./configs/global_dad_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize data

Mount current Colab session to Google Drive

In [None]:
# Will provide you with an authentication link
drive.mount('/content/drive')

In [None]:
data_module = JointTrainingDataModule(
    train_file_paths, 
    val_file_paths, 
    batch_size=config['batch_size'], 
    num_workers=2, 
    caching_strategy='at-init'
)

print(f"Number of training images: {len(train_file_paths)}")
print(f"Number of validation images: {len(val_file_paths)}")

Visualize patches, images, binary masks

In [None]:
# make sure batch_size in data_module is equal to BATCH_SIZE

BATCH_SIZE=4

# Reverse the normalization process done by LocalNetDataModule
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]


# Retrieve one batch of images
I, patches, binary_masks, labels = next(iter(data_module.train_dataloader()))

# Denormalize the patches for visualization
patches = denormalize(patches)

fig1, ax1 = plt.subplots(1, BATCH_SIZE, figsize=(20, 8))

# Plotting patches
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = patches[i].permute(1, 2, 0)

    # Display the image
    ax1[i].imshow(image.cpu().numpy())
    ax1[i].axis('on')

fig2, ax2 = plt.subplots(1, BATCH_SIZE, figsize=(20, 8))

# Plotting images
for i in range(BATCH_SIZE):
    # Permute the tensor to the format (H, W, C)
    image = I[i].permute(1, 2, 0)

    # Display the image
    ax2[i].imshow(image.cpu().numpy())
    ax2[i].axis('on')

fig3, ax3 = plt.subplots(1, BATCH_SIZE, figsize=(20, 8))
for i in range(BATCH_SIZE):
    # Squeeze the tensor to 2D [H, W] if it's 3D [1, H, W]
    mask = binary_masks[i].squeeze()

    # Display the mask
    ax3[i].imshow(mask.cpu().numpy(), cmap='gray', interpolation='none')
    ax3[i].axis('on')

plt.show()

print(labels)

## Set up tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Set up all models for training

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load fine-tuned local_net from Google Drive
local_net = LocalNet().to(device)
# Load the state dictionary from the saved file
state_dict = torch.load('/content/drive/MyDrive/AnomalyDetection/LocalNet/Fine-tuning/X-ray/Trained Models/V1/local_net_finetuned_xray_v1.pth', map_location=device)
# Update the local_net model's state dictionary
local_net.load_state_dict(state_dict)

global_net = GlobalNet().to(device)
dad_head = DADHead().to(device)

joint_train_module = JointGlobalDADTrainingModule(
    config, 
    local_net=local_net, 
    global_net=global_net, 
    dad_head=dad_head,  
)

## Calculate number of epochs

In [None]:
# Given by paper is batch size of 64 for 50k iterations
# Need to calculate max_epochs
total_iterations = config['iterations']
batch_size = config['batch_size']
num_training_images = len(train_file_paths)
# Calculate max_epochs
max_epochs = total_iterations / (num_training_images / batch_size)
max_epochs = int(max_epochs) + (max_epochs % 1 > 0)  # round up if not an integer
print(f"Calculated max_epochs: {max_epochs}")

## Create callbacks for training

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Setup the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Checkpoints/V3",  # Path where checkpoints will be saved
    filename="{epoch}-{val_loss:.2f}",  # Filename template
    monitor="val_loss",  # Metric to monitor for saving
    every_n_epochs=1,  # Save every epoch
    save_weights_only=True,  # If True, save only the model weights, not the full model
    save_top_k=3,  # Save the top 3 checkpoints based on val_loss
    save_last=True,  # Also save the last checkpoint to resume training later
    verbose=True  # If True, print a message to stdout for each save
)

## Setup new trainer

In [None]:

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./')
    ]
)

## Run joint training (and save trained models)

In [None]:
trainer.fit(joint_train_module, datamodule=data_module)
torch.save(global_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Trained Models/V3/global_net_v3.pth')
torch.save(dad_head.state_dict(), '/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/Trained Models/V3/dad_head_v3.pth')

If you want to load a certain checkpoint and save the models from that checkpoint

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load fine-tuned local_net from Google Drive
local_net = LocalNet().to(device)
global_net = GlobalNet().to(device)
dad_head = DADHead().to(device)

joint_train_module = JointGlobalDADTrainingModule(
    config, 
    local_net=local_net, 
    global_net=global_net, 
    dad_head=dad_head,  
)

# Step 2: Load the checkpoint
# Replace '/path/to/checkpoint.ckpt' with the path to your checkpoint file
checkpoint = torch.load("/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/X-ray/Checkpoints/V1/epoch=5481-val_loss=0.09021.ckpt", map_location=device)
joint_train_module.load_state_dict(checkpoint['state_dict'])

# Step 3: Extract and save the individual models
# Assuming `global_net` and `dad_head` are attributes of your joint module
global_net = joint_train_module.global_net
dad_head = joint_train_module.dad_head

# Save the state dictionaries of the individual models
torch.save(global_net.state_dict(), '/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/X-ray/Trained Models/V1/global_net_xray_v1.pth')
torch.save(dad_head.state_dict(), '/content/drive/MyDrive/AnomalyDetection/GlobalNet_DAD/X-ray/Trained Models/V1/dad_head_xray_v1.pth')


# 3. Inference, and evaluating

Create dataloader

In [None]:
split_dir = "/content/data/splits"

# omitting batch size allows for batch_size = size of test data 
# (either all images, or images from one specific pathology)
# this makes the evaluation easier
test_data_module = TestDataModule(
    split_dir, 
)


Display inference images

In [None]:
IMAGES_TO_DISPLAY=4

# ignore the pos_mask
images, _ = next(iter(test_data_module.test_dataloader_all()))

fig, ax = plt.subplots(1, IMAGES_TO_DISPLAY, figsize=(10, 10))

for i in range(IMAGES_TO_DISPLAY):
  ax[i].imshow(images[i].permute(1, 2, 0).cpu().numpy())

plt.show()

Check that AnomalyDetector creates correct patches and binary masks

In [None]:
anomaly_detector = AnomalyDetector(None, None, None, none=True)
patches, binary_masks = anomaly_detector.create_patches_and_masks(images[0])

# Reverse the normalization process done by data module
# to avoid the following error:
# WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean  # Reverses the normalization in-place
    return tensor.clamp(0, 1)  # Ensures the pixel values are within [0, 1]

# Denormalize
for i, patch in enumerate(patches):
  patches[i] = denormalize(patch)

# Plot image
image = images[0].squeeze(0)
plt.imshow(image.permute(1, 2, 0).cpu().numpy())
plt.show()
  
# Plot all (overlapping) patches
fig1, ax1 = plt.subplots(20, 20, figsize=(20, 20))
for j in range(20):
  for i in range(20):
    ax1[j][i].imshow(patches[i+j*20].squeeze(0).permute(1, 2, 0).cpu().numpy())
    ax1[j][i].axis('off')

# Plot all (overlappin) binary masks
fig2, ax2 = plt.subplots(20, 20, figsize=(20, 20))
for j in range(20):
  for i in range(20):
    ax2[j][i].imshow(binary_masks[i+j*20].squeeze(0).cpu().numpy(), cmap='gray', interpolation='none')
    #ax2[j][i].axis('off')

plt.show()

Create the Anomaly Detector

In [None]:
# Assuming 'device' is either 'cuda' if a GPU is available, otherwise 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load all trained models
local_state_dict = torch.load('/content/pretrained_models/local_net_finetuned_xray_v1.pth', map_location=device)
global_state_dict = torch.load('/content/pretrained_models/global_net_xray_v1.pth', map_location=device)
dad_state_dict = torch.load('/content/pretrained_models/dad_head_xray_v1.pth', map_location=device)

# Initialise all networks
local_net = LocalNet().to(device)
global_net = GlobalNet().to(device)
dad_head = DADHead().to(device)

# Update all network's state dictionaries
local_net.load_state_dict(local_state_dict)
global_net.load_state_dict(global_state_dict)
dad_head.load_state_dict(dad_state_dict)

anomaly_detector = AnomalyDetector(local_net, global_net, dad_head)


Get anomaly scores for each image in batch

In [None]:
anomaly_score_maps = anomaly_detector.detect_anomalies(images)

Visualize blended images

In [None]:
for image, anomaly_map in zip(images, anomaly_score_maps):
  anomaly_detector.visualize_anomaly(image, anomaly_map, alpha=0.7, cmap=plt.cm.plasma)

Display images and binary masks for one specific pathology

In [None]:
images, gt_masks = next(iter(test_data_module.test_dataloader("resection")))

SIZE = len(images)

fig, ax = plt.subplots(2, SIZE, figsize=(15, 3))

for i in range(SIZE):
    ax[0][i].imshow(images[i].permute(1, 2, 0).cpu().numpy())
    ax[1][i].imshow(gt_masks[i].permute(1, 2, 0).cpu().numpy(), cmap='gray')
    ax[0][i].axis('off')
    ax[1][i].axis('off')

plt.show()

Display images and binary masks for all pathologies

In [None]:

images, gt_masks = next(iter(test_data_module.test_dataloader_all()))

fig, ax = plt.subplots(2, 20, figsize=(25, 2))

for i in range(20):
    ax[0][i].imshow(images[i].permute(1, 2, 0).cpu().numpy())
    ax[1][i].imshow(gt_masks[i].permute(1, 2, 0).cpu().numpy(), cmap='gray')
    ax[0][i].axis('off')
    ax[1][i].axis('off')
    
plt.show()

Run evaluation on all pathologies

In [None]:
images, gt_masks = next(iter(test_data_module.test_dataloader_all()))
results = anomaly_detector.evaluate_performance(images, gt_masks)

Plot Precision-Recall curve and Receiver Operating Characteristics curve

In [None]:
precision, recall = results['PRC']
prc_auc = results['AUPRC']
anomaly_detector.plot_prc(precision, recall, prc_auc, "all pathologies")

fpr = results['FPR']
tpr = results['TPR']
roc_auc = results['AUROC']
anomaly_detector.plot_roc(fpr, tpr, roc_auc, "all pathologies")

Get AUROC/AUPRC per pathology

In [None]:
pathologies = [
        'absent_septum',
        'artefacts',
        'craniatomy',
        'dural',
        'ea_mass',
        'edema',
        'encephalomalacia',
        'enlarged_ventricles',
        'intraventricular',
        'lesions',
        'mass',
        'posttreatment',
        'resection',
        'sinus',
        'wml',
        'other'
    ]

results_json = []

for pathology in pathologies:
    images, gt_masks = next(iter(test_data_module.test_dataloader(pathology)))
    results = anomaly_detector.evaluate_performance(images, gt_masks)
    results_json.append({'pathology': pathology, 'auroc': results['AUROC'], 'auprc': results['AUPRC']})

for result in results_json:
    print(f"Pathology: {result['pathology']}")
    print(f"AUROC: {result['auroc']}")
    print(f"AUPRC: {result['auprc']}")


Create histograms

In [None]:
images, gt_masks = next(iter(test_data_module.test_dataloader_all()))
results = anomaly_detector.detect_anomalies(images.to(device))

anomaly_detector.plot_histogram(results, gt_masks)