In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras
import keras.backend as K

Using TensorFlow backend.


In [2]:
(X_train,Y_train),(X_test,Y_test) = keras.datasets.mnist.load_data()
X_train_new = np.zeros((X_train.shape[0],32,32))
X_test_new = np.zeros((X_test.shape[0],32,32))
X_train_new[:,2:-2,2:-2] = X_train
X_test_new[:,2:-2,2:-2] = X_test
X_train,X_test = X_train_new[...,np.newaxis],X_test_new[...,np.newaxis]
X_train,X_test = X_train/255,X_test/255
Y_train = keras.utils.to_categorical(Y_train, num_classes=10)
Y_test = keras.utils.to_categorical(Y_test, num_classes=10)

In [3]:
class CapsuleConv2D(keras.layers.Layer):
    
    def __init__(self, channels, kernel_size, strides=(1,1),
                 steps=3, mat_size=(4,4), **kwargs):
        self.channels = channels
        self.kernel_size = kernel_size
        self.strides = strides
        self.steps = steps
        self.mat_size = mat_size
        super().__init__(**kwargs)
    
    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(self.kernel_size[0]*self.kernel_size[1]*input_shape[3],
                   self.channels, *self.mat_size),
            initializer=keras.initializers.TruncatedNormal(
                mean=0., stddev=1.), name='kernel')
        super().build(input_shape)
    
    def call(self, inputs):
        input_shape = inputs.shape
        
        poses = inputs
        input_channel_size = poses.shape[3]
        
        poses = K.reshape(poses, (
            -1,*poses.shape[1:3],poses.shape[3]*poses.shape[4]*poses.shape[5]))
        def conv_expand(x):
            conv_data_size = x.shape[3]
            conv_kernel = np.zeros((*self.kernel_size, conv_data_size,
                                    self.kernel_size[0]*self.kernel_size[1]))
            for i in range(self.kernel_size[0]):
                for j in range(self.kernel_size[1]):
                    conv_kernel[i,j,:,i*self.kernel_size[0]+j] = 1
            x = K.depthwise_conv2d(x, K.constant(conv_kernel),
                                       strides=self.strides, padding='same')
            x = K.reshape(x, (-1,*x.shape[1:3],conv_data_size,
                                     self.kernel_size[0]*self.kernel_size[1]))
            x = K.permute_dimensions(x, (0,1,2,4,3))
            return x
        poses = conv_expand(poses)
        
        poses = K.reshape(poses, (-1,*poses.shape[1:3],
            self.kernel_size[0]*self.kernel_size[1]*input_channel_size,
            1,input_shape[-2],input_shape[-1]))
        poses = K.tile(poses, (1,1,1,1,self.channels,1,1))
        kernel = K.reshape(self.kernel, (1,1,1,*self.kernel.shape))
        kernel = K.tile(kernel, K.concatenate([
            K.shape(poses)[:3],K.constant(np.ones((4,)),dtype='int32')]))
        votes = kernel @ poses
        votes = K.reshape(votes, (
            -1,*votes.shape[1:5],self.mat_size[0]*input_shape[-1]))
        
        b = K.zeros((1,*votes.shape[1:5]))
        b = K.tile(b, K.concatenate([
            K.shape(votes)[:1],K.constant(np.ones((4,)),dtype='int32')]))
        for t in range(self.steps):
            c = K.softmax(b, axis=-1)
            s = K.expand_dims(c) * votes
            s = K.sum(s, axis=3)
            s_norm2 = K.sum(K.square(s),axis=-1,keepdims=True)
            v = (s_norm2/(1+s_norm2))
            v = (s/K.clip(K.sqrt(s_norm2),K.epsilon(),None)) * v
            if t == self.steps-1: continue
            b = b + K.sum(votes*K.expand_dims(v,axis=3),axis=-1)
        
        poses_new = K.reshape(v,(
            -1,*v.shape[1:4],self.mat_size[0],input_shape[-1]))
        return poses_new
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0],
                input_shape[1]//self.strides[0],
                input_shape[2]//self.strides[1],self.channels,
                self.mat_size[0],input_shape[-1])

In [4]:
X = X_input = keras.layers.Input(X_train.shape[1:])
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Conv2D(32, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(32, (3,3), strides=(1,1), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(64, (3,3), strides=(1,1), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
C = keras.Model(X_input, X)
C.compile('nadam', 'mse', ['acc'])
CF = keras.Model(X_input, X)
CF.trainable = False
CF.compile('nadam', 'mse', ['acc'])
X = X_input = keras.layers.Input(X_train.shape[1:])
X = C(X)
X = keras.layers.Conv2D(4, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(4, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(4, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Flatten()(X)
X = keras.layers.Dense(Y_train.shape[-1], activation='softmax')(X)
CM = keras.Model(X_input, X)
CM.compile('nadam', 'categorical_crossentropy', ['acc'])

In [5]:
CM.fit(X_train, Y_train,
      validation_data=(X_test, Y_test),
      batch_size=32, epochs=5)

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fe4d45b6a90>

In [6]:
X = X_input = keras.layers.Input(X_train.shape[1:])
X = CF(X)
X = keras.layers.Conv2D(4*4*2, (1,1), padding='same', activation='tanh')(X)
X = keras.layers.Lambda(
    lambda x: K.reshape(x, (-1,*x.shape[1:3],4,4,2)),
    output_shape=(int(X.shape[1]),int(X.shape[2]),4,4,2))(X)
X = CapsuleConv2D(4, (3,3), strides=(2,2))(X)
X = CapsuleConv2D(4, (3,3), strides=(2,2))(X)
X = CapsuleConv2D(4, (3,3), strides=(2,2))(X)
X = CapsuleConv2D(Y_train.shape[-1], (3,3), strides=(2,2))(X)
X_pose = keras.layers.Lambda(
    lambda x,d: K.stop_gradient(K.reshape(x,(-1,d*4*2))),
    output_shape=(Y_train.shape[-1]*4*2,),
    arguments={'d':Y_train.shape[-1]})(X)
X_pose = keras.layers.Dense(64, activation='tanh')(X_pose)
X_pose = keras.layers.Dense(Y_train.shape[-1], activation='softmax')(X_pose)
X = keras.layers.Lambda(
    lambda x,d: K.reshape(K.sqrt(K.sum(K.sum(
        K.square(x),axis=-1),axis=-1)),(-1,d)),
    output_shape=(Y_train.shape[-1],),
    arguments={'d':Y_train.shape[-1]})(X)
X = keras.layers.Lambda(lambda x: K.print_tensor(x))(X)
X = keras.layers.Softmax()(X)
M = keras.Model(X_input, [X,X_pose])
M_optimizer = keras.optimizers.SGD(momentum=0.9)
M.compile(M_optimizer, 'categorical_crossentropy', ['acc'])
M.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 32, 32, 1)    0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, 16, 16, 64)   28068       input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 16, 16, 32)   2080        model_2[1][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 16, 16, 4, 4, 0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
capsule_co

In [7]:
M.fit(X_train, [Y_train,Y_train],
      validation_data=(X_test, [Y_test,Y_test]),
      batch_size=8, epochs=10)

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7fe44447eba8>