# Disclaimer

We used elements of the code from:
https://github.com/yoonholee/pytorch-vae



# Import

In [0]:
import torch
from torch import nn
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
import h5py
import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.utils import save_image
import os
import numpy as np
from PIL import Image
import urllib.request
import matplotlib.pyplot as plt
from scipy.special import logsumexp

import os
import sys
import numpy as np
import torch
from torch import optim

# Data loader

Taken from  https://raw.githubusercontent.com/yoonholee/pytorch-vae/master/data_loader/fixed_mnist.py

In [0]:
class fixedMNIST(data.Dataset):
    """ Binarized MNIST dataset, proposed in
    http://proceedings.mlr.press/v15/larochelle11a/larochelle11a.pdf

    Based on https://raw.githubusercontent.com/yoonholee/pytorch-vae/master/data_loader/fixed_mnist.py

    """
    train_file = 'binarized_mnist_train.amat'
    val_file = 'binarized_mnist_valid.amat'
    test_file = 'binarized_mnist_test.amat'

    def __init__(self, root, train=True, transform=None, download=False):
        # we ignore transform.
        self.root = os.path.expanduser(root)
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        self.data = self._get_data(train=train)

    def __getitem__(self, index):
        img = self.data[index]
        img = Image.fromarray(img)
        img = transforms.ToTensor()(img).type(torch.FloatTensor)
        return img, torch.tensor(-1)  # Meaningless tensor instead of target

    def __len__(self):
        return len(self.data)

    def _get_data(self, train=True):
        with h5py.File(os.path.join(self.root, 'data.h5'), 'r') as hf:
            data = hf.get('train' if train else 'test')
            data = np.array(data)
        return data

    def get_mean_img(self):
        return self.data.mean(0).flatten()

    def download(self):
        if self._check_exists():
            return
        if not os.path.exists(self.root):
            os.makedirs(self.root)

        print('Downloading MNIST with fixed binarization...')
        for dataset in ['train', 'valid', 'test']:
            filename = 'binarized_mnist_{}.amat'.format(dataset)
            url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(dataset)
            print('Downloading from {}...'.format(url))
            local_filename = os.path.join(self.root, filename)
            urllib.request.urlretrieve(url, local_filename)
            print('Saved to {}'.format(local_filename))

        def filename_to_np(filename):
            with open(filename) as f:
                lines = f.readlines()
            return np.array([[int(i)for i in line.split()] for line in lines]).astype('int8')

        train_data = np.concatenate([filename_to_np(os.path.join(self.root, self.train_file)),
                                     filename_to_np(os.path.join(self.root, self.val_file))])
        test_data = filename_to_np(os.path.join(self.root, self.val_file))
        with h5py.File(os.path.join(self.root, 'data.h5'), 'w') as hf:
            hf.create_dataset('train', data=train_data.reshape(-1, 28, 28))
            hf.create_dataset('test', data=test_data.reshape(-1, 28, 28))
        print('Done!')

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, 'data.h5'))


def data_loaders():

    loader_fn, root = fixedMNIST, './dataset/fixedmnist'

    dataset_dir = r'/content/'
    root = dataset_dir
    kwargs = {}

    batch_size = 32
    test_batch_size = 32

    # Train set
    train_loader = torch.utils.data.DataLoader(
        loader_fn(root, train=True,
                  download=True,
                  transform=transforms.ToTensor()),
        batch_size=batch_size,
        shuffle=True,
        **kwargs)

    # Test set
    test_loader = torch.utils.data.DataLoader(
        loader_fn(root,
                  train=False,
                  download=True,
                  transform=transforms.ToTensor()),
        batch_size=test_batch_size,
        shuffle=False,
        **kwargs)

    return train_loader, test_loader


def show_input(loader):

    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(7,7), squeeze=False)

    ax[0, 0].imshow(loader.dataset[11][0].reshape((28,28)),
               cmap='gray', interpolation='nearest')
    ax[0, 1].imshow(loader.dataset[1121][0].reshape((28,28)),
               cmap='gray', interpolation='nearest')
    ax[1, 0].imshow(loader.dataset[111][0].reshape((28,28)),
               cmap='gray', interpolation='nearest')
    ax[1, 1].imshow(loader.dataset[121][0].reshape((28,28)),
               cmap='gray', interpolation='nearest')

# VAE Class

In [0]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), 16 * 8 * 8).contiguous()


class UnFlatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), 256, 1, 1).contiguous()


class ConvVAE(nn.Module):

    def __init__(self, device='cuda', z_dim=100, elbo=True):

        super().__init__()
        self.train_step = 0
        self.best_loss = np.inf
        self.elbo=elbo
        self.device = device
        # Prior on P(Z) => N(0, 1)
        self.prior = Normal(torch.zeros([z_dim], device = self.device ), \
                            torch.ones([z_dim], device = self.device ))
        
        self.proc_data = lambda x: x.to( self.device )

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3)),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(32, 64, kernel_size=(3, 3)),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(64, 256, kernel_size=(5, 5)),
            nn.ELU()
        )

        # Distribution parameters
        self.enc_mu = nn.Linear(in_features=256, out_features=100)
        self.enc_log_var = nn.Linear(in_features=256, out_features=100)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(in_features=100, out_features=256),
            UnFlatten(),
            nn.ELU(),
            nn.Conv2d(256, 64, kernel_size=(5, 5), padding=(4, 4)),
            nn.ELU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(64, 32, kernel_size=(3, 3), padding=(2, 2)),
            nn.ELU(),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(32, 16, kernel_size=(3, 3), padding=(2, 2)),
            nn.ELU(),
            nn.Conv2d(16, 1, kernel_size=(3, 3), padding=(2, 2)))
        
    def encode(self, x):
        x = self.proc_data(x)
        # Encoder
        h = self.encoder(x)
        # Reshape before linear.
        h = h.view(32, 256)
        # Distribution parameters.
        mu = self.enc_mu(h)
        log_var = self.enc_log_var(h)

        return mu, log_var

    def decode(self, z):
        x_hat = self.decoder(z)
        return torch.sigmoid(x_hat)

    def forward(self, x):
        # Encode
        mu, log_var = self.encode(x)
        std = 1e-10 + torch.sqrt(torch.exp(log_var))
        # Simulate K random
        loss = 0
        e = torch.randn((32, 100)).to( self.device )
        z = mu + std * e
        # DECODE
        x_hat = self.decode(z)
        loss += self.loss(x_hat, x, mu, std)

        return loss, x_hat, mu, std
    
    def loss(self, x_hat, x, mu, std):
        """
        ELBO 
        """
        x = self.proc_data(x)
        recon_loss = nn.functional.binary_cross_entropy(x_hat,
                                                        x,
                                                        reduction='sum')
        kl_loss = 0.5 * torch.sum(-1. - torch.log(std**2) + mu ** 2 + std**2)
        loss = recon_loss + kl_loss
        return loss


# Utils

GPU?

In [4]:
print(torch.cuda.is_available())

True


Train

In [5]:
# Use the GPU if you have one
if torch.cuda.is_available():
    print("Using the GPU")
    device = torch.device("cuda") 
else:
    print("WARNING: You are about to run on cpu, and this will likely run out \
      of memory. \n You can try setting batch_size=1 to reduce memory usage")
    device = torch.device("cpu")


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):

        data = data.to(device)
        optimizer.zero_grad()
        loss, x_hat, mu, std = model(data)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        model.train_step += 1
        if model.train_step % 100 == 0:
            print('Train Epoch: {} ({:.0f}%)\tELBO: {:.6f}'.format(
                epoch, 
                100. * batch_idx / len(train_loader), 
                loss.item() / 32))

    # Normalize per instance
    print('====> Epoch: {} - Train set - Per instance ELBO: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(test_loader):
            if data.size(0) == 32: 
                data = data.to(device)
                loss, x_hat, mu, std = model(data)
                test_loss += loss.item()
                if batch_idx == 0:
                    n = min(data.size(0), 32)
                    comparison = torch.cat(
                        [data[:n], x_hat.view(32, 1, 28, 28)[:n]]
                    )
                    save_image(comparison.cpu(),
                               r'/content/' + str(epoch) + '.png',
                               nrow=n)
    # Normalize per instance.
    test_loss /= len(test_loader.dataset)                
    print('====> Epoch: {} - Test set - Per Instance ELBO: {:.4f}'.format(
        epoch, test_loss))


Using the GPU


# Run training

In [6]:
train_loader, test_loader = data_loaders()
n_epochs = 20

model = ConvVAE(device=device,
                z_dim=100,
                elbo=True)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=3*1e-4, eps=1e-4)
batch_size = 32

for epoch in range(1, n_epochs + 1):
        train(epoch)
        test(epoch)
        
torch.save(model.state_dict(), r'/content/model_simple')

Downloading MNIST with fixed binarization...
Downloading from http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_train.amat...
Saved to /content/binarized_mnist_train.amat
Downloading from http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_valid.amat...
Saved to /content/binarized_mnist_valid.amat
Downloading from http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_test.amat...
Saved to /content/binarized_mnist_test.amat
Done!




Train Epoch: 1 (5%)	ELBO: 212.066666
Train Epoch: 1 (11%)	ELBO: 171.716995
Train Epoch: 1 (16%)	ELBO: 186.088882
Train Epoch: 1 (21%)	ELBO: 162.172287
Train Epoch: 1 (27%)	ELBO: 160.602661
Train Epoch: 1 (32%)	ELBO: 143.329346
Train Epoch: 1 (37%)	ELBO: 149.581604
Train Epoch: 1 (43%)	ELBO: 134.818954
Train Epoch: 1 (48%)	ELBO: 129.019791
Train Epoch: 1 (53%)	ELBO: 139.632126
Train Epoch: 1 (59%)	ELBO: 124.642929
Train Epoch: 1 (64%)	ELBO: 125.998009
Train Epoch: 1 (69%)	ELBO: 133.418945
Train Epoch: 1 (75%)	ELBO: 121.071045
Train Epoch: 1 (80%)	ELBO: 127.615189
Train Epoch: 1 (85%)	ELBO: 130.657379
Train Epoch: 1 (91%)	ELBO: 124.666245
Train Epoch: 1 (96%)	ELBO: 116.282753
====> Epoch: 1 - Train set - Per instance ELBO: 153.5802
====> Epoch: 1 - Test set - Per Instance ELBO: 118.2221
Train Epoch: 2 (1%)	ELBO: 112.939766
Train Epoch: 2 (7%)	ELBO: 117.157059
Train Epoch: 2 (12%)	ELBO: 117.544724
Train Epoch: 2 (17%)	ELBO: 106.003166
Train Epoch: 2 (23%)	ELBO: 115.502991
Train Epoch: 2 (

# Simulated output

In [7]:
for batch_idx, (data, _) in enumerate(test_loader):
    break
    
data.to(device)
mu, log_var = model.encode(data)
std = 1e-7 + torch.exp(log_var / 2)
tmp_list = [data]
for _ in range(10):
    e = torch.randn((32, 100))
    z = mu + std * e.to(device)
    # DECODE
    tmp_list.append(model.decode(z).cpu())
    
out = torch.cat(tmp_list)

save_image(out.cpu(), r'/content/simulated_output.png',  nrow=32)



# Log-likelihood estimation

In [0]:
def log_likelihood(loader, model, k):
    prior = Normal(torch.zeros([100], device=device), 
                   torch.ones([100], device=device))

    with torch.no_grad():

        cum_log_p = 0

        for batch_idx, (data, _) in enumerate(loader):

            if data.size(0) == 32:

                _, x_hat, mu, std = model(data)

                data.to(device)
                x_hat.to(device)
                mu.to(device)
                std.to(device)

                ExpSumLog = torch.FloatTensor(k, 32).to(device)

                for i in range(k):

                    e = torch.randn(32, 100).to(model.device)
                    z = mu + std * e # z \sim q(z|x)
                    x_hat = model.decode(z).to(device)

                    # prob of z according to the prior p(z)
                    log_p_z = prior.log_prob(z).sum(-1)  
                    # prob of z according to q(z|x)  
                    log_q_z_x = Normal(mu, std).log_prob(z).sum(-1)       
                    # prob of x|z
                    log_p_x_z = nn.functional.binary_cross_entropy(
                        x_hat.to(device),
                        data.to(device),
                        reduction='none'
                    ).sum([-1, -2, -3]) # prob of x|z

                    ExpSumLog[i, :] = -log_p_x_z + log_p_z - log_q_z_x

                MaxESL, _ = ExpSumLog.max(0)

                # vector log(p(x_1)), log(p(x_2)), ..., log(p(x_M))
                log_p = torch.log(torch.exp(ExpSumLog - MaxESL).sum(0)) \
                + MaxESL\
                - torch.log(torch.FloatTensor([k])).to(device)    

                cum_log_p += log_p.sum()     

        print('====> Log-likelihood: {:.4f}'.format(
            cum_log_p / len(loader.dataset)))


In [9]:
model.eval()
k = 200
print("Training set log-likelihood:")
log_likelihood(train_loader, model, k)
print("Test set log-likelihood")
log_likelihood(test_loader, model, k)


Training set log-likelihood:




====> Log-likelihood: -86.8304
Test set log-likelihood
====> Log-likelihood: -86.7261
