In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from ipywidgets import interact
import numpy as np

In [None]:
BATCH_SIZE = 512
Z_DIM = 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
validation_dataset = MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
dataloaders = {'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True),
               'validation': DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)}

In [None]:
model = torch.load('checkpoints/ae.pt')

In [None]:
transformToPILImage = transforms.ToPILImage()

### Reconstruction of validation set

In [None]:
@interact(k=(0, len(validation_dataset) - 1))
def xxx(k=0):
    x, label = validation_dataset[k]
    model.eval()
    x_hat = model(x.to(device).view(1, 1, 28, 28))

    img = transformToPILImage(x.view(1, 28, 28))
    img_hat = transformToPILImage(x_hat.cpu().view(1, 28, 28))

    plt.subplot(2, 1, 1)
    plt.imshow(img, cmap='gray')
    plt.subplot(2, 1, 2)
    plt.imshow(img_hat, cmap='gray');

### Projection of validation set onto latent space

In [None]:
import matplotlib.cm as cm

plt.figure(figsize=(10, 5))
xs = []
ys = []
cs = []

for inputs, labels in dataloaders['validation']:
    model.eval()
    z = model.encoder(inputs.to(device)).detach().cpu().numpy().squeeze()
    for (x, y), label in zip(z, labels):
        xs.append(x)
        ys.append(y)
        cs.append(label / 9)

plt.scatter(xs, ys, c=cs, cmap=cm.rainbow)

In [None]:
x_min = np.min(xs)
x_max = np.max(xs)
y_min = np.min(ys)
y_max = np.max(ys)

In [None]:
xs = np.linspace(x_min, x_max, 15, dtype=np.float32)
ys = np.linspace(y_min, y_max, 8, dtype=np.float32)

fig, axes = plt.subplots(len(ys), len(xs), figsize=(10, 5))

for (x_idx, x) in enumerate(xs):
    for (y_idx, y) in enumerate(np.flip(ys)):
        z = torch.tensor([x, y]).to(device).unsqueeze(0)
        x_hat = model.decoder(z)
        img_hat = transformToPILImage(x_hat.cpu().view(1, 28, 28))
        axes[y_idx, x_idx].imshow(img_hat, cmap='gray');
        axes[y_idx, x_idx].set_axis_off()

In [None]:
@interact(x=(x_min, x_max, 0.1), y=(y_min, y_max, 0.1))
def xx(x=0, y=0):
    z = torch.tensor([x, y]).to(device).unsqueeze(0)
    x_hat = model.decoder(z)
    img_hat = transformToPILImage(x_hat.cpu().view(1, 28, 28))
    plt.imshow(img_hat, cmap='gray');