In [1]:
import numpy as np
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.models import Model
from keras.layers.convolutional import Convolution2D
from keras.layers.core import Dense, Reshape
from keras.layers import Input
from keras.losses import mean_squared_error
from keras.optimizers import Adam

from utils import LossHistory, plotHistory
from keras_capsnet.layer.capsnet import PrimaryCaps, Caps, Length, Mask
from keras_capsnet.losses import margin

Using TensorFlow backend.


# Args

In [2]:
num_class = 10
input_shape = (28, 28, 1)

# Dataset

In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [4]:
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255

y_train = to_categorical(y_train, num_class)
y_test = to_categorical(y_test, num_class)

# Model

In [5]:
x = Input(shape=input_shape)
y = Input(shape=(num_class,))

encoder = Convolution2D(filters=256, kernel_size=(9, 9), activation='relu') (x)
encoder = PrimaryCaps(capsules=32, capsule_dim=8, kernel_size=(9, 9), strides=2) (encoder)
encoder = Caps(capsules=10, capsule_dim=16, routings=3) (encoder)

output = Length(name='capsule') (encoder)

decoder = Mask() (encoder, y_true=y)
decoder = Dense(512, activation='relu') (decoder)
decoder = Dense(1024, activation='relu') (decoder)
decoder = Dense(784, activation='sigmoid') (decoder)
decoder = Reshape((input_shape), name='reconstruction') (decoder)

model_training = Model(inputs=[x, y], outputs=[output, decoder])
model = Model(inputs=x, outputs=output)

In [6]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 20, 20, 256)       20992     
_________________________________________________________________
primary_caps_1 (PrimaryCaps) (None, 1152, 8)           5308672   
_________________________________________________________________
caps_1 (Caps)                (None, 10, 16)            1474560   
_________________________________________________________________
capsule (Length)             (None, 10)                0         
Total params: 6,804,224
Trainable params: 6,804,224
Non-trainable params: 0
_________________________________________________________________


# Training

In [7]:
model_training.compile(loss=[margin(), mean_squared_error],
                        loss_weights=[0.7, 0.3],
                        optimizer=Adam(),
                        metrics={'capsule': 'accuracy'})

In [None]:
history = LossHistory()
hist = model_training.fit([x_train, y_train],
                            [y_train, x_train],
                            batch_size=32,
                            epochs=2,
                            validation_data=([x_test, y_test], [y_test, x_test]),
                            callbacks=[history])

In [None]:
plotHistory(history.loss, history.acc)