In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

# %matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import os
import time

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

y_train=tf.keras.utils.to_categorical(y_train)
y_test=tf.keras.utils.to_categorical(y_test)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [3]:
print(x_train.shape,x_test.shape,y_train.shape,y_test.shape)

(50000, 32, 32, 3) (10000, 32, 32, 3) (50000, 10) (10000, 10)


In [4]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_dataset = test_dataset.batch(BATCH_SIZE)


2022-10-23 02:29:27.128090: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [5]:
#optimization parameter setting.
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)

In [6]:
@tf.function
def squash(x, axis=-1):
    s_squared_norm = tf.math.reduce_sum(tf.math.square(x), axis, keepdims=True) + keras.backend.epsilon()
    scale = tf.math.sqrt(s_squared_norm) / (1 + s_squared_norm)
    return scale * x

@tf.function
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return tf.math.reduce_sum((y_true * tf.math.square(tf.nn.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * tf.math.square(tf.nn.relu(y_pred - margin))), axis=-1)

#@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [7]:
class Capsule(keras.layers.Layer):
   

    def __init__(self,
                 num_capsule,
                 dim_capsule,
                 routings=3,
                 **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.caps_n = num_capsule
        self.caps_dim = dim_capsule
        self.r = routings

    def get_config(self):
        config = super().get_config().copy()
        config.update({
        'num_capsule':  self.caps_n,
        'dim_capsule' : self.caps_dim,
        'routings':  self.r,      
        })
        return config

    def build(self, input_shape):
        
        self.W = self.add_weight(name='W',
                    shape=[1, input_shape[1], self.caps_n, self.caps_dim, input_shape[-1]],
                    dtype=tf.float32,
                    initializer='glorot_uniform',
                    trainable=True)

    def call(self, input_tensor):
        batch_size = input_tensor.shape[0]
        
        W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1]) # replicating the weights for parallel processing of a batch.
        """ W_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i),caps_dim(i-1)] """

        caps_output_expanded = tf.expand_dims(input_tensor, -1) # converting last dim to a column vector.
        """ the above step change the input shape from 
            [batch_size,caps_n(i-1),caps_dim(i-1)] --> [batch_size,caps_n(i-1),caps_dim(i-1),1]"""

        caps_output_tile = tf.expand_dims(caps_output_expanded, 2)
        """ the above step change the input shape from 
            [batch_size,caps_n(i-1),caps_dim(i-1),1] --> [batch_size,caps_n(i-1),1,caps_dim(i-1),1]"""

        caps_output_tiled = tf.tile(caps_output_tile, [1, 1, self.caps_n, 1, 1]) # replicating the input capsule vector for every output capsule.
        """ i.e [batch_size,caps_n(i-1),1,caps_dim(i-1),1] --> [batch_size,caps_n(i-1),caps_n(i),1,caps_dim(i-1),1]"""

        caps_predicted = tf.matmul(W_tiled, caps_output_tiled) # this is performing element wise tf.matmul() operation.
        """ caps_predicted.shape = [1,caps_n(i-1),caps_n(i),caps_dim(i),1]"""

        """ dynamic routing """
        raw_weights = tf.zeros([batch_size,input_tensor.shape[1] , self.caps_n, 1, 1]) # non trainable weights.
        """ raw_weights.shape=[batch_size,caps_n(i-1) ,caps_n(i), 1, 1]"""

        r=self.r
        while(r):
          r-=1
          routing_weights = tf.nn.softmax(raw_weights,axis=2)
          """ [batch_size,caps_n(i-1) ,caps_n(i), 1, 1]  softmax applied along the pointed dim.
                                           ^                                                   """

          weighted_predictions = tf.multiply(routing_weights, caps_predicted)
          """ weighted_predictions.shape = [batch_size, caps_n(i-1),caps_n(i),caps_dim(i), 1]"""

          weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True)
          """ [batch_size,caps_n(i-1) ,caps_n(i),caps_dim(i), 1]  sum applied along the pointed dim.
                               ^                                                               
          therefore weighted_sum.shape=[batch_size,1 ,caps_n(i),caps_dim(i), 1]"""

          v = squash(weighted_sum, axis=-2) #normalize to unit length vector.
          v_tiled = tf.tile(v, [1, input_tensor.shape[1], 1, 1, 1])
          """ v_tiled.shape=[batch_size,caps_n(i-1),caps_n(i),caps_dim(i), 1]"""

          agreement = tf.matmul(caps_predicted, v_tiled,transpose_a=True)
          """ agreement.shape=[batch_size,caps_n(i-1),caps_n(i), 1, 1]"""

          if(r>0):
              routing_weights+=agreement
          else:
              v = tf.squeeze(v, axis=[1,4])
              return v


In [8]:
c1=tf.keras.layers.Conv2D(16,kernel_size=5,strides=1,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(32,kernel_size=5,strides=2,padding='valid',activation='relu')
c3=tf.keras.layers.Conv2D(64,kernel_size=5,strides=1,padding='valid',activation='relu')
c4=tf.keras.layers.Conv2D(128,kernel_size=3,strides=2,padding='valid',activation='relu')
dc1=tf.keras.layers.DepthwiseConv2D(kernel_size=9,strides=(1, 1),padding='valid',activation='relu')
last=Capsule(10,16)
bn1=tf.keras.layers.BatchNormalization()
bn2=tf.keras.layers.BatchNormalization()
bn3=tf.keras.layers.BatchNormalization()
bn4=tf.keras.layers.BatchNormalization()

In [9]:
model_input = keras.Input(shape=(32, 32, 3), batch_size=64)

In [10]:
dc1(c2(c1(model_input)))

<KerasTensor: shape=(64, 4, 4, 32) dtype=float32 (created by layer 'depthwise_conv2d')>

In [11]:
model_input = keras.Input(shape=(32, 32, 3), batch_size=64)
x=c1(model_input)
x=bn1(x,training=True)
x=c2(x)
x=bn2(x,training=True)
x=dc1(x)
x=tf.reshape(x,[-1,32,16])
x=last(x)
x=safe_norm(x, axis=2)
model_output = x


In [12]:
model = keras.Model(model_input, model_output, name="encoder")

In [13]:
adam = tf.keras.optimizers.Adam(learning_rate=0.0001) 

model.compile(loss=margin_loss, optimizer=adam, metrics=tf.keras.metrics.CategoricalAccuracy())
model.summary()

Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(64, 32, 32, 3)]         0         
_________________________________________________________________
conv2d (Conv2D)              (64, 28, 28, 16)          1216      
_________________________________________________________________
batch_normalization (BatchNo (64, 28, 28, 16)          64        
_________________________________________________________________
conv2d_1 (Conv2D)            (64, 12, 12, 32)          12832     
_________________________________________________________________
batch_normalization_1 (Batch (64, 12, 12, 32)          128       
_________________________________________________________________
depthwise_conv2d (DepthwiseC (64, 4, 4, 32)            2624      
_________________________________________________________________
tf.reshape (TFOpLambda)      (64, 32, 16)              0   

In [14]:
"""customize training loop."""

# Instantiate an optimizer to train the model.
base_learning_rate = 0.0001
optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate)
# Instantiate a loss function.
loss_fn = margin_loss

# Prepare the metrics.
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

In [15]:
epochs = 100
for epoch in range(epochs):
    print("\nepoch {}/{}".format(epoch+1,epochs))
    pbar = keras.utils.Progbar(target=int(train_dataset.cardinality()))
    metrics = {}

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_true) in enumerate(train_dataset):
        #y_true = tf.keras.utils.to_categorical(y_batch_train,num_classes=4)
        with tf.GradientTape() as tape:
            y_pred=model(x_batch_train) # $ better design needed.
            # y_pred is prob. dist.
            loss_value = loss_fn(y_true,y_pred) # loss computation
        grads = tape.gradient(loss_value, model.trainable_weights) # back prop
        optimizer.apply_gradients(zip(grads, model.trainable_weights)) # weight update

        # Update training metric.
        train_acc_metric.update_state(y_true, y_pred)
        metrics.update({'train_acc':train_acc_metric.result()})
        pbar.update(step+1, values=metrics.items(), finalize=False)


    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in validation_dataset:
      #y_batch_val=tf.keras.utils.to_categorical(y_batch_val,num_classes=4)
      val_pred = model(x_batch_val) # $ better design needed
      # Update val metrics
      val_acc_metric.update_state(y_batch_val, val_pred)

    metrics.update({'val_acc':val_acc_metric.result()})
    
    pbar.update(step+1, values=metrics.items(), finalize=True)
    
    # Reset training & val metrics at the end of each epoch
    train_acc_metric.reset_states()
    val_acc_metric.reset_states()


epoch 1/100


2022-10-23 02:29:32.180334: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)



epoch 2/100

epoch 3/100

epoch 4/100

epoch 5/100

epoch 6/100

epoch 7/100

epoch 8/100

epoch 9/100

epoch 10/100

epoch 11/100

epoch 12/100

epoch 13/100

epoch 14/100

epoch 15/100

epoch 16/100

epoch 17/100

epoch 18/100

epoch 19/100

epoch 20/100

epoch 21/100

epoch 22/100

epoch 23/100

epoch 24/100

epoch 25/100

epoch 26/100

epoch 27/100

epoch 28/100

epoch 29/100

epoch 30/100

epoch 31/100

epoch 32/100

epoch 33/100

epoch 34/100

epoch 35/100

epoch 36/100

epoch 37/100

epoch 38/100

epoch 39/100

epoch 40/100

epoch 41/100

epoch 42/100

epoch 43/100

epoch 44/100

epoch 45/100

epoch 46/100

epoch 47/100

epoch 48/100

epoch 49/100

epoch 50/100

epoch 51/100

epoch 52/100

epoch 53/100

epoch 54/100

epoch 55/100

epoch 56/100

epoch 57/100

epoch 58/100

epoch 59/100

epoch 60/100

epoch 61/100

epoch 62/100

epoch 63/100

epoch 64/100

epoch 65/100

epoch 66/100

epoch 67/100

epoch 68/100

epoch 69/100

epoch 70/100

epoch 71/100

epoch 72/100

epoch 73/100
