In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import os
import time

In [2]:
@tf.function
def squash(x, axis=-1):
    s_squared_norm = tf.math.reduce_sum(tf.math.square(x), axis, keepdims=True) + keras.backend.epsilon()
    scale = tf.math.sqrt(s_squared_norm) / (1 + s_squared_norm)
    return scale * x

@tf.function
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return tf.math.reduce_sum((y_true * tf.math.square(tf.nn.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * tf.math.square(tf.nn.relu(y_pred - margin))), axis=-1)

#@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [3]:
n=128
d=8
k=3
input_tensor=keras.Input(shape=(n,d))

In [4]:
init_sigma = 0.1

W_init = tf.random.normal(
    shape=(n,k,d,d),
    stddev=init_sigma, dtype=tf.float32)
W = tf.Variable(W_init)

2022-12-05 12:59:49.132321: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [5]:
W.shape

TensorShape([128, 3, 8, 8])

In [6]:
input_tensor.shape

TensorShape([None, 128, 8])

In [7]:
x=tf.expand_dims(input_tensor, -1)

In [8]:
x=tf.expand_dims(x, 2)

In [9]:
x=tf.tile(x, [1, 1, k, 1, 1])

In [10]:
x.shape

TensorShape([None, 128, 3, 8, 1])

In [11]:
tf.matmul(W,x)

<KerasTensor: shape=(None, 128, 3, 8, 1) dtype=float32 (created by layer 'tf.linalg.matmul')>

In [12]:
# data loading in appropriate formate

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

y_train=tf.keras.utils.to_categorical(y_train)
y_test=tf.keras.utils.to_categorical(y_test)
     

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [13]:
X=x_train[:32]

In [14]:
X.shape

(32, 28, 28, 1)

In [15]:
c1=tf.keras.layers.Conv2D(16,kernel_size=3,strides=1,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(8,kernel_size=5,strides=2,padding='valid',activation='relu')

In [16]:
input_tensor=c2(c1(X))

In [17]:
input_tensor.shape

TensorShape([32, 11, 11, 8])

In [18]:
input_tensor=tf.reshape(input_tensor,shape=[32,11*11,8])

In [19]:
input_tensor.shape

TensorShape([32, 121, 8])

In [20]:
n=11*11
d_n=8
k=10
d_k=16

In [21]:
#initializing W
init_sigma = 0.1

W_init = tf.random.normal(
    shape=(n,k,d_k,d_n),
    stddev=init_sigma, dtype=tf.float32)
W = tf.Variable(W_init)

In [22]:
W.shape

TensorShape([121, 10, 16, 8])

In [23]:
batch_size=32

In [24]:
x = tf.expand_dims(input_tensor, -1)
print(x.shape)
x = tf.expand_dims(x, 2)
print(x.shape)
x = tf.tile(x, [1, 1, k, 1, 1])
print(x.shape)

(32, 121, 8, 1)
(32, 121, 1, 8, 1)
(32, 121, 10, 8, 1)


In [25]:
W_old=tf.expand_dims(W,axis=0)

In [26]:
W_old.shape

TensorShape([1, 121, 10, 16, 8])

In [27]:
# old method
W_old = tf.tile(W_old, [batch_size, 1, 1, 1, 1])
caps_old=tf.matmul(W_old,x)

In [28]:
caps_old.shape

TensorShape([32, 121, 10, 16, 1])

In [29]:
#new method
caps_new=tf.matmul(W,x)

In [30]:
caps_new.shape

TensorShape([32, 121, 10, 16, 1])

In [34]:
tf.reduce_sum(caps_old-caps_new).numpy()

-9.326154e-07

In [37]:
# routing step
caps2_predicted=tf.matmul(W,x)

In [66]:
raw_weights = tf.zeros([1,n, k, 1, 1])
raw_weights.shape

TensorShape([1, 121, 10, 1, 1])

In [67]:
routing_weights = tf.nn.softmax(raw_weights,axis=1)
print(routing_weights.shape)
weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
print(weighted_predictions.shape)
weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)
print(weighted_sum.shape)

(1, 121, 10, 1, 1)
(32, 121, 10, 16, 1)
(32, 1, 10, 16, 1)


In [68]:
caps2_predicted.shape

TensorShape([32, 121, 10, 16, 1])

In [69]:
v = squash(weighted_sum, axis=-2) #normalize to unit length vector.
v_tiled = tf.tile(v, [1, n, 1, 1, 1])

In [70]:
v_tiled.shape

TensorShape([32, 121, 10, 16, 1])

In [71]:
agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True)

In [72]:

agreement.shape

TensorShape([32, 121, 10, 1, 1])

In [75]:
#routing_weights+agreement

In [64]:
def Routing(caps2_predicted,r=3):
    raw_weights = tf.zeros([1,n,k, 1, 1])

    while(r):
      r-=1
      routing_weights = tf.nn.softmax(raw_weights,axis=2)
      weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
      weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)
      v = squash(weighted_sum, axis=-2)
      v_tiled = tf.tile(v, [1, n, 1, 1, 1])
      agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True)
      if(r>0):
          routing_weights+=agreement
      else:
          return v

In [65]:
Routing(caps2_predicted)

<tf.Tensor: shape=(32, 1, 10, 16, 1), dtype=float32, numpy=
array([[[[[-5.24298579e-04],
          [-5.17527151e-05],
          [ 3.81430727e-04],
          ...,
          [-2.55915686e-04],
          [-7.91414350e-05],
          [ 8.45036528e-04]],

         [[-4.87259502e-04],
          [ 1.19481818e-03],
          [ 2.78140633e-05],
          ...,
          [ 8.78550985e-04],
          [-1.72511311e-04],
          [-4.93526750e-04]],

         [[-7.04654376e-04],
          [-1.14782539e-03],
          [ 1.15033268e-04],
          ...,
          [ 1.35363807e-04],
          [-2.95238948e-04],
          [-3.39778257e-04]],

         ...,

         [[ 5.00864349e-04],
          [-8.57401665e-06],
          [-4.24146128e-04],
          ...,
          [-4.48245992e-04],
          [ 5.33454004e-04],
          [ 9.83036734e-05]],

         [[ 1.18718017e-03],
          [-1.30032958e-03],
          [ 7.56149937e-04],
          ...,
          [ 5.79594169e-04],
          [-6.34861120e-04],
 