In [1]:
import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


## Load in data

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../Tensorflow-applications/MNIST_data//", one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../Tensorflow-applications/MNIST_data//train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../Tensorflow-applications/MNIST_data//train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../Tensorflow-applications/MNIST_data//t10k-images-idx3-ubyte.gz
Extracting ../Tensorflow-applications/MNIST_data//t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


## Settings

In [29]:
caps1_size = 8
caps2_size = 10 # actually number of capsules
pred_matrix_size = 16 # this is actually capsule size of 2
conv1_channels = 256
conv1_filter = 9
primaryCaps_channels = 32
primaryCaps_filter = 9
routing_rounds = 3
epsilon = 1e-7
learning_rate = 0.001

## Squash function

In [30]:
def squash(input_vector, axis):
    normalised_input = tf.reduce_sum(tf.square(input_vector), axis = axis, keepdims = True)
    scale = tf.divide(normalised_input, tf.add(normalised_input, 1.))
    vector = tf.divide(input_vector, tf.sqrt(tf.add(normalised_input, epsilon)))
    output = tf.multiply(scale, vector)
    
    return(output)

## Convolutional layer

In [31]:
def convolutional(input_data, conv_shape, stride_shape, relu=True):
    weights = tf.Variable(tf.truncated_normal(conv_shape, stddev=0.3), name='W')
    bias = tf.Variable(tf.truncated_normal([conv_shape[-1]], stddev=0.3), name='B')
    out_layer = tf.nn.conv2d(input_data, weights, stride_shape, padding = 'VALID')
    out_layer_bias = tf.add(out_layer, bias)
    
    if relu == True:
        out_layer_final = tf.nn.relu(out_layer_bias)
        return(out_layer_final)
    
    return(out_layer_bias)

## Primarycaps

In [36]:
def primarycaps(input_data, conv_shape, stride_shape, primaryCaps_channels, caps1_size, caps2_size, pose_size, batch):
    output = convolutional(input_data, conv_shape, stride_shape, relu=False)
    filter_size = output.get_shape().as_list()[1]
    caps1_raw = tf.reshape(output, [-1,filter_size*filter_size*primaryCaps_channels,caps1_size], name='caps1_raw')
    caps1_output = squash(caps1_raw, axis=-1)
    caps1_output_expand = tf.expand_dims(caps1_output, axis=-1)
    caps1_output_expand2 = tf.expand_dims(caps1_output_expand, axis=2)
    caps1_output_expand2_tiled = tf.tile(caps1_output_expand2, [1,1,caps2_size,1,1], name = 'caps1_out_tiled')
    
    weight_matrix = tf.Variable(tf.truncated_normal([filter_size*filter_size*primaryCaps_channels, caps2_size, pose_size, caps1_size], stddev=0.1), name='W_matrix')
    #weight_matrix_tiled = tf.tile(weight_matrix, [batch, 1, 1, 1, 1], name = 'W_matrix_tiled')
    #caps2_predicted = tf.matmul(weight_matrix_tiled, caps1_output_expand2_tiled, name='caps2_predicted')
    caps2_predicted = tf.einsum('abdc,iabcf->iabdf', weight_matrix, caps1_output_expand2_tiled)
    
    return(caps2_predicted)

## Routing by agreement

In [33]:
def routing_by_agreement(input_data, caps2_size, rounds, batch):
    raw_weights = tf.zeros([batch, input_data.get_shape().as_list()[1], caps2_size, 1, 1], name = 'raw_weights')
    
    for i in range(rounds):
        routing_weights = tf.nn.softmax(raw_weights, axis=2, name = 'routing_weights' + str(i))
        weighted_predictions = tf.multiply(routing_weights, input_data, name = 'weighted_predictions' + str(i))
        weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, name = 'weighted_sum' + str(i), keepdims = True)
        caps2_output = squash(weighted_sum, axis=-2)
    
        #caps2_output_tiled = tf.tile(caps2_output, [1, input_data.get_shape().as_list()[1], 1, 1, 1], name = 'caps2_output_tiled'+ str(i))
        #agreement = tf.matmul(input_data, caps2_output_tiled, transpose_a = True, name = 'agreement'+ str(i))
        agreement = tf.einsum('iabcd,ifbcd->iabcd', input_data, caps2_output)
        raw_weights = tf.add(raw_weights, agreement, name = 'raw_weights' + str(i))
        
    return(caps2_output)

## Graph

In [34]:
X = tf.placeholder(tf.float32, [None,784])
X_reshape = tf.reshape(X, shape = [-1,28,28,1])
y = tf.placeholder(tf.float32, [None,10])

In [37]:
conv1 = convolutional(X_reshape, [conv1_filter,conv1_filter,X_reshape.get_shape().as_list()[-1],
                                  conv1_channels],[1,1,1,1])
primary = primarycaps(conv1, [primaryCaps_filter,primaryCaps_filter,conv1.get_shape().as_list()[-1],
                              primaryCaps_channels*caps1_size], [1,2,2,1],primaryCaps_channels,
                      caps1_size,caps2_size, pred_matrix_size, batch=tf.shape(X)[0])
output = routing_by_agreement(primary, caps2_size, routing_rounds, batch=tf.shape(X)[0])

## DigitCaps

In [38]:
squared_norm_caps2 = tf.reduce_sum(tf.square(output), axis=-2, keepdims = True, name='caps2_norm')
caps2_activation = tf.sqrt(tf.add(squared_norm_caps2, epsilon), name = 'caps2_activation')

y_estimate = tf.squeeze(caps2_activation, axis=[1,3,4], name = 'y_pred')
y_pred = tf.nn.softmax(y_estimate, axis=1)

## Loss function

In [39]:
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits = y_estimate, labels=y))
optimiser = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cross_entropy)

## Make prediction

In [40]:
correct_prediction = tf.equal(tf.argmax(y,axis=1), tf.argmax(y_pred, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype = tf.float32))

## Setup training

In [41]:
init_op = tf.global_variables_initializer()
epochs = 10
batch_size = 50

#If run on AWS
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allocator_type = 'BFC'
config.gpu_options.allow_growth=True


with tf.Session(config=config) as sess:
    init_op.run()

    total_batch = int(len(mnist.train.images)/batch_size)
    total_batch_validation = int(len(mnist.validation.images)/batch_size)

    for epoch in range(epochs):
        avg_cost = 0
        for i in range(total_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
            a, c= sess.run([optimiser, cross_entropy], feed_dict={X: batch_x, y: batch_y})
            avg_cost += c/total_batch
            print('batch:', i)
        print("Epoch:", (epoch + 1), "cost =", "{:.3f}".format(avg_cost))
        
        acc_vals = []
        for iterations in range(total_batch_validation):
            batch_x, batch_y = mnist.validation.next_batch(batch_size=batch_size)
            val_acc = sess.run(accuracy, feed_dict={X: batch_x, y: batch_y})
            acc_vals.append(val_acc)
        acc_val = np.mean(acc_vals)
        print('val_acc:', acc_val)

batch: 0
batch: 1
batch: 2
batch: 3
batch: 4
batch: 5


KeyboardInterrupt: 