In [None]:
import tensorflow as tf
import numpy as np

In [None]:
from tensorflow.keras import layers, initializers, models, optimizers, callbacks
from tensorflow.keras import backend as K

In [None]:
def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

In [None]:
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsule, dim_vector, num_routing=3,
                kernel_initializer='glorot_uniform',
                bias_initializer='zeros',
                **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_vector = dim_vector
        self.num_routing = num_routing
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        
        
    def build(self, input_shape):
        super(CapsuleLayer, self).build(input_shape)

        self.input_num_capsule = input_shape[1]
        self.input_dim_vector = input_shape[2]
        
        # Transform matrix
        self.W = self.add_weight(shape=[self.input_num_capsule, self.num_capsule, self.input_dim_vector, self.dim_vector],
                                initializer=self.kernel_initializer,
                                name='W')
        
        self.bias = self.add_weight(shape=[1,self.input_num_capsule,self.num_capsule,1,1],
                                   initializer=self.bias_initializer,
                                   name='bias',
                                   trainable=False)
        self.built = True
                
    def call(self, inputs, training=None):
        # inputs.shape=[None,input_num_capsule,input_dim_vector]
        # Expand dims to [None,input_num_capsule,1,1,input_dim_vector]
        inputs_expand = K.expand_dims(K.expand_dims(inputs,2),2)
        
        # Replicate num_capsule dimension to prepare being multiplied by W
        # Now it has shape = [None,input_num_capsule,num_capsule,1,input_dim_vector]
        inputs_tiled = K.tile(inputs_expand, [1,1,self.num_capsule,1,1])
        
        # Begin: inputs_hat computation V2
        # Compute 'inputs * W' by scanning inputs_tiled on dimension 0. 
        # inputs_hat.shape = [None, input_num_capsule,num_capsule,1,dim_vector]
        inp = K.reshape(inputs_tiled,(-1, self.input_num_capsule*self.num_capsule,1,self.input_dim_vector))
        w = K.reshape(self.W, (self.input_num_capsule*self.num_capsule, self.input_dim_vector,self.dim_vector))
        inputs_hat = tf.scan(lambda ac, x: K.batch_dot(x, w, [2,1]),
                            elems=inp,
                            initializer=K.zeros([self.input_num_capsule*self.num_capsule,1,self.dim_vector]))
        inputs_hat = K.reshape(inputs_hat, (-1, self.input_num_capsule, self.num_capsule, 1, self.dim_vector))
        
        # Begin: routing algorithm V2
        # Routing alogrithm V2. Use iteration. V2 and V1 both work without much difference on performace
        assert self.num_routing > 0, 'The num_routing should be > 0'
        for i in range(self.num_routing):
            c = tf.nn.softmax(self.bias, axis=2)
            outputs = squash(K.sum(c*inputs_hat, 1, keepdims=True))
            
            # last iteration needs not compute the bias which will not be passed to the graph any more anyway.
            if i != self.num_routing - 1:
                (self.bias).assign_add(K.sum(inputs_hat*outputs, -1, keepdims=True))
                
        return K.reshape(outputs, [-1, self.num_capsule, self.dim_vector])
    
    def compute_output_shape(self,input_shape):
        return tuple([None, self.num_capsule, self.dim_vector])

In [None]:
class Length(layers.Layer):    
    def call(self, inputs, **kwargs):
        return K.sqrt(K.sum(K.square(inputs), -1))
    
    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

In [None]:
class Mask(layers.Layer):   
    def call(self, inputs, **kwargs):
        # use true label to select target capsule, shape=[batch_size, num_capsule]
        if type(inputs) is list: # true label is provided with shape=[batch_size, n_classes]
            assert len(inputs)==2
            inputs, mask = inputs
        else: # if no true label, mask by the max length of vectors of capsule
            x = inputs
            x = (x - K.max(x, 1, True)) / K.epsilon() + 1
            mask = K.clip(x, 0, 1)
        
        # masked inputs, shape=[batch_size, dim_vector]
        inputs_masked = K.batch_dot(inputs, mask, [1,1])
        return inputs_masked
    
    def computer_output_shape(self, input_shape):
        if type(input_shape[0]) is tupel: # true lable provided
            return tuple([None, input_shape[0][-1]])
        else:
            return tuple([None, input_shape[-1]])

In [None]:
def PrimaryCap(inputs, dim_vector, n_channels, kernel_size, strides, padding):
    output = layers.Conv2D(filters=dim_vector*n_channels, kernel_size=kernel_size,
                          strides=strides, padding=padding,
                          name='primaryCap_conv2d')(inputs)
    dim = output.shape[1]*output.shape[2]*output.shape[3]
    outputs = layers.Reshape(target_shape=(dim//dim_vector,dim_vector), name='primaryCap_reshape')(output)
    #outputs = layers.Reshape(target_shape=[-1,dim_vector], name='primaryCap_reshape')(output)
    return layers.Lambda(squash, name='primarycap_squash')(outputs)

In [None]:
from tensorflow.keras.utils import to_categorical
def CapsNet(input_shape, n_class, num_routing):
    x = layers.Input(shape=input_shape)
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid',
                         activation='relu', name='conv1')(x)
    primarycaps = PrimaryCap(conv1, dim_vector=8, n_channels=32, kernel_size=9, 
                            strides=2, padding='valid')
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_vector=16, num_routing=num_routing,
                            name='digitcaps')(primarycaps)
    out_caps = Length(name='out_caps')(digitcaps)
    
    # Decoder network
    y = layers.Input(shape=(n_class,))
    masked = Mask()([digitcaps, y])
    x_recon = layers.Dense(512, activation='relu')(masked)
    x_recon = layers.Dense(1024, activation='relu')(x_recon)
    x_recon = layers.Dense(np.prod(input_shape), activation='sigmoid')(x_recon)
    x_recon = layers.Reshape(target_shape=input_shape, name='out_recon')(x_recon)
    
    return models.Model([x,y], [out_caps, x_recon])

In [None]:
def margin_loss(y_true, y_pred):
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
    return K.mean(K.sum(L,1))

def train(model, data, lr=0.001, lam_recon=0.39, 
          batch_size=100, epochs=10):
    (x_train, y_train),(x_test, y_test) = data
    
    model.compile(optimizer=optimizers.Adam(lr=lr),
                 loss=[margin_loss, 'mse'],
                 loss_weights=[1., lam_recon],
                 metrics={'out_caps': 'accuracy'})
    model.fit([x_train, y_train], [y_train, x_train],
              batch_size=batch_size,
              epochs=epochs,
              validation_data=[[x_test,y_test],[y_test,x_test]])
    
def load_mnist():
    from tensorflow.keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.reshape(-1,28,28,1).astype('float32') / 255.
    x_test = x_test.reshape(-1,28,28,1).astype('float32') / 255.
    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))
    
    return (x_train, y_train),(x_test, y_test)

In [None]:
(x_train, y_train),(x_test, y_test) = load_mnist()
model = CapsNet(input_shape=[28,28,1], n_class=10, num_routing=3)
model.summary()
train(model=model,data=((x_train,y_train),(x_test,y_test)))