# NICE with MNIST

* `NICE: NON-LINEAR INDEPENDENT COMPONENTS ESTIMATION`, [arXiv:1410.8516](https://arxiv.org/abs/1410.8516)
  * Laurent Dinh, David Krueger and Yoshua Bengio
  
* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) and [`eager execution`](https://www.tensorflow.org/guide/eager).

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

import tensorflow_probability as tfp

tf.logging.set_verbosity(tf.logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
train_dir = 'train/nice/exp1/'
max_epochs = 1500
save_model_epochs = 100
print_steps = 100
save_images_epochs = 50
batch_size = 256
learning_rate = 1e-3
num_examples_to_generate = 25
MNIST_SIZE = 784
noise_dim = MNIST_SIZE

## Load the MNIST dataset

In [None]:
# Load training and eval data from tf.keras
(train_data, train_labels), _ = \
    tf.keras.datasets.mnist.load_data()

train_data = train_data.reshape(-1, MNIST_SIZE).astype('float32')
train_data = train_data / 255.
train_labels = np.asarray(train_labels, dtype=np.int32)

## Set up dataset with `tf.data`

### create input pipeline with `tf.data.Dataset`

In [None]:
tf.set_random_seed(219)

# for train
N = len(train_data)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data[:N])
train_dataset = train_dataset.shuffle(buffer_size = N)
train_dataset = train_dataset.batch(batch_size = batch_size)
print(train_dataset)

## Create the NICE models

In [None]:
class ReLUMLP(tf.keras.Model):
  def __init__(self, input_size):
    super(ReLUMLP, self).__init__()
    self.input_size = input_size
    self.fc1 = layers.Dense(units=1000, activation='relu')
    self.fc2 = layers.Dense(units=1000, activation='relu')
    self.fc3 = layers.Dense(units=1000, activation='relu')
    self.fc4 = layers.Dense(units=1000, activation='relu')
    self.fc5 = layers.Dense(units=1000, activation='relu')
    self.fc6 = layers.Dense(units=MNIST_SIZE-self.input_size)

  def call(self, inputs, training=True):
    """Run the model."""
    fc1 = self.fc1(inputs)
    fc2 = self.fc2(fc1)
    fc3 = self.fc3(fc2)
    fc4 = self.fc4(fc3)
    fc5 = self.fc5(fc4)
    fc6 = self.fc6(fc5)
    
    return fc6

In [None]:
class AdditiveCouplingLayer(tf.keras.Model):
  def __init__(self, input_size):
    super(AdditiveCouplingLayer, self).__init__()
    self.input_size = input_size
    self.relumlp = ReLUMLP(self.input_size)
    
  def call(self, x1, x2):
    y1 = x1
    y2 = x2 + self.relumlp(x1)
    
    return y1, y2
    
  def inverse(self, y1, y2):
    x1 = y1
    x2 = y2 - self.relumlp(y1)
    
    return x1, x2

In [None]:
def partition(inputs, method='oddeven', p1_size=int(MNIST_SIZE/2)):
  if method == 'oddeven':
    partition1 = inputs[:, 0::2]
    partition2 = inputs[:, 1::2]
  elif method == 'topdown':
    partition1 = inputs[:, :p1_size]
    partition2 = inputs[:, p1_size:]
  else:
    raise ValueError('Not allowed method')
    
  return partition1, partition2

In [None]:
def merge(partition1, partition2, method='oddeven'):
  if method == 'oddeven':
    merged = []
    for j in range(partition1.shape[1]):
      merged.append(partition1[:,j])
      merged.append(partition2[:,j])
    merged = tf.stack(merged, axis=1)
  elif method == 'topdown':
    merged = tf.concat((partition1, partition2), axis=1)
  else:
    raise ValueError('Not allowed method')

  return merged

In [None]:
class NICE(tf.keras.Model):
  def __init__(self, partition_method, partition_size):
    super(NICE, self).__init__()
    self.partition_method = partition_method
    self.partition_size1 = partition_size
    self.partition_size2 = MNIST_SIZE - partition_size
    
    self.coupling1 = AdditiveCouplingLayer(self.partition_size1)
    self.coupling2 = AdditiveCouplingLayer(self.partition_size2)
    self.coupling3 = AdditiveCouplingLayer(self.partition_size1)
    self.coupling4 = AdditiveCouplingLayer(self.partition_size2)
    self.scaling = tf.get_variable('scaling', shape=[MNIST_SIZE], dtype=tf.float32)
    
  def call(self, inputs):
    x1, x2 = partition(inputs, self.partition_method)
    
    # naming rule: (num_layer)_(num_partition)
    h1_1, h1_2 = self.coupling1(x1, x2)
    h2_2, h2_1 = self.coupling2(h1_2, h1_1)
    h3_1, h3_2 = self.coupling3(h2_1, h2_2)
    h4_2, h4_1 = self.coupling4(h3_2, h3_1)
    
    h = merge(h4_1, h4_2, self.partition_method) * tf.exp(self.scaling)
    
    return h, self.scaling
  
  def generate_sample(self, noise_vector):
    h4 = noise_vector / tf.exp(self.scaling)
    h4_1, h4_2 = partition(h4, self.partition_method)
    
    h3_2, h3_1 = self.coupling4.inverse(h4_2, h4_1)
    h2_1, h2_2 = self.coupling3.inverse(h3_1, h3_2)
    h1_2, h1_1 = self.coupling2.inverse(h2_2, h2_1)
    x1, x2 = self.coupling1.inverse(h1_1, h1_2)
    
    x = merge(x1, x2, self.partition_method)
    
    return x

In [None]:
nice = NICE(partition_method='oddeven', partition_size=int(MNIST_SIZE/2))

In [None]:
# Defun for performance boost
nice.call = tf.contrib.eager.defun(nice.call)

## Define the loss functions and the optimizer

In [None]:
def negative_log_likelihood(h, prior='logistic'):
  if prior == 'logistic':
    #log_likelihood = -tf.reduce_sum( tf.log(1. + tf.exp(h)) + tf.log(1. + tf.exp(-h)), axis=1 )
    log_likelihood = -tf.reduce_sum( tf.math.softplus(h) + tf.math.softplus(-h), axis=1 )
  elif prior == 'gaussian':
    log_likelihood = -0.5 * tf.reduce_sum(h**2, axis=1)

  return -tf.reduce_mean(ll, axis=0)

In [None]:
#optimizer = tf.train.AdamOptimizer(learning_rate)
#optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.01, epsilon=1e-4)
optimizer = tf.train.RMSPropOptimizer(learning_rate)
#optimizer = tf.train.GradientDescentOptimizer(learning_rate)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.gfile.Exists(checkpoint_dir):
  tf.gfile.MakeDirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 nice=nice)

## Training

In [None]:
# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the nice.
location = 0.0 # location
scale = 0.6 # scale
#random_vector_for_generation = tf.random_uniform([num_examples_to_generate, noise_dim], minval=0.0, maxval=1.0)
#random_vector_for_generation = mu + scale * (tf.log(random_vector_for_generation) - \
#                                             tf.log(1.0 - random_vector_for_generation))
random_vector_for_generation = tfp.distributions.Logistic(loc=location, scale=scale).sample([num_examples_to_generate, noise_dim])

In [None]:
def print_or_save_sample_images(sample_images, max_print_size=num_examples_to_generate,
                                is_square=False, is_save=False, epoch=None,
                                checkpoint_dir=checkpoint_dir):
  available_print_size = list(range(1, 26))
  assert max_print_size in available_print_size
  
  # scaling for showing images
  #max_val = np.expand_dims(np.max(sample_images, axis=1), axis=1)
  #min_val = np.expand_dims(np.min(sample_images, axis=1), axis=1)
  #sample_images = (sample_images - min_val) / (max_val - min_val)
  sample_images = np.clip(sample_images, 0.0, 1.0)
  
  if not is_square:
    print_images = sample_images[:max_print_size, ...]
    print_images = print_images.reshape([max_print_size, 28, 28])
    print_images = print_images.swapaxes(0, 1)
    print_images = print_images.reshape([28, max_print_size * 28])

    fig = plt.figure(figsize=(max_print_size, 1))
    plt.imshow(print_images, cmap='gray')
    plt.axis('off')
    
  else:
    num_columns = int(np.sqrt(max_print_size))
    max_print_size = int(num_columns**2)
    print_images = sample_images[:max_print_size, ...]
    print_images = print_images.reshape([max_print_size, 28, 28])
    print_images = print_images.swapaxes(0, 1)
    print_images = print_images.reshape([28, max_print_size * 28])
    print_images = [print_images[:,i*28*num_columns:(i+1)*28*num_columns] for i in range(num_columns)]
    print_images = np.concatenate(tuple(print_images), axis=0)
    
    fig = plt.figure(figsize=(num_columns, num_columns))
    plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
    plt.imshow(print_images, cmap='gray')
    plt.axis('off')
    
  if is_save and epoch is not None:
    filepath = os.path.join(checkpoint_dir, 'image_at_epoch_{:04d}.png'.format(epoch))
    plt.savefig(filepath)

  plt.show()

In [None]:
tf.logging.info('Start Training.')
global_step = tf.train.get_or_create_global_step()
for epoch in range(max_epochs):
  
  for images in train_dataset:
    start_time = time.time()
    
    with tf.GradientTape() as tape:
      hidden_state, scaling = nice(images)
      nll = negative_log_likelihood(hidden_state, prior='logistic')
      ss = -tf.reduce_sum(scaling) # sum of scaling
      loss = nll + ss

    gradients = tape.gradient(loss, nice.variables)

    optimizer.apply_gradients(zip(gradients, nice.variables), global_step=global_step)
    
    epochs = global_step.numpy() * batch_size / float(N)
    duration = time.time() - start_time

    if global_step.numpy() % print_steps == 0:
      display.clear_output(wait=True)
      examples_per_sec = batch_size / float(duration)
      print("Epochs: {:.2f} global_step: {} loss: {:.3f}  negative log likelihood: {:.3f}  ss: {:.3f}  max_ss: {:.3f}  min_ss: {:.3f}  max_h: {:.3f}  min_h: {:.3f}  ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step.numpy(), loss, nll, ss,
                nice.scaling[tf.argmax(nice.scaling)], nice.scaling[tf.argmin(nice.scaling)],
                hidden_state[4][tf.argmax(hidden_state[4])], hidden_state[8][tf.argmin(hidden_state[8])],
                examples_per_sec, duration))
      sample_images = nice.generate_sample(random_vector_for_generation)
      print_or_save_sample_images(sample_images.numpy(), max_print_size=num_examples_to_generate, is_square=True)

  if (epoch + 1) % save_images_epochs == 0:
    display.clear_output(wait=True)
    print("This images are saved at {} epoch".format(epoch+1))
    sample_images = nice.generate_sample(random_vector_for_generation)
    print_or_save_sample_images(sample_images.numpy(), is_square=True,
                                is_save=True, epoch=epoch+1,
                                checkpoint_dir=checkpoint_dir)

  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
# generating after the final epoch
display.clear_output(wait=True)
sample_images = nice.generate_sample(random_vector_for_generation)
print_or_save_sample_images(sample_images.numpy(), is_square=True,
                            is_save=True, epoch=epoch+1,
                            checkpoint_dir=checkpoint_dir)

## Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Display an image using the epoch number

In [None]:
def display_image(epoch_no, checkpoint_dir=checkpoint_dir):
  filepath = os.path.join(checkpoint_dir, 'image_at_epoch_{:04d}.png'.format(epoch_no))
  return PIL.Image.open(filepath)

In [None]:
display_image(max_epochs)

## Generate a GIF of all the saved images.

In [None]:
with imageio.get_writer('nice.gif', mode='I') as writer:
  filenames = glob.glob(os.path.join(checkpoint_dir, 'image*.png'))
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
    
# this is a hack to display the gif inside the notebook
os.system('cp nice.gif nice.gif.png')

In [None]:
display.Image(filename="nice.gif.png")