In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

In [None]:
# data loading in appropriate formate

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
caps1_n_maps = 32
caps1_n_caps = caps1_n_maps * 6 * 6  # 1152 primary capsules
caps1_n_dims = 8

In [None]:
X=x_train[:32]

In [None]:
#primary capsule.
c1=tf.keras.layers.Conv2D(256,kernel_size=9,strides=1,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(caps1_n_maps * caps1_n_dims,kernel_size=9,strides=2,padding='valid',activation='relu')

In [None]:
print(c1(X).shape)
print(c2(c1(X)).shape)

(32, 20, 20, 256)
(32, 6, 6, 256)


In [None]:
z=c2(c1(X))

In [None]:
z.shape

TensorShape([32, 6, 6, 256])

In [None]:
z=tf.reshape(z,[-1, caps1_n_caps, caps1_n_dims])

In [None]:
z.shape

TensorShape([32, 1152, 8])

In [None]:
# primary capsule layer.
def squash(v,epsilon=1e-7,axis=-1):
    sqnrm=tf.reduce_sum(tf.square(v), axis=axis,keepdims=True)
    nrm=tf.sqrt(sqnrm + epsilon) #safe norm to avoid divide by zero.
    sqsh_factor = sqnrm / (1. + sqnrm)
    unit_vect = v / nrm
    return sqsh_factor*unit_vect

    
def primary_capsule(input_tensor):
    c1=tf.keras.layers.Conv2D(256,kernel_size=9,strides=1,padding='valid',activation='relu')
    c2=tf.keras.layers.Conv2D(caps1_n_maps * caps1_n_dims,kernel_size=9,strides=2,padding='valid',activation='relu')
    z=c2(c1(X))
    z=tf.reshape(z,[-1, caps1_n_caps, caps1_n_dims])
    return squash(z)




In [None]:
pri_out=primary_capsule(X)
pri_out.shape #output shape of primary capsule.

TensorShape([32, 1152, 8])

In [None]:
# digit capsule layer
caps2_n_caps = 10 # 10 capsule each digit.
caps2_n_dims = 16 # each of the 10 capsules are of 16 dims.


In [None]:
""" Note primary capsule layer and digit capsule layer is fully connected ."""

In [None]:
init_sigma = 0.1

W_init = tf.random.normal(
    shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims),
    stddev=init_sigma, dtype=tf.float32)
W = tf.Variable(W_init)

In [None]:
W.shape

TensorShape([1, 1152, 10, 16, 8])

In [None]:
batch_size = tf.shape(X)[0]
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1])

In [None]:
W_tiled.shape

TensorShape([32, 1152, 10, 16, 8])

In [None]:
pri_out.shape

TensorShape([32, 1152, 8])

In [None]:
caps1_output_expanded = tf.expand_dims(pri_out, -1)
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2)
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1])

In [None]:
print(pri_out.shape)
print(caps1_output_expanded.shape)
print(caps1_output_tile.shape)
print(caps1_output_tiled.shape)

(32, 1152, 8)
(32, 1152, 8, 1)
(32, 1152, 1, 8, 1)
(32, 1152, 10, 8, 1)


In [None]:
caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled)

In [None]:
caps2_predicted.shape

TensorShape([32, 1152, 10, 16, 1])

In [None]:
""" Routing by agreement. """


In [None]:
raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1])

In [None]:
raw_weights.shape

TensorShape([32, 1152, 10, 1, 1])

In [None]:
routing_weights = tf.nn.softmax(raw_weights,axis=2)
weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)

In [None]:
print(routing_weights.shape)
print(weighted_predictions.shape)
print(weighted_sum.shape)

(32, 1152, 10, 1, 1)
(32, 1152, 10, 16, 1)
(32, 1, 10, 16, 1)


In [None]:
v=squash(weighted_sum, axis=-2)
print(v.shape)

(32, 1, 10, 16, 1)


In [None]:
v_tiled = tf.tile(v, [1, caps1_n_caps, 1, 1, 1])
v_tiled.shape

TensorShape([32, 1152, 10, 16, 1])

In [None]:
agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True)

In [None]:
agreement.shape

TensorShape([32, 1152, 10, 1, 1])

In [None]:
def Routing(caps2_predicted,r=3):
    raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1])

    while(r):
      r-=1
      routing_weights = tf.nn.softmax(raw_weights,axis=2)
      weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
      weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)
      v = squash(weighted_sum, axis=-2)
      v_tiled = tf.tile(v, [1, caps1_n_caps, 1, 1, 1])
      agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True)
      if(r>0):
          routing_weights+=agreement
      else:
          return v
      

In [None]:
caps2_output=Routing(caps2_predicted)
caps2_output.shape

TensorShape([32, 1, 10, 16, 1])

In [None]:
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [None]:
y_proba = safe_norm(caps2_output, axis=-2)
print(y_proba.shape)

(32, 1, 10, 1)


In [None]:
y=y_train[:32]
y_pred=y_train[:32]

In [None]:
reconstruction_targets = tf.cond(False, # condition
                                 lambda: y,        # if True
                                 lambda: y_pred,   # if False
                                 name="reconstruction_targets")

In [None]:
reconstruction_targets

array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0,
       9, 1, 1, 2, 4, 3, 2, 7, 3, 8], dtype=uint8)

In [None]:
reconstruction_mask = tf.one_hot(reconstruction_targets,
                                 depth=caps2_n_caps,
                                 name="reconstruction_mask")

In [None]:
reconstruction_mask

<tf.Tensor: shape=(32, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0.,

In [None]:
reconstruction_mask_reshaped = tf.reshape(
    reconstruction_mask, [-1, 1, caps2_n_caps, 1, 1],
    name="reconstruction_mask_reshaped")

In [None]:
reconstruction_mask_reshaped.shape

TensorShape([32, 1, 10, 1, 1])

In [None]:
caps2_output.shape

TensorShape([32, 1, 10, 16, 1])

In [None]:
caps2_output_masked = tf.multiply(
    caps2_output, reconstruction_mask_reshaped,
    name="caps2_output_masked")

In [None]:
caps2_output_masked.shape

TensorShape([32, 1, 10, 16, 1])

In [None]:
decoder_input = tf.reshape(caps2_output_masked,
                           [-1, caps2_n_caps * caps2_n_dims],
                           name="decoder_input")

In [None]:
decoder_input.shape

TensorShape([32, 160])

In [None]:
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28

In [None]:
Decoder=tf.keras.Sequential([
      keras.layers.Dense(n_hidden1, activation='relu'),
      keras.layers.Dense(n_hidden2, activation='relu'),
      keras.layers.Dense(n_output, activation='sigmoid'),
    ])

In [None]:
decoder_output=Decoder(decoder_input)

In [None]:
decoder_output.shape

TensorShape([32, 784])

In [None]:
Decoder.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (32, 512)                 82432     
                                                                 
 dense_1 (Dense)             (32, 1024)                525312    
                                                                 
 dense_2 (Dense)             (32, 784)                 803600    
                                                                 
Total params: 1,411,344
Trainable params: 1,411,344
Non-trainable params: 0
_________________________________________________________________


In [None]:
X_flat = tf.reshape(X, [-1, n_output], name="X_flat")
print(X_flat.shape)
squared_difference = tf.square(X_flat - decoder_output,
                               name="squared_difference")
print(squared_difference.shape)
reconstruction_loss = tf.reduce_mean(squared_difference,
                                    name="reconstruction_loss")
print(reconstruction_loss)

(32, 784)
(32, 784)
tf.Tensor(0.231478, shape=(), dtype=float32)


In [None]:
tmp=caps2_output

In [None]:
tmp = safe_norm(tmp, axis=-2) #x.shape=[batch_size,1,caps_n(i-1),1]
#z = tf.nn.softmax(z,axis=2) #converting those probabilities to prob dist.
print(tmp.shape)
tmp = tf.squeeze(tmp, axis=[1,3]) #reducing the extra dims. therefore the output shape =[batch_size,caps_n(i-1)] 
print(tmp.shape)

(32, 1, 10, 1)
(32, 10)


In [None]:
tf.one_hot(tf.argmax(tmp,axis=1),depth=caps2_n_caps)

<tf.Tensor: shape=(32, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0.,

In [None]:
alpha = 0.0005

loss = tf.add(margin_loss, alpha * reconstruction_loss, name="loss")