In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path
import itertools

import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

from uncertify.tutorials.auto_encoder import AutoEncoder
from uncertify.common import DATA_DIR_PATH

from typing import Tuple

In [None]:
def get_mnist_data_loaders(transform: transforms.Compose,
                             data_path: Path,
                             batch_size: int,
                             num_workers: int) -> Tuple[DataLoader, DataLoader]:
    train_set = torchvision.datasets.MNIST(root=data_path,
                                             train=True,
                                             download=True,
                                             transform=transform)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)

    test_set = torchvision.datasets.MNIST(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)
    return train_loader, test_loader

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

train_loader, test_loader = get_mnist_data_loaders(transform=transform,
                                                   data_path=DATA_DIR_PATH / 'mnist_data',
                                                   batch_size=4,
                                                   num_workers=4)

In [None]:
N_EPOCHS = 10
LEARNING_RATE = 0.01
PRINT_STEPS = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoEncoder(input_dim=784,
        latent_dim=128,
        encoder_hidden_dims=[512, 265])

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

def train(model, device, train_loader):
    model = model.to(device)
    for epoch_idx in range(N_EPOCHS):
        running_loss = 0.0
        for batch_idx, (batch_features, _) in enumerate(train_loader):
            batch_flat_features = batch_features.view(-1, 784).to(device)
            optimizer.zero_grad()
            outputs = model.forward(batch_flat_features)
            loss = criterion(outputs, batch_flat_features)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if (batch_idx + 1) % PRINT_STEPS == 0:
                print(f'epoch {epoch_idx + 1:<2} | batch {batch_idx + 1:5}  >>>  loss: {running_loss / PRINT_STEPS:.3f}')
                running_loss = 0.0
    return model

trained_model = train(model, device, train_loader)

In [None]:
def visualize_reconstructions(trained_model, test_loader, cmap='hot', n_batches=1):
    plt.set_cmap(cmap)
    with torch.no_grad():
        for batch_features, _ in itertools.islice(test_loader, n_batches):
            batch_flat_feature = batch_features.view(-1, 784)
            outputs = trained_model.forward(batch_flat_feature.to(device))
            for in_feature, out in zip(batch_features, outputs):
                out_np = out.view(28, 28).cpu().numpy()
                fig, (ax1, ax2) = plt.subplots(1, 2)
                ax1.imshow(in_feature.view(28, 28).numpy())
                ax2.imshow(out_np)
                ax1.set_axis_off()
                ax2.set_axis_off()
                
visualize_reconstructions(trained_model, test_loader, n_batches=3)