In [None]:
from typing import TypeVar

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torch import nn
from tqdm.notebook import tqdm

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    print('WARNING: not using CUDA')


class MNISTDataSet(MNIST):
    def __init__(self, train: bool):
        super().__init__(
            '../.data',
            train=train,
            download=True,
            transform=transforms.ToTensor(),
        )

class MNISTDataLoader(DataLoader):
    def __init__(self, dataset: Dataset):
        super().__init__(
            dataset,
            batch_size=32,
            shuffle=True,
        )


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Use the same encoder architecture as in the paper:
        conv_layers = [
            nn.Conv2d(1, 16, kernel_size=1, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
        ]

        modules = []
        for conv_layer in conv_layers:
            modules.extend((
                conv_layer,
                nn.BatchNorm2d(conv_layer.out_channels),
                nn.ReLU(),
            ))

        self.sequence = nn.Sequential(*modules)

    def forward(self, x):
        return self.sequence(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        deconv_layers = [
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, output_padding=1),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, output_padding=1),
        ]

        modules = []
        for deconv_layer in deconv_layers[:-1]:
            modules.extend((
                deconv_layer,
                nn.BatchNorm2d(num_features=deconv_layer.out_channels),
                nn.ReLU(),
            ))

        modules.append(deconv_layers[-1])
        self.sequence = nn.Sequential(*modules)

    def forward(self, x):
        return self.sequence(x)


class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.decoder(self.encoder(x))

T = TypeVar('T')

def extract_single_target_class(dataset: Dataset[T], target, targets=None) -> tuple[Subset[T], Subset[T]]:
    """Returns a subset of dataset containing only samples with target class, and a subset containing all other samples."""
    indices = np.arange(len(dataset))
    mask = dataset.targets == target
    return Subset(dataset, indices[mask]), Subset(dataset, indices[~mask])

NORMAL_TARGET_CLASS = 3

testing_dataset_normal, testing_dataset_outliers = extract_single_target_class(MNISTDataSet(train=False), NORMAL_TARGET_CLASS)
training_dataset, outliers_from_training_set = extract_single_target_class(MNISTDataSet(train=True), NORMAL_TARGET_CLASS)
print(f'Training set (normal only): {len(training_dataset)} samples')

training_loader = MNISTDataLoader(training_dataset)

# Show first 5 samples from training set
axes = plt.subplots(1, 5)[1]
for ax, image in zip(axes, training_dataset):
    ax.imshow(image[0][0])
    ax.axis('off')

display(AutoEncoder())

output_shape = AutoEncoder()(training_dataset[0][0].reshape(1, 1, 28, 28)).shape[1:]
assert output_shape == (1, 28, 28), output_shape

In [None]:
class ModelTrainer:
    def __init__(
        self, *,
        model: nn.Module,
        dataloader: DataLoader,
        loss_fn: nn.Module,
        optimizer: torch.optim.Optimizer,
        num_epoch: int,
    ):
        self.model = model
        self.dataloader = dataloader
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.num_epoch = num_epoch

    def train(self) -> nn.Module:
        self.model.train(True)
        with tqdm(
            range(self.num_epoch),
            total=self.num_epoch,
            desc='Training',
            unit='epoch',
        ) as pbar:
            for epoch in pbar:
                avg_loss = self._train_one_epoch(epoch, self.num_epoch)
                pbar.set_postfix({'Avg. loss': avg_loss})

        return self.model

    def _train_one_epoch(self, epoch_index, num_epoch):
            # Adapted for autoencoder from:
            # https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

            training_loss = 0
            with tqdm(
                self.dataloader,
                desc=f'Epoch {epoch_index + 1}/{num_epoch}',
                unit='batch',
                leave=False,
            ) as pbar:
                for data, _ in pbar:
                    data = data.to(device) # not sure if this is the best way to do this?
                    self.optimizer.zero_grad()
                    reconstructed = self.model(data)
                    loss = self.loss_fn(reconstructed, data)
                    loss.backward()
                    self.optimizer.step()
                    training_loss += loss.item()
                    pbar.set_postfix({'Loss': loss.item()})

            return training_loss / len(training_loader)

model=AutoEncoder().to(device)
model = ModelTrainer(
    model=model,
    dataloader=training_loader,
    loss_fn=nn.MSELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.0001),
    num_epoch=30,
).train()

In [None]:
def plot_reconstruction(image):
    model.eval()

    image = torch.unsqueeze(image, 0)
    with torch.no_grad():
        recon = model(image.to(device))

    axes = plt.subplots(1, 2)[1]
    axes[0].imshow(image[0][0])
    axes[0].set_title('Original')
    axes[1].imshow(recon[0][0].cpu().numpy())
    axes[1].set_title('Reconstructed')
    for ax in axes:
        ax.axis('off')

plot_reconstruction(training_dataset[0][0])
plot_reconstruction(outliers_from_training_set[10][0])
plot_reconstruction(testing_dataset_normal[10][0])
plot_reconstruction(testing_dataset_outliers[60][0])
