In [1]:
import cifar10_input
import numpy as np
import tensorflow as tf
import time
import os
from tensorflow.python.ops import control_flow_ops

In [2]:
cifar10_input.maybe_download_and_extract()

In [3]:
# Hyperparams
# Architecture
N_HIDDEN_1 = 256
N_HIDDEN_2 = 256
ETA = 0.01
EPOCHS = 1000
BATCH_SIZE = 128
DISPLAY_STEP = 1

In [4]:
def inputs(eval_data=True):
    data_dir = os.path.join('data/cifar10_data', 'cifar-10-batches-bin')
    return cifar10_input.inputs(
        eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE)

In [5]:
def distorted_inputs():
    data_dir = os.path.join('data/cifar10_data', 'cifar-10-batches-bin')
    return cifar10_input.distorted_inputs(data_dir=data_dir, 
                                          batch_size=BATCH_SIZE)

In [6]:
def batch_norm(x, n_out, phase_train, layer_type):
    beta_init  = tf.constant_initializer(value=0.0, dtype=tf.float32)
    gamma_init = tf.constant_initializer(value=1.0, dtype=tf.float32)
    beta  = tf.get_variable('beta',  [n_out], initializer=beta_init)
    gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
    axes = [0, 1, 2] if layer_type == 'conv' else [0]
    batch_mean, batch_var = tf.nn.moments(x, axes, name='moments')
    ema = tf.train.ExponentialMovingAverage(decay=0.9)
    ema_apply_op = ema.apply([batch_mean, batch_var])
    
    def  mean_var_with_update():
        with tf.control_dependencies([ema_apply_op]):
            return tf.identity(batch_mean), tf.identity(batch_var)
    
    ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
    mean, var = control_flow_ops.cond(
        phase_train, mean_var_with_update, lambda: (ema_mean, ema_var))
    
    if layer_type != 'conv':
        x = tf.reshape(x, [-1, 1, 1, n_out])
    normed = tf.nn.batch_norm_with_global_normalization(
        x, mean, var, beta, gamma, 1e-3, True)
    
    if layer_type != 'conv':
        normed = tf.reshape(normed, [-1, n_out])
    return normed

In [7]:
def conv_batch_norm(x, n_out, phase_train):    
    return batch_norm(x, n_out, phase_train, 'conv')


def layer_batch_norm(x, n_out, phase_train):
    return batch_norm(x, n_out, phase_train, 'fully_connected')

In [8]:
def filter_summary(V, weight_shape):
    ix = weight_shape[0] # ??? not used
    iy = weight_shape[1] # ???
    cx, cy = 8, 8        # ???
    V_T = tf.transpose(V, (3, 0, 1, 2)) # magic numbers!
    tf.summary.image('filters', V_T, max_outputs=64)

In [9]:
def conv2d(input, weight_shape, bias_shape, phase_train, visualize=False):
    incoming = weight_shape[0] * weight_shape[1] * weight_shape[2]
    weight_init = tf.random_normal_initializer(
        stddev=(2. / incoming) ** 0.5)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable('W', weight_shape, initializer=weight_init)
    b = tf.get_variable('b', bias_shape, initializer=bias_init)
    logits = tf.nn.bias_add(
        tf.nn.conv2d(input, W, strides=[1, 1, 1, 1], padding='SAME'), 
        b)
    if visualize: filter_summary(W, weight_shape)
    return tf.nn.relu(conv_batch_norm(
        logits, weight_shape[3], phase_train))

In [10]:
def max_pool(input, k=2):
    return tf.nn.max_pool(
        input, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')

In [11]:
def layer(input, weight_shape, bias_shape, phase_train):
    weight_init = tf.random_normal_initializer(
        stddev=(2. / weight_shape[0]) ** 0.5)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable('W', weight_shape, initializer=weight_init)
    b = tf.get_variable('b', bias_shape, initializer=bias_init)
    logits = tf.matmul(input, W) + b
    return tf.nn.relu(layer_batch_norm(
        logits, weight_shape[1], phase_train))

In [12]:
def inference(x, keep_prob, phase_train):
    with tf.variable_scope('conv_1'):
        conv_1 = conv2d(
            x, [5, 5, 3, 64], [64], phase_train, visualize=True)
        pool_1 = max_pool(conv_1)
    with tf.variable_scope('conv_2'):
        conv_2 = conv2d(pool_1, [5, 5, 64, 64], [64], phase_train)
        pool_2 = max_pool(conv_2)
    with tf.variable_scope('fc_1'): # fully connected
        dim = 1
        for d in pool_2.get_shape()[1:].as_list():
            dim *= d
        pool_2_flat = tf.reshape(pool_2, [-1, dim])
        fc_1 = layer(pool_2_flat, [dim, 384], [384], phase_train)
        
        # Apply dropout
        fc_1_drop = tf.nn.dropout(fc_1, keep_prob)
    with tf.variable_scope('fc_2'):
        fc_2 = layer(fc_1_drop, [384, 192], [192], phase_train)
        
        # Apply dropout
        fc_2_drop = tf.nn.dropout(fc_2, keep_prob)
    with tf.variable_scope('output'):
        output = layer(fc_2_drop, [192, 10], [10], phase_train)
    return output

In [13]:
def loss(output, y):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=output, labels=tf.cast(y, tf.int64))
    loss = tf.reduce_mean(xentropy)
    return loss

In [14]:
def training(cost, global_step):
    tf.summary.scalar('cost', cost)
    optimizer = tf.train.AdamOptimizer(ETA)
    train_op = optimizer.minimize(cost, global_step=global_step)
    return train_op

In [15]:
def evaluate(output, y):
    correct_pred = tf.equal(tf.cast(tf.argmax(output, 1), dtype=tf.int32),
                            y)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    tf.summary.scalar('validation_error', 1. - accuracy)
    return accuracy

In [16]:
with tf.device('/gpu:0'):
    with tf.Graph().as_default():
        with tf.variable_scope('cifar_conv_batchnorm_model'):
            x = tf.placeholder('float', [None, 24, 24, 3])
            y = tf.placeholder('int32', [None])
            keep_prob = tf.placeholder(tf.float32)
            phase_train = tf.placeholder(tf.bool) # T=train, F=Valid/Test
            distorted_images, distorted_labels = distorted_inputs()
            val_images, val_labels = inputs()
            output = inference(x, keep_prob, phase_train)
            cost = loss(output, y)
            global_step = tf.Variable(
                0, name='global_step', trainable=False)
            train_op = training(cost, global_step)
            eval_op = evaluate(output, y)
            summary_op = tf.summary.merge_all()
            saver = tf.train.Saver()
            sess = tf.Session()
            summary_writer = tf.summary.FileWriter(
                'conv_cifar_batchnorm_logs', graph=sess.graph)
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            tf.train.start_queue_runners(sess=sess)
            
            # Train
            for epoch in range(EPOCHS):
                avg_cost = 0.
                total_batches = int(
                    cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / 
                    BATCH_SIZE)
                
                # Loop over batches
                for i in range(total_batches):
                    # Fit to batch
                    train_x, train_y = sess.run([distorted_images, 
                                                 distorted_labels])
                    _, new_cost = sess.run(
                        [train_op, cost], 
                        feed_dict={x: train_x, 
                                   y: train_y, 
                                   keep_prob: 0.5, 
                                   phase_train: True})
                    
                    # Compute avg loss
                    avg_cost += new_cost / total_batches
                    
                # Display logs per epoch step
                if epoch % DISPLAY_STEP == 0:
                    val_x, val_y = sess.run([val_images, val_labels])
                    accuracy = sess.run(
                        eval_op, 
                        feed_dict={x: val_x, 
                                   y: val_y, 
                                   keep_prob: 1., 
                                   phase_train: False})
                    print('Epoch: %04d\tCost: %.6f\tValidation error: %.6f'
                          %(epoch + 1, avg_cost, 1 - accuracy))
                    summary_str = sess.run(
                        summary_op, 
                        feed_dict={x: train_x, 
                                   y: train_y, 
                                   keep_prob: 1., 
                                   phase_train: False})
                    summary_writer.add_summary(summary_str, 
                                               sess.run(global_step))
                    saver.save(
                        sess, 
                        'conv_cifar_batchnorm_logs/model-checkpoint', 
                        global_step=global_step)
            print('Optimization finished!')
            
            val_x, val_y = sess.run([val_images, val_labels])
            accuracy = sess.run(
                eval_op, 
                feed_dict={x: val_x, 
                           y: val_y, 
                           keep_prob: 1., 
                           phase_train: False})
            print('Test Accuracy:', accuracy)

Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
Epoch: 0001	Cost: 1.660202	Validation error: 0.460938
Epoch: 0002	Cost: 1.313401	Validation error: 0.367188
Epoch: 0003	Cost: 1.174350	Validation error: 0.343750
Epoch: 0004	Cost: 1.085870	Validation error: 0.234375
Epoch: 0005	Cost: 1.029030	Validation error: 0.265625
Epoch: 0006	Cost: 0.998964	Validation error: 0.242188
Epoch: 0007	Cost: 0.957401	Validation error: 0.179688
Epoch: 0008	Cost: 0.934747	Validation error: 0.265625
Epoch: 0009	Cost: 0.910355	Validation error: 0.164062
Epoch: 0010	Cost: 0.897537	Validation error: 0.257812
Epoch: 0011	Cost: 0.870341	Validation error: 0.187500
Epoch: 0012	Cost: 0.851445	Validation error: 0.179688
Epoch: 0013	Cost: 0.847089	Validation error: 0.218750
Epoch: 0014	Cost: 0.815981	Validation error: 0.179688
Epoch: 0015	Cost: 0.815756	Validation error: 0.210938
Epoch: 0016	Cost: 0.801222	Validation error: 0.203125
Epoch: 0017	Cost: 0.794212	Validation erro

Epoch: 0151	Cost: 0.467442	Validation error: 0.179688
Epoch: 0152	Cost: 0.480242	Validation error: 0.070312
Epoch: 0153	Cost: 0.472149	Validation error: 0.125000
Epoch: 0154	Cost: 0.470930	Validation error: 0.187500
Epoch: 0155	Cost: 0.465767	Validation error: 0.117188
Epoch: 0156	Cost: 0.472933	Validation error: 0.125000
Epoch: 0157	Cost: 0.468192	Validation error: 0.148438
Epoch: 0158	Cost: 0.464804	Validation error: 0.125000
Epoch: 0159	Cost: 0.470045	Validation error: 0.242188
Epoch: 0160	Cost: 0.467862	Validation error: 0.171875
Epoch: 0161	Cost: 0.468586	Validation error: 0.132812
Epoch: 0162	Cost: 0.469608	Validation error: 0.187500
Epoch: 0163	Cost: 0.464095	Validation error: 0.195312
Epoch: 0164	Cost: 0.462428	Validation error: 0.156250
Epoch: 0165	Cost: 0.463708	Validation error: 0.125000
Epoch: 0166	Cost: 0.456648	Validation error: 0.125000
Epoch: 0167	Cost: 0.472357	Validation error: 0.125000
Epoch: 0168	Cost: 0.462747	Validation error: 0.164062
Epoch: 0169	Cost: 0.460080	V

KeyboardInterrupt: 