<h1><center>Gated Pixel CNN</center></h1>

[The original ***Pixel Recurrent Neural Networks*** paper](https://arxiv.org/abs/1601.06759)

[The original ***Conditional Image Generation with PixelCNN Decoders*** paper](https://arxiv.org/abs/1606.05328)

[PixelCNN](http://sergeiturukin.com/2017/02/22/pixelcnn.html)

[Gated PixelCNN](http://sergeiturukin.com/2017/02/24/gated-pixelcnn.html)

In [1]:
import tensorflow as tf
import tensorflow.keras as K
import chainer.links as L

import numpy as np
import os
import scipy.misc
from datetime import datetime
import tensorflow as tf

# Blind Spot of PixelCNN

What information layer ***L*** gets from ***L-2*** layer? We’re particularly interested in rightmost pixels as they extend our receptive field. What we see is that due to masking they won’t have access to some pixels layer ***L*** interested in. If we continue our logic it’s clear that layer ***L does not see*** some pixels.

In [2]:
# create 5x5 input
# chainer requires input to have shape [BATCH, CHANNELS, HEIGHT, WIDTH]
input = np.arange(25).reshape([1,1,5,5]).astype('f')
print('input: ', input)
# array([[[[    0.,     1.,     2.,     3.,     4.],
#          [    5.,     6.,     7.,     8.,     9.],
#          [   10.,    11.,    12.,    13.,    14.],
#          [   15.,    16.,    17.,    18.,    19.],
#          [   20.,    21.,    22.,    23.,    24.]]]], dtype=float32)

# create kernel of ones so it just sums all values within
# use one for simplicity: easy to check
kernel = np.ones([3, 3])
# turn to proper type 'A' mask
kernel[2:, :] = 0.0
kernel[1, 1:] = 0.0
print('kernel: ', kernel)
# array([[ 1.,  1.,  1.],
#        [ 1.,  0.,  0.],
#        [ 0.,  0.,  0.]])

# create two convolution layers with total receptive field size 5x5
# so out input is exact fit

l1 = L.Convolution2D(1, 1, ksize=3, initialW=kernel)
l2 = L.Convolution2D(1, 1, ksize=3, initialW=kernel)
# c1 = K.layers.Conv2D(filters=1, kernel_size=3, padding='SAME', strides=(1, 1), kernel_initializer=lambda x: kernel)
# c2 = K.layers.Conv2D(filters=1, kernel_size=3, padding='SAME', strides=(1, 1), kernel_initializer=lambda x: kernel)

# here is the trick: pixel at [1, 4] position will be inside blind spot
# if we perform convolution its value won't be included in final sum
# so let's increase its value so it would be easy to check
input[:, :, 1, 4] = 1000
print('input modified: ', input)
# array([[[[    0.,     1.,     2.,     3.,     4.],
#          [    5.,     6.,     7.,     8.,  1000.],
#          [   10.,    11.,    12.,    13.,    14.],
#          [   15.,    16.,    17.,    18.,    19.],
#          [   20.,    21.,    22.,    23.,    24.]]]], dtype=float32)

output = l2(l1(input)).data
print('output: ', output)
# array([[[[ 64.]]]], dtype=float32)
# Viola! Sum is lesser that 1000 which means pixel at [1, 4] wasn't seen!

# Otherwise, let's return it value back
input[:, :, 1, 4] = 9
print('input initial: ', input)
# array([[[[    0.,     1.,     2.,     3.,     4.],
#          [    5.,     6.,     7.,     8.,     9.],
#          [   10.,    11.,    12.,    13.,    14.],
#          [   15.,    16.,    17.,    18.,    19.],
#          [   20.,    21.,    22.,    23.,    24.]]]], dtype=float32)

# perform computation again..
output = l2(l1(input)).data
print('output: ', output)
# array([[[[ 64.]]]], dtype=float32)
# Another evidence: no matter what value we assign to it final sum doesn't change
# That proves it's within blind spot and we can't access information at it.

input:  [[[[ 0.  1.  2.  3.  4.]
   [ 5.  6.  7.  8.  9.]
   [10. 11. 12. 13. 14.]
   [15. 16. 17. 18. 19.]
   [20. 21. 22. 23. 24.]]]]
kernel:  [[1. 1. 1.]
 [1. 0. 0.]
 [0. 0. 0.]]
input modified:  [[[[   0.    1.    2.    3.    4.]
   [   5.    6.    7.    8. 1000.]
   [  10.   11.   12.   13.   14.]
   [  15.   16.   17.   18.   19.]
   [  20.   21.   22.   23.   24.]]]]
output:  [[[[64.]]]]
input initial:  [[[[ 0.  1.  2.  3.  4.]
   [ 5.  6.  7.  8.  9.]
   [10. 11. 12. 13. 14.]
   [15. 16. 17. 18. 19.]
   [20. 21. 22. 23. 24.]]]]
output:  [[[[64.]]]]


# Useful functions

In [3]:
def get_weights(shape, name, horizontal, mask_mode='noblind', mask=None):
    weights_initializer = tf.contrib.layers.xavier_initializer()
    W = tf.get_variable(name, shape, tf.float32, weights_initializer)

    '''
        Use of masking to hide subsequent pixel values 
    '''
    if mask:
        filter_mid_y = shape[0]//2
        filter_mid_x = shape[1]//2
        mask_filter = np.ones(shape, dtype=np.float32)
        if mask_mode == 'noblind':
            if horizontal:
                # All rows after center must be zero
                mask_filter[filter_mid_y+1:, :, :, :] = 0.0
                # All columns after center in center row must be zero
                mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.0
            else:
                if mask == 'a':
                    # In the first layer, can ONLY access pixels above it
                    mask_filter[filter_mid_y:, :, :, :] = 0.0
                else:
                    # In the second layer, can access pixels above or even with it.
                    # Reason being that the pixels to the right or left of the current pixel
                    #  only have a receptive field of the layer above the current layer and up.
                    mask_filter[filter_mid_y+1:, :, :, :] = 0.0

            if mask == 'a':
                # Center must be zero in first layer
                mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0
        else:
            mask_filter[filter_mid_y, filter_mid_x+1:, :, :] = 0.
            mask_filter[filter_mid_y+1:, :, :, :] = 0.

            if mask == 'a':
                mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.
                
        W *= mask_filter 
    return W

def get_bias(shape, name):
    return tf.get_variable(name, shape, tf.float32, tf.zeros_initializer)

def conv_op(x, W):
    return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

class GatedCNN():
    def __init__(self, W_shape, fan_in, horizontal, gated=True, payload=None, mask=None, activation=True, conditional=None, conditional_image=None):
        self.fan_in = fan_in
        in_dim = self.fan_in.get_shape()[-1]
        self.W_shape = [W_shape[0], W_shape[1], in_dim, W_shape[2]]  
        self.b_shape = W_shape[2]

        self.in_dim = in_dim
        self.payload = payload
        self.mask = mask
        self.activation = activation
        self.conditional = conditional
        self.conditional_image = conditional_image
        self.horizontal = horizontal
        
        if gated:
            self.gated_conv()
        else:
            self.simple_conv()

    def gated_conv(self):
        W_f = get_weights(self.W_shape, "v_W", self.horizontal, mask=self.mask)
        W_g = get_weights(self.W_shape, "h_W", self.horizontal, mask=self.mask)

        b_f_total = get_bias(self.b_shape, "v_b")
        b_g_total = get_bias(self.b_shape, "h_b")
        if self.conditional is not None:
            h_shape = int(self.conditional.get_shape()[1])
            V_f = get_weights([h_shape, self.W_shape[3]], "v_V", self.horizontal)
            b_f = tf.matmul(self.conditional, V_f)
            V_g = get_weights([h_shape, self.W_shape[3]], "h_V", self.horizontal)
            b_g = tf.matmul(self.conditional, V_g)

            b_f_shape = tf.shape(b_f)
            b_f = tf.reshape(b_f, (b_f_shape[0], 1, 1, b_f_shape[1]))
            b_g_shape = tf.shape(b_g)
            b_g = tf.reshape(b_g, (b_g_shape[0], 1, 1, b_g_shape[1]))

            b_f_total = b_f_total + b_f
            b_g_total = b_g_total + b_g
        if self.conditional_image is not None:
            b_f_total = b_f_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_f")
            b_g_total = b_g_total + tf.layers.conv2d(self.conditional_image, self.in_dim, 1, use_bias=False, name="ci_g")

        conv_f = conv_op(self.fan_in, W_f)
        conv_g = conv_op(self.fan_in, W_g)
       
        if self.payload is not None:
            conv_f += self.payload
            conv_g += self.payload

        self.fan_out = tf.multiply(tf.tanh(conv_f + b_f_total), tf.sigmoid(conv_g + b_g_total))

    def simple_conv(self):
        W = get_weights(self.W_shape, "W", self.horizontal, mask_mode="standard", mask=self.mask)
        b = get_bias(self.b_shape, "b")
        conv = conv_op(self.fan_in, W)
        if self.activation: 
            self.fan_out = tf.nn.relu(tf.add(conv, b))
        else:
            self.fan_out = tf.add(conv, b)

    def output(self):
        return self.fan_out 

# Models

In [5]:
class PixelCNN(object):
    def __init__(self, X, conf, full_horizontal=True, h=None):
        self.X = X
        if conf.data == "mnist":
            self.X_norm = X
        else:
            '''
                Image normalization for CIFAR-10 was supposed to be done here
            '''
            self.X_norm = X
        v_stack_in, h_stack_in = self.X_norm, self.X_norm

        if conf.conditional is True:
            if h is not None:
                self.h = h
            else:
                self.h = tf.placeholder(tf.float32, shape=[None, conf.num_classes]) 
        else:
            self.h = None

        for i in range(conf.layers):
            filter_size = 3 if i > 0 else 7
            mask = 'b' if i > 0 else 'a'
            residual = True if i > 0 else False
            i = str(i)
            with tf.variable_scope("v_stack"+i):
                v_stack = GatedCNN([filter_size, filter_size, conf.f_map], v_stack_in, False, mask=mask, conditional=self.h).output()
                v_stack_in = v_stack

            with tf.variable_scope("v_stack_1"+i):
                v_stack_1 = GatedCNN([1, 1, conf.f_map], v_stack_in, False, gated=False, mask=None).output()

            with tf.variable_scope("h_stack"+i):
                h_stack = GatedCNN([filter_size if full_horizontal else 1, filter_size, conf.f_map], h_stack_in, True, payload=v_stack_1, mask=mask, conditional=self.h).output()

            with tf.variable_scope("h_stack_1"+i):
                h_stack_1 = GatedCNN([1, 1, conf.f_map], h_stack, True, gated=False, mask=None).output()
                if residual:
                    h_stack_1 += h_stack_in # Residual connection
                h_stack_in = h_stack_1

        with tf.variable_scope("fc_1"):
            fc1 = GatedCNN([1, 1, conf.f_map], h_stack_in, True, gated=False, mask='b').output()

        if conf.data == "mnist":
            with tf.variable_scope("fc_2"):
                self.fc2 = GatedCNN([1, 1, 1], fc1, True, gated=False, mask='b', activation=False).output()
            self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fc2, labels=self.X))
            self.pred = tf.nn.sigmoid(self.fc2)
        else:
            color_dim = 256
            with tf.variable_scope("fc_2"):
                self.fc2 = GatedCNN([1, 1, conf.channel * color_dim], fc1, True, gated=False, mask='b', activation=False).output()
                self.fc2 = tf.reshape(self.fc2, (-1, color_dim))

            self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(self.fc2, tf.cast(tf.reshape(self.X, [-1]), dtype=tf.int32)))

            '''
                Since this code was not run on CIFAR-10, I'm not sure which 
                would be a suitable way to generate 3-channel images. Below are
                the 2 methods which may be used, with the first one (self.pred)
                being more likely.
            '''
            self.pred_sampling = tf.reshape(tf.multinomial(tf.nn.softmax(self.fc2), num_samples=1, seed=100), tf.shape(self.X))
            self.pred = tf.reshape(tf.argmax(tf.nn.softmax(self.fc2), dimension=tf.rank(self.fc2) - 1), tf.shape(self.X))


class ConvolutionalEncoder(object):
    def __init__(self, X, conf):
        '''
            This is the 6-layer architecture for Convolutional Autoencoder
            mentioned in the original paper: 
            Stacked Convolutional Auto-Encoders for Hierarchical Feature Extraction
            Note that only the encoder part is implemented as PixelCNN is taken
            as the decoder.
        '''

        W_conv1 = get_weights([5, 5, conf.channel, 100], "W_conv1")
        b_conv1 = get_bias([100], "b_conv1")
        conv1 = tf.nn.relu(conv_op(X, W_conv1) + b_conv1)
        pool1 = max_pool_2x2(conv1)

        W_conv2 = get_weights([5, 5, 100, 150], "W_conv2")
        b_conv2 = get_bias([150], "b_conv2")
        conv2 = tf.nn.relu(conv_op(pool1, W_conv2) + b_conv2)
        pool2 = max_pool_2x2(conv2)

        W_conv3 = get_weights([3, 3, 150, 200], "W_conv3")
        b_conv3 = get_bias([200], "b_conv3")
        conv3 = tf.nn.relu(conv_op(pool2, W_conv3) + b_conv3)
        conv3_reshape = tf.reshape(conv3, (-1, 7*7*200))

        W_fc = get_weights([7*7*200, 10], "W_fc")
        b_fc = get_bias([10], "b_fc")
        self.pred = tf.nn.softmax(tf.add(tf.matmul(conv3_reshape, W_fc), b_fc))

# Utility

In [6]:
def binarize(images):
    return (np.random.uniform(size=images.shape) < images).astype(np.float32)

def generate_samples(sess, X, h, pred, conf, suff):
    print("Generating Sample Images...")
    n_row, n_col = 10,10
    samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
    # TODO make it generic
    labels = one_hot(np.array([0,1,2,3,4,5,6,7,8,9]*10), conf.num_classes)

    for i in range(conf.img_height):
        for j in range(conf.img_width):
            for k in range(conf.channel):
                data_dict = {X:samples}
                if conf.conditional is True:
                    data_dict[h] = labels
                next_sample = sess.run(pred, feed_dict=data_dict)
                if conf.data == "mnist":
                    next_sample = binarize(next_sample)
                samples[:, i, j, k] = next_sample[:, i, j, k]

    save_images(samples, n_row, n_col, conf, suff)


def generate_ae(sess, encoder_X, decoder_X, y, data, conf, suff=''):
    print("Generating Sample Images...")
    n_row, n_col = 10,10
    samples = np.zeros((n_row*n_col, conf.img_height, conf.img_width, conf.channel), dtype=np.float32)
    if conf.data == 'mnist':
        labels = binarize(data.train.next_batch(n_row*n_col)[0].reshape(n_row*n_col, conf.img_height, conf.img_width, conf.channel))
    else:
        labels = get_batch(data, 0, n_row*n_col) 

    for i in range(conf.img_height):
        for j in range(conf.img_width):
            for k in range(conf.channel):
                next_sample = sess.run(y, {encoder_X: labels, decoder_X: samples})
                if conf.data == 'mnist':
                    next_sample = binarize(next_sample)
                samples[:, i, j, k] = next_sample[:, i, j, k]

    save_images(samples, n_row, n_col, conf, suff)


def save_images(samples, n_row, n_col, conf, suff):
    images = samples 
    if conf.data == "mnist":
        images = images.reshape((n_row, n_col, conf.img_height, conf.img_width))
        images = images.transpose(1, 2, 0, 3)
        images = images.reshape((conf.img_height * n_row, conf.img_width * n_col))
    else:
        images = images.reshape((n_row, n_col, conf.img_height, conf.img_width, conf.channel))
        images = images.transpose(1, 2, 0, 3, 4)
        images = images.reshape((conf.img_height * n_row, conf.img_width * n_col, conf.channel))

    filename = datetime.now().strftime('%Y_%m_%d_%H_%M')+suff+".jpg"
    scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(os.path.join(conf.samples_path, filename))


def get_batch(data, pointer, batch_size):
    if (batch_size + 1) * pointer >= data.shape[0]:
        pointer = 0
    batch = data[batch_size * pointer : batch_size * (pointer + 1)]
    pointer += 1
    return [batch, pointer]


def one_hot(batch_y, num_classes):
    y_ = np.zeros((batch_y.shape[0], num_classes))
    y_[np.arange(batch_y.shape[0]), batch_y] = 1
    return y_


def makepaths(conf):
    ckpt_full_path = os.path.join(conf.ckpt_path, "data=%s_bs=%d_layers=%d_fmap=%d"%(conf.data, conf.batch_size, conf.layers, conf.f_map))
    if not os.path.exists(ckpt_full_path):
        os.makedirs(ckpt_full_path)
    conf.ckpt_file = os.path.join(ckpt_full_path, "model.ckpt")

    conf.samples_path = os.path.join(conf.samples_path, "epoch=%d_bs=%d_layers=%d_fmap=%d"%(conf.epochs, conf.batch_size, conf.layers, conf.f_map))
    if not os.path.exists(conf.samples_path):
        os.makedirs(conf.samples_path)

    if tf.gfile.Exists(conf.summary_path):
        tf.gfile.DeleteRecursively(conf.summary_path)
    tf.gfile.MakeDirs(conf.summary_path)

    return conf

# Train

In [None]:
def train(conf, data):
    X = tf.placeholder(tf.float32, shape=[None, conf.img_height, conf.img_width, conf.channel])
    model = PixelCNN(X, conf)

    trainer = tf.train.RMSPropOptimizer(1e-3)
    gradients = trainer.compute_gradients(model.loss)

    clipped_gradients = [(tf.clip_by_value(_[0], -conf.grad_clip, conf.grad_clip), _[1]) for _ in gradients]
    optimizer = trainer.apply_gradients(clipped_gradients)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as sess: 
        sess.run(tf.initialize_all_variables())
        if os.path.exists(conf.ckpt_file):
            saver.restore(sess, conf.ckpt_file)
            print("Model Restored")
       
        if conf.epochs > 0:
            print("Started Model Training...")
        pointer = 0
        for i in range(conf.epochs):
            for j in range(conf.num_batches):
                if conf.data == "mnist":
                    batch_X, batch_y = data.train.next_batch(conf.batch_size)
                    batch_X = binarize(batch_X.reshape([conf.batch_size, \
                            conf.img_height, conf.img_width, conf.channel]))
                    batch_y = one_hot(batch_y, conf.num_classes) 
                else:
                    batch_X, pointer = get_batch(data, pointer, conf.batch_size)
                data_dict = {X:batch_X}
                if conf.conditional is True:
                    data_dict[model.h] = batch_y
                _, cost = sess.run([optimizer, model.loss], feed_dict=data_dict)
            print("Epoch: %d, Cost: %f"%(i, cost))
            if (i+1)%10 == 0:
                saver.save(sess, conf.ckpt_file)
                generate_samples(sess, X, model.h, model.pred, conf, "")

        generate_samples(sess, X, model.h, model.pred, conf, "")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='mnist')
    parser.add_argument('--layers', type=int, default=12)
    parser.add_argument('--f_map', type=int, default=32)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--grad_clip', type=int, default=1)
    parser.add_argument('--model', type=str, default='')
    parser.add_argument('--data_path', type=str, default='data')
    parser.add_argument('--ckpt_path', type=str, default='ckpts')
    parser.add_argument('--samples_path', type=str, default='samples')
    parser.add_argument('--summary_path', type=str, default='logs')
    conf = parser.parse_args()
  
    if conf.data == 'mnist':
        from tensorflow.examples.tutorials.mnist import input_data
        if not os.path.exists(conf.data_path):
            os.makedirs(conf.data_path)
        data = input_data.read_data_sets(conf.data_path)
        conf.num_classes = 10
        conf.img_height = 28
        conf.img_width = 28
        conf.channel = 1
        conf.num_batches = data.train.num_examples // conf.batch_size
    else:
        from keras.datasets import cifar10
        data = cifar10.load_data()
        labels = data[0][1]
        data = data[0][0].astype(np.float32)
        data[:,0,:,:] -= np.mean(data[:,0,:,:])
        data[:,1,:,:] -= np.mean(data[:,1,:,:])
        data[:,2,:,:] -= np.mean(data[:,2,:,:])
        data = np.transpose(data, (0, 2, 3, 1))
        conf.img_height = 32
        conf.img_width = 32
        conf.channel = 3
        conf.num_classes = 10
        conf.num_batches = data.shape[0] // conf.batch_size

    conf = makepaths(conf) 
    if conf.model == '':
        conf.conditional = False
        train(conf, data)
    elif conf.model.lower() == 'conditional':
        conf.conditional = True
        train(conf, data)
    elif conf.model.lower() == 'autoencoder':
        conf.conditional = True
        trainAE(conf, data)