In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [2]:
from keras.engine import data_adapter

In [3]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

In [4]:
@tf.function
def squash(v,epsilon=1e-7,axis=-1):
    sqnrm=tf.reduce_sum(tf.square(v), axis=axis,keepdims=True)
    nrm=tf.sqrt(sqnrm + epsilon) #safe norm to avoid divide by zero.
    sqsh_factor = sqnrm / (1. + sqnrm)
    unit_vect = v / nrm
    return sqsh_factor*unit_vect

@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [5]:
# data loading in appropriate formate

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

y_train=tf.keras.utils.to_categorical(y_train)
y_test=tf.keras.utils.to_categorical(y_test)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [6]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [7]:
caps1_n_maps = 32
caps1_n_caps = caps1_n_maps * 6 * 6  # 1152 primary capsules
caps1_n_dims = 8

# digit capsule layer
caps2_n_caps = 10 # 10 capsule each digit.
caps2_n_dims = 16 # each of the 10 capsules are of 16 dims.


In [8]:
class Primary_caps_layer(tf.keras.layers.Layer):
  """ caps_n(i) --> no of capsule in ith layer 
      caps_dim(i) --> dimension of capsule in ith layer. 
      
      primary_caps_layer output shape = [batch_size,caps_n,caps_dim]"""

  def __init__(self,caps_dim=8,caps_n=1152):
    super(Primary_caps_layer, self).__init__()
    self.caps_n=caps_n  # no of capsule in this layer.
    self.caps_dim=caps_dim # dim of each capsule in this layer
    self.conv1=tf.keras.layers.Conv2D(256,kernel_size=9,strides=1,padding='valid',activation='relu') #@ changes may be needed of no of kernel.
    self.conv2=tf.keras.layers.Conv2D(256,kernel_size=9,strides=2,padding='valid',activation='relu')

  def call(self, input_tensor):
    x=self.conv1(input_tensor)
    x=self.conv2(x)
    x=tf.reshape(x,[-1,self.caps_n,self.caps_dim])
    return squash(x)
    

In [9]:
class Digit_caps_layer(tf.keras.layers.Layer):
  """ caps_n(i) --> no of capsule in ith layer 
      caps_dim(i) --> dimension of capsule in ith layer. 
      and we assume this is ith layer. 
      output.shape of ith layer = [batch_size, 1,caps_n(i),caps_dim(i), 1]"""

  def __init__(self,caps_dim=16,caps_n=10,r=3):
    super(Digit_caps_layer,self).__init__()
    self.caps_n=caps_n # no of capsule.
    self.caps_dim=caps_dim # dim of each capsule.
    self.r=r # no of iteration in routing by agreement algorithm.
  
  def build(self,input_shape): # input_shape = [batch_size,caps_n(i-1),caps_dim(i-1)] 
    self.W = tf.Variable(initial_value=tf.random.normal(
    shape=(1, input_shape[1], self.caps_n, self.caps_dim, input_shape[-1]),
    stddev=0.1, dtype=tf.float32),
    trainable=True)  #weigth initialization for this layer W.shape=[1,caps_n(i-1),caps_n(i),caps_dim(i),caps_dim(i-1)].

  def call(self,input_tensor): #input_tensor.shape=[batch_size,caps_n(i-1),caps_dim(i-1)]
    batch_size = input_tensor.shape[0]
    W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1]) # replicating the weights for parallel processing of a batch.
    """ W_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i),caps_dim(i-1)] """

    caps_output_expanded = tf.expand_dims(input_tensor, -1) # converting last dim to a column vector.
    """ the above step change the input shape from 
        [batch_size,caps_n(i-1),caps_dim(i-1)] --> [batch_size,caps_n(i-1),caps_dim(i-1),1]"""

    caps_output_tile = tf.expand_dims(caps_output_expanded, 2)
    """ the above step change the input shape from 
        [batch_size,caps_n(i-1),caps_dim(i-1),1] --> [batch_size,caps_n(i-1),1,caps_dim(i-1),1]"""

    caps_output_tiled = tf.tile(caps_output_tile, [1, 1, self.caps_n, 1, 1]) # replicating the input capsule vector for every output capsule.
    " i.e [batch_size,caps_n(i-1),1,caps_dim(i-1),1] --> [batch_size,caps_n(i-1),caps_n(i),1,caps_dim(i-1),1]"

    caps_predicted = tf.matmul(W_tiled, caps_output_tiled) # this is performing element wise tf.matmul() operation.
    """ caps_predicted.shape = [1,caps_n(i-1),caps_n(i),caps_dim(i),1]"""

    """ dynamic routing """
    raw_weights = tf.zeros([batch_size,input_tensor.shape[1] , self.caps_n, 1, 1]) # non trainable weights.
    """ raw_weights.shape=[batch_size,caps_n(i-1) ,caps_n(i), 1, 1]"""

    r=self.r
    while(r):
      r-=1
      routing_weights = tf.nn.softmax(raw_weights,axis=2)
      """ [batch_size,caps_n(i-1) ,caps_n(i), 1, 1]  softmax applied along the pointed dim.
                                       ^                                                   """

      weighted_predictions = tf.multiply(routing_weights, caps_predicted)
      """ weighted_predictions.shape = [batch_size, caps_n(i-1),caps_n(i),caps_dim(i), 1]"""

      weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)
      """ [batch_size,caps_n(i-1) ,caps_n(i),caps_dim(i), 1]  sum applied along the pointed dim.
                           ^                                                               
      therefore weighted_sum.shape=[batch_size,1 ,caps_n(i),caps_dim(i), 1]"""

      v = squash(weighted_sum, axis=-2) #normalize to unit length vector.
      v_tiled = tf.tile(v, [1, input_tensor.shape[1], 1, 1, 1])
      """ v_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i), 1]"""

      agreement = tf.matmul(caps_predicted, v_tiled,transpose_a=True)
      """ agreement.shape=[batch_size,caps_n(i-1),caps_n(i), 1, 1]"""

      if(r>0):
          routing_weights+=agreement
      else:
          return v

In [10]:
class Caps_net(tf.keras.Model):

  def __init__(self,no_classes=10):
    super(Caps_net,self).__init__()
    self.pri_layer=Primary_caps_layer(caps_dim=8,caps_n=1152)
    self.dig_layer=Digit_caps_layer(caps_dim=16,caps_n=10,r=3)

  def call(self,input_tensor):
    x = self.pri_layer(input_tensor) #x.shape=[batch_size,caps_n(i),caps_dim(i)]
    x = self.dig_layer(x) #x.shape=[batch_size, 1,caps_n(i),caps_dim(i), 1]

    """The lengths of the output vectors represent the class probabilities, 
       so we could just use tf.norm() to compute them,"""
    x = safe_norm(x, axis=-2) #x.shape=[batch_size,1,caps_n(i-1),1]

    x = tf.nn.softmax(x,axis=2) #converting those probabilities to prob dist.
    x = tf.squeeze(x, axis=[1,3]) #reducing the extra dims. therefore the output shape =[batch_size,caps_n(i-1)] 
    return x

  """ custom training loop 
  def train_step(self,data):
    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
   
    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics} """


In [11]:
model=Caps_net(no_classes=10)

In [12]:
model(x_train[:32]).shape

TensorShape([32, 10])

In [13]:
model.compile(
          loss      = tf.keras.losses.CategoricalCrossentropy(from_logits=True),
          metrics   = tf.keras.metrics.CategoricalAccuracy(),
          optimizer = tf.keras.optimizers.Adam())

In [14]:
model.fit(x_train, y_train, batch_size=32,epochs=2,validation_split=0.2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fdcea5d8950>