In [1]:
import torch
import torchvision.datasets as dset

In [2]:
mnist_dataset = dset.MNIST(root='./minst/', train=True, download=True)

print("Number of MNIST training samples: {} images".format(len(mnist_dataset)))

Number of MNIST training samples: 60000 images


In [6]:
import keras.backend as K
import tensorflow as tf
from keras import initializers, layers

class Length(layers.Layer):
    """
    Compute the length of vectors.
    
    Inputs: - Tensor [None, num_vectors, dim_vector]
    Ouputs: - Tensor [None, num_vectors]
    """
    
    def call(sefl, inputs, **kwargs):
        return K.sqrt(K.sum(inputs), -1)
    
    def compute_output_shape(self, input_shape):
        return input_shape[:-1]
    
class Mask(layers.Layer):
    """
    Mask a Tensor with shape=[None, num_capsule, dim_vector] by length or by an input masj
    """
    def call(self, inputs, **kwargs):
        # true label is provided with shappe=[None, n_classes]
        if type(inputs) is list:
            assert len(inputs) == 2
            inputs, mask = inputs
        else:
            x = K.sqrt(K.sum(K.square(inputs), -1))
            mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
            
        masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
        return masked
    
    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple: # true label
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:
            return tuple([None, input_shape[1] * input_shape[2]])
        
def squash(vectors, axis=-1):
    """
    Non-linear activation used within capsule. It drives the length of a large vector to near 1 and
    small vector to 0
    
    Args:
       vectors: - tensor - N_dim
       axis:    the axis to squash
       
    Returns:
        a tensor with same shape as input vectors
    """
    # Recall magitude of a vector : squared root of all the components to the power of two
    magnitude_squared =  K.sum(K.square(vectors), axis, keepdims=True)
    scaling_factor    =  magnitude_squared / (1 + magnitude_squared)
    vector_output     =  scaling_factor * (vectors / K.sqrt(magnitude_squared))
    return vector_output


Using TensorFlow backend.


In [32]:
class CapsuleLayer(layers.Layer):
    """
    Similar to Dense layer, except:
    
    inputs  = [None, num_capsules, input_dim_capsule] [Num, num_vectors, vec_dims]
    outputs = [None, num_capsules, dim_capsule]
    """
    def __init__(self, num_capsule, dim_capsule, num_routing=3,
                kernel_initializer='glorot_uniform',**kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.num_routing = num_routing
        self.kernel_initializer = initializers.get(kernel_initializer)
        
    def build(self, input_shape):
        assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]
        
        # Transform matrix
        self.w = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
                                        self.dim_capsule, self.input_dim+capsule],
                                initializer=self.kernel_initializer, name='w')
        self.built=True
    
    def call(self, inputs, training=None):
        inputs_expand = K.expand_dims(inputs, 1) # Why???
        
        # repilcate num_capsule dimension to prepare being multiplied by w
        intputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
        
        # Compute`inputs*W` by scaling inputs on dimension 0
        # Imaging first two dims as `batch` dims,
        # then matmul : [input_dim_capsules] x [dim_capsule, input_dim_capsule]^T 
        #           ---> [dim_capsule]
        inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled)
       
        inputs_hat_stopped = K.stop_gradient(inputs_hat)
        b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
        

        assert self.num_routing > 0, 'The num_routing should be > 0.'
        for i in range(self.num_routing):
            c = tf.nn.softmax(b, dim=1)

            # At last iteration, use `inputs_hat` to compute `outputs` in order to backpropagate gradient
            if i == self.num_routing - 1:
                # c.shape =  [batch_size, num_capsule, input_num_capsule]
                # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
                # The first two dimensions as `batch` dimension,
                # then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
                # outputs.shape=[None, num_capsule, dim_capsule]
                outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]))  # [None, 10, 16]
            else:  # Otherwise, use `inputs_hat_stopped` to update `b`. No gradients flow on this path.
                
                outputs = squash(K.batch_dot(c, inputs_hat_stopped, [2, 2]))

                # outputs.shape =  [None, num_capsule, dim_capsule]
                # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
                # The first two dimensions as `batch` dimension,
                # then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
                # b.shape=[batch_size, num_capsule, input_num_capsule]
                b += K.batch_dot(outputs, inputs_hat_stopped, [2, 3])
        # End: Routing algorithm -----------------------------------------------------------------------#

        return outputs
    
    def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_capsule])

    
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
    """
    Apply Conv2D `n_channels` times and concatenate all capsules
    :param inputs: 4D tensor, shape=[None, width, height, channels]
    :param dim_capsule: the dim of the output vector of capsule
    :param n_channels: the number of types of capsules
    :return: output tensor, shape=[None, num_capsule, dim_capsule]
    """
    output = layers.Conv2D(filters=dim_capsule*n_channels, 
                           kernel_size=kernel_size, 
                           strides=strides, 
                           padding=padding,
                           name='primarycap_conv2d')(inputs)
    
    outputs = layers.Reshape(target_shape=[-1, dim_capsule], 
                             name='primarycap_reshape')(output)
    
    return layers.Lambda(squash, name='primarycap_squash')(outputs)
