In [1]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_dir = './dataset/training'
test_dir = './dataset/testing'

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2)  

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(128, 128),  
    batch_size=32,
    class_mode='categorical',  
    subset='training')  

validation_generator = train_datagen.flow_from_directory(
    train_dir,  
    target_size=(128, 128),
    batch_size=32,
    class_mode='categorical',
    subset='validation')  


Found 4571 images belonging to 4 classes.
Found 1141 images belonging to 4 classes.


#### Adjusting CapsNet Architecture for Detection
##### Capsule Network Design

In [4]:
import numpy as np
from tensorflow.keras import layers, models
from capsule_layers import CapsuleLayer, Mask, margin_loss, squash

def CapsNet(input_shape, n_classes, routings):
    x = layers.Input(shape=input_shape)

    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(x)

    primarycaps = layers.Reshape(target_shape=[-1, 8], name='primarycap_reshape')(conv1)
    primarycaps = layers.Lambda(squash, name='primarycap_squash')(primarycaps)
    primarycaps = layers.BatchNormalization()(primarycaps)


    digitcaps = CapsuleLayer(num_capsule=n_classes, dim_capsule=16, routings=routings, name='digitcaps')(primarycaps)


    y = layers.Input(shape=(n_classes,))
    masked_by_y = Mask()([digitcaps, y])  

    # Decoder network
    decoder = models.Sequential([
        layers.Dense(512, activation='relu'),
        layers.Dense(1024, activation='relu'),
        layers.Dense(np.prod(input_shape), activation='sigmoid'),
        layers.Reshape(target_shape=input_shape)
    ], name='decoder')(masked_by_y)

    model = models.Model([x, y], [digitcaps, decoder])

    model.compile(optimizer='adam',
                  loss=[margin_loss, 'mse'],  
                  loss_weights=[1., 0.392], 
                  metrics={'capsnet': 'accuracy'})

    return model
