In [None]:
import os
import sys
sys.path.insert(0, '..')
run_id = os.environ.get("RUN_ID")
assert run_id != None, f"Can't detect the run with {run_id}"

from vqvae import VQModel

import umap
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid

from utils import inmap, outmap

device = "cuda:0"

In [None]:
model = VQModel.load_from_checkpoint(f"../logs/{run_id}/{run_id}.ckpt").to(device)
model

In [None]:
data = CIFAR10("../data", train=False, transform=T.Compose([
        T.PILToTensor(),
    ])
)
num_img = 16
x = [data[i][0].unsqueeze(0) for i in range(num_img)]
x = torch.vstack(x)

with torch.inference_mode():
    h = inmap(x).to(device)
    h = model.encoder(h)
    h = model.quant_conv(h)

    quant, _, _ = model.quantize(h)
    quant = model.post_quant_conv(quant)

    dec = model.decoder(quant).cpu()
    x_recon = outmap(dec)

assert x_recon.shape == (num_img, 3, 32, 32), f"z_hat is of shape {x_recon.shape}"

In [None]:
def show(x, x_recon):
    fig, ax = plt.subplots(2, 1, figsize=(40, 20))
    ax[0].imshow(np.transpose(x.numpy(), (1, 2, 0)), interpolation='nearest')
    ax[1].imshow(np.transpose(x_recon.numpy(), (1, 2, 0)), interpolation='nearest')
    for axis in fig.axes:
        axis.get_xaxis().set_visible(False)
        axis.get_yaxis().set_visible(False)
    plt.savefig(f"{run_id}-recon.pdf")

In [None]:
show(make_grid(x), make_grid(x_recon))

In [None]:
proj = umap.UMAP(n_neighbors=3,
                 min_dist=0.1,
                 metric='cosine').fit_transform(model.quantize.embedding.weight.data.cpu())

In [None]:
plt.scatter(proj[:,0], proj[:,1], alpha=0.3)
plt.savefig(f"{run_id}-codebook_embedding.pdf")