In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import os
import time

In [2]:
print(tf.__version__)

2.6.4


In [3]:
@tf.function
def squash(x, axis=-1):
    s_squared_norm = tf.math.reduce_sum(tf.math.square(x), axis, keepdims=True) + keras.backend.epsilon()
    scale = tf.math.sqrt(s_squared_norm) / (1 + s_squared_norm)
    return scale * x

@tf.function
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return tf.math.reduce_sum((y_true * tf.math.square(tf.nn.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * tf.math.square(tf.nn.relu(y_pred - margin))), axis=-1)

#@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)
    

In [4]:
ls ../input/brain-tumor-mri-dataset/

[0m[01;34mTesting[0m/  [01;34mTraining[0m/


In [5]:
PATH="../input/brain-tumor-mri-dataset/"

train_dir = PATH+"/Training" 
validation_dir = PATH+"/Testing"

BATCH_SIZE = 32
IMG_SIZE = (128, 128)

#train data
train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(train_dataset))

#validation model.
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                                 shuffle=True,
                                                                 batch_size=BATCH_SIZE,
                                                                 image_size=IMG_SIZE)


# creating test data.
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))


Found 5712 files belonging to 4 classes.


2022-11-07 11:24:30.036976: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-07 11:24:30.138949: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-07 11:24:30.139715: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-07 11:24:30.143798: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compil

Number of validation batches: 179
Found 1311 files belonging to 4 classes.
Number of validation batches: 33
Number of test batches: 8


In [6]:
train_dataset.take(1)

<TakeDataset shapes: ((None, 128, 128, 3), (None,)), types: (tf.float32, tf.int32)>

In [7]:
#optimization parameter setting.
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

In [8]:
class Capsule(keras.layers.Layer):
   

    def __init__(self,
                 num_capsule,
                 dim_capsule,
                 routings=3,
                 **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.caps_n = num_capsule
        self.caps_dim = dim_capsule
        self.r = routings

    def get_config(self):
        config = super().get_config().copy()
        config.update({
        'num_capsule':  self.caps_n,
        'dim_capsule' : self.caps_dim,
        'routings':  self.r,      
        })
        return config

    def build(self, input_shape):

        self.W = self.add_weight(name='W',
                    shape=[1, input_shape[1], self.caps_n, self.caps_dim, input_shape[-1]],
                    dtype=tf.float64,
                    initializer='glorot_uniform',
                    trainable=True)
        
        
    def call(self, input_tensor):
        assert input_tensor.shape[2]==self.caps_dim
        input_tensor=tf.cast(input_tensor,dtype=tf.float64)
        assert input_tensor.dtype==tf.float64
        batch_size = input_tensor.shape[0]
        n=input_tensor.shape[1]
        k=self.caps_n
        d=self.caps_dim
        
        W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1]) # replicating the weights for parallel processing of a batch.
        """ W_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i),caps_dim(i-1)] """

        caps_output_expanded = tf.expand_dims(input_tensor, -1) # converting last dim to a column vector.
        """ the above step change the input shape from 
            [batch_size,caps_n(i-1),caps_dim(i-1)] --> [batch_size,caps_n(i-1),caps_dim(i-1),1]"""

        caps_output_tile = tf.expand_dims(caps_output_expanded, 2)
        """ the above step change the input shape from 
            [batch_size,caps_n(i-1),caps_dim(i-1),1] --> [batch_size,caps_n(i-1),1,caps_dim(i-1),1]"""

        caps_output_tiled = tf.tile(caps_output_tile, [1, 1, self.caps_n, 1, 1]) # replicating the input capsule vector for every output capsule.
        """ i.e [batch_size,caps_n(i-1),1,caps_dim(i-1),1] --> [batch_size,caps_n(i-1),caps_n(i),1,caps_dim(i-1),1]"""

        caps_predicted = tf.matmul(W_tiled, caps_output_tiled) # this is performing element wise tf.matmul() operation.
        """ caps_predicted.shape = [1,caps_n(i-1),caps_n(i),caps_dim(i),1]"""

        """ dynamic routing """
        #initialization step.
        
        pi=np.ones([batch_size,k])/k
        mu=np.random.rand(batch_size,k,d)
        sigma=np.ones([batch_size,k,d])
        R=np.zeros(shape=(batch_size,n,k))

        pi=tf.convert_to_tensor(pi,dtype=tf.float64)
        mu=tf.convert_to_tensor(mu,dtype=tf.float64)
        sigma=tf.convert_to_tensor(sigma,dtype=tf.float64)
        R=tf.convert_to_tensor(R,dtype=tf.float64)

        r=self.r
        while(r):
          r=r-1
          """ E-step. """
          
          x_tmp=tf.expand_dims(input_tensor,axis=1) # x.shape==[b,n,d]
          x_tmp=tf.tile(x_tmp,[1,k,1,1]) # x_tmp.shape==[b,k,n,d]

          mu_tmp=tf.expand_dims(mu,axis=2) # mu.shape==[b,k,d]
          mu_tmp=tf.tile(mu_tmp,[1,1,n,1])   # mu_tmp.shape==[b,k,n,d]

          sig_tmp=tf.expand_dims(sigma,axis=2) # sigma.shape==[b,k,d]
          sig_tmp=tf.tile(sig_tmp,[1,1,n,1])   # sig_tmp.shape == [b,k,n,d]

          N = tfd.MultivariateNormalDiag(loc=mu_tmp,scale_diag=sig_tmp).prob(x_tmp)
          N = pi[:,:,None]*N
          N = N/tf.expand_dims(tf.reduce_sum(N,axis=1),axis=1)
          R = tf.transpose(N,perm=[0,2,1])

          """ M-step. """
          
          # updating pi.
          N_k = tf.reduce_sum(R,axis=1)
          pi = N_k/n

          # updating mu.
          mu = tf.matmul(tf.transpose(R,perm=[0,2,1]),input_tensor)
          mu = mu/N_k[:,:,None]

          # updating sigma.
          mu_tmp=tf.expand_dims(mu,axis=2)
          mu_tmp=tf.tile(mu_tmp,[1,1,n,1])
          x_tmp=x_tmp-mu_tmp
          x_tmp=tf.square(x_tmp)
          R_T=tf.transpose(R,perm=[0,2,1])
          x_tmp = tf.multiply(tf.reshape(R_T,[batch_size,k,n,1]),x_tmp)
          sigma = tf.reduce_sum(x_tmp,axis=2)/tf.reshape(N_k,[batch_size,k,1])
          sigma=tf.sqrt(sigma)
              
        weighted_prediction=tf.multiply(caps_predicted,tf.reshape(R,[batch_size,n,k,1,1]))
        weighted_sum = tf.reduce_sum(weighted_prediction, axis=1, keepdims=True)
        v=squash(weighted_sum, axis=-2)
        v = tf.squeeze(v, axis=[1,4])
        return v

    def compute_output_signature(self,input_shape):
      return tf.TensorSpec(shape=[input_shape[0],self.caps_n,self.caps_dim],dtype=tf.float64)

In [15]:
c1=tf.keras.layers.Conv2D(16,kernel_size=5,strides=2,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(32,kernel_size=5,strides=2,padding='valid',activation='relu')
c3=tf.keras.layers.Conv2D(64,kernel_size=5,strides=2,padding='valid',activation='relu')
c4=tf.keras.layers.Conv2D(128,kernel_size=5,strides=1,padding='valid',activation='relu')
dc1=tf.keras.layers.DepthwiseConv2D(kernel_size=9,strides=(1, 1),padding='valid',activation='relu')
last=Capsule(4,8)
bn1=tf.keras.layers.BatchNormalization()
bn2=tf.keras.layers.BatchNormalization()
bn3=tf.keras.layers.BatchNormalization()
bn4=tf.keras.layers.BatchNormalization()

In [16]:
model_input = keras.Input(shape=(128, 128, 3), batch_size=32)
x=c1(model_input)
x=bn1(x,training=True)
x=c2(x)
x=bn2(x,training=True)
x=c3(x)
x=bn3(x,training=True)
x=c4(x)
x=bn4(x,training=True)
x=dc1(x)
x=tf.reshape(x,[-1,16,8])
x=last(x)
x=tf.cast(x,tf.float32)
x=safe_norm(x, axis=2)
model_output = x

In [17]:
model = keras.Model(model_input, model_output, name="encoder")

In [18]:
adam = tf.keras.optimizers.Adam(learning_rate=0.0001) 

model.compile(loss=margin_loss, optimizer=adam, metrics=tf.keras.metrics.CategoricalAccuracy())
model.summary()

Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(32, 128, 128, 3)]       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (32, 62, 62, 16)          1216      
_________________________________________________________________
batch_normalization_8 (Batch (32, 62, 62, 16)          64        
_________________________________________________________________
conv2d_9 (Conv2D)            (32, 29, 29, 32)          12832     
_________________________________________________________________
batch_normalization_9 (Batch (32, 29, 29, 32)          128       
_________________________________________________________________
conv2d_10 (Conv2D)           (32, 13, 13, 64)          51264     
_________________________________________________________________
batch_normalization_10 (Batc (32, 13, 13, 64)          256 

In [19]:
"""customize training loop."""

# Instantiate an optimizer to train the model.
base_learning_rate = 0.0001
optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate)
# Instantiate a loss function.
loss_fn = margin_loss

# Prepare the metrics.
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

In [20]:
epochs = 30
for epoch in range(epochs):
    print("\nepoch {}/{}".format(epoch+1,epochs))
    pbar = keras.utils.Progbar(target=int(train_dataset.cardinality()))
    metrics = {}

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        y_true = tf.keras.utils.to_categorical(y_batch_train,num_classes=4)
        with tf.GradientTape() as tape:
            y_pred=model(x_batch_train) # $ better design needed.
            # y_pred is prob. dist.
            loss_value = loss_fn(y_true,y_pred) # loss computation
        grads = tape.gradient(loss_value, model.trainable_weights) # back prop
        optimizer.apply_gradients(zip(grads, model.trainable_weights)) # weight update

        # Update training metric.
        train_acc_metric.update_state(y_true, y_pred)
        metrics.update({'train_acc':train_acc_metric.result()})
        pbar.update(step+1, values=metrics.items(), finalize=False)


    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in validation_dataset:
      y_batch_val=tf.keras.utils.to_categorical(y_batch_val,num_classes=4)
      val_pred = model(x_batch_val) # $ better design needed
      # Update val metrics
      val_acc_metric.update_state(y_batch_val, val_pred)

    metrics.update({'val_acc':val_acc_metric.result()})
    
    pbar.update(step+1, values=metrics.items(), finalize=True)
    
    # Reset training & val metrics at the end of each epoch
    train_acc_metric.reset_states()
    val_acc_metric.reset_states()


epoch 1/30


2022-11-07 11:27:57.732415: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2022-11-07 11:27:59.959675: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8005


  4/179 [..............................] - ETA: 24s - train_acc: 0.3092

Cleanup called...


 32/179 [====>.........................] - ETA: 20s - train_acc: 0.2844

Cleanup called...


 34/179 [====>.........................] - ETA: 20s - train_acc: 0.2841

Cleanup called...




Cleanup called...




Cleanup called...



epoch 2/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.4245

Cleanup called...


 31/179 [====>.........................] - ETA: 19s - train_acc: 0.4239

Cleanup called...


 32/179 [====>.........................] - ETA: 18s - train_acc: 0.4244

Cleanup called...




Cleanup called...




Cleanup called...



epoch 3/30
  4/179 [..............................] - ETA: 23s - train_acc: 0.5391

Cleanup called...


 31/179 [====>.........................] - ETA: 19s - train_acc: 0.5949- ETA: 2

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.5954

Cleanup called...




Cleanup called...




Cleanup called...



epoch 4/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.7591

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.6989

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.6976

Cleanup called...




Cleanup called...




Cleanup called...



epoch 5/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.7448

Cleanup called...


 29/179 [===>..........................] - ETA: 18s - train_acc: 0.7141

Cleanup called...


 32/179 [====>.........................] - ETA: 19s - train_acc: 0.7119

Cleanup called...




Cleanup called...




Cleanup called...



epoch 6/30
  4/179 [..............................] - ETA: 23s - train_acc: 0.7689

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.7453

Cleanup called...


 32/179 [====>.........................] - ETA: 17s - train_acc: 0.7449

Cleanup called...




Cleanup called...




Cleanup called...



epoch 7/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.8066

Cleanup called...


 30/179 [====>.........................] - ETA: 18s - train_acc: 0.7782

Cleanup called...


 32/179 [====>.........................] - ETA: 18s - train_acc: 0.7780

Cleanup called...




Cleanup called...




Cleanup called...



epoch 8/30
  4/179 [..............................] - ETA: 21s - train_acc: 0.8053

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.7932

Cleanup called...


 32/179 [====>.........................] - ETA: 17s - train_acc: 0.7931

Cleanup called...




Cleanup called...




Cleanup called...



epoch 9/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.8164

Cleanup called...


 31/179 [====>.........................] - ETA: 18s - train_acc: 0.8225

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.8225

Cleanup called...




Cleanup called...




Cleanup called...



epoch 10/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.8210

Cleanup called...


 30/179 [====>.........................] - ETA: 19s - train_acc: 0.8279

Cleanup called...


 32/179 [====>.........................] - ETA: 19s - train_acc: 0.8273

Cleanup called...




Cleanup called...




Cleanup called...



epoch 11/30
  4/179 [..............................] - ETA: 21s - train_acc: 0.8997

Cleanup called...


 31/179 [====>.........................] - ETA: 18s - train_acc: 0.8336

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.8333

Cleanup called...




Cleanup called...




Cleanup called...



epoch 12/30
  4/179 [..............................] - ETA: 23s - train_acc: 0.8743

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.8689

Cleanup called...


 32/179 [====>.........................] - ETA: 17s - train_acc: 0.8688

Cleanup called...




Cleanup called...




Cleanup called...



epoch 13/30
  4/179 [..............................] - ETA: 21s - train_acc: 0.8750

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.8829

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.8824

Cleanup called...




Cleanup called...




Cleanup called...



epoch 14/30
  4/179 [..............................] - ETA: 21s - train_acc: 0.9460

Cleanup called...


 30/179 [====>.........................] - ETA: 17s - train_acc: 0.9035

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.9026

Cleanup called...




Cleanup called...




Cleanup called...



epoch 15/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.8652

Cleanup called...


 30/179 [====>.........................] - ETA: 19s - train_acc: 0.8887

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.8890

Cleanup called...




Cleanup called...




Cleanup called...



epoch 16/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9525

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.9287

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.9280

Cleanup called...




Cleanup called...




Cleanup called...



epoch 17/30
  4/179 [..............................] - ETA: 21s - train_acc: 0.9375

Cleanup called...


 30/179 [====>.........................] - ETA: 17s - train_acc: 0.9090

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.9093

Cleanup called...




Cleanup called...




Cleanup called...



epoch 18/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9173

Cleanup called...


 31/179 [====>.........................] - ETA: 19s - train_acc: 0.9372

Cleanup called...


 32/179 [====>.........................] - ETA: 18s - train_acc: 0.9372

Cleanup called...




Cleanup called...




Cleanup called...



epoch 19/30
  3/179 [..............................] - ETA: 23s - train_acc: 0.9514

Cleanup called...


 31/179 [====>.........................] - ETA: 18s - train_acc: 0.9376

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.9377

Cleanup called...




Cleanup called...




Cleanup called...



epoch 20/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9674

Cleanup called...


 30/179 [====>.........................] - ETA: 18s - train_acc: 0.9458

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.9450

Cleanup called...




Cleanup called...




Cleanup called...



epoch 21/30
  4/179 [..............................] - ETA: 41s - train_acc: 0.9603

Cleanup called...


 30/179 [====>.........................] - ETA: 21s - train_acc: 0.9543

Cleanup called...


 33/179 [====>.........................] - ETA: 20s - train_acc: 0.9545

Cleanup called...




Cleanup called...




Cleanup called...



epoch 22/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9668

Cleanup called...


 30/179 [====>.........................] - ETA: 18s - train_acc: 0.9634

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.9632

Cleanup called...




Cleanup called...




Cleanup called...



epoch 23/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9616

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.9645

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.9644

Cleanup called...




Cleanup called...




Cleanup called...



epoch 24/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9772

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.9706

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.9704

Cleanup called...




Cleanup called...




Cleanup called...



epoch 25/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9753

Cleanup called...


 31/179 [====>.........................] - ETA: 17s - train_acc: 0.9690

Cleanup called...


 32/179 [====>.........................] - ETA: 17s - train_acc: 0.9689

Cleanup called...




Cleanup called...




Cleanup called...



epoch 26/30
  4/179 [..............................] - ETA: 25s - train_acc: 0.9661

Cleanup called...


 30/179 [====>.........................] - ETA: 17s - train_acc: 0.9647

Cleanup called...


 32/179 [====>.........................] - ETA: 17s - train_acc: 0.9647

Cleanup called...




Cleanup called...




Cleanup called...



epoch 27/30
  4/179 [..............................] - ETA: 21s - train_acc: 0.9362

Cleanup called...


 31/179 [====>.........................] - ETA: 18s - train_acc: 0.9743

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.9748

Cleanup called...




Cleanup called...




Cleanup called...



epoch 28/30
  4/179 [..............................] - ETA: 21s - train_acc: 0.9980

Cleanup called...


 31/179 [====>.........................] - ETA: 18s - train_acc: 0.9883

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.9882

Cleanup called...




Cleanup called...




Cleanup called...



epoch 29/30
  4/179 [..............................] - ETA: 24s - train_acc: 0.9824

Cleanup called...


 31/179 [====>.........................] - ETA: 18s - train_acc: 0.9839

Cleanup called...


 33/179 [====>.........................] - ETA: 18s - train_acc: 0.9839

Cleanup called...




Cleanup called...




Cleanup called...



epoch 30/30
  4/179 [..............................] - ETA: 22s - train_acc: 0.9870

Cleanup called...


 30/179 [====>.........................] - ETA: 17s - train_acc: 0.9891

Cleanup called...


 33/179 [====>.........................] - ETA: 17s - train_acc: 0.9891

Cleanup called...




Cleanup called...




Cleanup called...


