In [1]:
import numpy as np
import tensorflow as tf

  return f(*args, **kwds)


In [13]:
class CapsLayer:
    
    ''' 
    This class will implement a capsule layer. Mainly the forward step.
    As input it will take the number and dims of the capsules for 
    both layers.
    Additionally the number of routing iterations that should
    be used has to be set.
    In the call function the input tensor will be given.
    '''
    
    def __init__(self, count1, dim1, count2, dim2, rout_iter):
        
        # assigning the given parameters for the layers
        self.count1 = count1
        self.dim1 = dim1
        self.count2 = count2
        self.dim2 = dim2
        self.rout_iter = rout_iter
        
        
        
    def __call__(self, input):
        
        ''' 
        This function will receive the input to the CapsLayer
        and compute the output.
        The input is a 3-D Tensor with shape (batch_size, count1, dim1).
        '''
        
        #CHECK IF INPUT SHAPE MATCHES WITH COUNT1 AND DIM1
        #THROW ERROR OTHERWISE
        
        #get the batch size from the input
        self.batch_size = input.get_shape()[0]
        
       
        
        c1, d1 = self.count1, self.dim1
        c2, d2 = self.count2, self.dim2
        
        #creating the weight tensor
        self.weights = tf.Variable(
                         tf.truncated_normal( 
                           shape = [c1, c2, d1, d2],
                           stddev = 0.1
                         )
                       )
        
        
        #compute the prediction vectors . matmul weigts inputs
        prediction_vectors = predict_vectors(input)
               
        return routing(prediction_vectors)
    
    
    def predict_vectors(self, inp):
        
        '''
        Gets the weights and the input into the right dimension and 
        returns the matrix multiplication.
        '''
        
        bs = self.batch_size
        weights = self.weights
    
        # reshape the weights and input
        inp = tf.reshape(inp, shape = [bs, count1, 1, dim1, 1])
        inp = tf.tile(inp, multiples = [1, 1, count2, 1, 1])
        weights = tf.expand_dims(weights, axis=-1)                 
        weights = tf.tile(weights, multiples = [bs, 1, 1, 1, 1])
        
        prediction_vectors = tf.matmul(weights, inp, transpose_a=True)
        
        return tf.squeeze(prediction_vectors)
    
    
    
    def routing(self, prediction_vectors):
        
        '''
        Does the routing algorithm and outputs the capsules
        of the next layer.
        '''
        
        c1, d1 = self.count1, self.dim1
        c2, d2 = self.count2, self.dim2
        
        logits = tf.zeros(shape = [self.batch_size, c1, c2])
        
        for i in range(self.rout_iter):
            
            # compute the coupling coefficients
            coupling_coeffs = tf.nn.softmax(logits)
            
            # reshape coupling coefficients
            coupling_coeffs = tf.expand_dims(coupling_coeffs, axis=-1)
            coupling_coeffs = tf.tile(coupling_coeffs, [1, 1, 1, c2])
            
            # compute the input
            drive = tf.reduce_sum(
                      coupling_coeffs*prediction_vectors,
                      axis=1
                    )
            
            activation = squash(drive)
                                   
            
            # if it is not the last iteration comput the
            # agreement and update the logits
            if (i != self.rout_iter-1):
                
                # reshape activation
                activation = tf.expand_dims(activation, axis=1)
                activation = tf.tile(activation, multiples=[1, c1, 1, 1])
                
                # compute agreement
                agreement = tf.reduce_sum(prediction_vectors, 
                                          activation, axis=-1)
                
                # update the logits
                logits = logits + agreement
                
            else:
                return activation
            
            
    # squashing function from paper
    def squash(tensor):

        # tensor with same dimensions as the tensor with the length of the 
        # vector along the specified axis stored in every component of this
        # vector, norm is the euclidean norm here

        norm = tf.norm(tensor, keep_dims=True, axis=2)
        normed_tensor = tensor/norm

        squashing_factor = norm**2/(1+norm**2)

        return squashing_factor * normed_tensor

                                   
            
            
            
            
            
            
    
    
    
    
    
    
        
    