In [215]:
import tensorflow as tf
import numpy as np
import datetime

In [32]:
W_exp = tf.tile(tf.expand_dims(W, axis=0), [32, 1, 1, 1, 1])
X_exp = tf.tile(tf.expand_dims(tf.expand_dims(X, axis=-1), axis=2), [1, 1, 10, 1, 1])

Y_matmul = tf.matmul(W_exp, X_exp)

print(W_exp, X_exp, Y_matmul, sep='\n')

%%timeit
with tf.Session() as sess:
    Y.eval()
    
%%timeit
with tf.Session() as sess:
    Y_matmul.eval()

Tensor("random_normal_4:0", shape=(1152, 10, 16, 8), dtype=float32)
Tensor("random_normal_5:0", shape=(32, 1152, 8), dtype=float32)
Tensor("einsum/transpose_2:0", shape=(32, 1152, 10, 16), dtype=float32)


## Load *notMNIST* dataset

In [231]:


# Load data
data = np.load('notMNIST.npz')

# Get length of data
N = len(data['labels'])

# Create shuffle index
shuffle_idx = np.arange(N)
np.random.shuffle(shuffle_idx)


# Assign images and labels
X = np.expand_dims(data['images'][shuffle_idx] / 255, axis=-1) # Normalise X values
Y = data['labels'][shuffle_idx]

# Split dataset
split_ratio = [0.8, 0.1, 0.1]
split_indices = (N * np.cumsum(split_ratio)[:-1]).astype(int)

X_train, X_dev, X_test = np.split(X, split_indices)
Y_train, Y_dev, Y_test = np.split(Y, split_indices)

print('Shape of X train, dev and test: ', X_train.shape, X_dev.shape, X_test.shape, sep='\n')
print('Shape of Y train, dev and test: ', Y_train.shape, Y_dev.shape, Y_test.shape, sep='\n')

Shape of X train, dev and test: 
(14979, 28, 28, 1)
(1872, 28, 28, 1)
(1873, 28, 28, 1)
Shape of Y train, dev and test: 
(14979,)
(1872,)
(1873,)


### Building Graph

In [260]:
def squash(v, axis=-1, name=None):
    '''
    Nonlinear squashing function for capsule outputs.
    
    Inputs:
        v: A tf tensor
        axis: Axis on which to perform squashing
        
    Returns:
        A tensor of the same shape
    '''
    with tf.name_scope(name, default_name='squashed'):
        # For numerical stability
        epsilon = 1e-6
        safe_norm = tf.sqrt(tf.reduce_sum(v**2, axis=axis, keepdims=True) + epsilon)

        squashed = safe_norm**2 / (1 + safe_norm**2) / safe_norm * v

        return squashed

def routing_by_agreement(input_caps, num_input_caps, num_output_caps, dim_input_caps, dim_output_caps, num_loops=3, name=None):
    with tf.variable_scope(name, default_name='routing_by_agreement'):
        # Create boolean flag for supposedly faster computations
        optimize_comp = False

        # Infer batch size
        batch_size = tf.shape(input_caps)[0]

        # Initialise logits to zero
        logits = tf.get_variable('logits', [num_input_caps, num_output_caps], initializer=tf.zeros_initializer())

        # Define weights
        W = tf.get_variable('W_caps', [num_input_caps, num_output_caps, dim_output_caps, dim_input_caps],
                            initializer=tf.truncated_normal_initializer(stddev=0.1))

        # Calculate predictions by lower capsules
        preds_by_inputs = tf.einsum('abcd,iad->iabc', W, input_caps, name='preds_by_inputs')

        def condition(preds, logits, counter, num_loops, squashed_preds):
            return tf.less(counter, num_loops)

        def loop(preds, logits, counter, num_loops, squashed_preds):
            # Compute routing weights
            routing_w = tf.nn.softmax(logits, axis=-1, name='routing_weights')

            # Compute mean prediction
            mean_preds = tf.einsum('ab,iabc->ibc', routing_w, preds, name='mean_unscaled_pred')

            # Squash mean prediction
            squashed_preds = squash(mean_preds, axis=-1, name='mean_scaled_pred')

            # Calculate agreement and update logits
            logits += tf.einsum('iabc,ibc->ab', preds_by_inputs, squashed_preds, name='update_to_logits')

            return preds, logits, tf.add(counter, 1), num_loops, squashed_preds

        with tf.name_scope('loop'):
            counter = tf.constant(0)
            output_caps = tf.while_loop(condition, loop, 
                                        [preds_by_inputs, logits, counter, 3, 
                                         tf.random_normal([batch_size, num_output_caps, dim_output_caps])])[-1]

        return output_caps
    
def vector_lengths(v, axis=-1, name=None):
    '''
    Returns the L2-norm of `v` calculated at `axis`.
    
    Inputs:
        v: A tensor of shape (B x n_caps x dim_caps)
        
    Returns:
        A tensor of shape (B x n_caps) containing lengths
    '''
    with tf.name_scope('lengths'):
        # For numerical stability
        epsilon = 1e-6

        return tf.sqrt(tf.reduce_sum(v**2, axis=axis))
    
def decoder_network(encodings, images, n_output, n_units=[512, 1024], name=None):
    '''
    Create a decoder network as regularization.
    Network consist of 2 dense layers and one softmax layer.
    '''
    with tf.name_scope(name, default_name='Decoder'):
        # Flatten encodings if not flattened yet
        batch_size = tf.shape(encodings)[0]
        flattened_enc = tf.reshape(encodings, [batch_size, -1], name='flattened_encs')

        # Flatten images
        flattened_images = tf.reshape(images, [batch_size, -1], name='flattened_imgs')

        # Construct decoder hidden layers
        layers = []

        for (i, hidden_units) in enumerate(n_units):
            layers.append(tf.layers.dense(encodings if i == 0 else layers[-1], 
                                          hidden_units, 
                                          activation=tf.nn.relu, 
                                          name='decoder{}'.format(i)))

        # Construct output layer
        output = tf.layers.dense(layers[-1], n_output, activation=tf.nn.sigmoid, name='decoder_output')
        
        # Calculate reconstruction loss
        with tf.name_scope('ReconstructionLoss'):
            reconstruction_loss = tf.reduce_sum(tf.sqrt((output - flattened_images)**2))

        return output, reconstruction_loss
    
def calc_margin_loss(caps_lengths, labels, m_plus=0.9, m_minus=0.1, lambda_=0.5, name=None):
    '''
    Calculate the margin loss as per Hinton's paper. 
    '''
    with tf.name_scope(name, default_name='MarginLoss'):
        # Create one-hot encoding of labels
        target_one_hot = tf.one_hot(tf.cast(labels, tf.int32), depth=tf.shape(caps_lengths)[1], name='target_ohe')

        # Calculate margin loss
        part1 = target_one_hot * (tf.maximum(tf.constant(0.), m_plus - caps_lengths))**2
        part2 = lambda_ * (1 - target_one_hot) * (tf.maximum(tf.constant(0.), caps_lengths - m_minus))**2
        
        margin_loss = tf.reduce_mean(part1 + part2, name='margin_loss')
        
        return margin_loss
    
# def evaluate_tensors(tensors, data, target, batch, batch_size=32, mask_with_labels=True):
#     '''
#     Wrapper to perform sess.run
#     '''
#     results = sess.run(tensors, feed_dict={
#         _X: data[(batch * batch_size):((batch + 1) * batch_size)],
#         _Y: target[(batch * batch_size):((batch + 1) * batch_size)],
#         mask_with_labels: mask_with_labels
#     })
    
#     return results

In [266]:
caps_dim = [8, 16]
alpha = 0.0005

conv1_params = {
    'filters': 256,
    'kernel_size': 9,
    'activation': tf.nn.relu,
}

conv2_params = {
    'filters': 256,
    'kernel_size': 9,
    'strides': 2,
}

tf.reset_default_graph()

with tf.name_scope('Placeholders'):
    _X = tf.placeholder(tf.float32, [None, 28, 28, 1], name='Images')
    _Y = tf.placeholder(tf.int64, [None], name='Labels')

layer1 = tf.layers.conv2d(_X, name='Conv1', **conv1_params)

with tf.name_scope('PrimaryCaps'):
    conv2 = tf.layers.conv2d(layer1, name='Conv2', **conv2_params)
    batch_size = tf.shape(_X)[0]
    print(_X, conv2, sep='\n')
    conv2_reshaped = tf.reshape(conv2, [batch_size, -1, 8], name='conv2_reshaped')
    caps1 = squash(conv2_reshaped)
    
with tf.name_scope('CharCaps'):
    caps2 = routing_by_agreement(caps1, num_input_caps=1152, num_output_caps=10, dim_input_caps=8, dim_output_caps=16, name='caps2')
    
with tf.name_scope('Lengths'):
    caps_lengths = vector_lengths(caps2)
    
with tf.name_scope('Predictions'):
    preds = tf.argmax(caps_lengths, axis=-1, name='predictions')
    
with tf.name_scope('Accuracy'):
    acc = tf.reduce_mean(tf.cast(tf.equal(preds, _Y), tf.float32))
      
with tf.name_scope('MarginLoss'):
    margin_loss = calc_margin_loss(caps_lengths, _Y)

with tf.name_scope('ReconstructionLoss'):
    # Create Boolean for masking choice
    mask_with_labels = tf.placeholder_with_default(False, shape=[], name='mask_with_labels')

    # Create reconstruction masks
    reconstruction_targets = tf.cond(mask_with_labels, lambda: _Y, lambda: preds, name='reconstruction_targets')
    reconstruction_masks = tf.one_hot(reconstruction_targets, depth=10, name='reconstruction_masks')

    # Mask encodings
    masked_encodings = tf.einsum('ibc,ib->ic', caps2, reconstruction_masks)

    # Feed encodings through decoder network
    decoder_output, reconstruction_loss = decoder_network(masked_encodings, _X, 784)

with tf.name_scope('total_loss'):
    total_loss = tf.add(margin_loss, alpha * reconstruction_loss, name='total_loss')

    # Add summaries
    tf.summary.scalar('margin_loss', margin_loss)
    tf.summary.scalar('reconstruction_loss', reconstruction_loss)
    tf.summary.scalar('total_loss', total_loss)
    tf.summary.scalar('accuracy', acc)
    
opt = tf.train.AdamOptimizer().minimize(total_loss)

Tensor("Placeholders/Images:0", shape=(?, 28, 28, 1), dtype=float32)
Tensor("PrimaryCaps/Conv2/BiasAdd:0", shape=(?, 6, 6, 256), dtype=float32)


In [None]:
epochs = 10
initial_batch_size = 32
early_stopping = True

early_stop = {'loss': 1e8, 'epoch': 0}

NOW = datetime.datetime.now().strftime('%b %d, %Y/%I.%M%p')
log_path = 'Logs/{}'.format(NOW)

# Create shuffle index
train_shuffle_idx = np.arange(len(X_train))
dev_shuffle_idx = np.arange(len(X_dev))

with tf.Session() as sess:    
    
    # Create summary writers
    train_writer = tf.summary.FileWriter('{}/train'.format(log_path), tf.get_default_graph())
    dev_writer = tf.summary.FileWriter('{}/dev'.format(log_path))
    
    # Obtain handle to all summaries
    merged = tf.summary.merge_all()
    
    tf.global_variables_initializer().run()
    
    # Calculate initial summaries
    train_summary = sess.run(merged, feed_dict={
        _X: X_train[:initial_batch_size],
        _Y: Y_train[:initial_batch_size],
        mask_with_labels: True
    })
    train_writer.add_summary(train_summary, epoch + 1)
    
    dev_summary = sess.run(merged, feed_dict={
        _X: X_dev[:initial_batch_size],
        _Y: Y_dev[:initial_batch_size],
        mask_with_labels: True
    })
    dev_writer.add_summary(dev_summary, epoch + 1)
    
    for epoch in range(epochs):
        # Randomise shuffle index
        np.random.shuffle(shuffle_idx)
        
        for batch in range(X_train.shape[0] // initial_batch_size):
            # Train model
            _, loss = sess.run([opt, total_loss], feed_dict={
                _X: X_train[train_shuffle_idx][(batch * initial_batch_size):((batch + 1) * initial_batch_size)],
                _Y: Y_train[train_shuffle_idx][(batch * initial_batch_size):((batch + 1) * initial_batch_size)],
                mask_with_labels: True
            })
            
        # Log summaries and print status updates
        if (epoch + 1) % min(50, epochs // 10): # TO CHANGE
            if (epoch + 1) % min(1000, epochs // 5):
                print('Epoch {} complete'.format(epoch))
        
            # Add training summary
            train_summary = sess.run(merged, feed_dict={
                _X: X_train[train_shuffle_idx][:initial_batch_size],
                _Y: Y_train[train_shuffle_idx][:initial_batch_size],
                mask_with_labels: True
            })
            train_writer.add_summary(train_summary, epoch + 1)
            
            # Add validation summary
            dev_summary = sess.run(merged, feed_dict={
                _X: X_dev[dev_shuffle_idx][:initial_batch_size],
                _Y: Y_dev[dev_shuffle_idx][:initial_batch_size],
                mask_with_labels: True
            })
            dev_writer.add_summary(dev_summary, epoch + 1)
            
            
        # Check on validation set
        np.random.shuffle(dev_shuffle_idx)
        dev_loss = sess.run(total_loss, feed_dict={
            _X: X_dev[dev_shuffle_idx][:initial_batch_size],
            _Y: Y_dev[dev_shuffle_idx][:initial_batch_size],
            mask_with_labels: True
        })
        
        # Check for early stopping
        if dev_loss < early_stop['loss']:
            early_stop['loss'], early_stop['epoch'] = loss, epoch + 1
        elif epoch - early_stop['epoch'] >= epochs // 10:
            print('Terminated training due to early stopping.')
            break
    
    # Flush summaries
    train_summary.flush()
    dev_summary.flush()