# Fully Visible Sigmoid Belief Networks

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

tf.logging.set_verbosity(tf.logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Load the MNIST dataset

In [None]:
# Load training and eval data from tf.keras
(train_data, train_labels), _ = \
    tf.keras.datasets.mnist.load_data()

train_data = train_data.reshape(-1, 784).astype('float32')
#train_data = train_data / 255.
# Change each pixel value from [0, 255] to [0, 1] using just one threshold (pixel mean)
train_data_binary = np.heaviside(train_data - train_data.mean() * 3.0, 0.0)
#train_labels = np.asarray(train_labels, dtype=np.int32)

In [None]:
index = 219
print("label = {}".format(train_labels[index]))

fig = plt.figure(figsize=(4, 2))
p = fig.add_subplot(1, 2, 1)
p.imshow(train_data[index].reshape([28, 28]), cmap='gray')
p.axis('off')

p = fig.add_subplot(1, 2, 2)
p.imshow(train_data_binary[index].reshape([28, 28]), cmap='gray')
p.axis('off')
#plt.show()

## Set up dataset with `tf.data`

### create input pipeline with `tf.data.Dataset`

In [None]:
batch_size = 32

In [None]:
tf.set_random_seed(219)

# for train
N = len(train_data_binary)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data_binary[:N])
train_dataset = train_dataset.shuffle(buffer_size = N)
train_dataset = train_dataset.batch(batch_size = batch_size)
print(train_dataset)

In [None]:
# models = []
# #model = tf.keras.Sequential([layers.InputLayer(input_shape=[1]),
# #                             layers.Dense(units=1, activation='sigmoid', use_bias=False)])
# #models.append(model)
# for i in range(1, 784):
#   model = tf.keras.Sequential([layers.InputLayer(input_shape=[i]),
#                                layers.Dense(units=1, activation='sigmoid')])
#   models.append(model)

In [None]:
class SigmoidBeliefNetwork(tf.keras.Model):
  def __init__(self):
    super(SigmoidBeliefNetwork, self).__init__()
    self.models = []
    #self.model = tf.keras.Sequential([layers.InputLayer(input_shape=[1]),
    #                                  layers.Dense(units=1, activation='sigmoid', use_bias=False)])
    #self.models.append(self.model)
    for i in range(1, 784):
      self.model = tf.keras.Sequential([layers.InputLayer(input_shape=[i]),
                                        layers.Dense(units=1)])
      self.models.append(self.model)
    
  def call(self, inputs):
    log_probability = 0.0
    for i in range(1, 784):
      log_probability += tf.log(tf.nn.sigmoid(self.models[i-1](inputs[:, 0:i])))
      
    return log_probability
  
  def sampling(self, num_samples):
    #samples = tf.zeros([num_samples, 784])
    samples = np.zeros([num_samples, 784])
    for i in range(1, 784):
      samples[:, i:i+1] = np.heaviside(self.models[i-1](samples[:, 0:i]).numpy(), 0.0)
      
    return samples

In [None]:
sbn = SigmoidBeliefNetwork()

In [None]:
optimizer = tf.train.AdamOptimizer(0.001)

In [None]:
tf.logging.info('Start Training.')
global_step = tf.train.get_or_create_global_step()
for epoch in range(10):
  
  for images in train_dataset:
    start_time = time.time()

    with tf.GradientTape() as tape:
      log_likelihood = sbn(images)
      loss = -tf.reduce_mean(log_likelihood)

    gradients = tape.gradient(loss, sbn.variables)
    optimizer.apply_gradients(zip(gradients, sbn.variables), global_step=global_step)
    
    epochs = global_step.numpy() * batch_size / float(N)
    duration = time.time() - start_time

    print_steps = 1
    if global_step.numpy() % print_steps == 0:
      display.clear_output(wait=True)
      examples_per_sec = batch_size / float(duration)
      print("Epochs: {:.2f} global_step: {} loss: {:.3f} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step.numpy(), loss.numpy(), examples_per_sec, duration))

In [None]:
samples = sbn.sampling(2)

In [None]:
fig = plt.figure(figsize=(4, 2))
p = fig.add_subplot(1, 2, 1)
p.imshow(samples[0].reshape([28, 28]), cmap='gray')
p.axis('off')

p = fig.add_subplot(1, 2, 2)
p.imshow(samples[1].reshape([28, 28]), cmap='gray')
p.axis('off')
#plt.show()