<a href="https://colab.research.google.com/github/nschuc/normalizing-flows/blob/master/Normalizing_Flows_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational Inference with Normalizing Flows

## Types of Generative Models


1.   **Generative Aversarial Networks**:  Generator and Discriminator, discriminator learns to distinguish the real data from the fake samples that are produced by the generator model. 
2.   ** Variational Autoencoders**: VAE inexplicitly optimizes the log-likelyhood of the data by maximizing the evidence lower bound (ELBO)
3. ** Flow-based** generative models: are constructed by a sequence of invertible transformations. Unlike GANs and VAEs the model explicitly learns the true data distribution $p(\mathbf x)$ and the loss function is simply the negative log-likelyhood.


Stolen from From [Lilian Weng's](https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html#jacobian-matrix-and-determinant) blog:


![alt text](https://lilianweng.github.io/lil-log/assets/images/three-generative-models.png)

In [0]:
%matplotlib inline
import numpy as np

In [0]:
! pip3 install torch torchvision
! pip3 install pillow



In [0]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

##  Normalizing Flows


The basic rule for transformation of densities consideres an invertible, smooth mapping $f: \mathbb{R}^D \rightarrow  \mathbb{R}^D$ with an inverse $f^{-1}=g$, such that  $ g \circ f (\textbf{z}) = \textbf{z}$.  If we use this mapping to transform a random variable $\mathbf{z}$ with distribution $q(\mathbf{z})$, then the resulting random variable $\mathbf{z}^\prime = f(\mathbf{z}$) has a distribution:

$$
q(\mathbf{z}^{\prime}) = q(\mathbf{z}) 
  \left| 
    \det \frac{\partial f^{-1}}{\partial \mathbf{z}^\prime}
  \right| = 
   q(\mathbf{z}) 
  \left| 
    \det \frac{\partial f}{\partial \mathbf{z}}
  \right|^{-1},
$$
where the last equality can be obtained by applying [the inverse function theorem](https://en.wikipedia.org/wiki/Inverse_function_theorem) and taking advantage of the property of Jacobians of intertible functions.

The density $q_K(\mathbf z)$ obtained by successively transforming a random variable $\mathbf z_0$ with distribution
$q_0$ through a chain of $K$ transformations $f_k$ is:

\begin{align}
  \mathbf z_K &= f_K \circ \ldots \circ f_1( \mathbf z_0), \\\
  \ln q_K (\mathbf z_K) &= \ln q_0(\mathbf z_0) - \sum_{k=1}^{K} \ln \det \frac{\partial f_k}{\partial \mathbf{z}_{k-1}}.
\end{align}

The formalism of normalizing flows now gives us a systematic
way of specifying the approximate posterior distributions
$q(\mathbf z| \mathbf x)$ required for variational inference. With an
appropriate choice of transformations $f_K$, we can initially
use simple factorized distributions such as an independent
Gaussian, and apply normalizing flows of different lengths
to obtain increasingly complex and multi-modal distributions.



From [Lilian Weng's  blog](https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html#jacobian-matrix-and-determinant) (she uses $p_i$ instead of $q_i$):
![alt text](https://lilianweng.github.io/lil-log/assets/images/normalizing-flow.png)

###  Remark:
If  $\mathbf{p}$ is a point in $\mathbb{R}^D$ and $f$ is differentiable at $\mathbf{p}$, then its derivative is given by $J_f(\mathbf{p})$. In this case, the linear map described by $J_f(\mathbf{p})$ is the best linear approximation of $f$ near the point $\mathbf{p}$, in the sense that

$$
\mathbf f(\mathbf x) = \mathbf f(\mathbf p) + \mathbf J_{\mathbf f}(\mathbf p)(\mathbf x - \mathbf p) + o(\|\mathbf x - \mathbf p\|),
$$

where $\mathbf x$ is close to $\mathbf p$ and where $o$ is the little o-notation.

Since, we can percieve the Jacobian of $f: \mathbb{R}^D \rightarrow  \mathbb{R}^D$ as locally linear map, we can describe the space distortions using the determinant: geometrically the absolute value of the Jacobian determinant gives the magnification/scalling factor when we transform an area or volume. It intuitevely make sense, that if function changes the volume by $a$ it's inverse should change the volme by $\frac{1}{a}$. 

### Abstract Flow

In [0]:
from abc import ABC, abstractmethod
from typing import List, Tuple
from torch import Tensor

class Flow(ABC):
    @abstractmethod
    def forward(self, z, parameters: Tuple[Tensor]) -> Tensor:
        pass

    @abstractmethod
    def log_det_jacobian(self, z, parameters: Tuple[Tensor]):
        pass
      
    @abstractmethod
    def unpack(self, parameters: Tensor) -> Tuple[Tensor, ...]:
        '''
        Method used to unpack the hidden layer to parameters of the flow
        
        From section 4.2:
        For amortized variational inference, we construct an inference model
        using a deep neural network to build a mapping from the observations x
        to the parameters of the initial density q0 = N(µ, σ) (µ∈R^D and σ∈R^D)
        as well as the parameters of the flow λ.
        '''
        pass
    
    @property
    @abstractmethod
    def dim(self) -> int:
        pass

### Planar Flow

In [0]:
def safe_log(z):
    return torch.log(z + 1e-7)
  
def tanh(x):
  return torch.tanh(x)


def tanh_prime(x):
  return 1 - torch.tanh(x)**2


class PlanarFlow(nn.Module, Flow):
  def __init__(self, dim = 2, h = tanh, h_prime= tanh_prime):
      super().__init__()
      '''
      f(z) = z + u h(w^T @ z + b)
      
      
      The flow defined by the transformation above modifies the
      initial density q_0 by applying a series of contractions and
      expansions in the direction perpendicular to the hyperplane
      w^T z+b = 0, hence we refer to these maps as planar flows.
      '''


      # h(·) is a smooth element-wise non-linearity
      self.h = h
      self.h_prime = h_prime
      self.d = dim
       

  def forward(self, z, parameters: Tuple[Tensor]) -> Tensor:
      '''
      f(z) = z + u h(w^T @ z + b)
      '''
      u, w, b = parameters
      
      z = z + self.h(F.linear(z, w, b)) @ u
      return z
  
  def log_det_jacobian(self, z, parameters: Tuple[Tensor]):
      '''
      ψ(z) = h'(w^T @ z + b)w
      |det @f/@z | = |1 + u^T ψ(z)|
      
      '''
      u, w, b = parameters
      
      
      psi = self.h_prime(F.linear(z, w, b)) @ w
      det_jacobian = torch.abs(1 + F.linear(psi, u))
      return safe_log(det_jacobian)
    
  def unpack(self, parameters: Tensor) -> Tuple[Tensor, ...]:
      ''' unpacks the free parameters
      λ = {w ∈ R^D, u ∈ R^D, b ∈ R} are free parameters
      '''
      w, u = parameters[:, :-1].chunk(2, dim=1)
      b = parameters[:, -1].view(-1, 1)
      
      return (w, u, b)

  @property
  def dim(self):
      return 2 * self.d + 1
      

### Radial Flow

In [0]:
class RadialFlow(nn.Module, Flow):
    def __init__(self, dim: int, h = tanh, h_prime= tanh_prime):
      super().__init__()
      '''
      f(z) = z + βh(α, r)(z − z0),
      
      
       It applies radial contractions and expansions
       around the reference point and are thus referred to as
       radial flows
      '''
      
      # λ = {z_0 ∈ R^D, α ∈ R+, β ∈ R}


      # h(·) is a smooth element-wise non-linearity
      self.h = h
      self.h_prime = h_prime
      
    
    
    def forward(self, z, parameters: Tensor) -> Tensor:
        pass

    def log_det_jacobian(self, z, parameters: Tensor):
        pass
      
    def unpack(self, parameters: Tensor) -> Tuple[Tensor, ...]:
        pass

### Normalizing Flows

In [0]:
class NormalizingFlow(nn.Module):
    def __init__(self, K: int, flow_class, *args, **kwargs):
        super().__init__()
        self.flows = nn.Sequential(*(
            flow_class(*args, **kwargs) for _ in range(K)
        ))
      
    def forward(self, z, lambdas: List[Tensor]):
      log_abs_det_jacobians = []

      for flow, parameters in zip(self.flows, lambdas):
          log_abs_det_jacobians.append(flow.log_det_jacobian(z, parameters))
          z = flow(z, parameters)
          
      return z, sum(log_abs_det_jacobians)
    
    def unpack(self, params: Tensor) -> List[Tensor]:
      flow_params = []
      start, end = 0, 0
      
      for flow in self.flows:
        start, end = end, end + flow.dim
        flow_params.append(flow.unpack(params[:, start:end]))

      return flow_params
    
    @property
    def dims(self):
      return sum(flow.dim for flow in self.flows)

In [0]:
#@title
import matplotlib
import matplotlib.pyplot as plt

def h(x):
    return np.tanh(x)

def h_prime(x):
    return 1 - np.tanh(x) ** 2

def f(z, w, u, b):
    return z + np.dot(h(np.dot(z, w) + b).reshape(-1,1), u.reshape(1,-1))
  
  
def plot_flow():
  plt.figure(figsize=[10, 14])

  id_figure = 1
  for i in np.arange(5):
      for j in np.arange(5):
          #represent w and u in polar coordinate system
          theta_w = 0
          rho_w = 5
          theta_u = np.pi / 8 * i
          rho_u = j / 4.0
          
          w = np.array([np.cos(theta_w), np.sin(theta_w)]) * rho_w
          u = np.array([np.cos(theta_u), np.sin(theta_u)]) * rho_u
          b = 0

          grid_use = np.meshgrid(np.arange(-1,1,0.001), np.arange(-1,1,0.001))
          z = np.concatenate([grid_use[0].reshape(-1,1), grid_use[1].reshape(-1,1)], axis=1)
          z = np.random.normal(size=(int(1e6),2))
          z_new = f(z, w, u, b)

          heatmap, xedges, yedges = np.histogram2d(
              z_new[:,0], z_new[:,1], bins=50, range=[[-3,3],[-3,3]])

          extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

          plt.subplot(5,5,id_figure)
          plt.imshow(heatmap, extent=extent, cmap='viridis')
          plt.title("u=(%.1f,%.1f)"%(u[0],u[1]) + "\n" +
                    "w=(%d,%d)"%(w[0],w[1]) + ", " + "b=%d"%b)
          id_figure += 1

          plt.xlim([-3,3])
          plt.ylim([-3,3])
        
  plt.show()

The effect of planar and radial flows on the Gaussian and uniform distributions. The figure comes from the original paper.
![alt text](http://akosiorek.github.io/resources/simple_flows.png)

## Distributions

In [0]:
import math
import random

from numbers import Number
from itertools import accumulate
from bisect import bisect_right


  
def p_1(z, shift = 2, scale_1 = 0.4 , scale_2 = 0.6):
    '''
    Unnormalized fancy 2D density number 1
    '''

    z1, z2 = torch.chunk(z, chunks=2, dim=1)
    norm = torch.sqrt(z1 ** 2 + z2 ** 2)

    exp1 = torch.exp(-0.5 * ((z1 - shift) / scale_2) ** 2)
    exp2 = torch.exp(-0.5 * ((z1 + shift) / scale_2) ** 2)
    u = 0.5 * ((norm - shift) / scale_1) ** 2 - safe_log(exp1 + exp2)

    return torch.exp(-u)


def sum_probs(point_1, point_2):
  z_1, x_1, y_1 = point_1
  z_2, x_2, y_2 = point_2
  return z_1 + z_2, x_2, y_2


def find_le(a, x):
    'Find rightmost value less than or equal to x'
    i = bisect_right(a, (x, ))
    if i:
        return a[i-1]
    raise ValueError

def sample(points):
  p, x, y = find_le(points, random.random())
  return (x, y)


class EmpiricalSampler:
  def __init__(self,  
               density,
               n_points: int = 600, 
               limits: Tuple[float] = (-4, 4)):
    '''
    Wrapper class to sample from a close form bivariate-pdf
    '''
    self.density = density

    x = np.linspace(*limits, n_points)
    y = np.linspace(*limits, n_points)
    x, y = np.meshgrid(x, y)
    z = density(Tensor(np.c_[x, y])).data.numpy().reshape((n_points, n_points))
    z = z / np.sum(z)
    
    points = zip(z.ravel(), x.ravel(), y.ravel())
    points = list(accumulate(points, sum_probs))
    self.points = points
    
  def sample(self, n: int):
    return np.array([sample(self.points) for _ in range(n)])


class Gaussian:
    def __init__(self, dim: int = 2):
      self.d = 2
      
    def unpack(self, parameters: Tensor) -> Tuple[Tensor, ...]:
      ''' takes hidden state and returns mu and sigma'''
      mu, log_var = parameters.chunk(2, dim=1)
      std = torch.exp(0.5*log_var)
      
      return mu, std 
    
    def sample_with_log_prob(self, n, parameters: Tensor):
      mu, std = parameters
      
      xs = mu + std * Gaussian.sample(n, self.d)

      return xs, self.log_prob(xs, mu, std)
    
    @classmethod
    def sample(cls, n: int, d: int, mean: float = 0, std:float = 1):
        return torch.zeros(n, d).normal_(mean=mean, std=std)
    
    @classmethod
    def log_prob(cls, value, mu, std):
      '''
      Log of density function of multivariate normal with diagonal 
      covariance matrix
      '''
      var = std ** 2
      log_std = math.log(std) if isinstance(std, Number) else std.log()
      return (-((value - mu) ** 2) / (2 * var) - log_std - math.log(math.sqrt(2 * math.pi) )).sum(1, True)

    
    @property
    def dims(self):
      ''' 
      params for mu and sigma
      '''
      return 2 * self.d

## Flow-Based Free Energy Bound

In [0]:
class FreeEnergyBound(nn.Module):

    def __init__(self, p_x):
        super().__init__()
        self.p_x = p_x
        

    def forward(self, log_q_0, z_k, log_jacobians, beta):
        energy =  log_q_0 \
                - beta * safe_log(self.p_x.density(z_k)) \
                - log_jacobians
        
        return energy.mean()

In [0]:
#@title
import os
from torch.autograd import Variable
from matplotlib import pyplot as plt



def scatter_points(points, directory, iteration, flow_length):

    X_LIMS = (-4, 4)
    Y_LIMS = (-4, 4)

    fig = plt.figure(figsize=(7, 7))
    ax = fig.add_subplot(111)
    ax.scatter(points[:, 0], points[:, 1], alpha=0.7, s=25)
    ax.set_xlim(*X_LIMS)
    ax.set_ylim(*Y_LIMS)
    ax.set_title(
        "Flow length: {}\n Samples on iteration #{}"
        .format(flow_length, iteration)
    )

    fig.savefig(os.path.join(directory, "flow_result_{}.png".format(iteration)))
    plt.close()


def plot_density(distribution, directory):

    X_LIMS = (-4, 4)
    Y_LIMS = (-4, 4)

    x1 = np.linspace(*X_LIMS, 300)
    x2 = np.linspace(*Y_LIMS, 300)
    x1, x2 = np.meshgrid(x1, x2)
    shape = x1.shape
    x1 = x1.ravel()
    x2 = x2.ravel()

    z = np.c_[x1, x2]
    z = torch.FloatTensor(z)
    z = Variable(z)

    density_values = distribution.density(z).data.numpy().reshape(shape)

    fig = plt.figure(figsize=(7, 7))
    ax = fig.add_subplot(111)
    ax.imshow(density_values, extent=(*X_LIMS, *Y_LIMS), cmap="summer")
    ax.set_title("True density")

    fig.savefig(os.path.join(directory, "density.png"))
    plt.close()

In [0]:
from itertools import chain


def init_weights(m):
  if type(m) == nn.Linear:
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)


class Maxout(nn.Module):
    def __init__(self, pool_size: int = 4):
        super().__init__()
        self._pool_size = pool_size

    def forward(self, x):
        assert x.shape[-1] % self._pool_size == 0, \
            f"Wrong input last dim size ({x.shape[-1]}) for Maxout({self._pool_size})"
        
        m, i = x.view(*x.shape[:-1], x.shape[-1] // self._pool_size, self._pool_size).max(-1)
        
        return m
      

class InferenceNetwork(nn.Module):
    def __init__(self, distribution, flows, sizes: List[Tuple[int]], pool_size: int = 4):
      '''
      Inference model using a deep neural network to build a mapping
      from the observations x to the parameters.
      
      '''
      super().__init__()
      
   
      self.flows = flows
      self.distribution = distribution
      
      layers = list(chain.from_iterable([
          (nn.Linear(d_in, d_out*pool_size), Maxout(pool_size=pool_size)) for d_in, d_out in sizes
      ]))
      
      d_in = sizes[-1][-1]
      d_out = (self.flows.dims + self.distribution.dims) * pool_size
      
      layers.append(nn.Linear(d_in, d_out))
      layers.append(Maxout(pool_size))
      
      
      self.net = nn.Sequential(*layers)
      self.net.apply(init_weights)
      
    
    def forward(self, x):
      parameters = self.net(x)
      # unpack initial distribution parameters
      start, end = 0, self.distribution.dims
      dist_params = self.distribution.unpack(parameters[:, start:end])
      
      # unpack flow parameters
      flow_params = self.flows.unpack(parameters[:, start:])
      
      return dist_params, flow_params
    
      

## Training

In [0]:
import os

directory = '/content/results/'
if not os.path.exists(directory):
  os.makedirs(directory)

In [0]:
K = 2

fancy_dist = EmpiricalSampler(p_1)

q_0 = Gaussian(2)
flow = NormalizingFlow(K=K, flow_class=PlanarFlow, dim=2)

net = InferenceNetwork(q_0, flow, [(2, 100)])
annealed_bound = FreeEnergyBound(p_x = fancy_dist)

optimizer = optim.RMSprop(net.parameters(), lr=1e-5, momentum=0.9)

for iteration in range(1, 500000):
    
    # get samples from the true distribution
    true_samples = Tensor(fancy_dist.sample(100))
    
    # use inference network to find the parameters of the

    dist_params, flow_params = net(true_samples)
    z_0, log_q_0 = q_0.sample_with_log_prob(100, dist_params)
    z_k, log_jacobians = flow(z_0, flow_params)
    loss = annealed_bound(log_q_0, z_k, log_jacobians, min(1, 0.01 + iteration/1000))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if iteration % 1000 == 0:
      print("Loss on iteration {}: {}".format(iteration , loss.data.item()))
    
    if iteration % 10000 == 0:
        scatter_points(
            z_k.data.numpy(),
            directory='/content/results/',
            iteration=iteration,
            flow_length=K
        )

# Vatiational Autoencoders

## Vanilla VAE

In [0]:
class VAE(nn.Module):
    '''
    On Mnist
    '''
    def __init__(self, 
                 feature_size: int = 784,
                 hidden_size: int = 400,
                 code_size: int = 20):
        super(VAE, self).__init__()
        self.code_size = code_size
        self.feature_size = feature_size
 
        self.encoder = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.ReLU(True),
            nn.Linear(hidden_size, code_size * 2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(code_size, hidden_size),
            nn.ReLU(True),
            nn.Linear(hidden_size, feature_size),
            nn.Sigmoid()
        )
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + std*eps

    def forward(self, x):
        x = x.view(-1, self.feature_size)
        
        mu, log_var = self.encoder(x).chunk(2, dim=1)

        z = self.reparameterize(mu, log_var)
        
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        kl_div = kl_div / x.size(0)  # mean over batch
        
        return self.decoder(z), kl_div

In [0]:
def reconstruction_loss(recon_x, x):
    # batch mean
    return F.binary_cross_entropy(recon_x, x, reduction="sum") / x.size(0)

In [0]:
class Trainer:
    def __init__(self, model, train_loader, test_loader,
                 log_interval: int =10,
                 batch_size: int =128):
      self.model = model
      self.train_loader = train_loader
      self.test_loader = train_loader

      self.log_interval = log_interval
      self.batch_size = batch_size
      
      self.optimizer = optim.Adam(model.parameters(), lr=1e-3)

    def train(self, epoch):
        self.model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(self.train_loader):
            data = data.to(device)
            self.optimizer.zero_grad()
            recon_batch, kl_div = self.model(data)
            recon_loss = reconstruction_loss(recon_batch, data)
            loss =  recon_loss + kl_div
            loss.backward()
            train_loss += loss.item()
            self.optimizer.step()

            if batch_idx % self.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data)))
                
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tRecon: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    recon_loss.item() / len(data)))
                
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tKL: {:.6f}\n'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    kl_div.item() / len(data)))
                

        print('====> Epoch: {} Average loss: {:.4f}'.format(
              epoch, train_loss / len(train_loader.dataset)))


    def test(self, epoch, name, fixed_sample = None):
        self.model.eval()
        test_loss = 0
        with torch.no_grad():
            for i, (data, _) in enumerate(self.test_loader):
                data = data.to(device)                
                recon_batch, kl_div = self.model(data)
                test_loss += (reconstruction_loss(recon_batch, data) + kl_div).item()
                
                if i == 0:
                    n = min(data.size(0), 8)
                    comparison = torch.cat([
                        data[:n],
                        recon_batch.view(self.batch_size, 1, 28, 28)[:n]
                    ])
                    
                    save_image(comparison.cpu(),
                             f'/content/results/reconstruction_{name}_{epoch}.png', nrow=n)
              
            if not fixed_sample is None:
                fixed_sample = fixed_sample.to(device)
                recon_sample, _ = self.model(fixed_sample)
                recon_sample = recon_sample.view(8, 1, 28, 28)
                
                n = min(fixed_sample.size(0), 8)
                comparison = torch.cat([
                    fixed_sample[:n],
                    recon_sample[:n]
                ])
                
                save_image(comparison.cpu(),
                           f'/content/results/fixed_reconstruction_{name}_{epoch}.png', nrow=n)   
              

        test_loss /= len(test_loader.dataset)
        print('====> Test set loss: {:.4f}'.format(test_loss))
        
      
    def run(self, num_epochs, name, fixed_sample = None):
        for epoch in range(1, num_epochs + 1):
          self.train(epoch)
          self.test(epoch, name, fixed_sample)
          with torch.no_grad():
              sample = torch.randn(64, model.code_size).to(device)
              sample = model.decoder(sample).cpu()
              sample_name = f"results/sample_{name}_{epoch}.png"
              save_image(sample.view(64, 1, 28, 28), sample_name)

In [0]:
device = torch.device("cuda")

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [0]:
model = VAE(code_size=40)
model.to(device)

trainer = Trainer(
    model = model, 
    train_loader = train_loader,
    test_loader = test_loader,
    log_interval = 100
)

In [0]:
trainer.run(10, 'VAE')

  "Please ensure they have the same size.".format(target.size(), input.size()))









  "Please ensure they have the same size.".format(target.size(), input.size()))


====> Epoch: 1 Average loss: 1.3236
====> Test set loss: 6.3234





====> Epoch: 2 Average loss: 0.9862
====> Test set loss: 5.6337





====> Epoch: 3 Average loss: 0.9148
====> Test set loss: 5.3674





====> Epoch: 4 Average loss: 0.8832
====> Test set loss: 5.2217





====> Epoch: 5 Average loss: 0.8645
====> Test set loss: 5.1407





====> Epoch: 6 Average loss: 0.8531
====> Test set loss: 5.0840





====> Epoch: 7 Average loss: 0.8453
====> Test set loss: 5.0471





====> Epoch: 8 Average loss: 0.8390
====> Test set loss: 5.0161





====> Epoch: 9 Average loss: 0.8345
====> Test set loss: 4.9941





====> Epoch: 10 Average loss: 0.8312
====> Test set loss: 4.9694


## VAE with Normalizing Flows

In [0]:
class VAE_NF(nn.Module):
    '''
    On Mnist
    '''
    def __init__(self, 
                 flows: NormalizingFlow,
                 feature_size: int = 784,
                 hidden_size: int = 400,
                 code_size: int = 20):
        super(VAE_NF, self).__init__()
        self.flow = flow
        self.code_size = code_size
        
        self.encoder = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.ReLU(True),
            nn.Linear(hidden_size, code_size * 2 + self.flow.dims)
        )

        self.decoder = nn.Sequential(
            nn.Linear(code_size, hidden_size),
            nn.ReLU(True),
            nn.Linear(hidden_size, feature_size),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + std*eps


    def forward(self, x):
        # Run inference network to get the posterior parameters
        params = self.encoder(x.view(-1, 784))
        
        mu = params[:, :self.code_size]
        log_var = params[:, self.code_size: self.code_size * 2]
        flow_params = self.flow.unpack(params[:, self.code_size*2:])
        
        # Get samples from posterior
        z = self.reparameterize(mu, log_var)
        z_K, log_jacobians = self.flow(z, flow_params)
        
        # Push it through generative network
        x_recon = self.decoder(z_K)
        
        # Calculate the loss
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        kl_div = kl_div / x.size(0) - log_jacobians.mean()
        
        return x_recon, kl_div
      

In [0]:
code_size = 40

flow = NormalizingFlow(K=4, flow_class=PlanarFlow, dim=code_size)

model = VAE_NF(flow, code_size=code_size)
model.to(device)

trainer = Trainer(
    model = model, 
    train_loader = train_loader,
    test_loader = test_loader,
    log_interval = 100
)

In [0]:
fixed_sample, _ = next(iter(test_loader))
fixed_sample = fixed_sample[:8]

In [0]:
trainer.run(100, 'VAE_NF', fixed_sample)

  "Please ensure they have the same size.".format(target.size(), input.size()))









  "Please ensure they have the same size.".format(target.size(), input.size()))


====> Epoch: 1 Average loss: 1.8708
====> Test set loss: 11.8883





====> Epoch: 2 Average loss: 1.7947
====> Test set loss: 19.8788





====> Epoch: 3 Average loss: 9.9437
====> Test set loss: 98.3073





====> Epoch: 4 Average loss: 9.9557
====> Test set loss: 12.6605





====> Epoch: 5 Average loss: 11.0579
====> Test set loss: 95.9289





====> Epoch: 6 Average loss: 13.9789
====> Test set loss: 48.9063





====> Epoch: 7 Average loss: 16.6587
====> Test set loss: 92.7288





====> Epoch: 8 Average loss: 19.9690
====> Test set loss: 107.0952





====> Epoch: 9 Average loss: 13.4478
====> Test set loss: 13.1737





====> Epoch: 10 Average loss: 17.5829
====> Test set loss: 179.5058





====> Epoch: 11 Average loss: 15.4364
====> Test set loss: 25.0211





====> Epoch: 12 Average loss: 15.5048
====> Test set loss: 160.0369





====> Epoch: 13 Average loss: 17.5101
====> Test set loss: 17.2657





====> Epoch: 14 Average loss: 18.6279
====> Test set loss: 22.4786





==