In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

In [None]:
# data loading in appropriate formate

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
caps1_n_maps = 32
caps1_n_caps = caps1_n_maps * 6 * 6  # 1152 primary capsules
caps1_n_dims = 8

In [None]:
X=x_train[:1]

In [None]:
#primary capsule.
c1=tf.keras.layers.Conv2D(256,kernel_size=9,strides=1,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(caps1_n_maps * caps1_n_dims,kernel_size=9,strides=2,padding='valid',activation='relu')

In [None]:
print(c1(X).shape)
print(c2(c1(X)).shape)

(1, 20, 20, 256)
(1, 6, 6, 256)


In [None]:
z=c2(c1(X))

In [None]:
z.shape

TensorShape([1, 6, 6, 256])

In [None]:
z=tf.reshape(z,[-1, caps1_n_caps, caps1_n_dims])

In [None]:

z.shape

TensorShape([1, 1152, 8])

In [None]:
# primary capsule layer.
def squash(v,epsilon=1e-7,axis=-1):
    sqnrm=tf.reduce_sum(tf.square(v), axis=axis,keepdims=True)
    nrm=tf.sqrt(sqnrm + epsilon) #safe norm to avoid divide by zero.
    sqsh_factor = sqnrm / (1. + sqnrm)
    unit_vect = v / nrm
    return sqsh_factor*unit_vect

    
def primary_capsule(input_tensor):
    c1=tf.keras.layers.Conv2D(256,kernel_size=9,strides=1,padding='valid',activation='relu')
    c2=tf.keras.layers.Conv2D(caps1_n_maps * caps1_n_dims,kernel_size=9,strides=2,padding='valid',activation='relu')
    z=c2(c1(X))
    z=tf.reshape(z,[-1, caps1_n_caps, caps1_n_dims])
    return squash(z)




In [None]:
pri_out=primary_capsule(X)
primary_capsule(X).shape #output shape of primary capsule.

TensorShape([1, 1152, 8])

In [None]:
# digit capsule layer
caps2_n_caps = 10 # 10 capsule each digit.
caps2_n_dims = 16 # each of the 10 capsules are of 16 dims.


In [None]:
""" Note primary capsule layer and digit capsule layer is fully connected ."""

In [None]:
init_sigma = 0.1

W_init = tf.random.normal(
    shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims),
    stddev=init_sigma, dtype=tf.float32)
W = tf.Variable(W_init)

In [None]:
W.shape

TensorShape([1, 1152, 10, 16, 8])

In [None]:
batch_size = tf.shape(X)[0]
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1])

In [None]:
W_tiled.shape

TensorShape([1, 1152, 10, 16, 8])

In [None]:
pri_out.shape

In [None]:
caps1_output_expanded = tf.expand_dims(pri_out, -1)
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2)
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1])

In [None]:
print(pri_out.shape)
print(caps1_output_expanded.shape)
print(caps1_output_tile.shape)
print(caps1_output_tiled.shape)

(1, 1152, 8)
(1, 1152, 8, 1)
(1, 1152, 1, 8, 1)
(1, 1152, 10, 8, 1)


In [None]:
caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled)

In [None]:
caps2_predicted.shape

TensorShape([1, 1152, 10, 16, 1])

In [None]:
""" Routing by agreement. """


In [None]:
raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1])

In [None]:
raw_weights.shape

TensorShape([1, 1152, 10, 1, 1])

In [None]:
routing_weights = tf.nn.softmax(raw_weights,axis=2)
weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)

In [None]:
weighted_sum.shape

TensorShape([1, 1, 10, 16, 1])

In [None]:
print(routing_weights.shape)
print(weighted_predictions.shape)
print(weighted_sum.shape)

(1, 1152, 10, 1, 1)
(1, 1152, 10, 16, 1)
(1, 1, 10, 16, 1)


In [None]:
print(routing_weights.shape)
print(weighted_predictions.shape)
print(weighted_sum.shape)

(1, 1152, 10, 1, 1)
(1, 1152, 10, 16, 1)
(1, 1, 10, 16, 1)


In [None]:
v=squash(weighted_sum, axis=-2)
print(v.shape)

(1, 1, 10, 16, 1)


In [None]:
v_tiled = tf.tile(v, [1, caps1_n_caps, 1, 1, 1])
v_tiled.shape

TensorShape([1, 1152, 10, 16, 1])

In [None]:
agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True)

In [None]:
agreement.shape

TensorShape([1, 1152, 10, 1, 1])

In [None]:
def Routing(caps2_predicted,r=3):
    raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1])

    while(r):
      r-=1
      routing_weights = tf.nn.softmax(raw_weights,axis=2)
      weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
      weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)
      v = squash(weighted_sum, axis=-2)
      v_tiled = tf.tile(v, [1, caps1_n_caps, 1, 1, 1])
      agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True)
      if(r>0):
          routing_weights+=agreement
      else:
          return v
      





In [None]:
caps2_output=Routing(caps2_predicted)
caps2_output.shape

TensorShape([1, 1, 10, 16, 1])

In [None]:
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [None]:
y_proba = safe_norm(caps2_output, axis=-2)

In [None]:
tf.reduce_sum(y_proba,axis=2)

<tf.Tensor: shape=(1, 1, 1), dtype=float32, numpy=array([[[0.00319418]]], dtype=float32)>

In [None]:
y_proba

<tf.Tensor: shape=(1, 1, 10, 1), dtype=float32, numpy=
array([[[[0.0003206 ],
         [0.0003218 ],
         [0.000317  ],
         [0.00032217],
         [0.00031829],
         [0.00031771],
         [0.00032117],
         [0.0003185 ],
         [0.0003177 ],
         [0.00031925]]]], dtype=float32)>

In [None]:
tf.nn.softmax(y_proba,axis=2)

<tf.Tensor: shape=(1, 1, 10, 1), dtype=float32, numpy=
array([[[[0.10000012],
         [0.10000024],
         [0.09999976],
         [0.10000028],
         [0.09999989],
         [0.09999983],
         [0.10000017],
         [0.09999991],
         [0.09999983],
         [0.09999999]]]], dtype=float32)>

In [None]:
y_proba_argmax = tf.argmax(y_proba, axis=2)

In [None]:
y_proba_argmax

<tf.Tensor: shape=(1, 1, 1), dtype=int64, numpy=array([[[3]]])>

In [None]:
y_pred = tf.squeeze(y_proba_argmax, axis=[1,2], name="y_pred")

In [None]:
y_pred

<tf.Tensor: shape=(1,), dtype=int64, numpy=array([3])>