## Imports

In [None]:
from argparse import Namespace

import numpy as np
import torch
from torch.utils.data import DataLoader

from mednist import (
    MEDNISTDIR,
    download_mednist,
    get_mednist_files,
    MedNISTDataset,
    MedNISTTestDataset
)
from models import Autoencoder
from utils import plot

# Select device to train on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Autoreload modules without restarting the kernel
%load_ext autoreload
%autoreload 2

## Config

In [None]:
config = Namespace()
config.batch_size = 128
config.val_split = 0.8
config.test_split = 0.9
config.device = device

## Download MedNIST and create DataLoader

In [None]:
# Download if necessary
download_mednist(MEDNISTDIR)

# Get all HeadCT and Hand files
head_ct_files = get_mednist_files(MEDNISTDIR, 'HeadCT')
hand_files = get_mednist_files(MEDNISTDIR, 'Hand')

# Create a training / validation / test split
val_split_idx = int(len(head_ct_files) * config.val_split)
test_split_idx = int(len(head_ct_files) * config.test_split)

# Take 8000 HeadCT images for training
train_files = head_ct_files[:val_split_idx]

# Take 1000 HeadCT images for validation
val_files = head_ct_files[val_split_idx:test_split_idx]

test_files_1 = head_ct_files[test_split_idx:]  # Take 1000 headCT images
test_files_2 = hand_files[test_split_idx:]  # And 1000 hand images for test
test_labels_1 = [0 for _ in range(len(test_files_1))]  # HeadCT are in-distribution -> 0
test_labels_2 = [1 for _ in range(len(test_files_2))]  # Hand are out-of-distribution -> 1
test_files = test_files_1 + test_files_2
test_labels = test_labels_1 + test_labels_2

# Create a training dataset with HeadCT files
train_ds = MedNISTDataset(train_files)
trainloader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)

# Create a validation dataset with HeadCT files
val_ds = MedNISTDataset(val_files)
valloader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False)

# Create a test dataset with HeadCT and Hand files
test_ds = MedNISTTestDataset(test_files, test_labels)
testloader = DataLoader(test_ds, batch_size=config.batch_size, shuffle=True)

print('Train dataset size:', len(train_ds))
print('Batches in trainloader:', len(trainloader))

### Show some images

In [None]:
plot([img for img in next(iter(trainloader))[:10, 0]])
plot([img for img in next(iter(valloader))[:10, 0]])
imgs, labels = next(iter(testloader))
plot([img for img in imgs[:10, 0]], titles=labels[:10])

## Create an Autoencoder

In [None]:
config.latent_dim = 128
ae = Autoencoder(latent_dim=config.latent_dim)
print(ae)

In [None]:
def ae_train_step(ae, x, optimizer, device):
    ae.train()
    optimizer.zero_grad()
    x = x.to(device)
    x_recon = ae(x)
    loss = ae.loss_function(x, x_recon)  # MSE loss
    loss.backward()
    optimizer.step()
    return loss.item()


def ae_val_step(ae, x, device):
    ae.eval()
    x = x.to(device)
    with torch.no_grad():
        x_recon = ae(x)
    return ae.loss_function(x, x_recon).item(), x_recon

In [None]:
 def train_ae(config, ae, optimizer, trainloader, valloader):
    i_step = 0
    i_epoch = 0
    all_losses = []
    losses = []
    ae.train()
    while True:
        for x in trainloader:
            # Train step
            loss = ae_train_step(ae, x, optimizer, config.device)

            # Store metrics
            losses.append(loss)
            all_losses.append(loss)

            # Log
            if i_step % config.log_frequency == 0:
                print(f'Iteration {i_step} - train loss {np.mean(losses):.4f}')
                losses = []

            # Validate
            if i_step % config.val_frequency == 0:
                val_loss, x_val, x_recon = validate_ae(config, ae, valloader)
                print(f'Iteration {i_step} - val loss {val_loss:.4f}')
                residual = torch.abs(x_val - x_recon)
                plot([x_val[0, 0], x_recon[0, 0], residual[0, 0]],
                     titles=['input', 'reconstruction', 'residual'])

            # Finish
            i_step += 1
            if i_step >= config.num_steps:
                print('Finished training')
                return
        i_epoch += 1
        print(f'Finished epoch {i_epoch}')


def validate_ae(config, ae, valloader):
    losses = []
    for x in valloader:
        loss, x_recon = ae_val_step(ae, x, config.device)
        losses.append(loss)
    return np.mean(losses), x.cpu(), x_recon.cpu()

In [None]:
# Train config
config.lr = 1e-3
config.num_steps = 1000
config.log_frequency = 10
config.val_frequency = 100

# Re-initialize Autoencoder
ae = Autoencoder(latent_dim=config.latent_dim).to(device)

# Optimizer
optimizer = torch.optim.Adam(ae.parameters(), lr=config.lr)

# Train
print('Start training...')
train_ae(config, ae, optimizer, trainloader, valloader)

In [None]:
# Testing

def ae_test_step(ae, x, device):
    ae.eval()
    x = x.to(device)
    with torch.no_grad():
        x_recon = ae(x)
    return x, x_recon, torch.abs(x - x_recon)

def test_ae(config, ae, testloader):
    scores = []
    labels = []
    ae.eval()
    for x, y in testloader:
        x, x_recon, residual = ae_test_step(ae, x, config.device)
        anomaly_score = torch.mean(residual, dim=(1, 2, 3))
        scores.extend(anomaly_score.cpu().numpy())
        labels.extend(y.numpy())

    return scores, labels

scores, labels = test_ae(config, ae, testloader)

In [None]:
# Evaluation
from sklearn.metrics import roc_auc_score  # Only quick test, remove later
auroc = roc_auc_score(labels, scores)
print(auroc)