In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from ipywidgets import interact
import matplotlib.cm as cm

In [None]:
BATCH_SIZE = 512
Z_DIM = 2
NUM_EMBEDDINGS = 20

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
validation_dataset = MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
dataloaders = {'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True),
               'validation': DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)}

In [None]:
initialise_from = "vae" # "none" or "ae" or "vae"
initialise_embedding_vectors = "normal" # "uniform" or "normal"

In [None]:
@interact(epoch = (0, 200))
def xxx(epoch=0):
    model = torch.load(f'checkpoints/vqvae_from_{initialise_from}_using_{initialise_embedding_vectors}_epoch_{epoch}.pt')
    plt.figure(figsize=(4,8))
    plt.subplot(2,1,1)
    xs = []
    ys = []
    colors_true = []
    colors_pred = []

    for inputs, labels in dataloaders['validation']:
        model.eval()
        z = model.encoder(inputs.to(device))
        quantized, indices, _ = model.vq.encode(z.unsqueeze(0))

        z = z.detach().cpu().numpy().squeeze()
        indices = indices.detach().cpu().numpy().squeeze()

        for (x, y), label, index in zip(z, labels, indices):
            xs.append(x)
            ys.append(y)
            colors_true.append(label / (10 - 1))
            colors_pred.append(index / (NUM_EMBEDDINGS - 1))

    plt.scatter(xs, ys, c=colors_pred, cmap=cm.rainbow)

    for k in range(NUM_EMBEDDINGS):
        z = model.vq.embedding[k].cpu().detach().numpy()
        x,y = z
        plt.scatter(z[0],z[1], color='black');
        plt.text(x + 0.05, y + 0.05, k, fontsize=8)
        
    plt.subplot(2,1,2)
    plt.scatter(xs, ys, c=colors_true, cmap=cm.rainbow)
