[Original VQ-VAE paper](https://arxiv.org/pdf/1711.00937)

The basic idea:

- Take a normal-style VAE Decoder and Encoder
- At the bottleneck, use a set of $K$ embedding vectors, $e_k$ of dimensionality $D$
    - The vectors have a set latent dimension
    - The vectors are learned during training
- During inference, the latent vector output by the encoder is mapped to the nearest neighbour embedding vector
- gradients flow back, with a straight-through estimator at the embedding layer
- For image generation, we learn a 2D grid of embedded features


The loss function is given by
$$
L \;=\; \underbrace{\log p\bigl(x \mid z_q(x)\bigr)}_{\text{Reconstruction Loss}}
\;+\; \underbrace{\Bigl\lVert \mathrm{sg}\bigl(z_e(x)\bigr) - e \Bigr\rVert_{2}^{2}}_{\text{Codebook Loss}}
\;+\; \underbrace{\beta \,\Bigl\lVert z_e(x) - \mathrm{sg}(e)\Bigr\rVert_{2}^{2}}_{\text{Commitment Loss}}.
$$

Where $sg(\cdot)$ is the stopgradient operator. The three terms are:

1. The Reconstruction Loss trains the encoder and decoder to generate close matches to the input image,
2. The Codebook Loss moves the embedding vectors towards embeddings generated by the encoder,
3. The Commitment Loss incentivises the encoder to keep latents close to the embedding vectors, rather than expanding out into the space in an unbounded manner.

$\beta$ = 0.25 is a good starting point. The given output shapes of the latent field are 32x32 (ImageNet) and 8x8x10 (Cifar10).

In [32]:
import jax

from flax import nnx
import jax.numpy as jnp
from jax import random
from einops import rearrange
from genlearn.data_utils import get_cifar_dataloaders
import chex

In [33]:
trainloader, valloader = get_cifar_dataloaders()

In [34]:
K = 512
D = 64  # embedding dimension
h, w, c = (
    8,
    8,
    10,
)  # shape of the embedding space that vectors are projected into (32, 32, 1 for ImageNet)
batch_size = 64

In [35]:
rngs = nnx.Rngs(0)
codebook = rngs.normal((K, D))
z = rngs.normal((batch_size, h, w, c, D))

In [36]:
z_flattened = rearrange(z, "b h w c d -> (b h w c) d")
z_flattened.shape

(40960, 64)

In [None]:
from functools import partial


@jax.jit
def embed(z, codebook):
    # want to get the distance from every vector in z to every vector in the codebook
    @partial(jax.vmap, in_axes=(0, None))
    @partial(jax.vmap, in_axes=(None, 0))
    def get_distance(e, z):
        chex.assert_rank([e, z], 1)  # comparing 64-dimensional vectors
        return jnp.linalg.norm(e - z)

    # returns a batch * K matrix with all the distances
    distances = get_distance(z, codebook)
    nearest_embeddings_idx = jnp.argmin(distances, axis=1)
    return codebook[nearest_embeddings_idx]

In [38]:
z_embedded = embed(z_flattened, codebook)
z_embedded = rearrange(
    z_embedded, "(b h w c) d -> b h w c d", b=batch_size, h=h, w=w, c=c
)

In [None]:
class Encoder(nnx.Module):
    def __init__(self):
        "ResNet-style encoder, down to an 8x8x10x64 shape output"
        ...


class Decoder(nnx.Module):
    def __init__(self):
        "ResNet-style decoder, promoting up from 8x8x10x64 latents back up to a 32x32x3 RGB out"
        ...


class Embed(nnx.Module):
    def __init__(self, K: int, D: int, *, rngs: nnx.Rngs):
        "Quantisation layer: converts the 8x8x10x64 latents to quantised representations using a learned codebook"
        self.codebook = nnx.Param(rngs.normal((K, D)))


class VQ_VAE(nnx.Module):
    def __init__(self):
        self.encoder = Encoder()
        self.embed = Embed()
        self.decoder = Decoder()

    def __call__(self, x_bhwc):
        ze_bwhcd = self.encoder(x_bhwc)
        e_bwhcd = self.embed(ze_bwhcd)
        zq_bhwc = self.decoder(e_bwhcd)
        return (
            zq_bhwc,
            e_bwhcd,
            ze_bwhcd,
        )  # reconstruction, the quantised vectors, the encoded latent

In [31]:
for x, y in trainloader:
    break