<a href="https://colab.research.google.com/github/ericatsu/msc_capsmodel/blob/main/caps_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import numpy as np
import os

In [None]:
# Mount Google Drive (if using Google Colab)
from google.colab import drive
drive.mount('/content/drive')

# Define constants
IMG_HEIGHT, IMG_WIDTH = 128, 128
BATCH_SIZE = 32
NUM_CLASSES = 4
EPOCHS = 10
TRAIN_DIR = "/content/drive/MyDrive/msc_capsmodel/dataset/training"
TEST_DIR = "/content/drive/MyDrive/msc_capsmodel/dataset/testing"

Mounted at /content/drive


In [None]:
# Data loading and preprocessing with augmentation
def load_data(data_dir, is_training=True):
    if is_training:
        datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True,
            zoom_range=0.2,
            brightness_range=[0.8, 1.2],
        )
    else:
        datagen = ImageDataGenerator(rescale=1./255)

    return datagen.flow_from_directory(
        data_dir,
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        batch_size=BATCH_SIZE,
        class_mode='categorical'
    )

train_data = load_data(TRAIN_DIR, is_training=True)
test_data = load_data(TEST_DIR, is_training=False)

Found 5711 images belonging to 4 classes.
Found 1311 images belonging to 4 classes.


In [None]:
# Inspect a batch of data
for data_batch, labels_batch in train_data:
    print(data_batch.shape)
    print(labels_batch.shape)
    break

(32, 128, 128, 3)
(32, 4)


In [None]:
# Squash activation function for capsules
def squash(vectors, axis=-1):
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis=axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + 1e-8)
    return scale * vectors

In [None]:
# Capsule layer
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsules, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsules = dim_capsules
        self.routings = routings

    def build(self, input_shape):
        self.W = self.add_weight(shape=[1, input_shape[1], self.num_capsules, self.dim_capsules, input_shape[-1]],
                                 initializer='glorot_uniform',
                                 name='W')

    def call(self, inputs):
        u = tf.expand_dims(inputs, 2)
        u = tf.expand_dims(u, 3)
        u_hat = tf.reduce_sum(self.W * u, axis=-1)

        b = tf.zeros(shape=[tf.shape(inputs)[0], inputs.shape[1], self.num_capsules, 1])

        for i in range(self.routings):
            c = tf.nn.softmax(b, axis=2)
            s = tf.reduce_sum(c * u_hat, axis=1, keepdims=True)
            v = squash(s, axis=-1)
            if i < self.routings - 1:
                b += tf.reduce_sum(u_hat * v, axis=-1, keepdims=True)

        return tf.squeeze(v, axis=1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.num_capsules, self.dim_capsules)

In [None]:
# Custom loss function for capsule network
def margin_loss(y_true, y_pred):
    num_classes = tf.shape(y_pred)[-1]
    y_true = tf.one_hot(tf.argmax(y_true, axis=1), depth=num_classes)
    m_plus = 0.9
    m_minus = 0.1
    lambda_val = 0.5

    y_pred = tf.clip_by_value(y_pred, 1e-9, 1.0 - 1e-9)

    L = y_true * tf.square(tf.maximum(0., m_plus - y_pred)) + \
        lambda_val * (1 - y_true) * tf.square(tf.maximum(0., y_pred - m_minus))

    return tf.reduce_mean(tf.reduce_sum(L, axis=1))

# Create the hybrid CNN-Capsule model
def create_hybrid_model():
    # Load pre-trained ResNet50
    base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))

    # Unfreeze the last few layers of ResNet50
    for layer in base_model.layers[-10:]:
        layer.trainable = True

    # Add custom layers on top
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    x = tf.keras.layers.Reshape((-1, 256))(x)
    capsule = CapsuleLayer(num_capsules=NUM_CLASSES, dim_capsules=16)(x)
    outputs = tf.keras.layers.Lambda(lambda z: tf.sqrt(tf.reduce_sum(tf.square(z), axis=-1)))(capsule)
    model = Model(inputs=base_model.input, outputs=outputs)

    return model

In [None]:
# Create and compile the model
model = create_hybrid_model()
model.compile(optimizer=optimizers.Adam(learning_rate=1e-4), loss=margin_loss, metrics=['accuracy'])

# Print model summary
model.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step


In [None]:
# Learning rate scheduler
def lr_schedule(epoch):
    lr = 1e-4
    if epoch > 10:
        lr *= 0.1
    if epoch > 20:
        lr *= 0.1
    return lr

In [None]:
# Callbacks for early stopping and learning rate reduction
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7)

In [11]:
# Train the model
history = model.fit(
    train_data,
    epochs=EPOCHS,
    validation_data=test_data,
    callbacks=[tf.keras.callbacks.LearningRateScheduler(lr_schedule), early_stopping, reduce_lr]
)

Epoch 1/10


  self._warn_if_super_not_called()


[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4328s[0m 23s/step - accuracy: 0.4187 - loss: 0.5593 - val_accuracy: 0.1747 - val_loss: 0.5875 - learning_rate: 1.0000e-04
Epoch 2/10
[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1898s[0m 11s/step - accuracy: 0.8169 - loss: 0.1693 - val_accuracy: 0.2082 - val_loss: 0.6222 - learning_rate: 1.0000e-04
Epoch 3/10
[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1995s[0m 11s/step - accuracy: 0.8722 - loss: 0.1172 - val_accuracy: 0.2845 - val_loss: 0.5898 - learning_rate: 1.0000e-04
Epoch 4/10
[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1899s[0m 11s/step - accuracy: 0.8900 - loss: 0.1027 - val_accuracy: 0.3402 - val_loss: 0.6589 - learning_rate: 1.0000e-04
Epoch 5/10
[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2005s[0m 11s/step - accuracy: 0.9123 - loss: 0.0826 - val_accuracy: 0.7414 - val_loss: 0.2477 - learning_rate: 1.0000e-04
Epoch 6/10
[1m179/179[0m [32m━━━━━━━━━━━━━━

In [12]:
model.save('/content/drive/MyDrive/msc_capsmodel/saved_model/capsModelV1.h5')



In [13]:
model.save('/content/drive/MyDrive/msc_capsmodel/saved_model/capsModelV1.keras')

In [14]:
test_loss, test_accuracy = model.evaluate(test_data)
print(f"Test accuracy: {test_accuracy:.4f}")

[1m41/41[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 2s/step - accuracy: 0.9096 - loss: 0.0882
Test accuracy: 0.9008
