# PixelCNN

* PixelCNN 은 PixelRNN 논문에서 처음 소개됨. 즉 따로 PixelCNN 논문이 있는것은 아님.  
* 해당 논문의 이름처럼 그 논문에서 최고 성능은 digonal BiLSTM PixelRNN 이었으나, PixelCNN 의 낮은 computational cost 때문에, 즉 가성비 때문에 그 이후로 PixelCNN 을 발전시킨 conditional PixelCNN, PixelCNN++ 등이 제안됨.

### References

* https://tensorflow.blog/2016/11/29/pixelcnn-1601-06759-summary/
    * https://github.com/rickiepark/pixel-rnn-tensorflow/blob/pixel-cnn/pixel-cnn.py
* https://github.com/carpedm20/pixel-rnn-tensorflow

개선참조 (위 참조를 통해 구현 후 아래 참조를 통해 개선함)

* https://github.com/igul222/pixel_rnn

In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
mnist = input_data.read_data_sets('MNIST_data/')

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
def mask_conv(inputs, filters, mask_type, kernel_size=[3,3], strides=1, name='mask-conv', activ_fn=None):
    with tf.variable_scope(name):
        kernel_h, kernel_w = kernel_size
        assert kernel_h % 2 == 1 and kernel_w % 2 == 1, "kernel height and width should be odd number"
        center_h = kernel_h // 2
        center_w = kernel_w // 2
        
        # tf.nn.conv2d 에서 사용하는 kernel 의 shape 이 아래처럼 생김.
        # (kernel_height, kernel_width, input_channel_size, output_channel_size)
        # NHWC 로 가정하여 input_channel_size 를 아래처럼 구함.
        entire_kernel_shape = (kernel_h, kernel_w, inputs.shape[-1], filters)
        mask = np.ones(entire_kernel_shape, dtype=np.float32)
        mask[center_h, center_w+1:, :, :] = 0
        mask[center_h+1:, :, :, :] = 0
        # Q. A 일때가 이렇게 되는건가?
        # A. 맞음. 왜냐면 어차피 이건 이전 레이어에서 가져오는 거기 때문에...
        if mask_type == 'A':
            mask[center_h, center_w, :, :] = 0
        
        weight = tf.get_variable("weight", entire_kernel_shape, tf.float32,
                                 tf.contrib.layers.variance_scaling_initializer())
        weight *= tf.constant(mask, dtype=tf.float32)
        
        bias = tf.get_variable("bias", [filters], tf.float32, tf.zeros_initializer())
        with tf.variable_scope('conv'):
            outputs = tf.nn.conv2d(inputs, weight, strides=[1, strides, strides, 1], padding='SAME')
            # bias 는 커널(필터) 당 하나씩 있음!
#             print outputs.shape # (?, 28, 28, 64)
#             print bias.shap # (64,)
            outputs = tf.nn.bias_add(outputs, bias)
        
        if activ_fn:
            outputs = activ_fn(outputs)
    
    return outputs

In [4]:
"""
이건 이해가 안 됨.
continuous distribution => discrete distribution 으로 바꾸는 작업인데...
이걸 왜 이렇게 랜덤하게 바꾸지?

그냥 해석해보면,
uniform distribution 을 쓰니까, 결국 자기 픽셀값에 따라 0/1 이 확률적으로 정해지게 됨.
즉, 0.7 인 값은 0.7 의 확률로 1이 됨. 0.3의 확률로 0이 되고.

그런갑지... 논문을 더 자세히 봐야겠다.
"""
def binarize(images):
    return (np.random.uniform(size=images.shape) < images).astype('float32')

### Model spec

* phase 1
    * 7x7 conv mask A
* phase 2
    * 3x3 conv mask B \* 15 layers
    * residual connection
* phase 3
    * 1x1 conv mask B \* 2 layers
* readout (phase 4)
    * dim matching (1x1 conv, without mask) - 단, 1x1 conv 에서는 mask B 가 no mask conv 와 동일해서 그냥 그걸 씀
    * softmax (256-color) or sigmoid (mnist)
    
no pooling layers (no downsampling)

In [5]:
batch_size = 128
n_filters = 64
n_filters_out = 64
n_conv_layers = 7
n_out_layers = 2

train_steps = mnist.train.num_examples // batch_size
test_steps = mnist.test.num_examples // batch_size

In [6]:
# build_net
tf.reset_default_graph()

X = tf.placeholder(tf.float32, [None, 784])
x_img = tf.reshape(X, [-1, 28, 28, 1])

# phase 1
net = mask_conv(x_img, n_filters, mask_type='A', kernel_size=[7,7], name='p1-convA')

# phase 2
for i in range(n_conv_layers):
    net = mask_conv(net, n_filters, mask_type='B', kernel_size=[3,3], name='p2-convB{}'.format(i), activ_fn=tf.nn.relu)

# phase 3
for i in range(n_out_layers):
    net = mask_conv(net, n_filters_out, mask_type='B', kernel_size=[1,1], name='p3-1x1convB{}'.format(i), activ_fn=tf.nn.relu)

# phase 4
# grad clipping 을 해줘야되나? 원소스에는 있음. 논문에는 별 얘기 없는거같은뎅
logits = mask_conv(net, 1, mask_type='B', kernel_size=[1,1], name='logits')
y_pred = tf.nn.sigmoid(logits)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=x_img))
# train_op = tf.train.AdamOptimizer().minimize(loss)
"""graident clipping

optimizer = tf.train.AdamOptimizer(learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)

new_grads_and_vars = [(tf.clip_by_value(gv[0], -1, 1), gv[1]) for gv in grads_and_vars]
optim = optimizer.apply_gradients(new_grads_and_vars)
"""
optim = tf.train.AdamOptimizer()
grads_and_vars = optim.compute_gradients(loss)
new_grads_and_vars = [(tf.clip_by_value(gv[0], -1, 1), gv[1]) for gv in grads_and_vars]
train_op = optim.apply_gradients(new_grads_and_vars)

summary_op = tf.summary.merge([
    tf.summary.scalar("loss", loss)
])

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

#     plt.show()
    return fig

In [None]:
# train
epoch_n = 300

# with tf.Session() as sess:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# writer
train_writer = tf.summary.FileWriter('summary/train', sess.graph)
test_writer = tf.summary.FileWriter('summary/test')

for epoch in range(epoch_n):
    # train
    train_loss = 0.
    test_loss = 0.
    for i in range(train_steps):
        batch = mnist.train.next_batch(batch_size)
        images = binarize(batch[0])
        _, cur_loss, cur_summary = sess.run([train_op, loss, summary_op], {X: images})
        train_loss += cur_loss
        train_writer.add_summary(cur_summary, epoch)

    # test
    for i in range(test_steps):
        batch = mnist.test.next_batch(batch_size)
        images = binarize(batch[0])
        cur_loss, cur_summary = sess.run([loss, summary_op], {X: images})
        test_writer.add_summary(cur_summary, epoch)
        test_loss += cur_loss

    train_loss /= train_steps
    test_loss /= test_steps
    print "[{}/{}] train_loss: {:.4f} / test_loss: {:.4f}".format(epoch+1, epoch_n, train_loss, test_loss)

    # generate samples
    if epoch == 0 or (epoch+1)%10 == 0:
        samples = np.zeros((16, 784), dtype='float32')
        gen_image = np.zeros((16, 784), dtype='float32')
        for i in range(28):
            for j in range(28):
                for k in range(1):
                    # 이렇게 하는건 좀 비효율적인게 아닌가 시픈댕
                    # 한픽셀을 제너레이트 하기 위해서 전체 이미지를 다 제너레이트함...
                    # 근데 뭐 이걸 개선하려면 골때릴듯...
                    pos = i*28 + j
                    cur_gen_image = sess.run(y_pred, {X: samples})
                    next_samples = binarize(cur_gen_image) # random noise 역할을 하겠네
                    samples[:, pos] = next_samples[:, i, j, k]
                    gen_image[:, pos] = cur_gen_image[:, i, j, k]

        fig = plot(samples) # binarized
#         plot(gen_image) # generated images without binarized
#         plot(cur_gen_image.reshape(16, 784)) # last one-shot generated images
        plt.savefig('out/{:0>4d}.png'.format(epoch), bbox_inches='tight')
        plt.close(fig)

[1/300] train_loss: 0.1572 / test_loss: 0.1204
[2/300] train_loss: 0.1199 / test_loss: 0.1185
[3/300] train_loss: 0.1183 / test_loss: 0.1168
[4/300] train_loss: 0.1175 / test_loss: 0.1163
[5/300] train_loss: 0.1168 / test_loss: 0.1156
[6/300] train_loss: 0.1163 / test_loss: 0.1153
[7/300] train_loss: 0.1160 / test_loss: 0.1147
[8/300] train_loss: 0.1156 / test_loss: 0.1145
[9/300] train_loss: 0.1154 / test_loss: 0.1152
[10/300] train_loss: 0.1151 / test_loss: 0.1144
[11/300] train_loss: 0.1149 / test_loss: 0.1142
[117/300] train_loss: 0.1111 / test_loss: 0.1106
[118/300] train_loss: 0.1111 / test_loss: 0.1108
[190/300] train_loss: 0.1109 / test_loss: 0.1106
[191/300] train_loss: 0.1108 / test_loss: 0.1105
[192/300] train_loss: 0.1108 / test_loss: 0.1105
[193/300] train_loss: 0.1109 / test_loss: 0.1107
[194/300] train_loss: 0.1108 / test_loss: 0.1105
[195/300] train_loss: 0.1108 / test_loss: 0.1105
[196/300] train_loss: 0.1108 / test_loss: 0.1104
[197/300] train_loss: 0.1108 / test_loss