In [None]:
import tensorflow as tf
import numpy as np
from mnist_dataset import ImageAndLabelSet
from os import path
import matplotlib.pyplot as plt

In [None]:
img_w, img_h = 28, 28
mnist_data_root = path.expanduser("~/coding/tensorflow_sandbox/data/")

In [None]:
fake_input_batch = np.random.uniform(low=0, high=255, size=(32, 28, 28, 1))
fake_input_gt = np.random.choice(10, size=(32))
fake_input_gt_oh = np.zeros(shape=(32, 10), dtype=np.float32)
fake_input_gt_oh[np.arange(32), fake_input_gt] = 1

In [None]:
def res_block(input_tensor: tf.Tensor, in_filters: int, out_filters: int, stride=1) -> tf.Tensor:
    # FIXME - may want to apply batch-norm here
    t = tf.nn.relu(input_tensor)
    
    # Downsample for the skip connection
    # FIXME - should this start close to an identity transformation?
    kernel_ds = tf.Variable(tf.truncated_normal(shape=(1, 1, in_filters, out_filters), dtype=tf.float32, stddev=1e-4))
    t_downsampled = tf.nn.conv2d(t, kernel_ds, strides=(1, stride, stride, 1), padding='SAME')
    
    kernel_c1 = tf.Variable(tf.truncated_normal(shape=(3, 3, in_filters, out_filters), dtype=tf.float32, stddev=1e-4))
    t = tf.nn.conv2d(t, kernel_c1, strides=(1, stride, stride, 1), padding='SAME')
    # FIXME - should also apply batch norm here to match Chen et al.
    t = tf.nn.relu(t)
    
    kernel_c2 = tf.Variable(tf.truncated_normal(shape=(3, 3, out_filters, out_filters), dtype=tf.float32, stddev=1e-4))
    t = tf.nn.conv2d(t, kernel_c2, strides=(1, 1, 1, 1), padding='SAME')
    
    # Skip connection
    # N.B. No activation function. This is the sum of the output of two convolutions (one 3x3, the other 1x1).
    return t + t_downsampled


In [None]:
# Test that our residual block works gives the correct shape of outputs
test_ph1 = tf.placeholder(shape=(None, 28, 28, 1), dtype=tf.float32)
test_out1 = res_block(test_ph1, in_filters=1, out_filters=64, stride=2)
print('Output shape should be [None, 14, 14, 64]: {}'.format(test_out1.shape))

test_ph2 = tf.placeholder(shape=(None, 28, 28, 64), dtype=tf.float32)
test_out2 = res_block(test_ph2, in_filters=64, out_filters=64, stride=1)
print('Output shape should be [None, 28, 28, 64]: {}'.format(test_out2.shape))

# Check that these will actually run
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(test_out1, feed_dict={test_ph1: fake_input_batch})
    sess.run(test_out2, feed_dict={test_ph2: np.tile(fake_input_batch, reps=(1, 1, 1, 64))}) # Tile to fake channels

In [None]:
def downsample_net(input_tensor: tf.Tensor, n_filters: int) -> tf.Tensor:
    # Convolution with 3x3 kernels
    # N.B. Chen et al. do use biases in this conv.
    kernel = tf.Variable(tf.truncated_normal(shape=(3, 3, 1, n_filters), stddev=1e-5, dtype=tf.float32))
    biases = tf.Variable(tf.zeros(shape=(1, 1, 1, n_filters), dtype=tf.float32))
    conv_output = tf.nn.conv2d(input_tensor, kernel, strides=(1, 1, 1, 1), padding='SAME') + biases
    
    return res_block(res_block(conv_output, in_filters=n_filters, out_filters=n_filters, stride=2), in_filters=n_filters, out_filters=n_filters, stride=2)

In [None]:
test_ph = tf.placeholder(shape=(None, 28, 28, 1), dtype=tf.float32)
test_out = downsample_net(test_ph, n_filters=64)

print('Output shape should be [None, 7, 7, 64]: {}'.format(test_out.shape))

# Check that this will actually run
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(test_out, feed_dict={test_ph: fake_input_batch})

In [None]:
class EWMA:
    
    def __init__(self, decay_lambda: float=0.1):
        self._decay_lambda = decay_lambda
        self._running_ewma = None
    
    def update(self, x: float):
        
        if self._running_ewma is None:
            self._running_ewma = x
        else:
            self._running_ewma *= 1.0 - self._decay_lambda
            self._running_ewma += self._decay_lambda * x
    
    def get(self):
        return self._running_ewma

In [None]:
class FCNet:
    
    def __init__(self):
        
        # Define the network
        self._input_img_batch = tf.placeholder(shape=(None, img_h, img_w, 1), dtype=tf.float32)
        
        downsampled_img = downsample_net(self._input_img_batch, n_filters=64)
        dim_ds_img = (img_h//4)*(img_w//4)*64
        
        flattened_ds_img = tf.reshape(downsampled_img, shape=(-1, dim_ds_img))
        flattened_act = tf.nn.relu(flattened_ds_img)
        self._flattened_act = flattened_act
        
        W_fc = tf.Variable(tf.truncated_normal(shape=(dim_ds_img, 10), dtype=tf.float32, stddev=1e-4), name='W_fc')
        b_fc = tf.Variable(tf.zeros(shape=(10), dtype=tf.float32))
        self._output_logits = tf.matmul(flattened_act, W_fc)
        
        self._W_fc = W_fc # Save a reference for debugging printing
        
        # Define a loss for training
        self._input_gt_oh = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        self.ce_loss = tf.losses.softmax_cross_entropy(logits=self._output_logits, onehot_labels=self._input_gt_oh)
        
#         self._train_step = tf.train.MomentumOptimizer(learning_rate=1e-2, momentum=0.9).minimize(self.ce_loss)
        self._train_step = tf.train.GradientDescentOptimizer(learning_rate=1e-2).minimize(self.ce_loss)
    
        self._sess = tf.Session()
        self._sess.run(tf.global_variables_initializer())
        
        self._running_ce = EWMA(decay_lambda=1.0)
        self._ce_history = []
    
    def train_batch(self, input_images, input_gt_oh):
        
        _, ce_this_batch = self._sess.run([self._train_step, self.ce_loss],
                                          feed_dict={self._input_img_batch: input_images,
                                                     self._input_gt_oh: input_gt_oh})
        
#         logits_this_batch = self._sess.run(self._output_logits,
#                                            feed_dict={self._input_img_batch: input_images,
#                                                       self._input_gt_oh: input_gt_oh})
#         print(logits_this_batch)
        
        # This shows that these activations are tiny, hence gradients are tiny!
#         flattened_act_this_batch = self._sess.run(self._flattened_act,
#                                            feed_dict={self._input_img_batch: input_images,
#                                                       self._input_gt_oh: input_gt_oh})
#         print(flattened_act_this_batch)
        
        self._running_ce.update(ce_this_batch)
        self._ce_history.append(self._running_ce.get())

In [None]:
image_and_label_set = ImageAndLabelSet(path.join(mnist_data_root, 'train-images-idx3-ubyte'),
                                       path.join(mnist_data_root, 'train-labels-idx1-ubyte') )
fc_net = FCNet()

for b in range(1000):
    image_batch, label_batch = image_and_label_set.getNextBatch(batchSize=32)
    image_batch = np.reshape(image_batch, (-1, img_h, img_w, 1)) # Add a channel dimension
    
    fc_net.train_batch(image_batch, label_batch)
    
    if b % 100 == 0:
        # FIXME - obvious from these prints that the train_step isn't actually updating anything!
        # In particular, W_fc is *not* changing!
        # Problem is that flattened actviations are tiny. So gradients are tiny!
        print('batch {}: CE {}'.format(b, fc_net._running_ce.get()))
#         print(fc_net._sess.run(fc_net._W_fc))

plt.plot(fc_net._ce_history)
plt.show()