# Neural Autoregressive Density Estimation

* `Neural Autoregressive Distribution Estimation`, [arXiv:1605.02226](https://arxiv.org/abs/1605.02226)
  * Benigno Uria, Marc-Alexandre Cˆot ́e, Karol Gregor, Iain Murray, and Hugo Larochelle

* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) and [`eager execution`](https://www.tensorflow.org/guide/eager).

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

sys.path.append(os.path.dirname(os.path.abspath('.')))
from utils.image_utils import *
from utils.ops import *

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'nade'
train_dir = 'train/' + model_name + '/exp1/'
constant_lr_epochs = 5
decay_lr_epochs = 5
max_epochs = constant_lr_epochs + decay_lr_epochs
save_model_epochs = 1
print_steps = 10
save_images_epochs = 1
batch_size = 128
learning_rate = 5e-2
num_examples_to_generate = 16
MNIST_SIZE = 28
hidden_dims = 500

## Load the MNIST dataset

In [None]:
# Load training and eval data from tf.keras
(train_data, train_labels), (test_data, test_labels) = \
    tf.keras.datasets.mnist.load_data()

train_data = train_data.reshape(-1, MNIST_SIZE * MNIST_SIZE).astype('float32')
train_data = train_data / 255.

# Binarization
train_data[train_data >= .5] = 1.
train_data[train_data < .5] = 0.

In [None]:
index = 219
print("label = {}".format(train_labels[index]))

plt.imshow(train_data[index].reshape([MNIST_SIZE, MNIST_SIZE]))
plt.colorbar()
#plt.gca().grid(False)
plt.show()

## Set up dataset with `tf.data`

### create input pipeline with `tf.data.Dataset`

In [None]:
def max_pooling(image):
  pool = layers.MaxPooling2D()(tf.expand_dims(image, axis=0))
  flatten = tf.reshape(pool, [MNIST_SIZE * MNIST_SIZE])
  return flatten

In [None]:
#tf.set_random_seed(219)

# for train
N = len(train_data)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)
#train_dataset = train_dataset.map(lambda x: max_pooling(x))
train_dataset = train_dataset.shuffle(buffer_size=N)
train_dataset = train_dataset.batch(batch_size=batch_size, drop_remainder=True)
print(train_dataset)

## Create the generator and discriminator models

In [None]:
def log_pmf(sample, probability):
  epsilon = 1e-10
  return sample * tf.log(probability + epsilon) + (1.-sample) * tf.log(1.-probability + epsilon)

In [None]:
class NeuralAutoregressiveDensityEstimation(tf.keras.Model):
  def __init__(self, hidden_dims=500):
    """Initiailize learnable parameters.
    
    Args:
      hidden_dims (int): number of hidden unit in every dense layer (input to hidden)
    
    Values:
      self.W (float32 2-rank Tensor): shared training weight of dense layer (input to hidden)
      self.c (float32 1-rank Tensor): shared training bais of dense layer (input to hidden)
      self.models (list): list of 783 dense layers
    """
    super(NeuralAutoregressiveDensityEstimation, self).__init__()
    self.hidden_dims = hidden_dims
    self.W = tf.get_variable(name='shared_weight', shape=[MNIST_SIZE**2, self.hidden_dims],
                             initializer=tf.random_normal_initializer(mean=0., stddev=0.04))
    self.c = tf.get_variable(name='shared_bias', shape=[self.hidden_dims],
                             initializer=tf.zeros_initializer())
    self.models = []
    for i in range(0, MNIST_SIZE**2):
      self.models.append(tf.keras.Sequential([layers.InputLayer(input_shape=[self.hidden_dims]),
                                              layers.Dense(units=1, activation='sigmoid')]))
      
  def call(self, inputs):
    """Build a likelihood function.
    Implement using papers' algorithm in order to reduce computation of p(x)
    
    Args:
      inputs (float32 2-rank Tensor): MNIST data in one batch
      
    Returns:
      logpx (float32 2-rank Tensor): log likelihood of corresponding batch data
    """
    a = tf.stack([self.c] * batch_size)
    logpx = 0.
    for i in range(MNIST_SIZE**2):
      h = tf.math.sigmoid(a)
      probability = self.models[i](h)
      logpx += log_pmf(inputs[:, i:i+1], probability)
      a = a + tf.matmul(inputs[:, i:i+1], self.W[i:i+1, :])
      
    return logpx

  
  def sampling(self, num_samples):
    """Sample images
    
    Args:
      num_samples (int): number of sample images
      
    Returns:
      samples (float32 2-rank Tensor): sampling images
    """
    h = tf.math.sigmoid(tf.stack([self.c] * num_samples))
    init_pixel_probability = self.models[0](h)
    samples = tf.get_variable(name='sample_pixel', shape=[num_samples, MNIST_SIZE**2],
                              dtype=tf.int32, initializer=tf.zeros_initializer())
    samples[:, 0:1].assign(tf.random.categorical(tf.log(tf.concat((1. - init_pixel_probability,
                                                                   init_pixel_probability), axis=1)),
                                                 num_samples=1, dtype=tf.int32))
    
    for i in range(1, MNIST_SIZE**2):
      h = tf.math.sigmoid(tf.matmul(tf.cast(samples[:, :i], dtype=tf.float32), self.W[:i, :]) + self.c)
      probability_ith_pixel = self.models[i](h).numpy() # actually i+1 th pixel
      samples[:, i:i+1].assign(tf.random.categorical(tf.log(tf.concat((1. - probability_ith_pixel,
                                                                       probability_ith_pixel), axis=1)),
                                                     num_samples=1, dtype=tf.int32))
    return samples

In [None]:
nade = NeuralAutoregressiveDensityEstimation(hidden_dims=hidden_dims)

In [None]:
# tf.contrib.eager.defun will be deprecated in TF version 2.0
# Defun for performance boost
nade.call = tf.contrib.eager.defun(nade.call)

## Define the loss functions and the optimizer

* Actually we use the log likelihood function as loss function

### Define learning rate decay functions

In [None]:
lr = learning_rate
def get_lr():
  global lr
  num_steps_per_epoch = int(N / batch_size)
  if global_step.numpy() > num_steps_per_epoch * constant_lr_epochs:
    decay_step = num_steps_per_epoch * decay_lr_epochs
    lr = lr - (learning_rate * 1. / decay_step) # tf.train.polynomial_decay (linear decay)
    #print('lr', lr)
    return lr
  else:
    return lr

### Define optimizer

In [None]:
optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
#optimizer = tf.train.GradientDescentOptimizer(learning_rate)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.gfile.Exists(checkpoint_dir):
  tf.gfile.MakeDirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, nade=nade)

## Training

In [None]:
print('Start Training.')
global_step = tf.train.get_or_create_global_step()
num_batches_per_epoch = int(N / batch_size)
for epoch in range(max_epochs):
  
  for step, images in enumerate(train_dataset):
    start_time = time.time()

    with tf.GradientTape() as tape:
      log_likelihood = nade(images)
      loss = -tf.reduce_mean(log_likelihood) # expected negative_log_likelihood
      
    gradients = tape.gradient(loss, nade.variables)
    optimizer.apply_gradients(zip(gradients, nade.variables), global_step=global_step)
    
    epochs = epoch + step / float(num_batches_per_epoch)
    duration = time.time() - start_time

    if global_step.numpy() % print_steps == 0:
      display.clear_output(wait=True)
      examples_per_sec = batch_size / float(duration)
      print("Epochs: {:.2f} global_step: {} loss: {:.3g} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step.numpy(), loss.numpy(), examples_per_sec, duration))
      sample_images = nade.sampling(num_examples_to_generate)
      print_or_save_sample_images(sample_images.numpy(), num_examples_to_generate)
      
  if (epoch + 1) % save_images_epochs == 0:
    display.clear_output(wait=True)
    print("This images are saved at {} epoch".format(epoch+1))
    sample_images = nade.sampling(num_examples_to_generate)
    print_or_save_sample_images(sample_images.numpy(), num_examples_to_generate,
                                is_square=True, is_save=True, epoch=epoch+1,
                                checkpoint_dir=checkpoint_dir)
      
  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix=checkpoint_prefix)
    
print('Training Done.')

In [None]:
# generating after the final epoch
display.clear_output(wait=True)
sample_images = nade.sampling(num_examples_to_generate)
print_or_save_sample_images(sample_images.numpy(), num_examples_to_generate,
                            is_square=True, is_save=True, epoch=epoch+1,
                            checkpoint_dir=checkpoint_dir)

## Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Display an image using the epoch number

In [None]:
display_image(max_epochs, checkpoint_dir)

## Generate a GIF of all the saved images.

In [None]:
filename = model_name + '.gif'
generate_gif(filename, checkpoint_dir)

In [None]:
display.Image(filename=filename + '.png')