In [65]:
import tensorflow as tf
from keras import layers, models
import numpy as np
from keras.utils import to_categorical
from keras import backend as K
from keras import initializers

In [70]:
#(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = tf.expand_dims(x_train, axis=-1)
y_train =  np.array(to_categorical(y_train.astype('float32')))

x_test = x_test.astype('float32') / 255.
x_test = tf.expand_dims(x_test, axis=-1)
y_test = np.array(to_categorical(y_test.astype('float32')))

num_classes = len(y_test[0])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [71]:
no_of_secondary_capsules = num_classes

params = {
    "no_of_conv_kernels": 256,
    "no_of_primary_capsules": 64,
    "no_of_secondary_capsules": num_classes,
    "primary_capsule_vector": 8,
    "secondary_capsule_vector": 16,
    "r":3,
}

In [56]:
def squash(s):
  s_norm = tf.norm(s, axis=-1, keepdims=True)
  return tf.square(s_norm)/(1 + tf.square(s_norm)) * s/(s_norm + K.epsilon())

def safe_norm(v, axis=-1):
    v_ = tf.reduce_sum(tf.square(v), axis = axis, keepdims=True)
    return tf.sqrt(v_ + K.epsilon())

def output_layer(v):
   return tf.reshape(safe_norm(v), [-1, no_of_secondary_capsules])

def loss_fn(y_true, y_pred):
 
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
 
    return K.mean(K.sum(L, 1))

In [73]:
input_shape = layers.Input(shape=(28,28,1))
convolution = layers.Conv2D(params['no_of_conv_kernels'], 9, activation='relu', kernel_initializer="he_normal")(input_shape)
primary_capsule = layers.Conv2D(params['no_of_primary_capsules'] * params['primary_capsule_vector'], [9,9], strides=[2,2])(convolution)

In [83]:
class CapsuleLayer(layers.Layer):
    def __init__(self, no_of_conv_kernels, no_of_primary_capsules, primary_capsule_vector, no_of_secondary_capsules, secondary_capsule_vector, r):
        super(CapsuleLayer, self).__init__()
        self.no_of_conv_kernels = no_of_conv_kernels
        self.no_of_primary_capsules = no_of_primary_capsules
        self.primary_capsule_vector = primary_capsule_vector
        self.no_of_secondary_capsules = no_of_secondary_capsules
        self.secondary_capsule_vector = secondary_capsule_vector
        self.r = r
        self.kernel_initializer = initializers.get('glorot_uniform')
        
        '''
        with tf.name_scope("Variables") as scope:
            self.convolution = tf.keras.layers.Conv2D(self.no_of_conv_kernels, 24, strides=[1,1], name='ConvolutionLayer', activation='relu')
            self.convolution2 = tf.keras.layers.MaxPooling2D()
            self.convolution3 = tf.keras.layers.Conv2D(self.no_of_conv_kernels, 9, strides=[1,1], activation='relu')
            self.convolution4 = tf.keras.layers.MaxPooling2D()
            self.convolution5 = tf.keras.layers.Conv2D(self.no_of_conv_kernels, 9, strides=[1,1], activation='relu')
            self.convolution6 = tf.keras.layers.MaxPooling2D()
            self.primary_capsule = tf.keras.layers.Conv2D(self.no_of_primary_capsules * self.primary_capsule_vector, [9,9], strides=[2,2], name="PrimaryCapsule")
            self.w = tf.Variable(tf.random_normal_initializer()(shape=[1, 64, self.no_of_secondary_capsules, self.secondary_capsule_vector, self.primary_capsule_vector]), dtype=tf.float32, name="PoseEstimation", trainable=True)
            self.dense_1 = tf.keras.layers.Dense(units = 512, activation='relu')
            self.dense_2 = tf.keras.layers.Dense(units = 1024, activation='relu')
            self.dense_3 = tf.keras.layers.Dense(units = 784, activation='sigmoid', dtype='float32')
            self.preconvolution = tf.keras.layers.Conv2D(3, [28,28], strides=[7, 7], name='ConvolutionLayer', activation='relu', input_shape=(180, 180, 3), padding='same')
        '''
        
    def build(self, input_shape):
        print(input_shape)
        self.w = tf.Variable(tf.random_normal_initializer()(shape=[1, self.no_of_primary_capsules * input_shape[1] * input_shape[2], self.no_of_secondary_capsules, self.secondary_capsule_vector, self.primary_capsule_vector]), dtype=tf.float32, name="PoseEstimation", trainable=True)
        self.b = tf.zeros((1, self.no_of_primary_capsules * input_shape[1] * input_shape[2], self.no_of_secondary_capsules, 1))
        self.built = True

    @tf.function
    def call(self, inputs):
        #input_x, y = inputs
        x = inputs
        #print('call: ', input_x.shape)

        #with tf.name_scope("Encoder") as scope:
            #input_x = self.preconvolution(input_x)

        # input_x.shape: (None, 28, 28, 1)
        # y.shape: (None, 10)
        
        #x = self.convolution(input_x) # x.shape: (None, 20, 20, 256)
        #x = self.convolution2(x)
        #x = self.convolution3(x)
        #x = self.convolution4(x)
        #x = self.convolution5(x)
        #x = self.convolution6(x)
        #x = self.primary_capsule(x) # x.shape: (None, 6, 6, 256)

        print('output x: ', x.shape)
        
        #with tf.name_scope("CapsuleFormation") as scope:
        #w = tf.Variable(tf.random_normal_initializer()(shape=[1, self.no_of_primary_capsules * x.shape[1] * x.shape[2], self.no_of_secondary_capsules, self.secondary_capsule_vector, self.primary_capsule_vector]), dtype=tf.float32, name="PoseEstimation", trainable=True)
        #print('w: ', w.shape)
        u = tf.reshape(x, (-1, self.no_of_primary_capsules * x.shape[1] * x.shape[2], 8)) # u.shape: (None, 1152, 8)
        u = tf.expand_dims(u, axis=-2) # u.shape: (None, 1152, 1, 8)
        u = tf.expand_dims(u, axis=-1) # u.shape: (None, 1152, 1, 8, 1)
        u_hat = tf.matmul(self.w, u) # u_hat.shape: (None, 1152, 10, 16, 1)
        u_hat = tf.squeeze(u_hat, [4]) # u_hat.shape: (None, 1152, 10, 16)

        
        #with tf.name_scope("DynamicRouting") as scope:
        #b = tf.zeros((x.shape[0], self.no_of_primary_capsules * x.shape[1] * x.shape[2], self.no_of_secondary_capsules, 1)) # b.shape: (None, 1152, 10, 1)
        for i in range(self.r): # self.r = 3
            c = tf.nn.softmax(self.b, axis=-2) # c.shape: (None, 1152, 10, 1)
            s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True) # s.shape: (None, 1, 10, 16)
            v = squash(s) # v.shape: (None, 1, 10, 16)
            agreement = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), [4]) # agreement.shape: (None, 1152, 10, 1)
            # Before matmul following intermediate shapes are present, they are not assigned to a variable but just for understanding the code.
            # u_hat.shape (Intermediate shape) : (None, 1152, 10, 16, 1)
            # v.shape (Intermediate shape): (None, 1, 10, 16, 1)
            # Since the first parameter of matmul is to be transposed its shape becomes:(None, 1152, 10, 1, 16)
            # Now matmul is performed in the last two dimensions, and others are broadcasted
            # Before squeezing we have an intermediate shape of (None, 1152, 10, 1, 1)
            self.b += agreement
        '''
        with tf.name_scope("Masking") as scope:
            y = tf.expand_dims(y, axis=-1) # y.shape: (None, 10, 1)
            y = tf.expand_dims(y, axis=1) # y.shape: (None, 1, 10, 1)
            mask = tf.cast(y, dtype=tf.float32) # mask.shape: (None, 1, 10, 1)
            v_masked = tf.multiply(mask, v) # v_masked.shape: (None, 1, 10, 16)
            
        with tf.name_scope("Reconstruction") as scope:
            v_ = tf.reshape(v_masked, [-1, self.no_of_secondary_capsules * self.secondary_capsule_vector]) # v_.shape: (None, 160)
            reconstructed_image = self.dense_1(v_) # reconstructed_image.shape: (None, 512)
            reconstructed_image = self.dense_2(reconstructed_image) # reconstructed_image.shape: (None, 1024)
            reconstructed_image = self.dense_3(reconstructed_image) # reconstructed_image.shape: (None, 784)
        '''
        
        print('v: ', v)
        #return v, reconstructed_image
        #v = safe_norm(v)
        #v = tf.reshape(v, [-1, no_of_secondary_capsules])
        return v
    @tf.function
    def predict_capsule_output(self, inputs):
        x = self.convolution(inputs) # x.shape: (None, 20, 20, 256)
        x = self.primary_capsule(x) # x.shape: (None, 6, 6, 256)
        
        with tf.name_scope("CapsuleFormation") as scope:
            u = tf.reshape(x, (-1, self.no_of_primary_capsules * x.shape[1] * x.shape[2], 8)) # u.shape: (None, 1152, 8)
            u = tf.expand_dims(u, axis=-2) # u.shape: (None, 1152, 1, 8)
            u = tf.expand_dims(u, axis=-1) # u.shape: (None, 1152, 1, 8, 1)
            u_hat = tf.matmul(self.w, u) # u_hat.shape: (None, 1152, 10, 16, 1)
            u_hat = tf.squeeze(u_hat, [4]) # u_hat.shape: (None, 1152, 10, 16)

        
        with tf.name_scope("DynamicRouting") as scope:
            b = tf.zeros((inputs.shape[0], 1152, self.no_of_secondary_capsules, 1)) # b.shape: (None, 1152, 10, 1)
            for i in range(self.r): # self.r = 3
                c = tf.nn.softmax(b, axis=-2) # c.shape: (None, 1152, 10, 1)
                s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True) # s.shape: (None, 1, 10, 16)
                v = self.squash(s) # v.shape: (None, 1, 10, 16)
                agreement = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), [4]) # agreement.shape: (None, 1152, 10, 1)
                # Before matmul following intermediate shapes are present, they are not assigned to a variable but just for understanding the code.
                # u_hat.shape (Intermediate shape) : (None, 1152, 10, 16, 1)
                # v.shape (Intermediate shape): (None, 1, 10, 16, 1)
                # Since the first parameter of matmul is to be transposed its shape becomes:(None, 1152, 10, 1, 16)
                # Now matmul is performed in the last two dimensions, and others are broadcasted
                # Before squeezing we have an intermediate shape of (None, 1152, 10, 1, 1)
                b += agreement
        return v

    @tf.function
    def regenerate_image(self, inputs):
        with tf.name_scope("Reconstruction") as scope:
            v_ = tf.reshape(inputs, [-1, self.no_of_secondary_capsules * self.secondary_capsule_vector]) # v_.shape: (None, 160)
            reconstructed_image = self.dense_1(v_) # reconstructed_image.shape: (None, 512)
            reconstructed_image = self.dense_2(reconstructed_image) # reconstructed_image.shape: (None, 1024)
            reconstructed_image = self.dense_3(reconstructed_image) # reconstructed_image.shape: (None, 784)
        return reconstructed_image

In [84]:
digit_caps = CapsuleLayer(**params)(primary_capsule)
outputs = layers.Lambda(output_layer)(digit_caps)

(None, 6, 6, 512)
output x:  (None, 6, 6, 512)
v:  Tensor("truediv_5:0", shape=(None, 1, 10, 16), dtype=float32)


In [85]:
model = models.Model(input_shape, outputs)
m = 128
epochs = 2
model.compile(optimizer='adam',loss=loss_fn,loss_weights = [1. ,0.0005],metrics=['accuracy'])
model.fit(x_train[:32], y_train[:32], epochs = epochs, validation_data = (x_test, y_test))

Epoch 1/2

KeyboardInterrupt: ignored