## CNN with BN in TensorFlow

* DCGAN 의 CNN 모델을 만들어서 MNIST classification 을 수행해보자. 
* TF 에서 BN 을 적용하는 걸 연습하는 용도. 
* MNIST 로 정확도를 테스트하기 어렵다면 다른 데이터셋도 구해서 적용해보자.

Discriminator of DCGAN:

![Discriminator of DCGAN](http://bamos.github.io/data/2016-08-09/discrim-architecture.png)

In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
def weight_init(shape):
    return tf.truncated_normal(shape, stddev=0.1)

def bias_init(shape):
    return tf.constant(0.1, shape=shape)

In [4]:
# 참조한 image completion 코드에서는 다른 식으로 구현하는데, 그게 더 빠른가?
# 특이하게 구현함. https://github.com/bamos/dcgan-completion.tensorflow/blob/master/ops.py
def lrelu(x, leak=0.2):
    return tf.maximum(x, x*leak)

In [5]:
# 일단 MNIST datset 은 28x28x1 이므로, 
X = tf.placeholder(tf.float32, shape=[None, 784])
Y = tf.placeholder(tf.float32, shape=[None, 10])

# reshape for CNN
X_img = tf.reshape(X, [-1, 28, 28, 1])

# first conv layer: 
W1 = tf.Variable(weight_init([5, 5, 1, 64]))
b1 = tf.Variable(bias_init([64]))

a1 = tf.nn.conv2d(X_img, W1, strides=[1, 2, 2, 1], padding='SAME') + b1
h1 = lrelu(a1)

In [12]:
h1

<tf.Tensor 'Maximum:0' shape=(?, 14, 14, 64) dtype=float32>

In [13]:
W2 = tf.Variable(weight_init([5, 5, 64, 128]))
b2 = tf.Variable(bias_init([128]))

a2 = tf.nn.conv2d(h1, W2, strides=[1, 2, 2, 1], padding='SAME') + b2
h2 = lrelu(a2)

In [14]:
h2

<tf.Tensor 'Maximum_1:0' shape=(?, 7, 7, 128) dtype=float32>

In [22]:
W3 = tf.Variable(weight_init([5, 5, 128, 256]))
b3 = tf.Variable(bias_init([256]))

a3 = tf.nn.conv2d(h2, W3, strides=[1, 2, 2, 1], padding='SAME') + b3
h3 = lrelu(a3)

In [24]:
h3

<tf.Tensor 'Maximum_3:0' shape=(?, 4, 4, 256) dtype=float32>

In [26]:
# FC layer. 원래 더 해야 하지만 이건 MNIST 니까 여기까지만 하자.
# 원래는 sigmoid 로 0/1 만 판별하는데, MNIST 는 10개니까 softmax 로 해야함

W4 = tf.Variable(weight_init([4096, 10]))
b4 = tf.Variable(bias_init([10]))
h3_flat = tf.reshape(h3, [-1, 4096])

# last layer activation = logit
logits = tf.matmul(h3_flat, W4) + b4
y_prob = tf.nn.softmax(logits)

In [28]:
loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y)
solver = tf.train.AdamOptimizer().minimize(loss)

In [60]:
pred = tf.argmax(logits, axis=1)
correction = tf.equal(pred, tf.argmax(Y, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction, "float"))

In [27]:
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [70]:
batch_size = 100
total_batch = mnist.train.num_examples / batch_size

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for epoch in range(10):
    loss_sum = 0
    for i in range(total_batch / 10):
        batch = mnist.train.next_batch(batch_size)
        
        loss_cur, _ = sess.run([loss, solver], feed_dict={X: batch[0], Y: batch[1]})
        loss_sum += np.average(loss_cur)
    
#     train_loss = np.average(loss_cur)
    train_loss = loss_sum / total_batch
    test_batch = mnist.test.next_batch(10000)
    test_loss = np.average(sess.run(loss, feed_dict={X: test_batch[0], Y: test_batch[1]}))
    
    train_acc = sess.run(accuracy, {X:batch[0], Y:batch[1]})
    test_acc = sess.run(accuracy, {X:test_batch[0], Y:test_batch[1]})
    print("[{:3}] train: {:.5f} / test: {:.5f} | [acc] train: {:.4f} / test: {:.4f}"
          .format(epoch+1, train_loss, test_loss, train_acc, test_acc))

[  1] train: 0.21900 / test: 0.37174 | [acc] train: 0.8900 / test: 0.8946
[  2] train: 0.03331 / test: 0.31657 | [acc] train: 0.9100 / test: 0.9061
[  3] train: 0.02625 / test: 0.22063 | [acc] train: 0.9200 / test: 0.9349
[  4] train: 0.02256 / test: 0.18872 | [acc] train: 0.9300 / test: 0.9415
[  5] train: 0.01758 / test: 0.16194 | [acc] train: 0.9200 / test: 0.9521
[  6] train: 0.01824 / test: 0.14636 | [acc] train: 0.9500 / test: 0.9557
[  7] train: 0.01499 / test: 0.14213 | [acc] train: 0.9500 / test: 0.9554
[  8] train: 0.01541 / test: 0.10929 | [acc] train: 0.9800 / test: 0.9651
[  9] train: 0.01411 / test: 0.12993 | [acc] train: 0.9700 / test: 0.9584
[ 10] train: 0.01096 / test: 0.08974 | [acc] train: 0.9900 / test: 0.9733
