In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from ipywidgets import interact
import numpy as np
import matplotlib.cm as cm

In [None]:
BATCH_SIZE = 512
Z_DIM = 2
NUM_EMBEDDINGS = 20

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
validation_dataset = MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
dataloaders = {'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True),
               'validation': DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)}

In [None]:
initialise_from = "vae" # "none" or "ae" or "vae"
initialise_embedding_vectors = "normal" # "uniform" or "normal"

In [None]:
model = torch.load(f'checkpoints/vqvae_from_{initialise_from}_using_{initialise_embedding_vectors}_epoch_10.pt')

In [None]:
model = torch.load(f'checkpoints/vqvae_from_{initialise_from}_using_{initialise_embedding_vectors}.pt')

In [None]:
model.to(device)
model.eval()

In [None]:
transformToPILImage = transforms.ToPILImage()

## Analysis

### Reconstruction of validation set

In [None]:
@interact(k=(0, len(validation_dataset) - 1))
def xxx(k=0):
    x, label = validation_dataset[k]
    x_hat, _, _ = model(x.to(device).view(1, 1, 28, 28))

    img = transformToPILImage(x.view(1, 28, 28))
    img_hat = transformToPILImage(x_hat.cpu().view(1, 28, 28))

    plt.subplot(2, 1, 1)
    plt.imshow(img, cmap='gray')
    plt.subplot(2, 1, 2)
    plt.imshow(img_hat, cmap='gray');

### Reconstruction of discrete embedding vectors

In [None]:
for k in range(NUM_EMBEDDINGS):
    plt.subplot(int(np.ceil(np.sqrt(NUM_EMBEDDINGS))),int(np.floor(np.sqrt(NUM_EMBEDDINGS))),k+1)
    x_hat = model.decoder(model.vq.embedding[k].unsqueeze(0))   
    img_hat = transformToPILImage(x_hat.view(1, 28, 28))
    plt.imshow(img_hat, cmap='gray');
    plt.axis('off')

### Projection of validation set onto latent space

#### Color coded by class labels

In [None]:
plt.figure()
xs = []
ys = []
cs = []

for inputs, labels in dataloaders['validation']:
    model.eval()
    z = model.encoder(inputs.to(device)).detach().cpu().numpy().squeeze()
    for (x, y), label in zip(z, labels):
        xs.append(x)
        ys.append(y)
        cs.append(label / 9)

plt.scatter(xs, ys, c=cs, cmap=cm.rainbow)

In [None]:
#### Color coded by embedding indices

In [None]:
plt.figure()
xs = []
ys = []
cs = []

for inputs, labels in dataloaders['validation']:
    model.eval()
    z = model.encoder(inputs.to(device))
    quantized, indices, _ = model.vq.encode(z.unsqueeze(0))
    
    z = z.detach().cpu().numpy().squeeze()
    indices = indices.detach().cpu().numpy().squeeze()
    
    for (x, y), label, index in zip(z, labels, indices):
        xs.append(x)
        ys.append(y)
        cs.append(index / 9)

plt.scatter(xs, ys, c=cs, cmap=cm.rainbow)

for k in range(NUM_EMBEDDINGS):
    z = model.vq.embedding[k].cpu().detach().numpy()
    x,y = z
    plt.scatter(z[0],z[1], color='black');
    plt.text(x + 0.05, y + 0.05, k, fontsize=8)


In [None]:
xmin = np.min(xs)
xmax = np.max(xs)
ymin = np.min(ys)
ymax = np.max(ys)

In [None]:
xs = np.linspace(xmin, xmax, 15, dtype=np.float32)
ys = np.linspace(ymin, ymax, 8, dtype=np.float32)

fig, axes = plt.subplots(len(ys), len(xs), figsize=(10, 5))

for (xidx, x) in enumerate(xs):
    for (yidx, y) in enumerate(np.flip(ys)):
        z = torch.tensor([x, y]).unsqueeze(0).to(device)
        xhat = model.decoder(z)
        img_hat = transformToPILImage(xhat.cpu().view(1, 28, 28))
        axes[yidx, xidx].imshow(img_hat, cmap='gray');
        axes[yidx, xidx].set_axis_off()

In [None]:
@interact(x=(xmin, xmax, 0.1), y=(ymin, ymax, 0.1))
def xx(x=0, y=0):
    z = torch.tensor([x, y]).unsqueeze(0).to(device)
    xhat = model.decoder(z)
    img_hat = transformToPILImage(xhat.cpu().view(1, 28, 28))
    plt.imshow(img_hat, cmap='gray');

In [None]:
from sklearn.metrics import adjusted_rand_score, rand_score, consensus_score

In [None]:
labels_true = np.array([])
labels_pred = np.array([])

for inputs, labels in dataloaders['validation']:
    labels = labels.cpu().numpy()
    indices = model.encode(inputs.to(device)).cpu().numpy()
    labels_true = np.append(labels_true, labels)
    labels_pred = np.append(labels_pred, indices)

In [None]:
np.unique(labels_pred)

In [None]:
rand_score(labels_true,labels_pred)

In [None]:
adjusted_rand_score(labels_true,labels_pred)