In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from models.VAE import VanillaVAE, IWAE, VectorQuantizedVAE

### Toy example: 8 gaussians

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def eightgaussian(n_points):
    """
     Returns the eight gaussian dataset.
    """
    n = np.random.randint(0,8, n_points)
    noisex = np.random.normal(size=(n_points)) * 0.2
    noisey = np.random.normal(size=(n_points)) * 0.2
    x_centers,y_centers = [np.cos(n* np.pi/4.0) * 5 + noisex, np.sin(n* np.pi/4.0) * 5 + noisey]
    return np.vstack((x_centers,y_centers)).T
            
X = eightgaussian(10000)
X_test = eightgaussian(5000)
X.shape
plt.figure(figsize=(6,6))
plt.scatter(X[:,0], X[:,1], s=1);
plt.axis('square');

In [None]:
from utils.data import make_dataloaders, save_model, DotDict
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(X, batch_size=32, collate_fn=lambda x:torch.Tensor(x))
val_loader = torch.utils.data.DataLoader(X_test, batch_size=32, collate_fn=lambda x:torch.Tensor(x))

clipping_value=0.1

def train(model, epochs):
    losses = []
    val_losses = []
    model.to(device)
    optim = model.get_optimizer()

    for epoch in range(epochs):
        print(f"{model.description} epoch: {epoch}")
        for batch in train_loader:
            optim.zero_grad()
            loss, x_hat, z, BCE = model.step(batch)
            losses.append(loss.item())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
            optim.step()
            model.zero_grad()

In [None]:
def display_vq(model, dims=[0,1]):
    
    x_tilde, z_e_x, z_q_x = model.forward(torch.Tensor(X_test[:300]))
    rec_x = x_tilde.detach().cpu().numpy()
    z_e_x=z_e_x.detach().cpu().numpy()
    z_q_x=z_q_x.detach().cpu().numpy()
    fig = plt.figure(figsize=(10, 10))
    plt.scatter(X_test[:300, 0], X_test[:300, 1], s=2, c="y", alpha=0.5, label="x")
    plt.scatter(z_e_x[:, dims[0]], z_e_x[:, dims[1]], s=2, c="g", alpha=0.5, label="z_e_x")
    plt.scatter(z_q_x[:, dims[0]], z_q_x[:, dims[1]], s=100, edgecolors="b", label="z_q_x", alpha=1, marker="o", facecolors='none')
    plt.scatter(rec_x[:, 0], rec_x[:, 1], s=100, edgecolors="r", alpha=1.0,label="reconstructed", marker="o", facecolors='none')
    ax = fig.axes[0]
    style = dict(size=20, color='gray')
    embs = vqvae.codebook.embedding.weight.data
    embs_dec = vqvae.net.decode(embs).data
    plt.scatter(embs[:, dims[0]], embs[:, dims[1]], s=20, c="b", label="embeddings", marker="x")
    plt.scatter(embs_dec[:, dims[0]], embs_dec[:, dims[1]], s=20, c="r", label="embeddings_dec", marker="x")
    for i in range(embs.shape[0]):
        ax.text(embs[i,0], embs[i,1], f"{i}", **style)
        ax.text(embs_dec[i,0], embs_dec[i,1], f"{i}", **style)
    plt.legend()

    plt.show()

In [None]:
vqvae = VectorQuantizedVAE(latent_dim = 2, K=16, output_dim=2, hidden_dim=50, archi="large", data_type="continuous")
train(vqvae, 20)

In [None]:
display_vq(vqvae)

### VQ-VAE with Gumbel-Softmax

In [None]:
from models.VQ import VQEmbeddingGumbel
vqvae = VectorQuantizedVAE(latent_dim = 2, K=8, output_dim=2, gumbel=True, tau=2., beta=0.1, hidden_dim=50, data_type="continuous")
train(vqvae, 20)
display_vq(vqvae)

### MNIST example