In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from time import time

import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torch.distributions as TD

import numpy as np
from matplotlib import pyplot as plt

import sys
sys.path.append('../../homeworks') # to grab dgm_utils from ../../homeworks directory
from tqdm.notebook import tqdm

if torch.cuda.is_available():
    DEVICE = 'cuda'
    GPU_DEVICE = 2
    torch.cuda.set_device(GPU_DEVICE)
else:
    DEVICE='cpu'
# DEVICE='cpu'

import warnings
warnings.filterwarnings('ignore')

# dgm_utils
from dgm_utils import train_model, show_samples, visualize_images
from dgm_utils import visualize_2d_samples, visualize_2d_densities, visualize_2d_data

def reset_seed():
    OUTPUT_SEED = 0xBADBEEF
    torch.manual_seed(OUTPUT_SEED)
    np.random.seed(OUTPUT_SEED)

reset_seed()

# <center>Deep Generative Models</center>
## <center>Seminar 12</center>

<center><img src="pics/AIMastersLogo.png" width=600 /></center>
<center>28.11.2022</center>


## Plan

1. VQ-VAE implementation hints

2. StyleGAN details


## VQ-VAE implementation hints

Link to the original paper: [article](https://arxiv.org/pdf/1711.00937v2.pdf).

<center><img src="pics/vqvae_scheme.png" width=1000 /></center>

**Original VAE ELBO objective**:

$$L_{\text{scaled}}(\phi, \theta) = \frac{1}{N} \sum\limits_{n = 1}^{N} \left(\mathbb{E}_{z_n \sim q(z_n| x_n, \phi)} \ln p(x_n|z_n, \theta) - KL(q(z| x_n, \phi)||p(z))\right)$$

**VQ-VAE ELBO objective**: ?

$$L_{\text{scaled}}(\phi, \theta) = \frac{1}{N} \sum\limits_{n = 1}^{N} \left( \ln p(x_n|z_q(x_n| \phi), \theta)\right) - \log K + \text{<embeddings related loss>}$$

**Question 1**: Let data images have shape `(1, w, h)` (Binarized MNIST). What is **encoder** $z_e(\cdot | \phi)$? Input dimensionality? Output dimensionality?

```python
# x : tensor (bs, 1, w, h) 
# D is embedding vectors dimensionality

encoded = z_e(x) # (bs, D, w_z, h_z)
encoded_perm = encoded.permute(0, 2, 3, 1) # latent codes (bs, w_z, h_z, D)

```

**Question 2**: What is posterior distribution $q(z | x, \phi)$? How to map the **encoder** output `encoded` to a sample $z \sim q(z | x, \phi)$?

* The distribution $q(z | x, \phi)$ is $\big{[}\text{Categorical}(\pi_1, \dots, \pi_K)\big{]}^{w\_z \,\times\, h\_z}$, i.e. $z \in \{1, 2, \dots K\}^{w\_z \,\times\, h\_z}$

    ```python
    # encoded : tensor (bs, w_z, h_z, D)
    codes = encoded2z(encoded_perm) # (bs, w_z, h_z)
    ```

* Model the embeddings $e_1, \dots, e_K$; where $e_i \in \mathbb{R}^D$, then:

$$q(z_{i, j} = k | x, \phi) = \begin{cases}1, k = \arg\min\limits_{k'} \Vert z_{i, j} - e_{k'} \Vert \\ 0, \text{otherwise}\end{cases}$$

```python
# K : number of embedding vectors
# D : dimensionality of embedding vectors 
embedding_module = torch.nn.Embedding(K, D)
...
# encoded : (bs, w_z, h_z, D)
flat_encoded = encoded_perm.view(-1, D) # (bs * w_z * h_z, D)
distances = dist(flat_encoded, embedding_module.weight)  # (bs * w_z * h_z , K)

codes = torch.argmax(distances, dim=1)
# What to do next? How to transform `distances` to quantized codes
```

<center><img src="pics/fqgan_lookup.png" width=600 /></center>

**Question 3.** How to map `codes` to decoder input $z_q(x | \phi)$?

<center><img src="pics/fqgan_cnn.png" width=600 /></center>

```python
quantized_perm = embedding_module(codes) # (bs, w_z, h_z, D)

# What to do next to prepare z_q(x)?
```

quantized = quantized_perm.permute(0, 3, 1, 2)

**Question 4** What is **decoder** ? Input dimensionality? Output dimensionality? (Recall: we train our model on Binarized MNIST)


* Output: `(bs, 2, 1, w, h)` -> `nn.CrossEntropy`

### VQ-VAE loss

<center><img src="pics/vqvae_scheme.png" width=1000 /></center>

**VQ-VAE ELBO**:
$$L_{\text{scaled}}(\phi, \theta) = \frac{1}{N} \sum\limits_{n = 1}^{N} \left( \ln p(x_n|z_q(x_n| \phi), \theta)\right) - \log K + \text{<embeddings related loss>}$$

**Question 1.** How to estimate $\frac{\partial z_q(x_n | \phi)}{\partial \phi}$?

<center><img src="pics/straight_through.png" width=700 /></center>

```python
# encoded : (bs, D, w_z, h_z)
# quantized : (bs, D, w_z, h_z)
quantized = encoded + (quantized - encoded).detach()
```

**Embedings related loss**

$$L_{\text{emb}}(\zeta, \phi) = \frac{1}{N} \sum\limits_{n = 1}^{N} \left( \Vert e_{z_e(x_n)} - \text{stop_gradient}(z_e(x_n)) \Vert_2^2 + \beta \Vert \text{stop_gradient}(e_{z_e(x_n)}) - z_e(x_n) \Vert_2^2 \right)$$

* $\zeta$ parameterizes the embedding layer (weights of $e_i, \, i \in \{1, \dots K\}$)

**Final Loss** **(to be maximized!):**

$$L_{\text{scaled}}(\phi, \theta) = \frac{1}{N} \sum\limits_{n = 1}^{N} \left( \ln p(x_n|z_q(x_n| \phi), \theta) - \Vert e_{z_e(x_n | \phi)} - \text{stop_gradient}(z_e(x_n | \phi)) \Vert_2^2 - \beta \Vert \text{stop_gradient}(e_{z_e(x_n | \phi)}) - z_e(x_n | \phi) \Vert_2^2\right)$$

### VQ-VAE prior

<center><img src="pics/vqvae_scheme.png" width=1000 /></center>

Recall the $q_{\text{agg}}(z)$ distribution: 

$$q_{\text{agg}}(z) = \frac{1}{N} \sum\limits_{n = 1}^{N} q(z | x_n) \, , \, z \in \mathbb{R}^{d_{\text{latent}}}$$

**Question 1**: What is the type of distribution $q_{\text{agg}}(z)$?

* $w\_z \,\times\, h\_z$-dimensional K-categorical

**Empirical fact**: it is worth to sample from $q_{\text{agg}}(z)$ at inference

**Question 2**: How to obtain $q_{\text{agg}}(z)$ given trained **VQ-VAE** model? What NN model can be used for this purpose? What is the data to train the model? What is the objective function for such training? Input dims/output tensors?

* One choise is **PixelCNN** AR model

* The data is all `codes` $q(z | x_n)$

* `nn.CrossEntropy` is loss function

* Input : 
    * Input: `(bs, w_z, h_z)`
    
    * one-hot-encoding: `(bs, K, w_z, h_z)`
    
    * `PixelCNN` application -> output `(bs, K, w_z, h_z)`
    
    * CrossEntropy

## StyleGAN

<center><img src="pics/stylegan_scheme.png" width=600 /></center>