In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

from uncertify.tutorials.variational_auto_encoder import VariationalAutoEncoder
from uncertify.tutorials.variational_auto_encoder import train_vae, visualize_reconstructions, visualize_generated
from uncertify.common import DATA_DIR_PATH

from typing import Tuple

In [None]:
def get_mnist_data_loaders(transform: transforms.Compose,
                             data_path: Path,
                             batch_size: int,
                             num_workers: int) -> Tuple[DataLoader, DataLoader]:
    train_set = torchvision.datasets.MNIST(root=data_path,
                                             train=True,
                                             download=True,
                                             transform=transform)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)

    test_set = torchvision.datasets.MNIST(root=data_path,
                                            train=False,
                                            download=True,
                                            transform=transform)
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)
    return train_loader, test_loader

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

train_loader, test_loader = get_mnist_data_loaders(transform=transform,
                                                   data_path=DATA_DIR_PATH / 'mnist_data',
                                                   batch_size=64,
                                                   num_workers=4)

In [None]:
N_EPOCHS = 30
LEARNING_RATE = 0.00003
PRINT_STEPS = 200
N_Z_SAMPLES = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VariationalAutoEncoder(input_dim=784, hidden_dim=128, bottleneck_dim=20)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

sampled_z = torch.randn(N_Z_SAMPLES, model.bottleneck_dim).cuda()
trained_model = train_vae(model, device, train_loader, test_loader, optimizer, N_EPOCHS, 1, sampled_z)

In [None]:
figs = visualize_reconstructions(trained_model, test_loader, device, n_batches=1, max_samples=5, show=True)

In [None]:
def reconstruct_random_z():
    sampled_z = torch.randn(16, model.bottleneck_dim).cuda()
    generated = model._decode(sampled_z).view(-1, 28, 28).cpu().detach().numpy()
    fig = visualize_generated(generated)
reconstruct_random_z()
