In [None]:
!nvidia-smi

Tue Sep 10 22:31:47 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   35C    P8              12W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## Required Imports and Device Setup

In [None]:
%pip install umap-learn

Collecting umap-learn
  Downloading umap_learn-0.5.6-py3-none-any.whl.metadata (21 kB)
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Downloading umap_learn-0.5.6-py3-none-any.whl (85 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.7/85.7 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pynndescent-0.5.13-py3-none-any.whl (56 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pynndescent, umap-learn
Successfully installed pynndescent-0.5.13 umap-learn-0.5.6


In [None]:
from __future__ import print_function


import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter


from six.moves import xrange

import umap

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from tqdm import tqdm

import sys

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data

In [None]:
simple_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
transform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
training_data = datasets.CIFAR10(root="data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

validation_data = datasets.CIFAR10(root="data", train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43478747.39it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [None]:
# training_data = datasets.Flowers102(root="data", split="train", download=True,
#                                     transform=transforms.Compose([
#                                       transforms.RandomCrop(256),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
#                                   ]))
# validation_data = datasets.Flowers102(root="data", split="val", download=True,
#                                       transform=transforms.Compose([
#                                       transforms.CenterCrop(256),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
#                                   ]))
# testing_data = datasets.Flowers102(root="data", split="test", download=True,
#                                    transform=transforms.Compose([
#                                       transforms.CenterCrop(256),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
#                                   ]))

In [None]:
# data = datasets.MNIST(root="data", train=True, download=True,
#                                   transform=transforms.Compose([
#                                       transforms.ToTensor(),
#                                       transforms.Normalize((0.5,), (1.0,))
#                                   ]))


# training_data, validation_data = torch.utils.data.random_split(data, [50000, 10000])

In [None]:
# training_data = datasets.CelebA("data", split="train", download=True, transform=transform)
# validation_data = datasets.CelebA("data", split="valid", download=True, transform=transform)
# testing_data = datasets.CelebA("data", split="test", download=True, transform=transform)

In [None]:
# size of the training data
len(training_data)

50000

In [None]:
display(training_data)

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
           )

In [None]:
# universal experiment set-up

batch_size = 8
num_epochs = 1000

num_channels = 3 # 1 for grayscale images 3 for RGB images
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

embedding_dim = 16
num_embeddings = 512

commitment_cost = 0.25

learning_rate = 1e-4

disc_start = int(0.8 * (len(training_data) / batch_size) * num_epochs) if num_epochs >= 5000 else sys.maxsize

log_interval = 100

debug = False

In [None]:
# define dataloaders
training_loader = DataLoader(training_data,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=True)

validation_loader = DataLoader(validation_data,
                               batch_size=32,
                               shuffle=False,
                               pin_memory=True)

In [None]:
# compute dataset variance
data_variance = 0.

for x, _ in training_loader:
    data_variance += x.var()
data_variance /= len(training_loader)

print(f'Dataset variance: {data_variance}')

Dataset variance: 0.06132998690009117


## VQ-VAE

Variational Auto Encoders (VAEs) can be thought of as what all but the last layer of a neural network is doing, namely feature extraction or seperating out the data. Thus given some data we can think of using a neural network for representation generation.

Recall that the goal of a generative model is to estimate the probability distribution of high dimensional data such as images, videos, audio or even text by learning the underlying structure in the data as well as the dependencies between the different elements of the data. This is very useful since we can then use this representation to generate new data with similar properties. This way we can also learn useful features from the data in an unsupervised fashion.

The VQ-VAE uses a discrete latent representation mostly because many important real-world objects are discrete. For example in images we might have categories like "Cat", "Car", etc. and it might not make sense to interpolate between these categories. Discrete representations are also easier to model since each category has a single value whereas if we had a continous latent space then we will need to normalize this density function and learn the dependencies between the different variables which could be very complex.

### Code

I have followed the code from the TensorFlow implementation by the author which you can find here [vqvae.py](https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py) and [vqvae_example.ipynb](https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb).

Another PyTorch implementation is found at [pytorch-vqvae](https://github.com/ritheshkumar95/pytorch-vqvae).


### Basic Idea

The overall architecture is summarized in the diagram below:

![](https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/images/vq-vae.png?raw=1)

We start by defining a latent embedding space of dimension `[K, D]` where `K` are the number of embeddings and `D` is the dimensionality of each latent embeddng vector, i.e. $e_i \in \mathbb{R}^{D}$. The model is comprised of an encoder and a decoder. The encoder will map the input to a sequence of discrete latent variables, whereas the decoder will try to reconstruct the input from these latent sequences.

More preciesly, the model will take in batches of RGB images,  say $x$, each of size 32x32 for our example, and pass it through a ConvNet encoder producing some output $E(x)$, where we make sure the channels are the same as the dimensionality of the latent embedding vectors. To calculate the discrete latent variable we find the nearest embedding vector and output it's index.

The input to the decoder is the embedding vector corresponding to the index which is passed through the decoder to produce the reconstructed image.

Since the nearest neighbour lookup has no real gradient in the backward pass we simply pass the gradients from the decoder to the encoder  unaltered. The intuition is that since the output representation of the encoder and the input to the decoder share the same `D` channel dimensional space, the gradients contain useful information for how the encoder has to change its output to lower the reconstruction loss.

### Vector Quantizer Layer

This layer takes a tensor to be quantized. The channel dimension will be used as the space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

As an example for a `BCHW` tensor of shape `[16, 64, 32, 32]`, we will first convert it to an `BHWC` tensor of shape `[16, 32, 32, 64]` and then reshape it into `[16384, 64]` and all `16384` vectors of size `64`  will be quantized independently. In otherwords, the channels are used as the space in which to quantize. All other dimensions will be flattened and be seen as different examples to quantize, `16384` in this case.

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

We will also implement a slightly modified version  which will use exponential moving averages to update the embedding vectors instead of an auxillary loss. This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

In [None]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost

        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()

        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

### Encoder & Decoder Architecture

The encoder and decoder architecture is based on a ResNet and is implemented below:

In [None]:
class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return x + self._block(x)


class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()

        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens//2,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2,
                                 out_channels=num_hiddens,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        x = F.relu(x)

        x = self._conv_2(x)
        x = F.relu(x)

        x = self._conv_3(x)
        return self._residual_stack(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, num_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()

        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)

        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)

        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
                                                out_channels=num_hiddens//2,
                                                kernel_size=4,
                                                stride=2, padding=1)

        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
                                                out_channels=num_channels,
                                                kernel_size=4,
                                                stride=2, padding=1)

    def forward(self, inputs):
        x = self._conv_1(inputs)

        x = self._residual_stack(x)

        x = self._conv_trans_1(x)
        x = F.relu(x)

        return self._conv_trans_2(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_channels, num_filters_last=64, n_layers=3):
        super(Discriminator, self).__init__()

        layers = [nn.Conv2d(num_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)]
        num_filters_mult = 1

        for i in range(1, n_layers + 1):
            num_filters_mult_last = num_filters_mult
            num_filters_mult = min(2 ** i, 8)
            layers += [
                nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
                          2 if i < n_layers else 1, 1, bias=False),
                nn.BatchNorm2d(num_filters_last * num_filters_mult),
                nn.LeakyReLU(0.2, True)
            ]

        layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

### Training Experiments

We use the hyperparameters from the author's code:

In [None]:
# hyperparameters

decay = 0.99

In [None]:
class Model(nn.Module):
    def __init__(self, num_hiddens, num_channels, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()

        self._encoder = Encoder(num_channels, num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)
        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
                                      out_channels=embedding_dim,
                                      kernel_size=1,
                                      stride=1)
        if decay > 0.0:
            print(f'Performing EMA updates with decay rate: {decay}...')
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
        self._decoder = Decoder(embedding_dim,
                                num_channels,
                                num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)

    def calculate_lambda(self, perceptual_loss, gan_loss, epsilon=1e-4, max_lambda=1e4, scale=0.8):
        '''Calculate the lambda value for the loss function.
        '''
        ell = self._decoder._conv_trans_2 # the last layer of the decoder
        ell_weight = ell.weight
        perceptual_loss_gradients = torch.autograd.grad(perceptual_loss, ell_weight, retain_graph=True)[0]
        gan_loss_gradients = torch.autograd.grad(gan_loss, ell_weight, retain_graph=True)[0]

        lambda_factor = torch.norm(perceptual_loss_gradients) / torch.norm(gan_loss_gradients + epsilon)
        lambda_factor = torch.clamp(lambda_factor, min=0.0, max=max_lambda).detach()

        return scale * lambda_factor

    def forward(self, x):
        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized)

        return loss, x_recon, perplexity, z

In [None]:
def init_weights(m):
    """Initialize the weights of the module.

    Args:
        m (nn.Module): Module to initialize the weights of.
    """
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
discriminator = Discriminator(num_channels=num_channels).to(device)
discriminator.apply(init_weights)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
model = Model(num_hiddens, num_channels, num_residual_layers, num_residual_hiddens,
              num_embeddings, embedding_dim,
              commitment_cost, decay).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)
opt_disc = optim.Adam(discriminator.parameters(), lr=learning_rate, amsgrad=False)

print('='*50)
print('Model Summary:')
print(model.eval())
print('-'*50)
print('Experiment Settings:')
print(f'- Number of epochs: {num_epochs}')
print(f'- Learning rate: {learning_rate}')
print(f'- Batch size: {batch_size}')
print(f'- Commitment cost: {commitment_cost}')
print(f'- Number of Embeddings: {num_embeddings}')
print(f'- Embedding dimension: {embedding_dim}')
print('='*50)

model.train()
train_res_recon_error = []
train_res_perplexity = []
train_res_vq_loss = []
train_res_gan_loss = []
train_res_psnr = []
disc_factor = 0.

global_step = 0

for epoch in range(num_epochs):
  with tqdm(training_loader, unit="batch") as pbar:
      for data, _ in pbar:
        data = data.to(device)
        optimizer.zero_grad()

        vq_loss, data_recon, perplexity, latents = model(data)

        recon_error = F.mse_loss(data_recon, data) / data_variance
        loss = recon_error + vq_loss

        opt_disc.zero_grad()

        disc_real = discriminator(data)
        disc_fake = discriminator(data_recon)

        disc_loss_real = torch.mean(F.relu(1. - disc_real))
        disc_loss_fake = torch.mean(F.relu(1. + disc_fake))

        if global_step > disc_start:
          disc_factor = 0.2

        g_loss = -torch.mean(disc_fake)

        loss += disc_factor * model.calculate_lambda(recon_error, g_loss) * g_loss

        loss.backward(retain_graph=True)

        gan_loss = disc_factor * 0.5 * (disc_loss_real + disc_loss_fake)

        gan_loss.backward()

        optimizer.step()
        opt_disc.step()

        # compute PSNR
        psnr = 10 * torch.log10(1 / recon_error)

        train_res_recon_error.append(recon_error.item())
        train_res_perplexity.append(perplexity.item())
        train_res_vq_loss.append(vq_loss.item())
        train_res_gan_loss.append(gan_loss.item())
        train_res_psnr.append(psnr.item())

        pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")

        pbar.set_postfix(recon_error=np.mean(train_res_recon_error[-100:]),
                          perplexity=np.mean(train_res_perplexity[-100:]),
                          vq_loss=np.mean(train_res_vq_loss[-100:]),
                          gan_loss=np.mean(train_res_gan_loss[-100:]),
                          psnr=np.mean(train_res_psnr[-100:]))

        global_step += 1

Performing EMA updates with decay rate: 0.99...
Model Summary:
Model(
  (_encoder): Encoder(
    (_conv_1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (_conv_2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (_conv_3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (_residual_stack): ResidualStack(
      (_layers): ModuleList(
        (0-1): 2 x Residual(
          (_block): Sequential(
            (0): ReLU(inplace=True)
            (1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (2): ReLU(inplace=True)
            (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
        )
      )
    )
  )
  (_pre_vq_conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
  (_vq_vae): VectorQuantizerEMA(
    (_embedding): Embedding(512, 16)
  )
  (_decoder): Decoder(
    (_conv_1): Conv2d(16, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

Epoch 1/1000: 100%|██████████| 6250/6250 [02:07<00:00, 49.04batch/s, gan_loss=0, perplexity=118, psnr=10.8, recon_error=0.0849, vq_loss=0.0153]
Epoch 2/1000:  61%|██████    | 3789/6250 [01:16<00:47, 51.62batch/s, gan_loss=0, perplexity=170, psnr=11.4, recon_error=0.0735, vq_loss=0.0235]

In [None]:
model.eval()

In [None]:
# save model
torch.save(model.state_dict(), 'vqgan.pt')

In [None]:
sum(p.numel() for p in model._encoder.parameters())

### Plot Evaluations for the VQ-GAN Training

In [None]:
train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7)
train_res_vq_loss_smooth = savgol_filter(train_res_vq_loss, 201, 7)
# train_res_gan_loss_smooth = savgol_filter(train_res_gan_loss, 201, 7)
train_res_psnr_smooth = savgol_filter(train_res_psnr, 201, 7)

In [None]:
f = plt.figure(figsize=(16,10))
ax = f.add_subplot(2,2,1)
ax.plot(train_res_recon_error_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed NMSE.')
ax.set_xlabel('iteration')
ax = f.add_subplot(2,2,2)
ax.plot(train_res_perplexity_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed perplexity.')
ax.set_xlabel('iteration')
ax = f.add_subplot(2,2,3)
ax.plot(train_res_vq_loss_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed VQ loss.')
ax.set_xlabel('iteration')
ax = f.add_subplot(2,2,4)
# ax.plot(train_res_gan_loss_smooth)
# ax.set_yscale('log')
# ax.set_title('Smoothed GAN loss.')
# ax.set_xlabel('iteration')
ax.plot(train_res_psnr_smooth)
ax.set_title('Smoothed PSNR.')
ax.set_xlabel('iteration')

### View Reconstructions

In [None]:
model.eval()

(valid_originals, _) = next(iter(validation_loader))
valid_originals = valid_originals.to(device)

vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals))
_, valid_quantize, _, _ = model._vq_vae(vq_output_eval)
valid_reconstructions = model._decoder(valid_quantize)

In [None]:
(train_originals, _) = next(iter(training_loader))
train_originals = train_originals.to(device)
_, train_reconstructions, _, _ = model._vq_vae(train_originals)

In [None]:
plt.rcParams["figure.figsize"] = (15, 10)

In [None]:
def show(img):
    npimg = img.numpy()
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest', cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

In [None]:
channel_idx = 0

In [None]:
# Latent space images
latents = vq_output_eval.permute(0, 2, 3, 1).contiguous().view(-1, embedding_dim)
U, S, V = torch.pca_lowrank(latents)
projections = torch.matmul(latents, V[:, :num_channels]).view(32, num_channels, 7, 7) # project to 3 channel to view the latents as RGB images
latent_imgs = projections.cpu().data
for i in range(len(latent_imgs)):
    for j in range(num_channels):
      tmp = latent_imgs[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs[i,j,:,:] = tmp
show(make_grid(latent_imgs), )

In [None]:
projections_top = torch.matmul(latents, V[:, channel_idx]).view(32, 1, 7, 7) # project to 1 channel to view the latents as RGB images
latent_imgs_top = projections_top.cpu().data
for i in range(len(latent_imgs_top)):
    for j in range(1):
      tmp = latent_imgs_top[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs_top[i,j,:,:] = tmp
show(make_grid(latent_imgs_top), )

In [None]:
# Latent space reconstructions
reconstructed_latents = valid_quantize.permute(0, 2, 3, 1).contiguous().view(-1, embedding_dim)
U, S, V = torch.pca_lowrank(reconstructed_latents)
projections = torch.matmul(reconstructed_latents, V[:, :num_channels]).view(32, num_channels, 7, 7) # project to 3 channel to view the latents as RGB images
latent_imgs = projections.cpu().data
for i in range(len(latent_imgs)):
    for j in range(num_channels):
      tmp = latent_imgs[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs[i,j,:,:] = tmp
show(make_grid(latent_imgs), )

In [None]:
projections_recon_top = torch.matmul(reconstructed_latents, V[:, channel_idx]).view(32, 1, 7, 7) # project to 1 channel to view the latents as RGB images
latent_imgs_recon_top = projections_recon_top.cpu().data
for i in range(len(latent_imgs_recon_top)):
    for j in range(1):
      tmp = latent_imgs_recon_top[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs_recon_top[i,j,:,:] = tmp
show(make_grid(latent_imgs_recon_top), )

In [None]:
print(F.mse_loss(latents, reconstructed_latents))

In [None]:
show(make_grid(valid_originals.cpu()+0.5))

In [None]:
reconstructed_imgs = valid_reconstructions.cpu().data + 0.5

for i in range(len(reconstructed_imgs)):
    for j in range(1):
      tmp = torch.clamp(reconstructed_imgs[i,j,:,:], 0, 1)
      reconstructed_imgs[i,j,:,:] = tmp
show(make_grid(reconstructed_imgs), )

In [None]:
# compute PSNR
mse = F.mse_loss(valid_originals, valid_reconstructions)
psnr = 10 * torch.log10(1 / mse)
print(f'PSNR: {psnr}')

### View Embedding

In [None]:
proj = umap.UMAP(n_neighbors=3,
                 min_dist=0.1,
                 metric='cosine').fit_transform(model._vq_vae._embedding.weight.data.cpu())

In [None]:
plt.scatter(proj[:,0], proj[:,1], alpha=0.3)

### Visualize the Codebook Vectors

In [None]:
codebook = model._vq_vae._embedding.weight.data
print(codebook.shape)

In [None]:
def visualize_codebook(codebook, nrows, ncols):
    fig = plt.figure(figsize=(10, 10))
    assert nrows * ncols <= codebook.shape[0]

    for i in range(nrows * ncols):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        embedding_vector = codebook[i].view(4, 4, 1).contiguous().cpu().data # convert to RGB format
        embedding_vector -= embedding_vector.min()
        embedding_vector /= embedding_vector.max()
        ax.imshow(embedding_vector.numpy(), cmap='gray')
        ax.axis('off')

    plt.show()

In [None]:
visualize_codebook(codebook, 16, 16)

## DL-GAN

### Dictionary Learning Bottleneck

In [None]:
class DictLearn(nn.Module):
    """
    Dictionary Learning.

    See:
    - M. Aharon, M. Elad, and A. Bruckstein, "K-SVD: An Algorithm for Designing Overcomplete Dictionaries for Sparse Representation," IEEE Trans. Signal Processing, vol. 54, no. 11, pp. 4311-4322, 2006.
    - Rubinstein, R., Zibulevsky, M. and Elad, M., "Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal Matching Pursuit," CS Technion, 2008.
    """
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, eta, sparsity_level):
        """
        class constructor for DictLearn

        :param embedding_dim: dimension of the embedding
        :param num_embeddings: number of dictionary atoms
        :param sparsity_level: maximum sparsity (number of non-zero coefficients) of the representation, reduces to K-Means (Vector Quantization) when set to 1
        :param initial_dict: initial dictionary if given, otherwise random rows from the data matrix are used
        :param max_iter: maximum number of iterations
        """
        super(DictLearn, self).__init__()

        self._num_embeddings = num_embeddings
        self._embedding_dim = embedding_dim
        self._commitment_cost = commitment_cost
        self._eta = eta
        self._sparsity_level = sparsity_level

        self._dictionary = nn.Embedding(embedding_dim, num_embeddings)
        self._dictionary.weight.data.normal_(0, 1)
        self._dictionary.weight.data.copy_(self._dictionary.weight / torch.linalg.norm(self._dictionary.weight, dim=0))

        self._gamma = None
        self._A = None
        self._B = None

    def forward(self, z_e):
        if z_e.shape[2] * z_e.shape[3] < self._num_embeddings:
            kernel_size = z_e.shape[2]
        else:
            kernel_size = 16

        stride = kernel_size # non-overlapping patches

        # break the input tensor into patches
        patches = F.unfold(z_e, kernel_size=kernel_size, stride=stride).permute(2, 0, 1).contiguous()
        patches_shape = patches.shape
        patches = patches.view(patches.shape[0] * patches.shape[1], self._embedding_dim, kernel_size, kernel_size).contiguous()

        # permute
        z_e = z_e.permute(0, 2, 3, 1).contiguous()
        ze_shape = z_e.shape
        # Flatten input
        ze_flattened = z_e.view(self._embedding_dim, -1).contiguous() # convert to column-major order, i.e., each column is a data point

        # flatten patches
        patches = patches.view(self._embedding_dim, -1).contiguous()

        """
        Sparse Coding Stage
        """

        if self._gamma is None:
          # initialize dictionary with random columns of z_e
          # self._dictionary.weight.data.copy_(nn.Parameter(z_e[:, torch.randperm(z_e.shape[1])[:self._num_embeddings]]))
          # normalize the dictionary
          self._dictionary.weight.data.copy_(nn.Parameter(self._dictionary.weight / torch.linalg.norm(self._dictionary.weight, dim=0)))
          # initialize the sparse codes
          self._gamma = nn.Parameter(self.update_gamma(patches.detach(), self._dictionary.weight.detach(), debug=False))
        else:
          # normalize dictionary
          self._dictionary.weight.data.copy_(nn.Parameter(self._dictionary.weight / torch.linalg.norm(self._dictionary.weight, dim=0)))
          # update the sparse codes
          self._gamma.data.copy_(self.update_gamma(patches.detach(), self._dictionary.weight.detach(), debug=False))

        encodings = self._gamma.detach() # sparse codes

        # compute reconstruction
        recon = self._dictionary.weight @ self._gamma.detach()
        recon = recon.view(patches_shape).permute(1, 2, 0).contiguous() # convert to patches

        # fold back the patches
        recon = F.fold(recon, (ze_shape[1], ze_shape[2]), kernel_size=kernel_size, stride=stride).permute(0, 2, 3, 1).contiguous()

        # compute loss
        e_latent_loss = F.mse_loss(recon.detach(), z_e) * self._commitment_cost # latent loss from encoder
        loss = e_latent_loss + F.mse_loss(recon, z_e.detach()) # reconstruction loss

        # straight-through gradient estimator
        recon = z_e + (recon - z_e).detach()

        # compute perplexity
        avg_probs = torch.mean(encodings.bool().float() / self._sparsity_level, dim=1) # convert nonzero entries of encodings to 1.0
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return loss, e_latent_loss, recon.permute(0, 3, 1, 2).contiguous(), z_e.detach(), perplexity, encodings

    def update_dictionary(self, z_e, dictionary, t):
        """online dictionary update via Block Coordinate Descent.

        References:
        - Mairal J, Bach F, Ponce J, Sapiro G. Online dictionary learning for sparse coding.
          In Proceedings of the 26th annual international conference on machine learning 2009 Jun 14 (pp. 689-696).
        """
        # # compute beta
        # theta = 0
        # eta = z_e.shape[1]

        # if t < eta:
        #   theta = t * eta
        # else:
        #   theta = eta ** 2 + t - eta

        # beta = (theta + 1 - eta) / (theta + 1)

        # # precomputations
        # if self._A is None:
        #   self._A = self._gamma.mm(self._gamma.t()) + torch.diag(torch.ones(self._num_embeddings, device='cuda')) * 1e-10
        # else:
        #   self._A = nn.Parameter(beta * self._A + self._gamma.mm(self._gamma.t()) + torch.diag(torch.ones(self._num_embeddings, device='cuda')) * 1e-10)
        # if self._B is None:
        #   self._B = z_e.mm(self._gamma.t()) + torch.diag(torch.ones(self._embedding_dim, device='cuda')) * 1e-10
        # else:
        #   self._B = nn.Parameter(beta * self._B + z_e.mm(self._gamma.t()) + torch.diag(torch.ones(self._embedding_dim, device='cuda')) * 1e-10)

        # self._dictionary.weight.data.copy_((self._B - self._dictionary.weight @ self._A) / self._A.diag() + self._dictionary.weight)
        # self._dictionary.weight.data.copy_(self._dictionary.weight / torch.linalg.norm(self._dictionary.weight, dim=0))

        self._dictionary.weight.data.copy_(
            self._dictionary.weight - (self._eta / z_e.shape[1]) * (self._dictionary.weight @ self._gamma - z_e) @ torch.sign(self._gamma).t())
        self._dictionary.weight.data.copy_(self._dictionary.weight / torch.linalg.norm(self._dictionary.weight, dim=0))

    def update_gamma(self, signals, dictionary, debug=False):
        """sparse coding stage

        Implemented using the Batch Orthogonal Matching Pursuit (OMP) algorithm.

        Reference:
        - Rubinstein, R., Zibulevsky, M. and Elad, M., "Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal Matching Pursuit," CS Technion, 2008.

        :param signals: input signals to be sparsely coded
        """
        embedding_dim, num_signals = signals.shape
        dictionary_t = dictionary.t() # save the transpose of the dictionary for faster computation
        gram_matrix = dictionary_t.mm(dictionary) # the Gram matrix, dimension: num_atoms x num_atoms
        eps = torch.norm(signals, dim=0) # residual, initialized as the L2 norm of the signal
        corr_init = dictionary_t.mm(signals).t() # initial correlation vector, transposed to make num_signals the first dimension
        gamma = torch.zeros_like(corr_init) # placeholder for the sparse coefficients

        corr = corr_init
        L = torch.ones(num_signals, 1, 1, device=signals.device) # contains the progressive Cholesky of the Gram matrix in the selected indices
        I = torch.zeros(num_signals, 0, dtype=torch.long, device=signals.device) # placeholder for the index set
        omega = torch.ones_like(corr_init, dtype=torch.bool) # used to zero out elements in corr before argmax
        signal_idx = torch.arange(num_signals, device=signals.device)
        delta = torch.zeros(num_signals, device=signals.device) # to track residuals

        k = 0
        while k < self._sparsity_level:
            k += 1
            k_hats = torch.argmax(torch.abs(corr * omega), dim=1) # select the index of the maximum correlation
            # update omega to make sure we do not select the same index twice
            omega[torch.arange(k_hats.shape[0], device=signals.device), k_hats] = 0
            expanded_signal_idx = signal_idx.unsqueeze(0).expand(k, num_signals).t() # expand is more efficient than repeat

            if k > 1: # Cholesky update
                G_ = gram_matrix[I[signal_idx, :], k_hats[expanded_signal_idx[...,:-1]]].view(num_signals, k - 1, 1) # compute for all signals in a vectorized manner
                w = torch.linalg.solve_triangular(L, G_, upper=False).view(-1, 1, k - 1)
                w_br = torch.sqrt(1 - (w**2).sum(dim=2, keepdim=True)) # L bottom-right corner element: sqrt(1 - w.t().mm(w))

                # concatenate into the new Cholesky: L <- [[L, 0], [w, w_br]]
                k_zeros = torch.zeros(num_signals, k - 1, 1, device=signals.device)
                L = torch.cat((
                    torch.cat((L, k_zeros), dim=2),
                    torch.cat((w, w_br), dim=2),
                    ), dim=1)

            # update non-zero indices
            I = torch.cat([I, k_hats.unsqueeze(1)], dim=1)

            # solve L
            corr_ = corr_init[expanded_signal_idx, I[signal_idx, :]].view(num_signals, k, 1)
            gamma_ = torch.cholesky_solve(corr_, L)

            # de-stack gamma into the non-zero elements
            gamma[signal_idx.unsqueeze(1), I[signal_idx]] = gamma_[signal_idx].squeeze(-1)

            # beta = G_I * gamma_I
            beta = gamma[signal_idx.unsqueeze(1), I[signal_idx]].unsqueeze(1).bmm(gram_matrix[I[signal_idx], :]).squeeze(1)

            corr = corr_init - beta

            # update residual
            # new_delta = (gamma * beta).sum(dim=1)
            # eps += delta-new_delta
            # delta = new_delta

            if debug and k % 1 == 0:
              print('Step {}, residual: {:.4f}, below tolerance: {:.4f}'.format(k, eps.max(), (eps < 1e-7).float().mean().item()))

        return gamma.t() # transpose the sparse coefficients to make num_signals the first dimension

### Define the Model DL-GAN

In [None]:
class DLGAN(nn.Module):
    def __init__(self, num_hiddens, num_channels, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, eta, sparsity_level):
        super(DLGAN, self).__init__()

        self._encoder = Encoder(num_channels, num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)

        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
                                      out_channels=embedding_dim,
                                      kernel_size=1,
                                      stride=1)

        self._dl_bottleneck = DictLearn(num_embeddings,
                                     embedding_dim,
                                     commitment_cost=commitment_cost,
                                     eta=eta,
                                     sparsity_level=sparsity_level)

        self._decoder = Decoder(embedding_dim,
                                num_channels,
                                num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)

    def forward(self, x, global_step):
        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        loss, dl_loss, sparsified, latents, perplexity, encodings = self._dl_bottleneck(z)
        # self._dl_bottleneck.update_dictionary(latents.detach(), self._dl_bottleneck._dictionary.weight, global_step)
        x_recon = self._decoder(sparsified)

        return loss, dl_loss, x_recon, latents, perplexity, encodings

    def calculate_lambda(self, perceptual_loss, gan_loss, epsilon=1e-4, max_lambda=1e4, scale=0.8):
        '''Calculate the lambda value for the loss function.
        '''
        ell = self._decoder._conv_trans_2 # the last layer of the decoder
        ell_weight = ell.weight
        perceptual_loss_gradients = torch.autograd.grad(perceptual_loss, ell_weight, retain_graph=True)[0]
        gan_loss_gradients = torch.autograd.grad(gan_loss, ell_weight, retain_graph=True)[0]

        lambda_factor = torch.norm(perceptual_loss_gradients) / torch.norm(gan_loss_gradients + epsilon)
        lambda_factor = torch.clamp(lambda_factor, min=0.0, max=max_lambda).detach()

        return scale * lambda_factor

In [None]:
dl_disc = Discriminator(num_channels).to(device)
dl_disc.apply(init_weights)

### Training Experiments

In [None]:
# hyperparameters

sparsity_level = 8
eta = int(embedding_dim / sparsity_level)

In [None]:
dlgan_model = DLGAN(num_hiddens, num_channels, num_residual_layers, num_residual_hiddens,
              num_embeddings, embedding_dim,
              commitment_cost, eta, sparsity_level).to(device)

In [None]:
dlgan_model.eval()

In [None]:
# optimizer setup
optimizer = optim.Adam(dlgan_model.parameters(), lr=learning_rate, amsgrad=False)
opt_disc = optim.Adam(dl_disc.parameters(), lr=learning_rate, amsgrad=False)

In [None]:
dlgan_model.train()

print('='*50)
print('Experiment Settings:')
print(f'- Number of epochs: {num_epochs}')
print(f'- Learning rate: {learning_rate}')
print(f'- Batch size: {batch_size}')
print(f'- Sparsity level: {sparsity_level}')
print(f'- eta: {eta}')
print(f'- Commitment cost: {commitment_cost}')
print(f'- Number of Embeddings: {num_embeddings}')
print(f'- Embedding dimension: {embedding_dim}')
print('='*50)
train_res_recon_error = []
train_res_perplexity = []
train_res_dl_loss = []
train_res_gan_loss = []
train_res_psnr = []
disc_factor = 0.

global_step = 0

for epoch in range(num_epochs):
  with tqdm(training_loader, unit="batch") as pbar:
    for i, (data, _) in enumerate(pbar):
      (data, _) = next(iter(training_loader))
      data = data.to(device)
      optimizer.zero_grad()

      sc_loss, dl_loss, data_recon, latents, perplexity, encodings = dlgan_model(data, global_step)

      if debug:
        print(f'DEBUG: sparsity check: {torch.count_nonzero(encodings, dim=0)}')

      recon_error = F.mse_loss(data_recon, data) / data_variance

      loss = sc_loss + recon_error

      opt_disc.zero_grad()

      disc_real = discriminator(data)
      disc_fake = discriminator(data_recon)

      disc_loss_real = torch.mean(F.relu(1. - disc_real))
      disc_loss_fake = torch.mean(F.relu(1. + disc_fake))

      if global_step > disc_start:
        disc_factor = 0.2

      g_loss = -torch.mean(disc_fake)

      loss += disc_factor * dlgan_model.calculate_lambda(recon_error, g_loss) * g_loss

      loss.backward(retain_graph=True)

      gan_loss = disc_factor * 0.5 * (disc_loss_real + disc_loss_fake)

      gan_loss.backward()

      optimizer.step()
      opt_disc.step()

      # compute PSNR
      psnr = 10 * torch.log10(1 / recon_error)

      train_res_recon_error.append(recon_error.item())
      train_res_perplexity.append(perplexity.item())
      train_res_dl_loss.append(dl_loss.item())
      train_res_gan_loss.append(gan_loss.item())
      train_res_psnr.append(psnr.item())

      global_step += 1

      pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")

      pbar.set_postfix(recon_error=np.mean(train_res_recon_error[-100:]),
                       perplexity=np.mean(train_res_perplexity[-100:]),
                       dl_loss=np.mean(train_res_dl_loss[-100:]),
                       gan_loss=np.mean(train_res_gan_loss[-100:]),
                       psnr=np.mean(train_res_psnr[-100:]))

In [None]:
# save model
torch.save(dlgan_model.state_dict(), 'dlgan.pt')

In [None]:
sum(p.numel() for p in dlgan_model._encoder.parameters())

### Plot Evaluations for the DL-GAN Training

In [None]:
train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 2)
train_res_dl_loss_smooth = savgol_filter(train_res_dl_loss, 201, 7)
# train_res_gan_loss_smooth = savgol_filter(train_res_gan_loss, 201, 7)
train_res_psnr_smooth = savgol_filter(train_res_psnr, 201, 7)

In [None]:
f = plt.figure(figsize=(16,10))
ax = f.add_subplot(2,2,1)
ax.plot(train_res_recon_error_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed NMSE.')
ax.set_xlabel('Iteration')
ax = f.add_subplot(2,2,2)
ax.plot(train_res_perplexity_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed Perplexity.')
ax.set_xlabel('Iteration')
ax = f.add_subplot(2,2,3)
ax.plot(train_res_dl_loss_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed DL loss.')
ax.set_xlabel('Iteration')
ax = f.add_subplot(2,2,4)
# ax.plot(train_res_gan_loss_smooth)
# ax.set_yscale('log')
# ax.set_title('Smoothed GAN loss.')
# ax.set_xlabel('Iteration')
ax.plot(train_res_psnr_smooth)
ax.set_title('Smoothed PSNR.')
ax.set_xlabel('Iteration')

### View Reconstructions

In [None]:
dlgan_model.eval()

(valid_originals, _) = next(iter(validation_loader))
valid_originals = valid_originals.to(device)

sc_output_eval = dlgan_model._pre_vq_conv(dlgan_model._encoder(valid_originals))
# preprocess
sc_output_eval = sc_output_eval.permute(0, 2, 3, 1).contiguous() # convert to column-major order, i.e., each column is a data point
sc_output_eval_shape = sc_output_eval.shape
sc_output_eval = sc_output_eval.view(embedding_dim, -1).contiguous() # convert to column-major order, i.e., each column is a data point
# normalize the dictionary
dictionary = dlgan_model._dl_bottleneck._dictionary.weight.data / torch.linalg.norm(dlgan_model._dl_bottleneck._dictionary.weight.data, dim=0)
# compute the sparse code
valid_sc = dlgan_model._dl_bottleneck.update_gamma(sc_output_eval, dictionary, debug=False)
# reconstruct the latent space
valid_latent = dictionary @ valid_sc
valid_latent = valid_latent.view(sc_output_eval_shape).permute(0, 3, 1, 2).contiguous()
# reconstruct the images
valid_reconstructions = dlgan_model._decoder(valid_latent)

In [None]:
show(make_grid(valid_originals.cpu() + 0.5))

In [None]:
reconstructed_imgs = valid_reconstructions.cpu().data + 0.5

for i in range(len(reconstructed_imgs)):
    for j in range(num_channels):
      tmp = torch.clamp(reconstructed_imgs[i,j,:,:], 0, 1)
      reconstructed_imgs[i,j,:,:] = tmp
show(make_grid(reconstructed_imgs), )

In [None]:
# compute psnr
recon_error = F.mse_loss(valid_reconstructions, valid_originals)
psnr = 10 * torch.log10(1 / recon_error)
print(f'PSNR: {psnr.item()}')

### View Latent Space Reconstructions

In [None]:
channel_idx = 0

In [None]:
# Latent space images
original_latents = sc_output_eval.view(-1, embedding_dim)
original_latents_shape = original_latents.shape
U, S, V = torch.pca_lowrank(original_latents)
projections = torch.matmul(original_latents, V[:, :num_channels]).view(32, num_channels, 7, 7) # project to 3 channel to view the latents as RGB images
latent_imgs = projections.cpu().data
for i in range(len(latent_imgs)):
    for j in range(num_channels):
      tmp = latent_imgs[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs[i,j,:,:] = tmp
show(make_grid(latent_imgs), )

In [None]:
# Latent space images top channel
projections_top = torch.matmul(original_latents, V[:, channel_idx]).view(32, 1, 7, 7) # project to 3 channel to view the latents as RGB images
latent_imgs_top = projections_top.cpu().data
for i in range(len(latent_imgs_top)):
    for j in range(1):
      tmp = latent_imgs_top[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs_top[i,j,:,:] = tmp
show(make_grid(latent_imgs_top), )

In [None]:
# Latent space images
reconstructed_latents = valid_latent.permute(0, 2, 3, 1).contiguous().view(-1, embedding_dim)
U, S, V = torch.pca_lowrank(reconstructed_latents)
projections = torch.matmul(reconstructed_latents, V[:, :3]).view(32, num_channels, 7, 7) # project to 3 channel to view the latents as RGB images
latent_imgs = projections.cpu().data
for i in range(len(latent_imgs)):
    for j in range(num_channels):
      tmp = latent_imgs[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs[i,j,:,:] = tmp
show(make_grid(latent_imgs), )

In [None]:
# Latent space images
projections_recon_top = torch.matmul(reconstructed_latents, V[:, channel_idx]).view(32, 1, 7, 7) # project to 3 channel to view the latents as RGB images
latent_imgs_recon_top = projections_recon_top.cpu().data
for i in range(len(latent_imgs_recon_top)):
    for j in range(1):
      tmp = latent_imgs_recon_top[i,j,:,:]
      tmp -= tmp.min()
      tmp /= tmp.max()
      latent_imgs_recon_top[i,j,:,:] = tmp
show(make_grid(latent_imgs_recon_top), )

In [None]:
print(F.mse_loss(original_latents, reconstructed_latents))

### View the Dictionary Atoms

In [None]:
dictionary = dlgan_model._dl_bottleneck._dictionary.weight.t().data # dim: K x D
dictionary.shape

In [None]:
proj = umap.UMAP(n_neighbors=3,
                 min_dist=0.1,
                 metric='cosine').fit_transform(dictionary.cpu())

In [None]:
plt.scatter(proj[:,0], proj[:,1], alpha=0.3)

In [None]:
def visualize_dictionary(dictionary, nrows, ncols):
    fig = plt.figure(figsize=(10, 10))
    assert nrows * ncols <= dictionary.shape[0]

    for i in range(nrows * ncols):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        atom = dictionary[i].view(4, 4, 1).contiguous().cpu().data # convert to RGB format
        atom -= atom.min()
        atom /= atom.max()
        ax.imshow(atom.numpy(), cmap='gray')
        ax.axis('off')

    plt.show()

In [None]:
visualize_dictionary(dictionary, 16, 16)

In [None]:
print(dictionary)