In [1]:
%run Latex_macros.ipynb

<IPython.core.display.Latex object>

$$
\newcommand{\V}{\mathbf{V}}
\newcommand{\Vr}{\mathbb{V}}
\newcommand{\codebook}{\mathbf{E}}
\newcommand{\encoder}{\mathcal{E}}
\newcommand{\decoder}{\mathcal{D}}
\def\prs#1#2{\mathcal{p}_{#2}(#1)}
\def\qr#1{\mathcal{q}(#1)}
\def\qrs#1#2{\mathcal{q}_{#2}(#1)}
$$

- [paper](https://arxiv.org/pdf/1711.00937.pdf), [VQ-VAE-2 paper](https://arxiv.org/pdf/1906.00446.pdf)
    - [example code from paper's authors](https://github.com/deepmind/sonnet/blob/v1/sonnet/examples/vqvae_example.ipynb)
        - [`VectorQuantizer` from paper's authors](https://github.com/deepmind/sonnet/blob/v1/sonnet/python/modules/nets/vqvae.py)
            - can't run: depends on module `sonnet`
- [Colab](https://keras.io/examples/generative/vq_vae/)

Charlie Snell


# From PCA to VQ-VAE

[paper: vanilla VQ-VAE](https://arxiv.org/pdf/1711.00937.pdf)

[paper: VQ-VAE-2](https://arxiv.org/pdf/1906.00446.pdf)

The common element in the design of any Autoencoder method is 
- to create a 
latent representation $\z$ of input $\x$ 
- such that $\z$ can be (approximately) inverted
to reconstruct $\x$.


Principal Components Analysis is a type of Autoencoder that produces a latent representation $\z$ of $\x$

- $\x$ is a vector of length $n$: $\x \in \Reals^n$
- $\z$ is a vector of length $n' \le n$: $\z \in \Reals^{n'}$

Usually $n' << n$: achieving *dimensionality reduction*

This is accomplished by decomposing $\x$ into a weighted product of $n$ *Principal Components*
- $\V \in \Reals^{n \times n}$
$$\x = \z' \V^T$$
    - where $\z' \in \Reals^n$
    - rows of $\V^T$ are the components

So $\x$ can be decomposed into the weighted sum (with $\z'$ specifying the weights) 
- of $n$ component vectors
- each of length $n$

Since $\z' \in \Reals^n$: there is **no** dimensionality reduction just yet.

One can view $\V^T$ as a kind of *code book*
- any $\x$ can be represented (as a linear combination) of the *codes* (components) in $V^T$
$$\x = \z' \V^T$$

$\z'$ is like a translation of $\x$, using $\V$ as the vocabulary.
- weights in the codebook
- rather than weights in the standard basis space $I \in \Reals^{n \times n} = \text{diagonal}(n)$
$$
\x = \x I
$$

Dimensionality reduction is achieved by defining $\z$ as a length $n'$ prefix of $\z$
- $\z = \z'_{1:n'}$
- $\z \in \Reals^{n'}$

Similarly, we needed only $n'$ components from $\V$
- $\Vr^T = \V^T_{1:n'}$
- $\Vr^T \in \Reals^{n' \times n}$

We can construct an *approximation* $\hat\x$ of $\x$ using *reduced dimension* $\z'$ and $\Vr$
$$
\hat\x = \z \Vr^T
$$


The Autoencoder (and variants such as VAE) produces $\z^\ip$, the latent representation of $\x^\ip$
- directly
- independent of any other training example $\x^{(i')}$ for $i \ne i'$

Our goal in using AE's is in generating synthetic data
- the dimensionality reduction achieved thus far was a necessity, not a goal

# Vector Quantized Autoencoder

A *Vector Quantized VAE* is a VAE with similarities to PCA.  It creates $\z$
- which is an **integer**
- that is the index of a row
- in a codebook with $K$ rows

That is: the input is represented by one of $K$ possible vectors.



The goal is **not necessarily** dimensionality reduction.

Rather, there are some advantages to a **discrete** representation of a continuously-valued vector.
- Each vector
- Drawn from the infinite space of continuously-valued vectors of length $n$
- Can be approximated by one of $K$ possible vectors of length $n$

Thus, a sequence of $T$ continuously valued vectors
- can be represented as a sequence of $T$ integers
- over a "vocabulary" defined by the code book

This is analogous to text
- sequence of works
- represented as a sequence of integer indices in a vocabulary of tokens

Once we put complex objects
- like images
- timeseries
- speech

into a representation similar to text
- we can have *mixed type* sequences
    - e.g., words, images
    

In a subsequent module we will take advantage of mixed type sequences
- to produce an image
- from a text *description* of the image
- using the "predict the next" element of a sequence technique of Large Language Models

<table>
    <tr>
        <th><center>DALL-E: Text to Image</center></th>
    </tr>  
    <tr>
        <td>
            Text input: "An illustration of a baby daikon radish in a tutu walking a dog"
        </td>
    </tr>
    <tr>
        <td>
            <center>Image output:</center>
        </td>
    </tr>
    <tr>
        <td>
            <img src="https://cdn.openai.com/dall-e/v2/samples/anthropomorphism/091432009673a3a126fdec860933cdce_26.png">
        </td>
    </tr>
    
</table>

# Details

Here is diagram of a VQ-VAE
- that creates a latent representation of a 3-dimensional image $(w \times h \times 3)$
- a a 2-dimensional matrix of integers

There is a bit of notation: referring to the diagram should facilitate understanding the notation.

<table>
    <tr>
        <th><center>VQ-VAE</center></th>
    </tr>
    <tr>
        <td><img src="https://i.imgur.com/R9VMWD6.png" width = 200%></td>
    </tr>
</table>


In general, we assume the input has $\#S$ *spatial*  dimensions
- where each location in the spatial dimension is a vector of length $n$
- input shape $(n_1 \times n_2 \ldots \times n_{\#S} \times n)$

We will explain this diagram in steps.

First, we summarize the notation in a single spot for easy subsequent reference.

**Notation summary**

term &nbsp; &nbsp; &nbsp; &nbsp; | shape &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;  &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; | meaning 
:---|:---|:---
$S$ | $(n_1 \times n_2 \ldots \times n_{\#S})$ | Spatial dimensions of $\#S$-dimensional input
$\x$ | $\Reals^{S \times n}$ | Input
$D$ | | length of latent vectors (Encoder output, Quantized Encoder output, Codebook entry)
$\encoder$ | | Encoder function
$\mathbb{z}_e(\x)$ | $\mathbb{R}^{S \times D}$ |  Encoder output over each location of spatial dimension
                         | | $\z_e(\x) = \encoder(\x)$
$\z_e(\x)$ | $\mathbb{R}^D$ |  Encoder output at a **single** representative spatial location
                         | | $\z_e(\x) = \encoder(\x)$        
$K$ | | number of codes
$\codebook$ | $\mathbb{R}^{K \times D}$ | Codebook/Embedding
  | | $K$ codes, each of length $D$
$e \in \codebook$ | $\mathbb{R}^D$  | code/embedding
$\mathbb{z}$ | $\{1, \ldots,  K\}^{S \times D}$ | latent representation over all spatial dimensions
$\z$ | $\{1, \ldots,  K\}$ | Latent representation at a **single** representative spatial location
| | one integer per spatial location
$\qr{ \z | \x }$ | $\text{integer} \in [1 \ldots K] $| Index $k$ of $e_k \in \codebook$ that is closest to $\z_e(\x)$ 
| | $k = \argmin{j \in [1,K]} \| \z_e(\x) - \e_j \|_2$
         | | actually: encoded as a OHE vector of length $K$
$\z_q(\x) $     | $\mathbb{R}^D$ | Quantized $\z_e(\x)$
|| $\z_q(\x) = e_k$ where  $k = \qr{\z | \x }$
| | i.e, the element of codebook that is closest to $\z_e(\x)$
  | | $\z_q(\x) \approx \z_e(\x)$
$\tilde\x$ | $n$ | Output: reconstructed $\x$
       | | $\pr{ \x | \z_q(\x) }$ 
$\decoder$  | $\mathbb{R}^{n'} \rightarrow \mathbb{R}^n$ | Decoder
| | input: element of codebook $\codebook$
| | $\tilde\x = \decoder( \z_q(\x) )$

## Quanitization

Let $S$ denote the spatial dimensions, e.g. $S = (n_1 \times n_2)$ for 2D

So input $\x \in \Reals^{S \times n}$
- $n$ features over $S$ spatial locations

The input $\x$ is transformed in a sequence of steps
- Encoder output (continuous value)
- Latent representation (discrete value)
    - Quantized (continuous value)

In the first step, the *Encoder* maps input $\x$ 
- to Encoder output $\mathbb{z}_e(\x)$
- an alternate representation of $D$ features over $S'$ spatial locations

(For simplicity, we will assume $S' = S$)


**Notational simplification**

In the sequel, we will apply the same transformation **to each element** of the spatial dimension

Rather than explicitly iterating over each location
we write

$$\z_e(\x) \in \Reals^D$$

to denote a representative element of $\mathbb{z}_e(\x)$ at a single location $s = (i_1, \ldots, i_{\#S})$

$$
\z_e(\x) = \mathbb{z}_e(\x)_{s}
$$

We will continue the transformation at the single representative location
- and implicitly iterate over all locations $s \in S$

The continuous (length $D$) Encoder output vector $\z_e(\x)$
- is mapped to a *latent representation* $\qr{ \z | \x }$
- which is a **discrete** value (integer)

$$k = \qr{ \z | \x } \in \{1, \ldots, K\}$$

where $k$ is the *index* of a row $\e_k$ in codebook $\codebook$
$$\e_k = \codebook_k \in \Reals^D $$

$k$ is chosen such that $\e_k$ is the row in $\codebook$ **closest to** $\z_e(\x)$

 $$\begin{array}\\
        k & = & \qr{ \z | \x } \\
          & = &\argmin{j \in \{ 1, \ldots, K \} } \| \z_e(\x) - \e_j \|_2 \\
          \end{array}
        $$

We denote the codebook vector 
- closest to representative encoder output $\z_e(\x)$
- as $\z_q( \x )$
$$
\z_q( \x ) \in \{1, \ldots, K \} = e_k
$$




The Decoder tries to invert the codebook entry $\e_k = \z_q(\x) $
so that
$$
\begin{array}\\
\tilde\x & = & \decoder( \z_q(\x) ) \\
& \approx & \x \\
\end{array}
$$


# Discussion

## Why do we need the CNN Encoder ?

The input $\x$ is first transformed into an *alternate representation*
- the **number** and shape of the spatial dimensions are preserved (not necessary)
- but the number of features is transformed from $n$ raw features to $D \ge n$ synthetic features
    - typical behavior for, e.g., an image classifier
    

The part of the VQ-VAE after the initial CNN
- reduces the size of the **feature dimension** from $D$ to 1
- this is the primary source of dimensionality reduction
    - the raw $n$ of image input is usually only $n=3$ channels

It may be useful for the CNN to *down-sample* spatial dimension $S$ to a smaller $S'$

For example
- 3 layers of stride 2 CNN layers
- will reduce a 2D image of spatial dimension $(n_1 \times n_2)$
- to spatial dimension $(\frac{n_1}{8} \times \frac{n_2}{8})$

This replaces each $(8 \times 8 \times n)$ *patch* of raw input
- into a single vector of length $D$
- that summarizes the $(8 \times 8)$ the patch

One possible role (not strictly necessary) for the CNN Encoder
- is to replace a large spatial dimensions
- by smaller "summaries" of local neighborhoods (patches)

## Why quantize ?

Quantization 
- converts the continuous $\z_e(\x)$
- into discrete $\qr{ \z | \x }$
- representing the approximation $\z_q(\x) \approx \z_e(\x) $

The Decoder inverts the approximation.

Why bother when the Quantization/De-Quantization is Lossy ?


One motivation comes from observing what happens if we *quantize and flatten* the $\#S'$-dimensional
spatial locations to a one-dimensional vector.

Quantizing replaces each patch with a single integer index.
- the integer is the index of an *image token* within a list of $K$ possible toke

By flattening the quantized higher dimensional matrix of patches, we convert the input
- into a sequence of image tokens
- over a "vocabulary" defined by the codebook $\codebook$.


This yields an image representation
- similar to the representation of text

Thus, we open the possibility of processing sequences
of mixed text and image tokens.

### Quantized image embeddings mixed with Text: preview of DALL-E

The Large Language Model operates on a sequence of text tokens
- where the text tokens are fragments of words
- when run autoregressively
    - concatenating each output to the initial input sequence
    - the LLM shows an ability to produce a "sensible" continuation of an initial "thought"

Suppose we train a LLM on input sequences
- that start with a sequence of *text* tokens describing an image
- followed by a separator `[SEP]` token
- followed by a sequence of of quantized image tokens

        <text token> <text token> ... <text token> [SEP] <image token> <image token> ...

What continuation will our trained LLM produce given prompt

        <text token> <text token> ... <text token> [SEP]
        
Hopefully:
- a sequence of *image tokens*
- that can be reconstructed
- into an image matching the description given by the text tokens !

That is the key idea behind a Text to Image model called DALL-E that we will discuss in a later module.

There remains an important technical detail
- the embedding space of text and image are distinct
- they need to be merged into a common embedding space

We will visit these issues in the module on CLIP.

# Loss function

The Loss function for the VQ-VAE entails several parts
- Reconstruction loss
    - enforcing constraint that reconstructed image is similar to input
    $$\tilde{\x} \approx \x
    $$
- Vector Quantization (VQ) Loss:
    - enforcing similarity of quantized encoder output and actual encoder output
    $$
    \z_q(\x) \approx \z_e(\x)
    $$
- Commitment Loss
    - a constraint that prevents the Quantization of $\z_e(\x)$ from alternating rapidly between code book entries

The Reconstruction Loss term is our familiar: Maximize Likelihood 
- written to minimize the negative of the log likelihood, as usual
$$
\pr{ \x | \z_q(\x) } 
$$

The Vector Quantization Loss is more complex

$$
\| \text{sg} ( \z_e(\x) ) - \z_q({\x}) \| 
$$


The `sg` operator is the *Stop Gradient* operator.

We will explain this in more detail below and give reference to a `VectorQuantizer` layer type.


Commitment Loss:
$$
\| \z_e(\x) - \text{sg} ( \z_q(\x) )  \|
$$

The Commitment and Vector Quantization losses are similar except for the placement of the Stop Gradient.

The Stop Gradient in the Commitment Loss prevents a change in the Embeddings from affecting the Encoder weights (and thus, $z_e(\x)$).

The Stop Gradient of the Vector Quantization Loss prevents a change in the Encoder weights (and thus, $z_e(\x)$) from affecting the embeddings.

This prevents a feedback loop 
- Encoder updating $\z_e(\x)$ reduces Reconstruction Loss *assuming* embeddings remain constant
- But changing Encoder output results in embeddings being updated
- So embeddings *do not* remain constant 
- The net effect may not be a reduction in Reconstruction Loss


Which parts of the architecture are responsible for each Loss component
- The Decoder is responsible for the Reconstruction Loss (through the term $\tilde{\x}$
- The Encoder (through the term $\z_e(\x)$ is responsible for
    - The Reconstruction Loss 
    - The Commitment Loss
- The embeddings $\E$ are updated via the Vector Quantizer Loss
    - Does not affect the Encoder or Decoder weights

Straight Through Estimation (discussed below) causes the gradient from Reconstruction Loss to "by-pass" $\E$
- effectively, for the purpose of gradient/weight update: 
$$\z_q(\x) = \z_e(\x)$$
If there were no Vector Quantizer Loss, the Reconstruction Loss would not lead to Embeddings $\E$ being updated


Loss function

$$\begin{array} \\
\loss(\x, \decoder(\e)) & = & || \x − \decoder(\e)||_2^2 & \text{Reconstruction Loss} \\
& & + ||\text{sg}[\encoder(x)] − \e||_2^2  & \text{VQ loss, codebook loss: train codebook } \e \\
& & + β||\text{sg}[\e] − \encoder(\x)||_2^2 & \text{Commitment Loss: force } E(\x) \text{  to be close to codebook entries} \\
& &\text{where } 
\e = \z_q(\x)
\end{array}
$$

Need the stop gradient operator $\text{sg}$ to control the mutual dependence
- of $\encoder(\x)$ and $\e$

# Straight-through Estimation and the Stop Gradient operator sg

Gradient Descent is the algorithm that we use to find values for a model's weights that minimize the model's Loss Function.

Recall: it works by recursively (backwards from head to input) layer by layer
- updating the partial of the Loss with respect to the layer's inputs
    - respectively: the partial of the Loss with respect to each operation

But there is a problem in the Quantization operation
- argmin is not differentiable !
- it is not continuous at the point that its value switches between $k$ and $k' \ne k$
    - For example,
        - Non-unique arguments: when $\e_k = \e_{k'}$ for $k \ne k'$
        - small changes in the arguments cause a change from $k$ to $k'$

The non-differentiability of certain operators led to the creation of the Stop Gradient operator `sg`


$$
\begin{array} \\
\text{sg}(\x) & = & \x \\
\frac{\partial \, \text{sg}(\x)}{\partial \y} & = & 0 & \text{for all } \y \\
\end{array}
$$

It is the identity operation on the Forward pass.

But on the Backward pass (Gradient Descent) it treats its argument as if it were a constant.


## Straight through estimation

The Stop Gradient operator can be used in conjunction with *Straight Through Estimation*.

Let's recall the definition of the Loss Gradient

Let
$$\loss'_\llp = \frac{\partial \loss}{\partial \y_\llp}$$ 
denote the derivative of $\loss$ with respect to the output of layer $\ll$, i.e., $\y_\llp$.

This is called the **loss gradient**.
- although we state this with respect to a "layer-ed" architecture this is for notational convenience only
- the same if true if we replace "layer" with "operator" whose input is denoted $\y_{(\ll-1)}$ and output denoted $\y_\llp$

Back propagation inductively updates the Loss Gradient from the output of layer $\ll$ to its inputs (e.g., prior layer's output $\y_{(\ll-1)}$)
- Given $\loss'_\llp$
- Compute $\loss'_{(\ll-1)}$
- Using the chain rule

$$
\begin{array}[lll] \\
\loss'_{(\ll-1)} & = & \frac{\partial \loss}{\partial \y_{(\ll-1)}} \\
         & = & \frac{\partial \loss}{\partial \y_\llp} \frac{\partial \y_\llp}{\partial \y_{(\ll-1)}} \\
         & = & \loss'_\llp \frac{\partial \y_\llp}{\partial \y_{(\ll-1)}}
\end{array}
$$

The loss gradient "flows backward", from $\y_{(L+1)}$ to $\y_{(1)}$.

This is referred to as the *backward pass*.

That is:
- the upstream Loss Gradient $\loss'_\llp$
- is modulated by the local gradient $\frac{\partial \y_\llp}{\partial \y_{(\ll-1)}}$
- where the "layer" is the operation transforming input $\y_{(\ll-1)}$ to output $\y_\llp$

What happens when the operation implemented by the function that takes $\y_{(\ll-1)}$ to $\y_\llp$ is either
- non-differentiable
- or has zero derivative almost everywhere
- non-deterministic (e.g., `tf.argmin` when two inputs are identical)

This is the case with any type of quantization operation (uses `tf.argmin`) resulting in
- $\frac{\partial \y_\llp}{\partial \y_{(\ll-1)}} = 0$ 
- and $\loss'_{(\ll-1)} = \loss_\llp * 0 = 0$

So the quantization operation disconnects the gradient flow from the Decoder backwards to the Encoder.
- Encoder won't learn

Hence, the notion of a Straight Through Estimator is developed
- identity operation on forward pass
- with local derivation $\frac{\partial \y_\llp}{\partial \y_{(\ll-1)}}$ **defined** to be equal to $1$

We see this in
the [Colab](https://keras.io/examples/generative/vq_vae/) 
implementation of Vector Quantization (the
`VectorQuantizer` layer)
```
class VectorQuantizer(layers.Layer):
...
    def call(self, x):
...
        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
```
Code similar to the [`VectorQuantizer` of the paper's authors](https://github.com/deepmind/sonnet/blob/v1/sonnet/python/modules/nets/vqvae.py)

The last line is a ["straight through estimator"](https://www.hassanaskary.com/python/pytorch/deep%20learning/2020/09/19/intuitive-explanation-of-straight-through-estimators.html)
- On the forward pass: identity assignment `quantized = quantized`
- On the backward pass, the Loss gradient is passed through unchanged from upstream
    - i.e, from output (the tensor `quantizer`) to the *layer input* (denoted by formal parameters `x`, don't confuse it with the VQ-VAE's input)
    - this is because the `tf.stop_gradient` causes the enclosed expression to be treated as a constant
        - hence will contribute $0$ to the loss gradient back propagation
        

So
- `tf.stop_gradient` **kills** the gradient along one path
- the Straight Through Estimator passes it through unchanged

In the VQ-VAE, straight through estimation
- passes the gradient from the Decoder input back to the Encoder outputs
- ignoring the quantization
- allowing the Encoder to adapt to reduce Reconstruction Loss

# Learning the distribution of latents

For a VAE, we assume a functional form for the prior distribution of latents  $\qr{\z}$
- usually Normal

The authors wish to do away with an assumption of the prior distribution $\qr{\z}$.

Retaining spatial/temporal dimensions in $\z_q(\x)$ is key to achieving this goal.b


The authors *flatten* the spatial/temporal dimensions
- Assume (for example) a two dimensional $\mathbf{Z}$ with $h$ rows and $w$ columns
- $\mathbf{Z}^\ip_j$ denotes the vector of length $D$ at row $i$, column $j$ of $\mathbf{Z}$
- Flatten $\mathbf{Z}$ into a *sequence* $[\z_1, \z_2, ... ]$
    - where $\z_k$ is the quantization of $\mathbf{Z}^{(r)}_c$
        - for $r = \text{int}( \frac{k}{w}), c = (k \mod w)$




The authors then learn an autoregressive model for sequences

$$
\pr{\z_{k+1} | \z_1, \ldots, \z_k }
$$

by using some Autoregressive model (e.g, PixelCNN) to predict $\z_k$ from its predecessors.


The Autoressive model
- learns $\z_k$.  Doesn't assume what type of distribution it comes from
- can be sampled
    - seed the model with $\z_1$, generate the rest of the sequence
    - append predicted $\z_k$ to sequence upon which $\z_{k+1}$ is conditioned
- Is trained *subsequent* to learning the Embeddings
    - future research: learn them jointly

Thus, adding the Autoregressive step facilitates generating new sample sequences from
which to generate synthetic examples.

In [2]:
print("Done")

Done
