<a href="https://colab.research.google.com/github/girishp1983/llama2/blob/master/categorical_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch Categorical VAE with Gumbel-Softmax

This notebook shows how to train a VAE with categorical latents using the Gumbel-softmax trick. The accompanying blog post is here: https://jxmo.io/posts/variational-autoencoders

In [None]:
import math
import numpy as np
import os
import torch
import torchvision
from typing import Tuple
from tqdm.auto import tqdm
from PIL import Image

import torch.distributions as dist
import torch.optim as optim
import torchvision.transforms as transforms

In [None]:
class Encoder(torch.nn.Module):
    cnn: torch.nn.Module
    input_shape: torch.Size
    N: int # number of categorical distributions
    K: int # number of classes
    def __init__(self, N: int, K: int, input_shape: torch.Size):
        super().__init__()
        self.N = N
        self.K = K
        self.input_shape = input_shape
        print('N =', N, 'and K =', K)
        self.network = torch.nn.Sequential(
            torch.nn.Conv2d(1, 8, 3, stride=2, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(8, 16, 3, stride=2, padding=1),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 32, 3, stride=2, padding=0),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(3 * 3 * 32, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, N*K),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Produces encoding `z` for input spectrogram `x`.

        Actually returns theta, the parameters of a bernoulli producing `z`.
        """
        assert len(x.shape) == 4 # x should be of shape [B, C, Y, X]
        return self.network(x).view(-1, self.N, self.K)


class Decoder(torch.nn.Module):
    output_shape: torch.Size
    N: int # number of categorical distributions
    K: int # number of classes
    def __init__(self, N: int, K: int, output_shape: torch.Size):
        super().__init__()
        self.N = N
        self.K = K
        self.output_shape = output_shape
        self.network = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(N*K, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 3 * 3 * 32),
            torch.nn.ReLU(),
            torch.nn.Unflatten(dim=1, unflattened_size=(32, 3, 3)),
            torch.nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1),
            torch.nn.Sigmoid()
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """Produces output `x_hat` for input `z`.

        z is a tensor with a batch dimension and, for each item,
            containing parameters of N categorical distributions,
            each with K classes
        """
        assert len(z.shape) == 3 # [B, N, K]
        assert z.shape[1:] == (self.N, self.K)
        x_hat = self.network(z)
        return x_hat.view((-1,) + self.output_shape)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def gumbel_distribution_sample(shape: torch.Size, eps=1e-20) -> torch.Tensor:
    """Samples from the Gumbel distribution given a tensor shape and value of epsilon.

    note: the \eps here is just for numerical stability. The code is basically just doing
            > -log(-log(rand(shape)))
    where rand generates random numbers on U(0, 1).
    """
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_distribution_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """Adds Gumbel noise to `logits` and applies softmax along the last dimension.

    Softmax is applied wrt a given temperature value. A higher temperature will make the softmax
    softer (less spiky). Lower temperature will make softmax more spiky and less soft. As
    temperature -> 0, this distribution approaches a categorical distribution.
    """
    assert len(logits.shape) == 2 # (should be of shape (b, n_classes))
    y = logits + gumbel_distribution_sample(logits.shape).to(device)
    return torch.nn.functional.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits: torch.Tensor, temperature: float, batch=False) -> torch.Tensor:
    """
    Gumbel-softmax.
    input: [*, n_classes] (or [b, *, n_classes] for batch)
    return: flatten --> [*, n_class] a one-hot vector (or b, *, n_classes for batch)
    """
    input_shape = logits.shape
    if batch:
        assert len(logits.shape) == 3
        b, n, k = input_shape
        logits = logits.view(b*n, k)
    assert len(logits.shape) == 2
    y = gumbel_softmax_distribution_sample(logits, temperature)
    n_classes = input_shape[-1] # TODO(jxm): check this!
    return y.view(input_shape)

In [None]:
class CategoricalVAE(torch.nn.Module):
    encoder: torch.nn.Module
    decoder: torch.nn.Module
    temperature: float
    def __init__(self, encoder: torch.nn.Module, decoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.temperature = 1.0

    def forward(self, x: torch.Tensor, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """VAE forward pass. Encoder produces phi, the parameters of a categorical distribution.
        Samples from categorical(phi) using gumbel softmax to produce a z. Passes z through encoder p(x|z)
        to get x_hat, a reconstruction of x.

        Returns:
            phi: parameters of categorical distribution that produced z
            x_hat: auto-encoder reconstruction of x
        """
        phi = self.encoder(x)
        B, N, K = phi.shape

        z_given_x = gumbel_softmax(phi, temperature, batch=True)
        x_hat = self.decoder(z_given_x)
        return phi, x_hat

In [None]:
import math
import numpy as np
import os
import torch
import torchvision
import tqdm
from PIL import Image

import torch.distributions as dist
import torch.optim as optim
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_training_data():
    # TODO implement datasets better
    # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    transform = transforms.Compose([transforms.ToTensor()])
    return torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

def categorical_kl_divergence(phi: torch.Tensor) -> torch.Tensor:
    # phi is logits of shape [B, N, K] where B is batch, N is number of categorical distributions, K is number of classes
    B, N, K = phi.shape
    phi = phi.view(B*N, K)
    q = dist.Categorical(logits=phi)
    p = dist.Categorical(probs=torch.full((B*N, K), 1.0/K).to(device)) # uniform bunch of K-class categorical distributions
    kl = dist.kl.kl_divergence(q, p) # kl is of shape [B*N]
    return kl.view(B, N)

In [None]:
model_save_interval = 5_000
batch_size = 64
max_steps = 50_000
initial_learning_rate = 0.001
initial_temperature = 1.0
minimum_temperature = 0.5
temperature_anneal_rate = 0.00003
K = 10 # number of classes
N = 30 # number of categorical distributions

training_images = load_training_data()
train_dataset = torch.utils.data.DataLoader(
    dataset=training_images,
    batch_size=batch_size,
    shuffle=True
)

image_shape = next(iter(train_dataset))[0][0].shape # [1, 28, 28]
encoder = Encoder(N, K, image_shape)
decoder = Decoder(N, K, image_shape)
model = CategoricalVAE(encoder, decoder)

parameters = list(model.parameters())
optimizer = optim.SGD(parameters, lr=initial_learning_rate, momentum=0.0)
learning_rate_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
step = 0
temperature = initial_temperature

# make folder for images
output_dir = os.path.join('outputs', 'categorical_vae')
os.makedirs(output_dir, exist_ok=True)

N = 30 and K = 10


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('using device:', device)
model = model.to(device)

using device: cuda


In [None]:
progress_bar = tqdm.tqdm(total=max_steps, desc='Training')
while step < max_steps:
    for data in train_dataset: # x should be a batch of torch.Tensor spectrograms, of shape [B, F, T]
        x = data[0].to(device)
        phi, x_hat = model(x, temperature) # phi shape: [B, N, K]; x_hat shape: [B, C, Y, X]
        reconstruction_loss = (
            torch.nn.functional.binary_cross_entropy(x_hat, x, reduction="none").sum()) / x.shape[0]
        kl_loss = torch.mean(
            torch.sum(categorical_kl_divergence(phi), dim=1)
        )
        loss = kl_loss + reconstruction_loss
        progress_bar.set_description(f'Training | Recon. loss = {reconstruction_loss:.7f} / KL loss = {kl_loss:.7f}')
        gradnorm = torch.nn.utils.clip_grad_norm_(parameters, 1)
        loss.backward()
        optimizer.step()

        # Incrementally anneal temperature and learning rate.
        if step % 1000 == 1:
            temperature = np.maximum(initial_temperature*np.exp(-temperature_anneal_rate*step), minimum_temperature)
            learning_rate_scheduler.step() # should multiply learning rate by 0.9

        if (step+1) % model_save_interval == 0:
            torch.save(model.state_dict(), os.path.join(output_dir, f'save_{step}.pt'))

        step += 1
        progress_bar.update(1)

Training | Recon. loss = 106.8197479 / KL loss = 16.9503593:  25%|██▍       | 12332/50000 [03:49<11:14, 55.86it/s]