In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

# %matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import os
import time


In [3]:
print(tf.__version__)

2.6.4


In [4]:

@tf.function
def squash(v,epsilon=1e-7,axis=-1):
    sqnrm=tf.reduce_sum(tf.square(v), axis=axis,keepdims=True)
    nrm=tf.sqrt(sqnrm + epsilon) #safe norm to avoid divide by zero.
    sqsh_factor = sqnrm / (1. + sqnrm)
    unit_vect = v / nrm
    return sqsh_factor*unit_vect

@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

class Primary_caps_layer(tf.keras.layers.Layer):
  """ caps_n(i) --> no of capsule in ith layer 
      caps_dim(i) --> dimension of capsule in ith layer. 
      
      primary_caps_layer output shape = [batch_size,caps_n,caps_dim]"""

  def __init__(self,caps_n=1152,k1=256,k2=256,k_s1=9,k_s2=5,s1=1,s2=3):
    super(Primary_caps_layer, self).__init__()
    self.caps_n=caps_n  # no of capsule in this layer.(as initialized by usr this may be changed based on other parameters.)
    self.k1=k1          # no of filter in 1st conv layer.
    self.k2=k2          # no of filter in 2nd conv layer.
    self.k_s1=k_s1      # kernel_size of 1st conv layer.
    self.k_s2=k_s2      # kernel_size of 2nd conv layer.
    self.s1=s1          # stride in 1st conv layer.
    self.s2=s2          # stride in 2nd conv layer.
    self.conv1=tf.keras.layers.Conv2D(k1,kernel_size=k_s1,strides=s1,padding='valid',activation='relu') 
    self.conv2=tf.keras.layers.Conv2D(k2,kernel_size=k_s2,strides=s2,padding='valid',activation='relu')

  def call(self, input_tensor):
    batch_size=input_tensor.shape[0]
    x=self.conv1(input_tensor)
    x=self.conv2(x) 

    assert x.shape[1]*x.shape[1]*self.k2==self.caps_n*self.caps_dim # $ eqn--1

    x=tf.reshape(x,[batch_size,self.caps_n,self.caps_dim]) # *
    return squash(x)

  def build(self,input_shape):
    self.batch_size=input_shape[0] 
    tmp=int(((input_shape[1]-self.k_s1)/self.s1))+1
    self.conv1_output_shape=[input_shape[0],tmp,tmp,self.k1]
    tmp=int(((tmp-self.k_s2)/self.s2))+1
    self.conv2_output_shape=[input_shape[0],tmp,tmp,self.k2]
    tmp1=tmp*tmp*self.k2
    self.caps_n=self.caps_n-(tmp1%self.caps_n) # recomputing apropriate no of capsule : $ eqn--1 is true.
    self.caps_dim=int((tmp*tmp*self.k2)/self.caps_n); # same is done for caps_dim.
    

class Digit_caps_layer(tf.keras.layers.Layer):
  """ caps_n(i) --> no of capsule in ith layer 
      caps_dim(i) --> dimension of capsule in ith layer. 
      and we assume this is ith layer. 
      output.shape of ith layer = [batch_size, 1,caps_n(i),caps_dim(i), 1]"""

  def __init__(self,caps_dim=16,caps_n=10,r=3):
    super(Digit_caps_layer,self).__init__()
    self.caps_n=caps_n # no of capsule.
    self.caps_dim=caps_dim # dim of each capsule.
    self.r=r # no of iteration in routing by agreement algorithm.
    
  def build(self,input_shape): # input_shape = [batch_size,caps_n(i-1),caps_dim(i-1)] 
    self.W = tf.Variable(initial_value=tf.random.normal(
    shape=(1, input_shape[1], self.caps_n, self.caps_dim, input_shape[-1]),
    stddev=0.1, dtype=tf.float32),
    trainable=True)  #weigth initialization for this layer W.shape=[1,caps_n(i-1),caps_n(i),caps_dim(i),caps_dim(i-1)].

  def call(self,input_tensor): #input_tensor.shape=[batch_size,caps_n(i-1),caps_dim(i-1)]
    batch_size = input_tensor.shape[0]
    W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1]) # replicating the weights for parallel processing of a batch.
    """ W_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i),caps_dim(i-1)] """

    caps_output_expanded = tf.expand_dims(input_tensor, -1) # converting last dim to a column vector.
    """ the above step change the input shape from 
        [batch_size,caps_n(i-1),caps_dim(i-1)] --> [batch_size,caps_n(i-1),caps_dim(i-1),1]"""

    caps_output_tile = tf.expand_dims(caps_output_expanded, 2)
    """ the above step change the input shape from 
        [batch_size,caps_n(i-1),caps_dim(i-1),1] --> [batch_size,caps_n(i-1),1,caps_dim(i-1),1]"""

    caps_output_tiled = tf.tile(caps_output_tile, [1, 1, self.caps_n, 1, 1]) # replicating the input capsule vector for every output capsule.
    """ i.e [batch_size,caps_n(i-1),1,caps_dim(i-1),1] --> [batch_size,caps_n(i-1),caps_n(i),1,caps_dim(i-1),1]"""

    caps_predicted = tf.matmul(W_tiled, caps_output_tiled) # this is performing element wise tf.matmul() operation.
    """ caps_predicted.shape = [1,caps_n(i-1),caps_n(i),caps_dim(i),1]"""

    """ dynamic routing """
    raw_weights = tf.zeros([batch_size,input_tensor.shape[1] , self.caps_n, 1, 1]) # non trainable weights.
    """ raw_weights.shape=[batch_size,caps_n(i-1) ,caps_n(i), 1, 1]"""

    r=self.r
    while(r):
      r-=1
      routing_weights = tf.nn.softmax(raw_weights,axis=2)
      """ [batch_size,caps_n(i-1) ,caps_n(i), 1, 1]  softmax applied along the pointed dim.
                                       ^                                                   """

      weighted_predictions = tf.multiply(routing_weights, caps_predicted)
      """ weighted_predictions.shape = [batch_size, caps_n(i-1),caps_n(i),caps_dim(i), 1]"""

      weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)
      """ [batch_size,caps_n(i-1) ,caps_n(i),caps_dim(i), 1]  sum applied along the pointed dim.
                           ^                                                               
      therefore weighted_sum.shape=[batch_size,1 ,caps_n(i),caps_dim(i), 1]"""

      v = squash(weighted_sum, axis=-2) #normalize to unit length vector.
      v_tiled = tf.tile(v, [1, input_tensor.shape[1], 1, 1, 1])
      """ v_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i), 1]"""

      agreement = tf.matmul(caps_predicted, v_tiled,transpose_a=True)
      """ agreement.shape=[batch_size,caps_n(i-1),caps_n(i), 1, 1]"""

      if(r>0):
          routing_weights+=agreement
      else:
          return v

In [5]:
ls ../input/brain-tumor-mri-dataset/

[0m[01;34mTesting[0m/  [01;34mTraining[0m/


In [6]:
PATH="../input/brain-tumor-mri-dataset/"

train_dir = PATH+"/Training" 
validation_dir = PATH+"/Testing"

BATCH_SIZE = 32
IMG_SIZE = (256, 256)

#train data
train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(train_dataset))

#validation model.
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                                 shuffle=True,
                                                                 batch_size=BATCH_SIZE,
                                                                 image_size=IMG_SIZE)


# creating test data.
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))


Found 5712 files belonging to 4 classes.


2022-09-20 05:05:38.075614: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-20 05:05:38.174367: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-20 05:05:38.175170: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-20 05:05:38.177673: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compil

Number of validation batches: 179
Found 1311 files belonging to 4 classes.
Number of validation batches: 33
Number of test batches: 8


In [7]:
train_dataset.take(1)

<TakeDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

In [8]:
#optimization parameter setting.
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

In [9]:

class Caps_net(tf.keras.Model):

  def __init__(self,no_classes=10):
    super(Caps_net,self).__init__()
    self.no_classes=no_classes

    self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    self.loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
    self.train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
    self.val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

    self.pri_layer=Primary_caps_layer(caps_n=256,k1=64,k2=64,k_s1=9,k_s2=5,s1=1,s2=3)
    self.dig_layer=Digit_caps_layer(caps_dim=8,caps_n=no_classes,r=3)

    self.decoder=tf.keras.Sequential([
      keras.layers.Dense(128, activation='relu'),
      keras.layers.Dense(128, activation='relu'),
      keras.layers.Dense(256*256*3, activation='sigmoid'),
    ])

  def call(self,input_tensor,y,training=False):
    """ y should not be prob. dist/one-hot vectors it should be list of label for mnist it would 
        be as [1,4,6,3,8,7,...,5]. 
        when training is false y is not needed."""

    batch_size=input_tensor.shape[0]
    img_dim=input_tensor.shape[1] # considering image size=(img_dim,img_dim,img_depth)
    img_depth=input_tensor.shape[-1]

    x = self.pri_layer(input_tensor) #x.shape=[batch_size,caps_n(i),caps_dim(i)]
    x = self.dig_layer(x) #x.shape=[batch_size, 1,caps_n(i),caps_dim(i), 1]
    z = safe_norm(x, axis=-2) #x.shape=[batch_size,1,caps_n(i-1),1]
    z = tf.nn.softmax(z,axis=2) #converting those probabilities to prob dist.
    y_pred = tf.squeeze(z, axis=[1,3]) #reducing the extra dims. therefore the output shape =[batch_size,caps_n(i-1)] 
    if(training==False):
      return y_pred  # y_pred is a prob. dist.

    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False)(tf.one_hot(y,depth=self.no_classes), y_pred)

    #loss2 i.e reconstruction loss.
    reconstruction_mask = tf.one_hot(y,depth=self.no_classes) # recon_mask is one-hot vect rep. of y.
    
    reconstruction_mask_reshaped = tf.reshape(reconstruction_mask, [batch_size, 1, self.no_classes, 1, 1])
    # above reshape is done so that we can apply the mask.
    lastcaps_output_masked = tf.multiply(x, reconstruction_mask_reshaped)

    lastcaps_n=x.shape[2] # no of capsule in last layer.
    lastcaps_dims=x.shape[3] # dim of capsule in last layer.

    decoder_input = tf.reshape(lastcaps_output_masked,[batch_size, lastcaps_n * lastcaps_dims])
    
    decoder_output=self.decoder(decoder_input) 
    """ reconstruction of the input image based on the output vector of last layer
        we apply the mask to the output of the last layer such that only the vector corresponding to a
        particular lable is passed to the decoder."""

    X_flat = tf.reshape(input_tensor, [batch_size,img_dim*img_dim*img_depth]) 
    
    squared_difference = tf.square(X_flat - decoder_output)
    reconstruction_loss = tf.reduce_mean(squared_difference) # computation of mean squared loss between input image and reconstructed image.
  
    return loss+0.0005*reconstruction_loss

  def fit(self,train_dataset,validation_dataset,epochs=3):

    for epoch in range(epochs):
      print("\nepoch {}/{}".format(epoch+1,epochs))
      pbar = keras.utils.Progbar(target=int(train_dataset.cardinality()))
      metrics = {}

      # Iterate over the batches of the dataset.
      for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
          with tf.GradientTape() as tape:
              y_pred=self(x_batch_train,y_batch_train,training=False) # $ better design needed.
              # y_pred is prob. dist.
              loss_value = self(x_batch_train,y_batch_train,training=True) # loss computation
          grads = tape.gradient(loss_value, self.trainable_weights) # back prop
          self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) # weight update

          # Update training metric.
          self.train_acc_metric.update_state(tf.keras.utils.to_categorical(y_batch_train,num_classes=self.no_classes), y_pred)
          metrics.update({'train_acc':self.train_acc_metric.result()})
          pbar.update(step+1, values=metrics.items(), finalize=False)


      # Run a validation loop at the end of each epoch.
      for x_batch_val, y_batch_val in validation_dataset:
        y_batch_val=tf.keras.utils.to_categorical(y_batch_val,num_classes=self.no_classes)
        val_pred = self(x_batch_val,y_batch_val,training=False) # $ better design needed
        # Update val metrics
        self.val_acc_metric.update_state(y_batch_val, val_pred)

      metrics.update({'val_acc':self.val_acc_metric.result()})
      
      pbar.update(step+1, values=metrics.items(), finalize=True)
      
      # Reset training & val metrics at the end of each epoch
      self.train_acc_metric.reset_states()
      self.val_acc_metric.reset_states()


In [10]:

model=Caps_net(no_classes=4)



In [11]:
model.fit(train_dataset,validation_dataset,epochs=30)


epoch 1/30


2022-09-20 05:06:08.489795: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Cleanup called...
2022-09-20 05:06:10.689780: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8005




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 2/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 3/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 4/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 5/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 6/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 7/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 8/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 9/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 10/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 11/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 12/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 13/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 14/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 15/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 16/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 17/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 18/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 19/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 20/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 21/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 22/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 23/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 24/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 25/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 26/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 27/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 28/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 29/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...



epoch 30/30


Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




Cleanup called...




In [12]:
model.save_weights('./checkpoints/my_checkpoint')