In [1]:
#Capsule Network routing twice without reconstruction

In [2]:
#Import dependencies
import tensorflow as tf   #For Machine Learning
import numpy as np        #For Mathematial Operations
from tensorflow.examples.tutorials.mnist import input_data       #Import MNIST data object

In [3]:
#Resetting default tensorflow session
tf.reset_default_graph()

In [4]:
#Session Object
sess = tf.InteractiveSession()

In [5]:
#MNIST dataset
MNIST = input_data.read_data_sets('MNIST_data/', one_hot = True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [6]:
n_epochs = 1        #Number of epochs to train
batch_size = 128    #Size of batch 

In [7]:
#Safer Implementation of squashing mentioned in Capsule Net paper. Safer because in case the arr becomes zero then there
#will be divide by zero error.
def squash(arr,name,axis=-1,keep_dims=False,epsilon=1e-7):
    with tf.name_scope(name):
        squared_norm = tf.reduce_sum(tf.square(arr), axis=axis, keep_dims=keep_dims, name = 'squared_norm')
        safe_norm = tf.sqrt(squared_norm + epsilon)
        unit_vector = arr/safe_norm
        safe_factor = safe_norm/(1.0 + safe_norm)
        return safe_factor*unit_vector
        

In [8]:
#Safe normalizer to get predicted probabilited for each class from secondary capsules
def norm(arr,name,axis=-1,epsilon=1e-7,keep_dims=False):
    with tf.name_scope(name):
        arr_sum = tf.reduce_sum(tf.square(arr),axis=axis,keep_dims=keep_dims,name=name+'sum')
        return tf.sqrt(arr_sum + epsilon,name=name)

In [9]:
#Placeholders for inputs
with tf.name_scope('placeholders'):
    X = tf.placeholder(dtype=tf.float32, shape=[None,784], name='X_placeholder')
    Y = tf.placeholder(dtype=tf.float32, shape=[None,10], name='Y_placeholder')

In [10]:
#The first convolutional layer
with tf.variable_scope('conv_1') as scope:
    #Reshape input to create images
    images = tf.reshape(X, shape=[-1,28,28,1], name='images')
    #256, 9x9x1 kernels
    kernels = tf.get_variable(name='kernels', shape=[9,9,1,256], dtype=tf.float32,
                             initializer = tf.truncated_normal_initializer())
    #Convolutional layer with VALID padding
    conv = tf.nn.conv2d(images, kernels, strides=[1,1,1,1], padding='VALID')
    #First conv layer output after relu
    conv1 = tf.nn.relu(conv, name=scope.name)

In [11]:
#Second Convolutional layer
with tf.variable_scope('conv_2') as scope:
    #256, 9x9x256 kernels
    kernels = tf.get_variable(name='kernels', shape=[9,9,256,256], dtype=tf.float32,
                             initializer=tf.truncated_normal_initializer())
    #Convolutional layer with strides at 2 and VALID padding
    conv = tf.nn.conv2d(conv1, kernels, strides=[1,2,2,1], padding='VALID')
    #Second conv layer output after relu 
    conv2 = tf.nn.relu(conv, name=scope.name)

In [12]:
#Primary capsules layer
with tf.variable_scope('primary_capsules') as scope:
    #Reshape conv2 layer to form the primary capsules shaped
    conv2_reshape = tf.reshape(conv2, shape=[-1,1152,8], name='conv2_reshape')
    #Since each capsule denotes probabilities in depth dimension, it cannot be greater than 1. So squash the output to get
    #primary capsules
    caps_1 = squash(conv2_reshape, keep_dims=True, name=scope.name)

In [13]:
#Weight definitions for secondary capsule computation
with tf.variable_scope('weights'):
    #Weight matrix initialization
    W_raw = tf.get_variable(name='W_raw', shape=[1,1152,10,16,8], dtype=tf.float32,
                           initializer=tf.random_normal_initializer(stddev=0.01))
    #Repeating weight matrix for the entire batch
    W_tiled = tf.tile(W_raw,[batch_size,1,1,1,1],name='W_tiled')
    #Expand dimensions of caps_1 
    caps_1_expanded = tf.expand_dims(caps_1, axis=-1, name='caps_1_expanded')
    #Create a tile of caps_1_expanded by expanding dimension for units in secondary capsules
    caps_1_tile = tf.expand_dims(caps_1_expanded, axis=2, name='caps_1_tile')
    #Repeating caps_1_tile for the each capsule in second layer
    caps_1_output = tf.tile(caps_1_tile,[1,1,10,1,1], name='caps_1_output')

In [14]:
#Computing secondary capsules
with tf.variable_scope('secondary_capsules'):
    caps_2 = tf.matmul(W_tiled,caps_1_output,name='caps_2')

In [15]:
#Round 1 of dynamc routing
with tf.name_scope('round_1'):
    #Raw weights for each capsule connection: 1152x10
    raw_weights_round1 = tf.zeros(shape=[batch_size,1152,10,1,1], dtype=tf.float32, name='raw_weights_round1')
    #Routing weights through softmax
    routing_weights_round1 = tf.nn.softmax(raw_weights_round1, dim=2, name='routing_weights_round1' )
    #Weighted predictions by secondary capsules
    weighted_preds_round1 = tf.multiply(routing_weights_round1,caps_2,name='weighted_preds_round1')
    #Weighted sum of weighted predictions
    weighted_sum_round1 = tf.reduce_mean(weighted_preds_round1,axis=1,keep_dims=True,name='weighted_sum_round1')
    #Since they are probabilites, we need to squash them to get secondary capsules output for round 1
    caps2_output_round1 = squash(weighted_sum_round1,axis=-2,keep_dims=True,name='caps2_output_round1')

In [16]:
#Computing agreement between primary and secondary capsules
with tf.variable_scope('agreement'):
    caps2_output_tiled = tf.tile(caps2_output_round1,[1,1152,1,1,1],name='caps2_output_tiled')
    agreement = tf.matmul(caps_2,caps2_output_tiled,transpose_a=True, name='agreement')

In [17]:
#Dynamic routing round 2
with tf.name_scope('round2'):
    #Raw weights for round 2 is addition of raw_weights for round 1 and agreement
    raw_weights_round2 = tf.add(raw_weights_round1,agreement,name='raw_weights_round2')
    #Computing routing weights for round 2
    routing_weights_round2 = tf.nn.softmax(raw_weights_round2,dim=2,name='routing_weights_round2')
    #Weighted predictions
    weighted_preds_round2 = tf.multiply(routing_weights_round2,caps_2,name='weighted_preds_round2')
    #Weighted sum
    weighted_sum_round2 = tf.reduce_sum(weighted_preds_round2,axis=1,keep_dims=True,name='weighted_sum_round2')
    #Secondary capsule output for round 2
    caps2_output_round2 = squash(weighted_sum_round2,axis=-2,keep_dims=True,name='caps2_output_round2')

In [18]:
caps2_output = caps2_output_round2

In [19]:
#Probabilites for each class in the output
y_probs = norm(caps2_output,name='y_probs',axis=-2)

In [20]:
#Reducing dimsensions to (batch_size,10)
y_probs_final = tf.squeeze(y_probs,[1,-1],name='y_probs_squeezed')

In [21]:
#Parameters in loss computation
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5

In [22]:
#One-hot vector of output labels for computing loss
T = Y

In [23]:
#Part one of margin loss
present_error = tf.square(tf.maximum(0.,m_plus - y_probs_final),name='present_error')

In [24]:
#Part two of margin loss
absent_error = tf.square(tf.maximum(0.,y_probs_final - m_minus),name='absent_error')

In [25]:
#margin loss for each sample in batch
margin_loss_raw = tf.add(T*present_error,lambda_*(1 - T)*absent_error, name='margin_loss_raw')

In [26]:
#Total batch loss
margin_loss = tf.reduce_mean(tf.reduce_sum(margin_loss_raw,axis=1),name='margin_loss')

In [27]:
#Optimizer for model
optimizer = tf.train.AdamOptimizer().minimize(margin_loss)

In [None]:
#The Process
sess.run(tf.global_variables_initializer())
#Graph writer for tensorboard
writer = tf.summary.FileWriter('/graphs',sess.graph)
#Saver object to save and restore
saver = tf.train.Saver()
#Fetch checkpoint if present
ckpt = tf.train.get_checkpoint_state(os.path.dirname('/checkpoints/caps_net/checkpoint'))
#Restore model if checkpoint present
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

#Training
n_batches = int(MNIST.train.num_examples/batch_size)
loss_sum = 0.0
start = time.time()
for i in range(n_epochs*n_batches):
    X_batch, Y_batch = MNIST.train.next_batch(batch_size)
    opt,loss = sess.run([optimizer, margin_loss], {X:X_batch, Y:Y_batch})
    
    if (i+1)%10==0:
        print( "Loss at step {}: {:5.1f} ".format(i+1, loss_sum/i))
        saver.save(sess, '/checkpoints/caps_net/checkpoint', i)
        loss_sum = 0.0
print('Optimization finished')
print('Time taken: {}'.format(time.time() - start))

#testing
total_correct_preds = 0.0
n_batches = int(MNIST.test.num_examples/batch_size)
for i in range(n_batches):
    X_batch, Y_batch = MNIST.test.next_batch(batch_size)
    opt, loss, y_preds = sess.run([optimizer, margin_loss, y_probs_final], {X:X_batch, Y_batch})
    correct_preds = tf.equal(tf.argmax(y_preds,1),tf.argmax(Y_batch,1))
    accuracy = tf.reduce_sum(tf.cast(correct_preds,tf.float32))
    total_correct_preds += sess.run(accuracy)
print("Accuracy = {}". format(total_correct_preds/MNIST.test.num_examples))
writer.close()