In [61]:
import numpy as np
import h5py
import keras.optimizers as optimizers
from keras.models import Model, Sequential
import keras.layers as layers
import keras.backend as K
from matplotlib import pyplot as plt

from src.capsulelayers import PrimaryCap, CapsuleLayer, Length

%matplotlib inline

In [62]:
def buildNetwork(
                input_shape=(28, 28, 1),
                n_class=10,
                dim_capsule=16,
                routings=3):
    
    # Base network
    x = layers.Input(shape=input_shape)
    
    conv1 = layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='valid', activation='relu')(x)

    primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=16, kernel_size=3, strides=2, padding='valid')

    digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
                             name='digitcaps')(primarycaps)

    out_caps = Length(name='capsnet')(digitcaps)
    

    # Decoder network.
    decoder = Sequential(name='decoder')
    decoder.add(layers.Reshape((16*n_class,), input_shape=(n_class, 16)))
    decoder.add(layers.Dense(512, activation='relu'))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = Model(x, decoder(digitcaps))
    eval_model = Model(x, out_caps)

    return train_model, eval_model

In [63]:
train_model, eval_model = buildNetwork(input_shape=(160, 60, 3))

In [64]:
def trainGenerator(database="cuhk.h5", batch_size=32):
    # Open database
    with h5py.File(database, "r") as db:
        n_ids = len(db["train"])
        image_shape = db["train"]["0"].shape[1:]

        while True:
            batch_x = np.zeros((batch_size, *image_shape))

            for index in range(batch_size):
                pair_id = np.random.choice(n_ids)
                image_id = np.random.choice(len(db["train"][str(pair_id)]))

                batch_x[index] = db["train"][str(pair_id)][image_id]
                
            yield batch_x, batch_x

In [65]:
def margin_loss(y_true, y_pred):
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))

In [66]:
train_model.compile(optimizer=optimizers.Adam(lr=0.001),
          loss=[margin_loss],
          metrics={'capsnet': 'accuracy'})

In [68]:
train_model.fit_generator(generator=trainGenerator(), steps_per_epoch=100, epochs=3)

Epoch 1/3
  4/100 [>.............................] - ETA: 614s - loss: 17.3359

KeyboardInterrupt: 