# Unsupervised Anomaly Detection on fastMRI

### *Run these cells only when in Google Colab*

In [None]:
# # Clone the repository
# !git clone https://github.com/compai-lab/mad_seminar_s23.git
# # Move all content to the current directory
# !mv ./mad_seminar_s23/* ./
# # Remove the empty directory
# !rm -rf mad_seminar_s23/
# # Download the data
# !wget https://syncandshare.lrz.de/dl/fiNRBX6FN1Vh67NqugzqKo/data.zip -P ./data/
# # Extract the data
# !unzip -q ./data/data.zip -d ./data/

In [None]:
# # Install pytorch-lightning
# !pip install -q pytorch-lightning

## Imports

In [None]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import yaml

from trainer.trainer import AutoencoderModel
from data_loader import TrainDataModule

# autoreload imported modules
%load_ext autoreload
%autoreload 2

## Load the config

In [None]:
with open('./configs/autoencoder_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

## Load and visualize the data

In [None]:
train_data_module = TrainDataModule(
    data_dir=config['train_data_dir'],
    target_size=config['target_size'],
    batch_size=config['batch_size'])

# Plot some images
batch = next(iter(train_data_module.train_dataloader()))

fig, ax = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    ax[i].imshow(batch[i].squeeze(), cmap='gray')
    ax[i].axis('off')
plt.show()

## Prepare model

In [None]:
# Init model
model = AutoencoderModel(config)

max_epochs = config['num_epochs']
trainer = pl.Trainer(max_epochs=max_epochs)

## Run training

In [None]:
# Train the model
trainer.fit(model, datamodule=train_data_module)

## Evaluation

In [None]:
batch = next(iter(train_data_module.train_dataloader()))

with torch.inference_mode():
    results = model.detect_anomaly(batch)
    reconstructions = results['reconstruction']

# Plot images and reconstructions
fig, ax = plt.subplots(2, 5, figsize=(15, 5))
for i in range(5):
    ax[0][i].imshow(batch[i].squeeze(), cmap='gray')
    ax[0][i].axis('off')
    ax[1][i].imshow(reconstructions[i].squeeze(), cmap='gray')
    ax[1][i].axis('off')

In [None]:
from data_loader import get_all_test_dataloaders

test_dataloaders = get_all_test_dataloaders(config['split_dir'], config['target_size'], config['batch_size'])

In [None]:
test_dataloaders

In [None]:
batch = next(iter(test_dataloaders['absent_septum']))