In [None]:
from typing import TypeVar

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset, random_split
from torch import nn
from tqdm.notebook import tqdm

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    print('WARNING: not using CUDA')


class MNISTDataSet(MNIST):
    def __init__(self, train: bool):
        super().__init__(
            '../.data',
            train=train,
            download=True,
            transform=transforms.ToTensor(),
        )


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Use the same encoder architecture as in the paper:
        conv_layers = [
            nn.Conv2d(1, 16, kernel_size=1, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
        ]

        modules = []
        for conv_layer in conv_layers:
            modules.extend((
                conv_layer,
                nn.BatchNorm2d(conv_layer.out_channels),
                nn.ReLU(),
            ))

        self.sequence = nn.Sequential(*modules)

    def forward(self, x):
        return self.sequence(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        deconv_layers = [
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, output_padding=1),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, output_padding=1),
        ]

        modules = []
        for deconv_layer in deconv_layers[:-1]:
            modules.extend((
                deconv_layer,
                nn.BatchNorm2d(num_features=deconv_layer.out_channels),
                nn.ReLU(),
            ))

        modules.append(deconv_layers[-1])
        self.sequence = nn.Sequential(*modules)

    def forward(self, x):
        return self.sequence(x)


class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.decoder(self.encoder(x))

T = TypeVar('T')

def extract_single_target_class(dataset: Dataset[T], target, targets=None) -> tuple[Subset[T], Subset[T]]:
    """Returns a subset of dataset containing only samples with target class, and a subset containing all other samples."""
    indices = np.arange(len(dataset))
    mask = dataset.targets == target
    return Subset(dataset, indices[mask]), Subset(dataset, indices[~mask])

NORMAL_TARGET_CLASS = 3

testing_normal_set, testing_outlier_set = extract_single_target_class(
    MNISTDataSet(train=False), NORMAL_TARGET_CLASS
)
training_normal_set, training_outlier_set = extract_single_target_class(
    MNISTDataSet(train=True), NORMAL_TARGET_CLASS
)

# Further split the 'training' set into
training_set, testing_normal_set_from_training_set = random_split(
    training_normal_set, [2/3, 1/3]
)

testing_normal_set = ConcatDataset((testing_normal_set, testing_normal_set_from_training_set))
testing_outlier_set = ConcatDataset((testing_outlier_set, training_outlier_set))
testing_set = ConcatDataset((testing_normal_set, testing_outlier_set))

print(f'Training set (normal only):        {len(training_set)} samples')
print(f'Testing set (normal only):         {len(testing_normal_set)} samples')
print(f'Testing set (outliers only):       {len(testing_outlier_set)} samples')
print(f'Testing set (normal and outliers): {len(testing_set)} samples')
print(f'Total:                             {len(training_set) + len(testing_set)} samples')

training_loader = DataLoader(training_set, batch_size=32, shuffle=False)

# Show first 5 samples from training and testing outlier sets
axes = plt.subplots(2, 5)[1]
for ax, image in zip(axes[0], training_set):
    ax.imshow(image[0][0])
    ax.axis('off')
for ax, image in zip(axes[1], testing_outlier_set):
    ax.imshow(image[0][0])
    ax.axis('off')
plt.tight_layout()

display(AutoEncoder())

output_shape = AutoEncoder()(training_set[0][0].reshape(1, 1, 28, 28)).shape[1:]
assert output_shape == (1, 28, 28), output_shape

In [None]:
class ModelTrainer:
    def __init__(
        self, *,
        model: nn.Module,
        dataloader: DataLoader,
        loss_fn: nn.Module,
        optimizer: torch.optim.Optimizer,
        num_epoch: int,
    ):
        self.model = model
        self.dataloader = dataloader
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.num_epoch = num_epoch

    def train(self) -> nn.Module:
        self.model.train(True)
        with tqdm(
            range(self.num_epoch),
            total=self.num_epoch,
            desc='Training',
            unit='epoch',
        ) as pbar:
            for epoch in pbar:
                avg_loss = self._train_one_epoch(epoch, self.num_epoch)
                pbar.set_postfix({'Avg. loss': avg_loss})

        return self.model

    def _train_one_epoch(self, epoch_index, num_epoch):
            # Adapted for autoencoder from:
            # https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

            training_loss = 0
            with tqdm(
                self.dataloader,
                desc=f'Epoch {epoch_index + 1}/{num_epoch}',
                unit='batch',
                leave=False,
            ) as pbar:
                for data, _ in pbar:
                    data = data.to(device) # not sure if this is the best way to do this?
                    self.optimizer.zero_grad()
                    reconstructed = self.model(data)
                    loss = self.loss_fn(reconstructed, data)
                    loss.backward()
                    self.optimizer.step()
                    training_loss += loss.item()
                    pbar.set_postfix({'Loss': loss.item()})

            return training_loss / len(training_loader)

model=AutoEncoder().to(device)
model = ModelTrainer(
    model=model,
    dataloader=training_loader,
    loss_fn=nn.MSELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.0001),
    num_epoch=60,
).train()

In [None]:
def plot_reconstruction(image):
    model.eval()

    image = torch.unsqueeze(image, 0)
    with torch.no_grad():
        recon = model(image.to(device))

    fig, axes = plt.subplots(1, 2)
    axes[0].imshow(image[0][0])
    axes[0].set_title('Original')
    axes[1].imshow(recon[0][0].cpu().numpy())
    axes[1].set_title('Reconstructed')
    for ax in axes:
        ax.axis('off')

plot_reconstruction(training_set[0][0])
plot_reconstruction(testing_normal_set[10][0])
plot_reconstruction(testing_outlier_set[10][0])
plot_reconstruction(testing_outlier_set[20][0])


In [None]:
import numpy as np

def validate_model(model):
    y_score = []
    y_true = []
    model.eval()
    testing_loader = DataLoader(testing_set, batch_size=256)
    for data, target in tqdm(testing_loader, total=len(testing_loader), unit='batch'):
        data = data.to(device)
        with torch.no_grad():
            recon = model(data)
            true = (target == NORMAL_TARGET_CLASS).cpu().numpy()

            # Lower MSE = higher score, so negate the value
            score = -((recon - data)**2).mean(dim=(1, 2, 3)).cpu().numpy()

        y_true.extend(true)
        y_score.extend(score)

    return np.asarray(y_true), np.asarray(y_score)

y_true, y_score = validate_model(model)

In [None]:
from sklearn.metrics import roc_curve, roc_auc_score

plt.plot(
    *roc_curve(y_true, y_score)[:2],
    label=f'AE (AUC = {roc_auc_score(y_true, y_score):g})'
)
plt.legend()
plt.gca().set_box_aspect(1)