In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [22]:
import tensorflow_probability as tfp

In [2]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

In [23]:
from scipy import random

In [3]:
import os
import time

In [4]:
@tf.function
def squash(x, axis=-1):
    s_squared_norm = tf.math.reduce_sum(tf.math.square(x), axis, keepdims=True) + keras.backend.epsilon()
    scale = tf.math.sqrt(s_squared_norm) / (1 + s_squared_norm)
    return scale * x

@tf.function
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return tf.math.reduce_sum((y_true * tf.math.square(tf.nn.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * tf.math.square(tf.nn.relu(y_pred - margin))), axis=-1)

#@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [92]:
class Capsule(keras.layers.Layer):
   

    def __init__(self,
                 num_capsule,
                 dim_capsule,
                 routings=3,
                 **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.caps_n = num_capsule
        self.caps_dim = dim_capsule
        self.r = routings

    def get_config(self):
        config = super().get_config().copy()
        config.update({
        'num_capsule':  self.caps_n,
        'dim_capsule' : self.caps_dim,
        'routings':  self.r,      
        })
        return config

    def build(self, input_shape):

        batch_size = input_shape[0]
        n=input_shape[1]
        k=self.caps_n
        d=self.caps_dim

        self.W = self.add_weight(name='W',
                    shape=[1, input_shape[1], self.caps_n, self.caps_dim, input_shape[-1]],
                    dtype=tf.float64,
                    initializer='glorot_uniform',
                    trainable=True)
        
        #initialization step.
        init_mu = random.rand(batch_size,k, d)*20 - 10
        self.mu = init_mu #initializing mean.

        init_sigma = np.zeros((k, d, d))
        for i in range(k):
            init_sigma[i] = np.eye(d)
        sigma = init_sigma
        sigma=tf.expand_dims(sigma,axis=0)
        self.sigma=tf.tile(sigma,[batch_size,1,1,1]) # initializing cov matrix.

        init_pi = np.ones(k)/k
        pi = init_pi
        pi=tf.expand_dims(pi,axis=0)
        self.pi=tf.tile(pi,[batch_size,1])

        R=np.zeros(shape=(n,k))
        R=tf.expand_dims(R,axis=0)
        self.R=tf.tile(R,[batch_size,1,1]) # coupling coefficient.
        
    def call(self, input_tensor):
        assert input_tensor.shape[2]==self.caps_dim
        input_tensor=tf.cast(input_tensor,dtype=tf.float64)
        assert input_tensor.dtype==tf.float64
        batch_size = input_tensor.shape[0]
        n=input_tensor.shape[1]
        k=self.caps_n
        d=self.caps_dim
        
        W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1]) # replicating the weights for parallel processing of a batch.
        """ W_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i),caps_dim(i-1)] """

        caps_output_expanded = tf.expand_dims(input_tensor, -1) # converting last dim to a column vector.
        """ the above step change the input shape from 
            [batch_size,caps_n(i-1),caps_dim(i-1)] --> [batch_size,caps_n(i-1),caps_dim(i-1),1]"""

        caps_output_tile = tf.expand_dims(caps_output_expanded, 2)
        """ the above step change the input shape from 
            [batch_size,caps_n(i-1),caps_dim(i-1),1] --> [batch_size,caps_n(i-1),1,caps_dim(i-1),1]"""

        caps_output_tiled = tf.tile(caps_output_tile, [1, 1, self.caps_n, 1, 1]) # replicating the input capsule vector for every output capsule.
        """ i.e [batch_size,caps_n(i-1),1,caps_dim(i-1),1] --> [batch_size,caps_n(i-1),caps_n(i),1,caps_dim(i-1),1]"""

        caps_predicted = tf.matmul(W_tiled, caps_output_tiled) # this is performing element wise tf.matmul() operation.
        """ caps_predicted.shape = [1,caps_n(i-1),caps_n(i),caps_dim(i),1]"""

        """ dynamic routing """
        """#initialization step.
        init_mu = random.rand(batch_size,k, d)*20 - 10
        mu = init_mu #initializing mean.

        init_sigma = np.zeros((k, d, d))
        for i in range(k):
            init_sigma[i] = np.eye(d)
        sigma = init_sigma
        sigma=tf.expand_dims(sigma,axis=0)
        sigma=tf.tile(sigma,[batch_size,1,1,1]) # initializing cov matrix.

        init_pi = np.ones(k)/k
        pi = init_pi
        pi=tf.expand_dims(pi,axis=0)
        pi=tf.tile(pi,[batch_size,1])

        R=np.zeros(shape=(n,k))
        R=tf.expand_dims(R,axis=0)
        R=tf.tile(R,[batch_size,1,1]) # coupling coefficient."""

        pi=tf.Variable(self.pi,dtype=tf.float64)
        mu=tf.Variable(self.mu,dtype=tf.float64)
        sigma=tf.Variable(self.sigma,dtype=tf.float64)
        R=tf.Variable(self.R,dtype=tf.float64)

        #print(mu.shape,pi.shape,sigma.shape,R.shape)

        N=np.zeros((batch_size,n))
        N=tf.Variable(N,dtype=tf.float64)

        r=self.r
        while(r):
          r=r-1
          # E-step.
          for i in range(k):
              for b in range(batch_size):
                  tmp = tfp.distributions.MultivariateNormalFullCovariance(loc=mu[b][i],
                                                                        covariance_matrix=sigma[b][i]).prob(z[b])
                  N[b].assign(tmp)
              R[:,:,i].assign(tf.expand_dims(pi[:,i],axis=1)*N)
          R.assign(R/tf.reduce_sum(R,axis=2, keepdims=True))

          # M-step
          N_k=tf.reduce_sum(R,axis=1)
          pi=N_k/n
          mu=tf.matmul(tf.transpose(R,perm=[0,2,1]),z)
          mu=mu/N_k[:,:,None]

          for i in range(k):
              tmp=z-tf.expand_dims(mu[:,i,:],axis=1)
              tmp=tf.expand_dims(tmp,axis=-1)
              tmp_T=tf.transpose(tmp,perm=[0,1,3,2])
              res=tf.matmul(tmp,tmp_T)
              res=tf.multiply(tf.reshape(R[:,:,i],[batch_size,n,1,1]),res)
              res=tf.reduce_sum(res,axis=1)/tf.reshape(N_k[:,i],[batch_size,1,1])
              sigma[:,i].assign(res)
              
        weighted_prediction=tf.multiply(caps_predicted,tf.reshape(R,[batch_size,n,k,1,1]))
        weighted_sum = tf.reduce_sum(weighted_prediction, axis=1, keepdims=True)
        v=squash(weighted_sum, axis=-2)
        v = tf.squeeze(v, axis=[1,4])
        return v

In [93]:
c1=tf.keras.layers.Conv2D(16,kernel_size=3,strides=1,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(32,kernel_size=5,strides=2,padding='valid',activation='relu')

In [94]:
model=Capsule(32,8)

In [95]:
# data loading in appropriate formate

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float64")
x_test = x_test[..., tf.newaxis].astype("float64")

In [96]:
X=x_train[:32]

In [97]:
z=c2(c1(X))

In [98]:
z.shape

TensorShape([32, 11, 11, 32])

In [99]:
z=tf.reshape(z,(-1,484,8))

In [100]:
z=tf.cast(z,dtype=tf.float64)

In [101]:
z.shape

TensorShape([32, 484, 8])

In [102]:
model(z)

<tf.Tensor: shape=(32, 32, 8), dtype=float64, numpy=
array([[[ 2.60525010e-39, -1.10744637e-39,  5.17747581e-39, ...,
         -5.12111835e-39,  1.47361385e-39,  4.20327545e-39],
        [ 3.77290147e-15,  3.03763049e-15, -3.15913735e-15, ...,
         -2.56150151e-15, -2.39209760e-16,  4.92679079e-16],
        [-3.09397448e-38,  3.19553504e-38, -3.38924579e-38, ...,
         -3.66954949e-38,  1.93367930e-38,  9.97688436e-40],
        ...,
        [-8.91536481e-24, -1.25219929e-23, -2.38204248e-23, ...,
          4.82921722e-23,  5.76069843e-24, -3.93484628e-23],
        [-2.67297426e-44, -1.23796962e-44, -2.22491837e-44, ...,
          1.48530721e-44,  1.59411309e-44, -1.61663921e-44],
        [ 2.72908273e-14,  1.09342643e-13,  1.01090134e-13, ...,
          5.03776621e-14, -1.04412165e-13, -6.17895605e-14]],

       [[-1.43309715e-12, -5.12576378e-12,  1.47106051e-12, ...,
         -2.68715109e-12, -7.38666075e-13,  3.31803853e-12],
        [ 4.81911919e-19,  2.08407644e-19, -1.2166