In [1]:
%matplotlib inline
%load_ext tensorboard
import math
import os
from itertools import count, groupby
from operator import itemgetter

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.distributions import Bernoulli, Normal
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from pytorch_lightning import LightningModule, Trainer

In [2]:
class BetaTCVAE(nn.Module):
    """beta-TC class VAE

    Parameters
    ----------
    encoder : torch.nn.Module
    decoder : torch.nn.Module
        pytorch networks used to encode and decode to/from the latent variables
    e : int
        The output dimension of the encoder network
    z : int
        The input dimension of the decoder network
    beta : int
        Total Correlation weight term (default=1)
    lamb : float in [0, 1]
        Dimension wise KL term is (1 - lamb)
    """
    def __init__(self, encoder, decoder, e, z, beta=1, lamb=0):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.e_dim = e
        self.z_dim = z
        self.beta = beta
        self.lamb = lamb
        # Learned Z Hyperparams
        # Q: Why logvar and not stddev?
        # A: https://stats.stackexchange.com/a/353222
        self.mu = nn.Linear(e, z)
        self.logvar = nn.Linear(e, z)

    def encode(self, x):
        """Hook for reparameterizing the outs of the encoder"""
        h = self.encoder(x)
        mu = self.mu(h)
        std = torch.exp(0.5*self.logvar(h))
        eps = torch.randn_like(std)
        z = mu + std*eps
        return z, mu, std

    def get_xdist(self, z):
        """Hook for customising interpretation of decoder output"""
        return Bernoulli(logits=self.decoder(z))

    def get_pdist(self, z):
        """Hook to customize prior distribution"""
        return Normal(torch.zeros_like(z), torch.ones_like(z))

    def get_qdist(self, mu, std):
        """Hook to customize construction of qdist from mean and stddev"""
        return Normal(mu, std)

    def forward(self, x, dataset_size):
        """Calculates the Evidence Lower Bound (ELBO) of the VAE on x"""
        x_len = x.shape[0]
        z, mu, std = self.encode(x)

        # log(p(x))
        xdist = self.get_xdist(z)
        logpx = xdist.log_prob(x).view(x_len, -1).sum(1)

        # log(p(z))
        pdist = self.get_pdist(z)
        logpz = pdist.log_prob(z).view(x_len, -1).sum(1)

        # log(q(z|x))
        qdist = self.get_qdist(mu, std)
        logqz_condx = qdist.log_prob(z).view(x_len, -1).sum(1)

        # Calculate matrix of shape (x_len, x_len, z_dim) which contains the
        # log probability of each instance's latent variables under the
        # distributions of every instance latent vars in the batch
        qdist = qdist.expand((1, x_len, self.z_dim))
        qzmat = qdist.log_prob(z.view(x_len, 1, self.z_dim))

        # log(q(z)) via minibatch weighted sampling
        logmn = math.log(dataset_size * x_len)
        logqz = torch.logsumexp(qzmat.sum(2), dim=1) - logmn
        logqz_prodmarginals = (torch.logsumexp(qzmat, dim=1) - logmn).sum(1)

        # Calculate Modified ELBO:
        # Basic ELBO is just logpx + logpz - logqz_condx
        ix_code_mi = logqz_condx - logqz
        total_corr = self.beta * (logqz - logqz_prodmarginals)
        dimwise_kl = (1 - self.lamb) * (logqz_prodmarginals - logpz)
        modified_elbo = logpx - ix_code_mi - total_corr - dimwise_kl
        return modified_elbo

In [3]:
MNIST(os.getcwd(), train=True, download=True)

Dataset MNIST
    Number of datapoints: 60000
    Root location: /Users/stilljm/projects/johnstill/vae_experiments
    Split: Train

In [4]:
data = MNIST(os.getcwd(), transform=ToTensor())
loader = DataLoader(data, batch_size=128)
X, y = next(iter(loader))
X = X.view(X.shape[0], -1)
x_len = len(X)
X.shape, x_len

(torch.Size([128, 784]), 128)

In [5]:
e = 100
z = 20
encoder = nn.Sequential(nn.Linear(784, 400),
                        nn.ReLU(),
                        nn.Linear(400, 200),
                        nn.ReLU(),
                        nn.Linear(200, e),
                        nn.ReLU())
decoder = nn.Sequential(nn.Linear(z, e),
                        nn.ReLU(),
                        nn.Linear(e, 200),
                        nn.ReLU(),
                        nn.Linear(200, 400),
                        nn.ReLU(),
                        nn.Linear(400, 784))
vae = BetaTCVAE(encoder, decoder, e, z)
vae

BetaTCVAE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=100, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=20, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=400, bias=True)
    (5): ReLU()
    (6): Linear(in_features=400, out_features=784, bias=True)
  )
  (mu): Linear(in_features=100, out_features=20, bias=True)
  (logvar): Linear(in_features=100, out_features=20, bias=True)
)

In [6]:
optimizer = optim.Adam(vae.parameters())

for j in range(10):
    print(f'Epoch {j}')
    for i, (X, y) in enumerate(loader):
        optimizer.zero_grad()

        x_len = X.shape[0]
        X = X.view(x_len, -1)

        z, mu, std = vae.encode(X)
        ez = torch.exp(z)
        assert not torch.isnan(ez).any()

        xdist = vae.get_xdist(ez)
        logpx = xdist.log_prob(X).sum(1)

        pdist = vae.get_pdist(ez)
        logpz = pdist.log_prob(ez).sum(1)

        qdist = vae.get_qdist(mu, std)
        logqz_condx = qdist.log_prob(z).sum(1)

        qdist = qdist.expand((1, x_len, vae.z_dim))
        qzmat = qdist.log_prob(ez.view(x_len, 1, vae.z_dim))

        logmn = math.log(len(data) * x_len)
        logqz = torch.logsumexp(qzmat.sum(2), dim=1) - logmn
        logqz_prodmarginals = (torch.logsumexp(qzmat, dim=1) - logmn).sum(1)

        ix_code_mi = logqz_condx - logqz
        total_corr = vae.beta * (logqz - logqz_prodmarginals)
        dimwise_kl = (1 - vae.lamb) * (logqz_prodmarginals - logpz)
        modified_elbo = logpx - ix_code_mi - total_corr - dimwise_kl

        loss = modified_elbo.mul(-1).mean()
        assert not torch.isnan(loss).any()
        assert not torch.isinf(loss).any()
        if i %100 == 0:
            print(f'Epoch {j} Batch {i} Loss: {loss.item()}')
        loss.backward()
        optimizer.step()

Epoch 0
Epoch 0 Batch 0 Loss: 600.1227416992188
Epoch 0 Batch 100 Loss: 225.73538208007812
Epoch 0 Batch 200 Loss: 191.56504821777344
Epoch 0 Batch 300 Loss: 171.66139221191406
Epoch 0 Batch 400 Loss: 228.4467010498047
Epoch 1
Epoch 1 Batch 0 Loss: 179.67726135253906
Epoch 1 Batch 100 Loss: 172.30633544921875
Epoch 1 Batch 200 Loss: 186.26449584960938
Epoch 1 Batch 300 Loss: 160.93212890625
Epoch 1 Batch 400 Loss: 186.82305908203125
Epoch 2
Epoch 2 Batch 0 Loss: 169.50994873046875
Epoch 2 Batch 100 Loss: 201.45382690429688
Epoch 2 Batch 200 Loss: 183.2902069091797
Epoch 2 Batch 300 Loss: 161.8167724609375
Epoch 2 Batch 400 Loss: 181.29129028320312
Epoch 3
Epoch 3 Batch 0 Loss: 173.8082733154297
Epoch 3 Batch 100 Loss: 172.75070190429688
Epoch 3 Batch 200 Loss: 181.92477416992188
Epoch 3 Batch 300 Loss: 160.8887939453125
Epoch 3 Batch 400 Loss: 186.173828125
Epoch 4
Epoch 4 Batch 0 Loss: 172.81275939941406


AssertionError: 

In [7]:
loss

tensor(nan, grad_fn=<MeanBackward0>)

In [9]:
torch.isnan(ez).any()

tensor(False)

In [15]:
torch.isinf(ez).any()

tensor(True)

In [16]:
ez[torch.isnan(logpx)]

tensor([[5.5890e-25, 1.5919e-08, 3.1325e-13, 4.5740e-10, 1.9590e-15, 2.1965e-25,
         2.9276e-25, 1.5055e-22, 9.7902e-04, 9.0113e-19, 3.2367e-26, 1.7909e-13,
         1.3453e-22, 3.5180e-22, 3.8236e-16,        inf, 8.3145e-19, 6.9134e-20,
         5.9167e-23, 2.0944e-24]], grad_fn=<IndexBackward>)

In [17]:
z[torch.isinf(ez)]

tensor([297.0267], grad_fn=<IndexBackward>)

In [18]:
torch.exp(z[torch.isinf(ez)])

tensor([inf], grad_fn=<ExpBackward>)

In [19]:
np.exp(297.0267)

9.932459158195629e+128

In [27]:
xdist.log_prob(X).sum(1)

tensor([-2.0584e+02, -2.2263e+02, -1.5050e+02, -2.2141e+02, -1.4010e+02,
        -1.7858e+02, -2.5226e+02, -2.0033e+02, -2.6512e+02, -2.0508e+02,
        -1.9433e+02, -1.8068e+02, -1.4199e+03, -1.3812e+02, -1.6304e+02,
        -2.2851e+02, -2.3150e+02, -1.6934e+02, -1.3970e+02, -2.4321e+02,
        -2.0734e+02, -2.2168e+02, -1.9343e+02, -2.0837e+02, -2.9674e+02,
        -1.7253e+02, -1.4250e+02, -1.9290e+02, -1.5223e+02, -2.5754e+02,
        -2.2191e+02, -2.2269e+02, -2.4998e+02, -1.4117e+02, -2.0471e+02,
        -5.0500e+07, -2.0817e+02, -1.3788e+02, -1.9334e+02, -2.4626e+02,
                nan, -2.2713e+02, -1.4206e+02, -1.9521e+02, -1.2639e+02,
        -2.4949e+02, -2.1696e+02, -1.5610e+02, -2.0589e+02, -1.5835e+02,
        -1.1693e+02, -1.7149e+02, -2.4095e+02, -2.2080e+02, -1.9652e+02,
        -1.1750e+02, -1.3911e+02, -1.6507e+02, -3.6388e+02, -1.8240e+02,
        -1.6396e+02, -2.2006e+02, -2.1524e+02, -2.5575e+02, -1.6065e+02,
        -1.4846e+02, -1.8503e+02, -1.8924e+02, -1.4

In [28]:
ez.dtype

torch.float32

In [31]:
inf_idx = torch.isinf(ez)
mu[inf_idx], std[inf_idx]

(tensor([-57.4084], grad_fn=<IndexBackward>),
 tensor([99.9527], grad_fn=<IndexBackward>))