In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import time

import pandas as pd
import numpy as np
from tqdm import tqdm

import tensorflow as tf
import tensorflow.contrib.slim as slim

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", reshape=False)

In [None]:
img_size = (28, 28)
n_classes = 10

imgs_train = mnist.train.images.reshape((-1, 1, *img_size)) / 255
imgs_val = mnist.validation.images.reshape((-1, 1, *img_size)) / 255
imgs_test = mnist.test.images.reshape((-1, 1, *img_size)) / 255

y_train = mnist.train.labels.astype(np.int32)
y_val = mnist.validation.labels.astype(np.int32)
y_test = mnist.test.labels.astype(np.int32)

def make_batch_iter(x, y, batch_size, shuffle=False):
    n = len(x)
    idx = np.arange(n)
    if shuffle:
        np.random.shuffle(idx)
    
    for i in range(0, n, batch_size):
        x_batch = x[i:i+batch_size]
        y_batch = y[i:i+batch_size]
        yield np.array(x_batch, np.float32), np.array(y_batch)

In [None]:
def cba(t, n_chans, kernel_size, name, training,
        activation=tf.identity):
    with tf.variable_scope(name):
        t = tf.layers.conv2d(
            t, n_chans, kernel_size, use_bias=False, 
            data_format='channels_first')
        #t = tf.layers.batch_normalization(t, axis=1, training=training, fused=True)
        t = slim.batch_norm(
            t, decay=0.9, scale=True, is_training=training,
            data_format='NCHW', fused=True)
        return activation(t)

def dba(t, units, name, training, activation=tf.identity):
    with tf.variable_scope(name):
        t = tf.layers.dropout(t, training=training)
        t = tf.layers.dense(t, units, use_bias=False)
        #t = tf.layers.batch_normalization(t, axis=1, training=training)
        t = slim.batch_norm(
            t, decay=0.9, scale=True, is_training=training, fused=True)
        return activation(t)
    

def build_cnn(t, blocks, kernel_size, training, name):
    with tf.variable_scope(name):
        for i, n_chans in enumerate(blocks):
            t = cba(t, n_chans, kernel_size, activation=tf.nn.relu,
                    name='cba_{}'.format(i), training=training)
            t = tf.layers.max_pooling2d(t, 2, 2, data_format='channels_first')
        t = tf.reduce_mean(t, axis=(2, 3), name='global_average_pooling')
    return t


def build_mlp(t, blocks, training, name):
    with tf.variable_scope(name):
        for i, units in enumerate(blocks[:-1]):
            t = dba(t, units, activation=tf.nn.relu,
                    name='fc_{}'.format(i), training=training)
        
        t = dba(t, blocks[-1], name='logits', training=training)
    return t

class Model:
    def __init__(self, cnn_blocks, mpl_blocks, kernel_size):
        self.x_ph = tf.placeholder(tf.float32, (None, 1, None, None))
        self.training = tf.placeholder(tf.bool)
        self.y_ph = tf.placeholder(tf.int64, (None,))

        with tf.variable_scope('model'):
            self.cnn = build_cnn(self.x_ph, cnn_blocks, kernel_size,
                                 self.training, 'cnn')
            self.logits = build_mlp(self.cnn, mlp_blocks,
                                    self.training, 'mlp')
        
        self.loss = tf.losses.sparse_softmax_cross_entropy(self.y_ph, self.logits)
        self.acc = tf.contrib.metrics.accuracy(tf.argmax(self.logits, 1), self.y_ph)
        
        with tf.name_scope('summary'):
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('acc', self.acc)
            self.summary = tf.summary.merge_all()
        
        self.train_op = slim.learning.create_train_op(self.loss, tf.train.AdamOptimizer())

cnn_blocks = [32, 32, 64]
mlp_blocks = [1024, 1024, 10]
kernel_size = 3

tf.reset_default_graph()


model = Model(cnn_blocks, mlp_blocks, kernel_size)

In [None]:
n_epoch = 500
batch_size = 1024

graph = tf.get_default_graph()
file_writer = tf.summary.FileWriter('./log/new2', graph=graph, flush_secs=10)

k = 0
with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())
    for epoch in range(n_epoch):
        train_iter = make_batch_iter(imgs_train, y_train, batch_size=batch_size)
        val_iter = make_batch_iter(imgs_val, y_val, batch_size=batch_size)

        start = time.time()

        losses = []
        weights = []
        accs = []
        for x_batch, y_batch in train_iter:
            feed_dict = {model.x_ph: x_batch, model.y_ph: y_batch, model.training: True}
            _, loss, acc, summary = session.run([model.train_op, model.loss, model.acc, model.summary], feed_dict)
            file_writer.add_summary(summary, k)
            k += 1

            accs.append(acc)
            losses.append(loss)
            weights.append(len(x_batch))

        train_loss = np.average(np.array(losses).flatten(), weights=weights)
        train_acc = np.average(np.array(accs).flatten(), weights=weights)

        losses = []
        weights = []
        accs = []
        for x_batch, y_batch in val_iter:
            feed_dict = {model.x_ph: x_batch, model.y_ph: y_batch, model.training: False}
            loss, acc = session.run([model.loss, model.acc], feed_dict)

            accs.append(acc)
            losses.append(loss)
            weights.append(len(x_batch))

        end = time.time()

        val_loss = np.average(np.array(losses).flatten(), weights=weights)
        val_acc = np.average(np.array(accs).flatten(), weights=weights)

        print('Epoch {}'.format(epoch))
        print('Train:', train_loss, train_acc)
        print('Val  :', val_loss, val_acc)
        print('Time :', end - start)
        print('\n')

In [None]:
!ls log