In [11]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

from scipy import random

import os
import time

In [12]:
@tf.function
def squash(x, axis=-1):
    s_squared_norm = tf.math.reduce_sum(tf.math.square(x), axis, keepdims=True) + keras.backend.epsilon()
    scale = tf.math.sqrt(s_squared_norm) / (1 + s_squared_norm)
    return scale * x

@tf.function
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return tf.math.reduce_sum((y_true * tf.math.square(tf.nn.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * tf.math.square(tf.nn.relu(y_pred - margin))), axis=-1)

#@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [13]:
class Capsule(keras.layers.Layer):
   

    def __init__(self,
                 num_capsule,
                 dim_capsule,
                 **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.caps_n = num_capsule
        self.caps_dim = dim_capsule

    def get_config(self):
        config = super().get_config().copy()
        config.update({
        'num_capsule':  self.caps_n,
        'dim_capsule' : self.caps_dim,    
        })
        return config

    def build(self, input_shape):
        
        self.R = self.add_weight(name='R',
                    shape=[1, input_shape[1],self.caps_n],
                    dtype=tf.float32,
                    initializer='glorot_uniform',
                    trainable=True)
        
        
    def call(self, input_tensor):
        batch_size = input_tensor.shape[0]
        n=input_tensor.shape[1]
        k=self.caps_n
        
        R_tiled = tf.tile(self.R,[batch_size,1,1])
        R_tiled = tf.nn.softmax(R_tiled,axis=1)
        R_tiled = tf.expand_dims(R_tiled,axis=-2)
        
       
        caps_i=tf.expand_dims(input_tensor,axis=-1)
        caps_i=tf.tile(caps_i,[1,1,1,k])
        
        caps_i=tf.multiply(caps_i,R_tiled)
        caps_i = tf.reduce_sum(caps_i, axis=1)
        caps_i = tf.transpose(caps_i,perm=[0,2,1])
        v=squash(caps_i)
        return v

    def compute_output_signature(self,input_shape):
      return tf.TensorSpec(shape=[input_shape[0],self.caps_n,self.caps_dim],dtype=tf.float32)

In [14]:
c1=tf.keras.layers.Conv2D(16,kernel_size=5,strides=1,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(32,kernel_size=9,strides=1,padding='valid',activation='relu')
bn1=tf.keras.layers.BatchNormalization()
bn2=tf.keras.layers.BatchNormalization()
last=Capsule(10,16)

In [15]:
model_input = keras.Input(shape=(28,28,1), batch_size=32)
x=c1(model_input)
x=bn1(x,training=True)
x=c2(x)
x=bn2(x,training=True)
x=tf.reshape(x,[-1,16*32,16])
x=last(x)
#x=tf.cast(x,tf.float32)
x=safe_norm(x, axis=2)
#x = tf.nn.softmax(x,axis=1)
model_output = x


In [16]:
model = keras.Model(model_input, model_output)

In [17]:
adam = tf.keras.optimizers.Adam(learning_rate=0.0001) 
model.compile(loss=margin_loss,
              optimizer=adam,
              metrics=tf.keras.metrics.CategoricalAccuracy())
model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(32, 28, 28, 1)]         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (32, 24, 24, 16)          416       
_________________________________________________________________
batch_normalization_2 (Batch (32, 24, 24, 16)          64        
_________________________________________________________________
conv2d_3 (Conv2D)            (32, 16, 16, 32)          41504     
_________________________________________________________________
batch_normalization_3 (Batch (32, 16, 16, 32)          128       
_________________________________________________________________
tf.reshape_1 (TFOpLambda)    (32, 512, 16)             0         
_________________________________________________________________
capsule_1 (Capsule)          (32, 10, 16)              5120

In [18]:
# data loading in appropriate formate

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

y_train=tf.keras.utils.to_categorical(y_train)
y_test=tf.keras.utils.to_categorical(y_test)

In [19]:
history=model.fit(x_train, y_train, batch_size=32,epochs=30,validation_split=0.2)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
