In [1]:
from collections import namedtuple
import itertools

In [24]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.distributions import multivariate_normal

In [3]:
from typing import Tuple, List

In [4]:
import numpy as np
import matplotlib.pyplot as plt

The inference algorithm:
    
Parameters: $\phi$ variational, $\theta$ generative
    
while not converged do:
        
$\quad \text{x = {Get mini batch}}\\
\quad z_0 ~ q_0(\cdot | x)\\
\quad z_K = f_K( ... f_1(z_0) )\\
\quad F(x) \approx F(x, z_K)\\
\quad \Delta \theta \propto - \nabla_\theta F(x)\\
\quad \Delta \phi \propto - \nabla_\phi F(x)$

In [5]:
class Transformation:
    
    def __init__(self):
        self.training = None
        self.log_det = None
        
    @property
    def training(self):
        return self._training
    
    @training.setter
    def training(self, enable:bool):
        if not enable:
            self.log_det = None
        self._training = enable
    
    def forward(self, zi, params):
        if self.training:
            self.log_det = torch.log( self.det( zi, params ).squeeze() + 1e-7 )
        return self.transform( zi, params )
    
    def get_num_params(self):
        return 0
    
    
class PlanarTransformation(Transformation):
    
    def __init__(self, dim:int, u:list=None, w:list=None, b:list=None, training:bool=True):
        
        self.dim = dim
        self.h = nn.Tanh()
        self.training = training
        
    def get_num_params(self):
        return self.dim * 2 + 1
    
    def transform(self, z, params):
#         if torch.dot(self.w.data.squeeze(), self.u.data.squeeze()) < -1:
#             print( "adjusting u")
#             dotwu = torch.dot( self.w.data.squeeze(), self.u.data.squeeze() )
#             self.u.data = self.u + ( -1 + torch.log( 1 + torch.exp( dotwu ) ) - dotwu ) \
#                             * self.w.data / torch.sqrt( torch.dot( self.w.data.squeeze(), self.w.data.squeeze() ) )
        return z + params[self.dim:-1].unsqueeze(0) * self.h( F.linear(z, params[:self.dim].unsqueeze(0), params[-1]) )
    
    def h_deriv(self, x):
        ff = self.h( x )
        return 1 - ff * ff
    
    def psi(self, z, params):
        return self.h_deriv( F.linear(z, params[:self.dim].unsqueeze(0), params[-1]) ) * params[:self.dim].unsqueeze(0)
    
    def det(self, z, params):
        return ( 1 + torch.mm( self.psi(z, params), params[self.dim:-1].unsqueeze(0).t() ) ).abs()


class RadialTransformation(Transformation):
    
    def __init__(self, dim:int, z0=None, alpha=None, beta=None, training:bool=True):
        
        self.dim = dim
        self.training = training
        
    def get_num_params(self):
        return self.dim + 2
        
    def transform(self, z, params):
#         if self.beta < -self.alpha:
#             print( "adjusting beta")
#             self.beta.data = -self.alpha + torch.log( 1 + torch.exp( self.beta ) )
        r = torch.norm( ( z - params[:self.dim].unsqueeze(0) ), p=2, dim=1, keepdim=True )
        return z + params[-1] * ( self.h( r, params[-2] ) * (z - params[:self.dim].unsqueeze(0)) )
    
    def h(self, r, alpha):
        return 1 / (alpha + r)
    
    def h_deriv(self, r, alpha):
        ff = self.h( r, alpha )
        return - ff * ff
    
    def det(self, z, params):
        r = torch.norm( ( z - params[:self.dim].unsqueeze(0) ), p=2, dim=1, keepdim=True )
        tmp = 1 + params[-1] * self.h( r, params[-2] )
        return torch.clamp(tmp.pow(self.dim - 1) * (tmp + params[-1] * self.h_deriv(r, params[-2]) * r), min=1e-7)
    
class NormalizingFlow:
    
    def __init__( self, transformation, dim:int, K:int, transformations=None ):
        self.K = K
        self.dim = dim
        
        if transformations is None:
            transformations = [ transformation( dim ) for i in range( K ) ]
        self.flow = transformations
        self.nParams = self.flow[0].get_num_params()
        
    def get_last_log_det(self):
        return self.flow[-1].log_det
    
    def get_sum_log_det(self):
        ret = 0
        for trans in self.flow:
            ret += trans.log_det
        return ret
        
    def forward( self, z, params ):
        for i, transf in enumerate( self.flow ):
            z = transf.forward(z, params[i])
        return z

In [13]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=100, shuffle=False)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.ToTensor()),
    batch_size=100, shuffle=False)

In [14]:
# a large batch will be used to compute average flow parameters after training
large_batch = []
for i, (data, _) in enumerate(train_loader):
    if i > 1000:
        break
    large_batch.append(data)
large_batch = torch.cat(large_batch)

In [21]:
class VariationalAutoencoderNormalizingFlow(nn.Module):
    def __init__(self, flow_transform, flow_latent, flow_len):
        super(VariationalAutoencoderNormalizingFlow, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.flow = NormalizingFlow(flow_transform, flow_latent, flow_len )
        self.fc23 = nn.Linear(400, self.flow.nParams * flow_len )
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1), self.fc23(h1).mean(dim=0).chunk(self.flow.K, dim=0)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar, params = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        z = self.flow.forward(z, params)
        return self.decode(z), mu, logvar

In [31]:
def plot_image(img):
    plt.figure()
    plt.imshow(make_grid(img.data.numpy()))
    plt.show()
    plt.close()

In [38]:
device = "cpu"
model = VariationalAutoencoderNormalizingFlow(RadialTransformation, 20, 32)

optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function_VAENF(recon_x, x, mu, logvar, sum_log_det):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    KLD = KLD / x.size(0) - sum_log_det.mean()  # mean over batch

    return BCE + KLD

def train_VAENF(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        sum_log_det = model.flow.get_sum_log_det()
        loss = loss_function_VAENF(recon_batch, data, mu, logvar, sum_log_det)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test_VAENF(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            sum_log_det = model.flow.get_sum_log_det()
            test_loss += loss_function_VAENF(recon_batch, data, mu, logvar, sum_log_det).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(100, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         '../results/reconstruction_VAENF_' + str(model.flow.K) + "_" + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [39]:
for epoch in range(1, 10 + 1):
    train_VAENF(epoch)
    test_VAENF(epoch)
    with torch.no_grad():
        _ , _, params = model.encode(large_batch.view(-1, 784))
        sample = torch.randn(64, 20).to(device)
        sample = model.decode( model.flow.forward( sample, params ) ).cpu()
        save_image(sample.view(64, 1, 28, 28),
        '../results/sample_VAENF_' + str(model.flow.K) + "_" + str(epoch) + '.png')

====> Epoch: 1 Average loss: 137.2828
====> Test set loss: 102.2544
====> Epoch: 2 Average loss: 85.3859
====> Test set loss: 86.9399
====> Epoch: 3 Average loss: 78.5412
====> Test set loss: 80.7491
====> Epoch: 4 Average loss: 75.4289
====> Test set loss: 77.5757
====> Epoch: 5 Average loss: 73.4791
====> Test set loss: 75.1793
====> Epoch: 6 Average loss: 72.1626
====> Test set loss: 73.4380
====> Epoch: 7 Average loss: 71.1915
====> Test set loss: 72.3078
====> Epoch: 8 Average loss: 70.5028
====> Test set loss: 71.6908
====> Epoch: 9 Average loss: 69.9461
====> Test set loss: 71.4568
====> Epoch: 10 Average loss: 69.4464
====> Test set loss: 70.9226


In [18]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 1000 == 1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(100, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         '../results/reconstruction_VAE_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

for epoch in range(1, 10 + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
        '../results/sample_VAE_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 163.7044
====> Test set loss: 129.3005
====> Epoch: 2 Average loss: 122.1876
====> Test set loss: 117.5843
====> Epoch: 3 Average loss: 114.7716
====> Test set loss: 112.3378
====> Epoch: 4 Average loss: 111.6862
====> Test set loss: 110.3866
====> Epoch: 5 Average loss: 109.8727
====> Test set loss: 108.9313
====> Epoch: 6 Average loss: 108.6934
====> Test set loss: 108.6331
====> Epoch: 7 Average loss: 107.8680
====> Test set loss: 107.4852
====> Epoch: 8 Average loss: 107.2541
====> Test set loss: 107.0832
====> Epoch: 9 Average loss: 106.7469
====> Test set loss: 107.1155
====> Epoch: 10 Average loss: 106.3030
====> Test set loss: 106.2739
