In [1]:
import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


## Squash function

In [None]:
def squash(input_vector, axis):
    normalised_input = tf.reduce_sum(tf.square(input_vector), axis = axis, keepdims = True)
    scale = tf.divide(normalised_input, tf.add(normalised_input, 1.))
    vector = tf.divide(input_vector, tf.sqrt(tf.add(normalised_input, epsilon)))
    output = tf.multiply(scale, vector)
    
    return(output)

## Convolutional layer

In [None]:
def convolutional(input_data, conv_shape, stride_shape, relu=True):
    weights = tf.Variable(tf.truncated_normal(conv_shape, stddev=0.3), name='W')
    bias = tf.Variable(tf.truncated_normal(conv_shape[3], stddev=0.3), name='B')
    out_layer = tf.nn.conv2d(input_data, weights, stride_shape, padding = 'VALID')
    out_layer_bias = tf.add(out_layer, bias)
    
    if relu == True:
        out_layer_final = tf.nn.relu(out_layer_bias)
        return(out_layer_final)
    
    return(out_layer_bias)

## Primarycaps

In [None]:
def primarycaps(input_data, conv_shape, stride_shape, caps1_shape, caps1_size, caps2_size, pose_size, batch):
    output = convolutional(input_data, conv_shape, stride_shape, relu=False)
    caps1_raw = tf.reshape(output, caps1_shape, name='caps1_raw')
    caps1_output = squash(caps1_raw, axis=-1)
    caps1_output_expand = tf.expand_dims(caps1_output, axis=-1)
    caps1_output_expand2 = tf.expand_dims(caps1_output_expand, axis=2)
    caps1_output_expand2_tiled = tf.tile(caps1_output_expand2, [1,1,caps2_size,1,1], name = 'caps1_out_tiled')
    
    weight_matrix = tf.Variable(tf.truncated_normal([1, 1152, caps2_size, pose_size, caps1_size], stddev=0.1), name='W_matrix')
    weight_matrix_tiled = tf.tile(weight_matrix, [batch, 1, 1, 1, 1], name = 'W_matrix_tiled')
    caps2_predicted = tf.matmul(weight_matrix_tiled, caps1_output_expand2_tiled, name='caps2_predicted')
    
    return(caps2_predicted)

## Routing by agreement

In [None]:
def routing_by_agreement(input_data, caps2_size, rounds):
    raw_weights = tf.zeros([batch, 1152, 10, 1, 1], name = 'raw_weights')
    
    for i in range(rounds):
        routing_weights = tf.nn.softmax(raw_weights, axis=2, name = 'routing_weights' + str(i))
        weighted_predictions = tf.multiply(routing_weights, caps2_predicted, name = 'weighted_predictions' + str(i))
        weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, name = 'weighted_sum' + str(i), keepdims = True)
        caps2_output = squash(weighted_sum, axis=-2)
    
        caps2_output_tiled = tf.tile(caps2_output, [1, 1152, 1, 1, 1], name = 'caps2_output_tiled'+ str(i))
        agreement = tf.matmul(caps2_predicted, caps2_output_tiled, transpose_a = True, name = 'agreement'+ str(i))
        raw_weights = tf.add(raw_weights, agreement, name = 'raw_weights' + str(i))
        
    return(caps2_output)