In [None]:
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import Conv2D, Dense, Input, Reshape, Lambda, Layer, Flatten
from keras import backend as K
import numpy as np
from tqdm import tqdm

from keras import initializers, regularizers
from keras.utils import to_categorical
from keras.layers.core import Activation
import pathlib
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
import matplotlib.pyplot as plt

In [None]:
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 64
img_height = 56
img_width = 56

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


In [None]:
gen = ImageDataGenerator(rescale=1./255, validation_split=0.2, rotation_range=8, width_shift_range=0.08, shear_range=0.3, height_shift_range=0.08, zoom_range=0.08)

train_ds = gen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    shuffle=True,
    seed=123,
    subset='training',
    class_mode='sparse'
)

val_ds = gen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    class_mode='sparse',
    batch_size=batch_size,
    shuffle=True,
    seed=123,
    subset='validation'
)

num_classes = len(train_ds.class_indices)

Found 2939 images belonging to 5 classes.
Found 731 images belonging to 5 classes.


In [None]:
checkpoint_path = 'drive/My Drive/TIES4911/capsule_network/checkpoint'
model_weight_path = 'drive/My Drive/TIES4911/capsule_network/model_weight'
# File writer
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
scalar_logdir = 'drive/My Drive/TIES4911/capsule_network/logs/scalars/%s' % stamp
file_writer = tf.summary.create_file_writer(scalar_logdir + "/metrics")

In [None]:
dataset_len = 1000
training_dataset_size = batch_size * dataset_len
testing_dataset_size = batch_size * dataset_len * 0.2

In [None]:
class CapsuleNetwork(tf.keras.Model):
    def __init__(self, no_of_conv_kernels, no_of_primary_capsules, primary_capsule_vector, no_of_secondary_capsules, secondary_capsule_vector, r):
        super(CapsuleNetwork, self).__init__()
        self.no_of_conv_kernels = no_of_conv_kernels
        self.no_of_primary_capsules = no_of_primary_capsules
        self.primary_capsule_vector = primary_capsule_vector
        self.no_of_secondary_capsules = no_of_secondary_capsules
        self.secondary_capsule_vector = secondary_capsule_vector
        self.r = r
        
        with tf.name_scope("Variables") as scope:
            self.convolution = tf.keras.layers.Conv2D(self.no_of_conv_kernels, 48, strides=[1,1], kernel_initializer="he_normal")
            self.convolution2 = tf.keras.layers.BatchNormalization(axis=3)
            self.convolution3 = tf.keras.layers.Activation('relu')
            self.primary_capsule = tf.keras.layers.Conv2D(self.no_of_primary_capsules * self.primary_capsule_vector, [9,9], strides=[2,2], name="PrimaryCapsule")
            self.w = tf.Variable(tf.random_normal_initializer()(shape=[1, 64, self.no_of_secondary_capsules, self.secondary_capsule_vector, self.primary_capsule_vector]), dtype=tf.float32, name="PoseEstimation", trainable=True)
            self.dense_1 = tf.keras.layers.Dense(units = 512, activation='relu')
            self.dense_2 = tf.keras.layers.Dense(units = 1024, activation='relu')
            self.dense_3 = tf.keras.layers.Dense(units = 9408, activation='sigmoid', dtype='float32')
        
    def build(self, input_shape):
        pass
        
    def squash(self, s):
        with tf.name_scope("SquashFunction") as scope:
            s_norm = tf.norm(s, axis=-1, keepdims=True)
            return tf.square(s_norm)/(1 + tf.square(s_norm)) * s/(s_norm + epsilon)
    
    @tf.function
    def call(self, inputs):
        input_x, y = inputs
        
        x = self.convolution(input_x) # x.shape: (None, 9, 9, 256)
        x = self.convolution2(x)
        x = self.convolution3(x)
        x = self.primary_capsule(x) # x.shape: (None, 1, 1, 512)

        print('output x: ', x.shape)
        
        with tf.name_scope("CapsuleFormation") as scope:
            u = tf.reshape(x, (-1, self.no_of_primary_capsules * x.shape[1] * x.shape[2], 8)) # u.shape: (None, 64, 8)
            u = self.squash(u)
            u = tf.expand_dims(u, axis=-2) # u.shape: (None, 64, 1, 8)
            u = tf.expand_dims(u, axis=-1) # u.shape: (None, 64, 1, 8, 1)
            u_hat = tf.matmul(self.w, u) # u_hat.shape: (None, 64, 5, 16, 1)
            u_hat = tf.squeeze(u_hat, [4]) # u_hat.shape: (None, 64, 5, 16)

        
        with tf.name_scope("DynamicRouting") as scope:
            b = tf.zeros((input_x.shape[0], self.no_of_primary_capsules * x.shape[1] * x.shape[2], self.no_of_secondary_capsules, 1)) # b.shape: (None, 64, 5, 1)
            for i in range(self.r): # self.r = 3
                c = tf.nn.softmax(b, axis=-2) # c.shape: (None, 64, 5, 1)
                s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True) # s.shape: (None, 1, 5, 16)
                v = self.squash(s) # v.shape: (None, 1, 5, 16)
                agreement = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), [4]) # agreement.shape: (None, 64, 5, 1)
                # Before matmul following intermediate shapes are present, they are not assigned to a variable but just for understanding the code.
                # u_hat.shape (Intermediate shape) : (None, 64, 5, 16, 1)
                # v.shape (Intermediate shape): (None, 1, 5, 16, 1)
                # Since the first parameter of matmul is to be transposed its shape becomes:(None, 64, 5, 1, 16)
                # Now matmul is performed in the last two dimensions, and others are broadcasted
                # Before squeezing we have an intermediate shape of (None, 64, 5, 1, 1)
                b += agreement
        
        with tf.name_scope("Masking") as scope:
            print('y: ', y.shape)
            y = tf.expand_dims(y, axis=-1) # y.shape: (None, 5, 1)
            y = tf.expand_dims(y, axis=1) # y.shape: (None, 1, 5, 1)
            mask = tf.cast(y, dtype=tf.float32) # mask.shape: (None, 1, 5, 1)
            print('v: ', v.shape)
            print('mask; ', mask.shape)
            v_masked = tf.multiply(mask, v) # v_masked.shape: (None, 1, 5, 16)
            print('v_masked; ', v_masked.shape)
            
        with tf.name_scope("Reconstruction") as scope:
            v_ = tf.reshape(v_masked, [-1, self.no_of_secondary_capsules * self.secondary_capsule_vector]) # v_.shape: (None, 80)
            print('v_: ', v_.shape)
            reconstructed_image = self.dense_1(v_) # reconstructed_image.shape: (None, 512)
            reconstructed_image = self.dense_2(reconstructed_image) # reconstructed_image.shape: (None, 1024)
            reconstructed_image = self.dense_3(reconstructed_image) # reconstructed_image.shape: (None, 9408)
            print('reconstructed_image: ', reconstructed_image.shape)
        
        return v, reconstructed_image
    @tf.function
    def predict_capsule_output(self, inputs):
        x = self.convolution(inputs)
        x = self.convolution2(x)
        x = self.convolution3(x)
        x = self.primary_capsule(x)
        
        with tf.name_scope("CapsuleFormation") as scope:
            u = tf.reshape(x, (-1, self.no_of_primary_capsules * x.shape[1] * x.shape[2], 8))
            u = self.squash(u)
            u = tf.expand_dims(u, axis=-2)
            u = tf.expand_dims(u, axis=-1)
            u_hat = tf.matmul(self.w, u)
            u_hat = tf.squeeze(u_hat, [4])

        
        with tf.name_scope("DynamicRouting") as scope:
            b = tf.zeros((inputs.shape[0], self.no_of_primary_capsules * x.shape[1] * x.shape[2], self.no_of_secondary_capsules, 1))
            for i in range(self.r):
                c = tf.nn.softmax(b, axis=-2)
                s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True)
                v = self.squash(s)
                agreement = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), [4])
                b += agreement

        return v

In [None]:
def safe_norm(v, axis=-1, epsilon=1e-7):
    v_ = tf.reduce_sum(tf.square(v), axis = axis, keepdims=True)
    return tf.sqrt(v_ + epsilon)

def get_val_loss(x,y):
    y_one_hot = tf.one_hot(y, depth=num_classes)
    v, reconstructed_image = model([x, y_one_hot])

    print('y: ', y)
    print('pred: ', normalize_prediction(v))
    print('sum: ', sum(y==normalize_prediction(v)))
    loss = loss_function(v, reconstructed_image, y_one_hot, x)
    val_acc_metric(y, normalize_prediction(v))
    return loss

def train(x,y):
    y_one_hot = tf.one_hot(y, depth=num_classes)
    #print('y_one_hot: ', y_one_hot.shape)
    with tf.GradientTape() as tape:
        v, reconstructed_image = model([x, y_one_hot])
        loss = loss_function(v, reconstructed_image, y_one_hot, x)
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))

    print('y: ', y)
    print('pred: ', normalize_prediction(v))
    print('sum: ', sum(y==normalize_prediction(v)))
    train_acc_metric(y, normalize_prediction(v))
    return loss

def loss_function(v, reconstructed_image, y, y_image):   
    prediction = safe_norm(v)
    prediction = tf.reshape(prediction, [-1, no_of_secondary_capsules])
    
    print('prediction: ', prediction.shape)
    left_margin = tf.square(tf.maximum(0.0, m_plus - prediction))
    right_margin = tf.square(tf.maximum(0.0, prediction - m_minus))
    
    l = tf.add(y * left_margin, lambda_ * (1.0 - y) * right_margin)

    margin_loss = tf.reduce_mean(tf.reduce_sum(l, axis=-1))
    
    y_image_flat = tf.reshape(y_image, [-1, np.prod(y_image.shape[1:])])
    reconstruction_loss = tf.reduce_mean(tf.square(y_image_flat - reconstructed_image))
    
    loss = tf.add(margin_loss, alpha * reconstruction_loss)    
    
    return loss

def normalize_prediction(v):
  pred = safe_norm(v)
  pred = tf.squeeze(pred, [1])
  return np.argmax(pred, axis=1)[:,0]

def predict(model, x):
  v = model.predict_capsule_output(x)
  return normalize_prediction(v)

In [None]:
# Parameters Based on Paper
epsilon = K.epsilon()
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5
alpha = 0.0005
epochs = 50
no_of_secondary_capsules = num_classes

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
train_acc_metric = tf.keras.metrics.Accuracy()
val_acc_metric = tf.keras.metrics.Accuracy()

params = {
    "no_of_conv_kernels": 256,
    "no_of_primary_capsules": 64,
    "no_of_secondary_capsules": num_classes,
    "primary_capsule_vector": 8,
    "secondary_capsule_vector": 16,
    "r":3,
}

model = CapsuleNetwork(**params)
#model.load_weights(model_weight_path + '-epoch 4')

In [None]:
epochs =  1

checkpoint = tf.train.Checkpoint(model=model)

losses = []
accuracy = []
val_losses = []
val_accuary = []
for i in range(1, epochs+1, 1):
  loss = 0
  val_loss = 0

  with tqdm(total=dataset_len) as pbar:
    description = "Epoch " + str(i) + "/" + str(epochs)
    pbar.set_description_str(description)

    step_per_epoch = 0
    len = 0
    for X_batch, y_batch in train_ds:
      step_per_epoch += 1
      len += X_batch.shape[0]
      loss += train(X_batch,y_batch)

      pbar.update(1)

      if step_per_epoch == dataset_len:
        break        
      #break

    model.save_weights(model_weight_path + '-epoch ' + str(i))
    checkpoint.save(checkpoint_path)

    print('train_ds len: ', len)
    print('train_ds loss step: ', step_per_epoch)
    #loss /= len(dataset)
    loss /= step_per_epoch
    losses.append(loss.numpy())

    #print_statement = "Loss :" + str(loss.numpy()) + " Evaluating Accuracy ..."
    #pbar.set_postfix_str(print_statement)

    train_acc = train_acc_metric.result()
    train_acc_metric.reset_states()
    accuracy.append(train_acc.numpy())

    step_per_epoch = 0
    len = 0
    for X_batch, y_batch in val_ds:
        step_per_epoch += 1
        len += X_batch.shape[0]
        val_loss += get_val_loss(X_batch,y_batch)

        if step_per_epoch == dataset_len * 0.2:
            break        
        #break
    print('val_ds len: ', len)
    print('val_ds loss step: ', step_per_epoch)    
    #loss /= len(dataset)
    val_loss /= step_per_epoch
    val_losses.append(val_loss.numpy())

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()   
    val_accuary.append(val_acc.numpy())     
    
    with file_writer.as_default():
      tf.summary.scalar('loss', data=loss.numpy(), step=i)
      tf.summary.scalar('val_loss', data=val_loss.numpy(), step=i)
      tf.summary.scalar('accuracy', data=accuracy[-1], step=i)
      tf.summary.scalar('val_accuracy', data=val_accuary[-1], step=i)
    
    print_statement = "loss :" + str(loss.numpy()) + ", accuracy :" + str(accuracy[-1]) + ", val_loss :" + str(val_loss.numpy()) + ", val_accuracy :" + str(val_accuary[-1])
    pbar.set_postfix_str(print_statement)

model.save_weights(model_weight_path) 
#model.save('./capsule_network/flower_photos_model')

print('loss: ', losses)
print('val_loss: ', val_losses)
print('accuracy: ', accuracy)
print('val_accuracy: ', val_accuary)

Epoch 1/1:   0%|          | 0/1000 [00:00<?, ?it/s]

output x:  (64, 1, 1, 512)
y:  (64, 5)
v:  (64, 1, 5, 16)
mask;  (64, 1, 5, 1)
v_masked;  (64, 1, 5, 16)
v_:  (64, 80)
reconstructed_image:  (64, 9408)
output x:  (64, 1, 1, 512)
y:  (64, 5)
v:  (64, 1, 5, 16)
mask;  (64, 1, 5, 1)
v_masked;  (64, 1, 5, 16)
v_:  (64, 80)
reconstructed_image:  (64, 9408)
prediction:  (64, 5)


Epoch 1/1:   0%|          | 1/1000 [00:11<3:08:58, 11.35s/it]

y:  [0. 4. 2. 2. 1. 0. 2. 4. 1. 3. 3. 4. 1. 2. 0. 2. 1. 0. 4. 3. 0. 2. 3. 1.
 2. 1. 1. 3. 3. 3. 1. 2. 1. 2. 2. 3. 2. 1. 0. 3. 0. 0. 0. 0. 4. 2. 1. 1.
 3. 1. 3. 4. 0. 3. 3. 1. 1. 1. 4. 3. 1. 4. 1. 2.]
pred:  [4 4 4 4 0 4 0 0 4 4 0 4 3 3 0 0 4 3 4 4 4 4 4 4 4 4 4 3 3 4 4 3 4 0 0 4 4
 0 4 4 3 4 4 0 0 4 4 4 4 3 4 3 3 4 4 4 4 4 4 3 4 4 4 0]
sum:  10
train_ds len:  64
train_ds loss step:  1


Epoch 1/1:   0%|          | 1/1000 [00:12<3:31:24, 12.70s/it, loss :0.75446194, accuracy :0.15625, val_loss :0.449345, val_accuracy :0.1875 Checkpoint Saved]

y:  [1. 2. 3. 0. 1. 3. 1. 1. 0. 3. 4. 2. 4. 0. 3. 0. 3. 0. 0. 2. 4. 3. 2. 0.
 2. 4. 2. 2. 0. 3. 1. 1. 2. 4. 4. 2. 4. 1. 2. 4. 1. 3. 2. 1. 4. 1. 2. 4.
 0. 3. 0. 1. 3. 3. 0. 1. 1. 4. 1. 1. 1. 1. 3. 4.]
pred:  [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3]
sum:  12
prediction:  (64, 5)
val_ds len:  64
val_ds loss step:  1
loss:  [0.75446194]
val_loss:  [0.449345]
accuracy:  [0.15625]
val_accuracy:  [0.1875]





In [None]:
loss:  [0.4398154, 0.35719398, 0.3411422, 0.3387112, 0.33299312]
val_loss:  [0.39099583, 0.35036805, 0.33167157, 0.32724562, 0.32251003]
accuracy:  [0.26430517, 0.39441416, 0.4332425, 0.4294959, 0.4400545]
val_accuracy:  [0.39918256, 0.4414169, 0.4509537, 0.4509537, 0.45912805]