[Original VQ-VAE paper](https://arxiv.org/pdf/1711.00937)

The basic idea:

- Take a normal-style VAE Decoder and Encoder
- At the bottleneck, use a set of $K$ embedding vectors, $e_k$ of dimensionality $D$
    - The vectors have a set latent dimension
    - The vectors are learned during training
- During inference, the latent vector output by the encoder is mapped to the nearest neighbour embedding vector
- gradients flow back, with a straight-through estimator at the embedding layer
- For image generation, we learn a 2D grid of embedded features


The loss function is given by
$$
L \;=\; \underbrace{\log p\bigl(x \mid z_q(x)\bigr)}_{\text{Reconstruction Loss}}
\;+\; \underbrace{\bigl\lVert \mathrm{sg}\bigl(z_e(x)\bigr) - e \bigr\rVert_{2}^{2}}_{\text{Codebook Loss}}
\;+\; \underbrace{\beta \,\bigl\lVert z_e(x) - \mathrm{sg}(e)\bigr\rVert_{2}^{2}}_{\text{Commitment Loss}}.
$$

Where $sg(\cdot)$ is the stopgradient operator. The three terms are:

1. The Reconstruction Loss trains the encoder and decoder to generate close matches to the input image,
2. The Codebook Loss moves the embedding vectors towards embeddings generated by the encoder,
3. The Commitment Loss incentivises the encoder to keep latents close to the embedding vectors, rather than expanding out into the space in an unbounded manner.

$\beta$ = 0.25 is a good starting point. The given output shapes of the latent field are 32x32 (ImageNet) and 8x8x10 (Cifar10).

In [30]:
import jax

import flax.linen as nn
import jax.numpy as jnp
from jax import random

from genlearn.data import get_cifar_dataloaders

In [None]:
trainloader, valloader = get_cifar_dataloaders()

In [None]:
K = 512
D = 64
latent_shape = (8, 8, 10, D)

In [None]:
key = random.key(0)
key, subkey = random.split(key)
embeddings = random.normal(subkey, (K, D))
key, subkey = random.split(key)
z = random.normal(subkey, latent_shape)

In [None]:
def embed(z, embeddings):
    'Quantise a single latent vector to its closest codebook latent vector'
    def distance(e, z):
        'The L2 distance between two vectors'
        return jnp.linalg.norm(e - z)

    # vectorise to calc over all codebook vectors at once
    codebook_distances = jax.vmap(distance, (0, None), 0)(embeddings, z)
    nearest_embedding_ix = jnp.argmin(codebook_distances)
    return embeddings[nearest_embedding_ix]

# with an 8x8x10xD latent space, we need to vmap over all 3 first axes
# to fully vectorise the embedding function
batch_embed = jax.vmap(embed,       in_axes=(0, None))
batch_embed = jax.vmap(batch_embed, in_axes=(0, None))
batch_embed = jax.vmap(batch_embed, in_axes=(0, None))
batch_embed(z, embeddings)

(8, 8, 10, 64)

In [22]:
for x, y in trainloader:
    break