In [None]:
## General imports
import os

import numpy as np
import torch
import torch.nn as nn

In [None]:
## Notebook config
use_saved_model_if_exists = False
save_trained_model = True
model_saving_path = "./models/cnn_autoencoder.pt"

if not os.path.exists("./models"):
    os.mkdir("./models")

In [None]:
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
is_device_cuda = device=="cuda"
print(f"Device is {device}")

In [None]:
from data_loading import get_MNIST_train_validation_test_dataloaders

train_dataloader, _, test_dataloader = get_MNIST_train_validation_test_dataloaders(
    batch_size_train=64,
    batch_size_test=1,
    train_split=0.0, # Don't use a validation set
    use_cuda=is_device_cuda,
    filepath="./../datasets"
)

In [None]:
## Model training
from autoencoder import get_cnn_mnist_autoencoder
from training import train_autoencoder

model = get_cnn_mnist_autoencoder()
if use_saved_model_if_exists and os.path.exists(model_saving_path):
    model.load_state_dict(torch.load(model_saving_path))
    model.to(device)
else:
    model.to(device)
    train_autoencoder(
        model=model,
        optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
        criterion=nn.MSELoss(),
        dataloader=train_dataloader,
        epochs=30,
        device=device,
    )

    if save_trained_model:
        torch.save(model.state_dict(), model_saving_path)


In [None]:
## Model evaluation
from evaluation import get_autoencoder_original_reconstructed_pairs, plot_original_reconstructed_per_class_grayscale

originals, reconstructed, labels = get_autoencoder_original_reconstructed_pairs(model, test_dataloader, device)

mse = np.mean(np.square(originals - reconstructed))
print(f"MSE is {mse}")

plot_original_reconstructed_per_class_grayscale(
    originals, reconstructed, torch.Tensor(labels), subplot_shape=(4,6), figsize=(24,16)
)