In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

from torchvision import datasets
import torchvision.transforms as T
import torch.nn.functional as F

import torch.distributions as distrib
import torch.distributions.transforms as transform

import numpy as np
import pandas as pd
import os
import glob
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn

eps = np.finfo(float).eps

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline


%load_ext autoreload
%autoreload 2

In [None]:
# hack
# ref:
class Flow(transform.Transform, nn.Module):
    
    def __init__(self):
        transform.Transform.__init__(self)
        nn.Module.__init__(self)
    
    # Init all parameters
    def init_parameters(self):
        for param in self.parameters():
            param.data.uniform_(-0.01, 0.01)
            
    # Hacky hash bypass
    def __hash__(self):
        return nn.Module.__hash__(self)
    
    
class PlanarFlow(Flow):

    def __init__(self, dim):
        super(PlanarFlow, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(1, dim))
        self.scale = nn.Parameter(torch.Tensor(1, dim))
        self.bias = nn.Parameter(torch.Tensor(1))
        self.init_parameters()

    def _call(self, z):
        f_z = F.linear(z, self.weight, self.bias)
        return z + self.scale * torch.tanh(f_z)

    def log_abs_det_jacobian(self, z):
        f_z = F.linear(z, self.weight, self.bias)
        psi = (1 - torch.tanh(f_z) ** 2) * self.weight
        det_grad = 1 + torch.mm(psi, self.scale.t())
        return torch.log(det_grad.abs() + 1e-9)
    
    
# Main class for normalizing flow
class NormalizingFlow(nn.Module):

    def __init__(self, dim, blocks, flow_length, density):
        super().__init__()
        biject = []
        for f in range(flow_length):
            for b_flow in blocks:
                biject.append(b_flow(dim))
        self.transforms = transform.ComposeTransform(biject)
        self.bijectors = nn.ModuleList(biject)
        self.base_density = density
        self.final_density = distrib.TransformedDistribution(density, self.transforms)
        self.log_det = []

    def forward(self, z):
        self.log_det = []
        # Applies series of flows
        for b in range(len(self.bijectors)):
            self.log_det.append(self.bijectors[b].log_abs_det_jacobian(z))
            z = self.bijectors[b](z)
        return z, self.log_det
    

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim, n_classes, encoder_dim, decoder_dim, latent_dim):
        super().__init__()
        self.input_dim  = input_dim
        self.n_classes = n_classes
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.latent_dim = latent_dim
    
        # Encoder
        self.encoder = nn.Sequential(
                nn.Linear(self.input_dim, self.encoder_dim),
                nn.ReLU(True),
                nn.Linear(self.encoder_dim, self.encoder_dim),
                nn.ReLU(True),
                nn.Linear(self.encoder_dim, self.encoder_dim)
        )
        
        # Decoder network
        self.decoder = nn.Sequential(
                    nn.Linear(self.latent_dim, self.decoder_dim),
                    nn.ReLU(True),
                    nn.Linear(self.decoder_dim, self.decoder_dim),
                    nn.ReLU(True),
                    nn.Linear(self.decoder_dim, self.input_dim * n_classes),
                    nn.Sigmoid()
        )
        
        self.mu = nn.Linear(self.encoder_dim, self.latent_dim)
        self.sigma = nn.Sequential(
                    nn.Linear(self.encoder_dim, self.latent_dim),
                    nn.Softplus(),
                    nn.Hardtanh(min_val=1e-4, max_val=5.)
        )
        
        
    def encode(self, x):
        x = self.encoder(x)
        mu_z = self.mu(x)
        sigma_z = self.sigma(x)
        return mu_z, sigma_z
        
    def decode(self, z):
        return self.decoder(z)
    
    def reparameterize(self, x, mu_z, sigma_z):
        eps = torch.randn_like(sigma_z)
        z = eps.mul(sigma_z) + mu_z
        
        batch_size = x.size(0)
        kl_div = -0.5 * torch.sum(1 + sigma_z - mu_z.pow(2) - sigma_z.exp())
        kl_div = kl_div / batch_size
        return z, kl_div
        
    def forward(self, x):
        mu_z, sigma_z = self.encode(x)
        z_hat, kl_div = self.reparameterize(x, mu_z, sigma_z)
        x_hat = self.decode(z_hat)
        return x_hat, kl_div

In [None]:
def binary_loss(x_hat, x):
    return F.binary_cross_entropy(x_hat, x, reduction='sum')

def multinomial_loss(x_logit, x):
    batch_size = x.shape[0]
    # Reshape input
    x_logit = x_logit.view(batch_size, num_classes, x.shape[1], x.shape[2], x.shape[3])
    # Take softmax
    x_logit = F.log_softmax(x_logit, 1)
    # make integer class labels
    target = (x * (num_classes - 1)).long()
    # computes cross entropy over all dimensions separately:
    ce = F.nll_loss(x_logit, target, weight=None, reduction='none')
    return ce.sum(dim = 0)*100

def reconstruction_loss(x_tilde, x, num_classes=1, average=True):
    if (num_classes == 1):
        loss = binary_loss(x_tilde, x.view(x.size(0), -1))
    else:
        loss = multinomial_loss(x_tilde, x)
    if (average):
        loss = loss.sum() / x.size(0)
    return loss

#### Dataset

In [None]:
if torch.cuda.is_available():
    use_cuda = True
    dtype = torch.cuda.FloatTensor
    device = torch.device("cuda:0")
    print('GPU')
else:
    use_cuda = False
    dtype = torch.FloatTensor
    device = torch.device("CPU")
    
    
def to_cuda(tensor):
    if isinstance(tensor, torch.Tensor):
        tensor = tensor.cuda()
    return tensor

In [None]:
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
batch_size = 64
tens_t = T.ToTensor()

train_dset = datasets.FashionMNIST('data', train=True, download=True, transform=tens_t)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True, **kwargs)

test_dset = datasets.FashionMNIST('data', train=False, transform=tens_t)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=True, **kwargs)

In [None]:
def plot_batch(batch, nslices=8):
    # Create one big image for plot
    img = np.zeros(((batch.shape[2] + 1) * nslices, (batch.shape[3] + 1) * nslices))
    for b in range(batch.shape[0]):
        row = int(b / nslices); col = int(b % nslices)
        r_p = row * batch.shape[2] + row; c_p = col * batch.shape[3] + col
        img[r_p:(r_p+batch.shape[2]),c_p:(c_p+batch.shape[3])] = torch.sum(batch[b], 0)
    im = plt.imshow(img, cmap='Greys', interpolation='nearest'),
    return im
# Select a random set of fixed data
fixed_batch, fixed_targets = next(iter(test_loader))
plt.figure(figsize=(10, 10))
plot_batch(fixed_batch);

### Training

In [None]:
def train_vae(model, optimizer, scheduler, train_loader, model_name='basic', epochs=50, plot_it=1, flatten=True, use_cuda=True):
    # Losses curves
    losses = torch.zeros(epochs, 2)
    # Beta-warmup
    beta = 0
    # Plotting
    ims = []
    fig = plt.figure(figsize=(10, 10))
    # Main optimization loop
    for it in range(epochs):
        it_loss = torch.Tensor([2])
        # Update our beta
        beta = 1. * (it / float(epochs))
        n_batch = 0.
        # Evaluate loss and backprop
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.cuda() if use_cuda else x
            # Flatten input data
            if (flatten):
                x = x.view(-1, x.size(2) * x.size(3))
            # Pass through VAE
            x_tilde, loss_latent = model(x)
            # Compute reconstruction loss
            loss_recons = reconstruction_loss(x_tilde, x, num_classes)
            # Evaluate loss and backprop
            loss = loss_recons + (beta * loss_latent)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            losses[it, 0] += loss_recons.item()
            losses[it, 1] += loss_latent.item()
            n_batch += 1.
        losses[it, :] /= n_batch    
#         if (it % plot_it == 0):
#             # Encode our fixed batch
#             x_test, _  = next(iter(test_loader))
#             x_test = x_test.cuda() if use_cuda else x_test
#             if (flatten):
#                 x_test = x_test.view(-1, x_test.size(2) * x_test.size(3))
#             x_tilde, _ = model(x_test)
            
#             if (num_classes > 1):
#                 # Find largest class logit
#                 tmp = x_tilde.view(-1, num_classes, *x[0].shape[1:]).max(dim=1)[1]
#                 x_tilde = tmp.float() / (num_classes - 1.)
                
#             ims.append(plot_batch(x_tilde.cpu().detach().view(-1, 1, 28, 28)))
#             plt.title('Iter.%i'%(it), fontsize=15);
    return losses

### Model def and settings

In [None]:
# Using Bernoulli or Multinomial loss
num_classes = 1
# Number of hidden and latent
n_hidden = 512
n_latent = 2

# Compute input dimensionality
nin = fixed_batch.shape[2] * fixed_batch.shape[3]

# Build the VAE model
model = VAE(nin, num_classes, n_hidden, n_hidden, n_latent).type(dtype)
# Create optimizer algorithm
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Add learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.99995)

In [None]:
#  Launch our optimization
losses_kld = train_vae(model, optimizer, scheduler, train_loader, model_name='basic', epochs=100)



### NLL for generative models

In [None]:
import scipy.special as ss

def evaluate_nll_bpd(data_loader, model, batch_size = 500, R = 5, use_cuda=True):
    model.eval()
    # Set of likelihood tests
    likelihood_test = []
    # Go through dataset
    for batch_idx, (x, _) in enumerate(data_loader):
        x = x.cuda() if use_cuda else x
        for j in range(x.shape[0]):
            a = []
            for r in range(0, R):
                cur_x = x[j].unsqueeze(0)
                # Repeat it as batch
                x = cur_x.expand(batch_size, *cur_x.size()[1:]).contiguous()
                x = x.view(batch_size, -1)
                x_tilde, kl_div = model(x)
                rec = reconstruction_loss(x_tilde, x, average=False)
                a_tmp = (rec + kl_div)
                a.append(- a_tmp.cpu().data.numpy())
            # calculate max
            a = np.asarray(a)
            a = a[:, np.newaxis]
            likelihood_x = ss.logsumexp(a)
            likelihood_test.append(likelihood_x - np.log(len(a)))
    likelihood_test = np.array(likelihood_test)
    nll = - np.mean(likelihood_test)
    # Compute the bits per dim (but irrelevant for binary data)
    bpd = nll / (np.prod(nin) * np.log(2.))
    return nll, bpd

In [None]:
# Plot final loss
plt.figure()
plt.plot(losses_kld[:, 0].numpy());
# Evaluate log-likelihood and bits per dim
nll, _ = evaluate_nll_bpd(test_loader, model)
print('Negative Log-Likelihood : ' + str(nll))