# Unsupervised Anomaly Detection on fastMRI

### *Run these cells only when in Google Colab*

In [None]:
# # Clone the repository
# !git clone https://github.com/compai-lab/mad_seminar_s23.git
# # Move all content to the current directory
# !mv ./mad_seminar_s23/* ./
# # Remove the empty directory
# !rm -rf mad_seminar_s23/
# # Download the data
# !wget <link you got from your supervisor> -P ./data/
# # Extract the data
# !unzip -q ./data/data.zip -d ./data/

In [None]:
# !pip install pytorch_lightning --quiet

In [None]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import yaml

from model.model import AutoencoderModel
from data_loader import TrainDataModule, get_all_test_dataloaders

# autoreload imported modules
%load_ext autoreload
%autoreload 2

## Load the config

In [None]:
with open('./configs/autoencoder_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize the data

In [None]:
train_data_module = TrainDataModule(
    data_dir=config['train_data_dir'],
    target_size=config['target_size'],
    batch_size=config['batch_size'])

# Plot some images
batch = next(iter(train_data_module.train_dataloader()))

# Print statistics
print(f"Batch shape: {batch.shape}")
print(f"Batch min: {batch.min()}")
print(f"Batch max: {batch.max()}")

fig, ax = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    ax[i].imshow(batch[i].squeeze(), cmap='gray')
    ax[i].axis('off')
plt.show()

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Prepare model

In [None]:
# Init model
model = AutoencoderModel(config)

# Use tensorboard logger and CSV logger
trainer = pl.Trainer(
    max_epochs=config['num_epochs'],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./'),
        pl.loggers.CSVLogger(save_dir='./')
    ],
)

## Run training

In [None]:
# Train the model
trainer.fit(model, datamodule=train_data_module)

## Evaluation

In [None]:
# Reconstructions from the validation set
batch = next(iter(train_data_module.val_dataloader()))

with torch.inference_mode():
    results = model.detect_anomaly(batch)
    reconstructions = results['reconstruction']

# Plot images and reconstructions
fig, ax = plt.subplots(2, 5, figsize=(15, 5))
for i in range(5):
    ax[0][i].imshow(batch[i].squeeze(), cmap='gray')
    ax[0][i].axis('off')
    ax[1][i].imshow(reconstructions[i].squeeze(), cmap='gray')
    ax[1][i].axis('off')

## Visualize pathology and labels

In [None]:
# Get test dataloaders
test_dataloaders = get_all_test_dataloaders(config['split_dir'], config['target_size'], config['batch_size'])

In [None]:
diseases = ['absent_septum', 'edema', 'enlarged_ventricles', 'mass', 'dural']
fig, ax = plt.subplots(3, len(diseases), figsize=(15, 5))
for i in range(len(diseases)):
  batch = next(iter(test_dataloaders[diseases[i]]))
  inputs, pos_labels, neg_masks = batch
  ax[0][i].imshow(inputs[i].squeeze(), cmap='gray')
  ax[0][i].axis('off')
  ax[1][i].imshow(pos_labels[i].squeeze(), cmap='gray')
  ax[1][i].axis('off')
  ax[2][i].imshow(neg_masks[i].squeeze(), cmap='gray')
  ax[2][i].axis('off')
  ax[0][i].set_title(diseases[i])

In [None]:
from evaluate import Evaluator 

evaluator = Evaluator(model, model.device, test_dataloaders['mass'])

fig_metrics, metrics = evaluator.evaluate()