In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from ipywidgets import interact
import numpy as np
import matplotlib.cm as cm

In [None]:
BATCH_SIZE = 2048
Z_DIM = 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
validation_dataset = MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
dataloaders = {'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True),
               'validation': DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)}

In [None]:
model = torch.load('checkpoints/vae.pt')
model.to(device)
model.eval()

In [None]:
transformToPILImage = transforms.ToPILImage()

## Analysis

### Reconstruction of validation set

In [None]:
@interact(k=(0, len(validation_dataset) - 1))
def xxx(k=0):
    x, label = validation_dataset[k]
    z, _ = model.encoder(x.to(device).view(1, 1, 28, 28))
    x_hat, _, _ = model(x.to(device).view(1, 1, 28, 28))

    print(z)

    img = transformToPILImage(x.view(1, 28, 28))
    img_hat = transformToPILImage(x_hat.cpu().view(1, 28, 28))

    plt.subplot(2, 1, 1)
    plt.imshow(img, cmap='gray')
    plt.subplot(2, 1, 2)
    plt.imshow(img_hat, cmap='gray');

### Projection of validation set onto latent space

In [None]:


plt.figure()
xs = []
ys = []
cs = []

for inputs, labels in dataloaders['validation']:
    model.eval()
    z, _ = model.encoder(inputs.to(device))
    z = z.detach().cpu().numpy().squeeze()
    for (x, y), label in zip(z, labels):
        xs.append(x)
        ys.append(y)
        cs.append(label / 9)

plt.scatter(xs, ys, c=cs, cmap=cm.rainbow)

In [None]:
xmin = np.min(xs)
xmax = np.max(xs)
ymin = np.min(ys)
ymax = np.max(ys)

In [None]:
xs = np.linspace(xmin, xmax, 15, dtype=np.float32)
ys = np.linspace(ymin, ymax, 8, dtype=np.float32)

fig, axes = plt.subplots(len(ys), len(xs), figsize=(10, 5))

for (xidx, x) in enumerate(xs):
    for (yidx, y) in enumerate(np.flip(ys)):
        z = torch.tensor([x, y]).unsqueeze(0).to(device)
        xhat = model.decoder(z)
        img_hat = transformToPILImage(xhat.cpu().view(1, 28, 28))
        axes[yidx, xidx].imshow(img_hat, cmap='gray');
        axes[yidx, xidx].set_axis_off()

In [None]:
@interact(x=(xmin, xmax, 0.001), y=(ymin, ymax, 0.001))
def xx(x=0, y=0):
    z = torch.tensor([x, y]).unsqueeze(0).to(device)
    xhat = model.decoder(z)
    img_hat = transformToPILImage(xhat.cpu().view(1, 28, 28))
    plt.imshow(img_hat, cmap='gray');

In [None]:
@interact(mu_x=(xmin, xmax, 0.01), mu_y=(ymin, ymax, 0.01), var_x=(0, 1, 0.01), var_y=(0, 1, 0.01))
def xx(mu_x=0, mu_y=0, var_x = 0, var_y = 0):
    mu = torch.tensor([mu_x, mu_y])
    log_var = torch.log(torch.tensor([var_x, var_y]))
    z = mu + torch.exp(log_var / 2) * torch.randn_like(mu)
    xhat = model.decoder(z.to(device).view(1,2))
    img_hat = transformToPILImage(xhat.cpu().view(1, 28, 28))
    plt.imshow(img_hat, cmap='gray');

In [None]:
for i in range(10*10):
    plt.subplot(10,10,i+1)
    z = torch.randn(1,2).to(device)
    xhat = model.decoder(z)
    img_hat = transformToPILImage(xhat.cpu().view(1, 28, 28))
    plt.imshow(img_hat, cmap='gray');
    plt.axis('off')