In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras
import keras.backend as K

Using TensorFlow backend.


In [2]:
(X_train,Y_train),(X_test,Y_test) = keras.datasets.mnist.load_data()
X_train_new = np.zeros((X_train.shape[0],32,32))
X_test_new = np.zeros((X_test.shape[0],32,32))
X_train_new[:,2:-2,2:-2] = X_train
X_test_new[:,2:-2,2:-2] = X_test
X_train,X_test = X_train_new[...,np.newaxis],X_test_new[...,np.newaxis]
X_train,X_test = X_train/255,X_test/255
Y_train = keras.utils.to_categorical(Y_train, num_classes=10)
Y_test = keras.utils.to_categorical(Y_test, num_classes=10)

In [3]:
class CapsuleConv2D(keras.layers.Layer):
    
    def __init__(self, channels, kernel_size, strides=(1,1),
                 em_steps=3, temp_start=75., temp_end=25.,
                 mat_size=(4,4), coord=False, **kwargs):
        self.channels = channels
        self.kernel_size = kernel_size
        self.strides = strides
        self.em_steps = em_steps
        self.temp_start = temp_start
        self.temp_end = temp_end
        self.mat_size = mat_size
        self.coord = coord
        super().__init__(**kwargs)
    
    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(self.kernel_size[0]*self.kernel_size[1]*input_shape[0][3],
                   self.channels, *self.mat_size),
            initializer=keras.initializers.TruncatedNormal(
                mean=0., stddev=1.), name='kernel')
        self.beta_a = self.add_weight(
            shape=(self.channels,), constraint=keras.constraints.non_neg(),
            initializer='zeros', name='beta_a')
        self.beta_u = self.add_weight(
            shape=(self.channels,), constraint=keras.constraints.non_neg(),
            initializer='zeros', name='beta_u')
        super().build(input_shape)
    
    def call(self, inputs):
        input_shape = [i_.shape for i_ in inputs]
        
        poses, activations = inputs
        input_channel_size = poses.shape[3]
        
        if self.coord:
            coord_y = np.linspace(-1, 1, num=int(poses.shape[1]))
            coord_y = np.repeat(coord_y[:,np.newaxis],int(poses.shape[2]),axis=1)
            coord_x = np.linspace(-1, 1, num=int(poses.shape[2]))
            coord_x = np.repeat(coord_x[np.newaxis,:],int(poses.shape[1]),axis=0)
            coord = np.zeros((1,int(poses.shape[1]),int(poses.shape[2]),
                1,int(poses.shape[4]),int(poses.shape[5])))
            coord[:,:,:,0,0,-1] = coord_x
            coord[:,:,:,0,1,-1] = coord_y
            poses = poses + K.constant(coord)
        
        poses = K.reshape(poses, (
            -1,*poses.shape[1:3],poses.shape[3]*poses.shape[4]*poses.shape[5]))
        def conv_expand(x):
            conv_data_size = x.shape[3]
            conv_kernel = np.zeros((*self.kernel_size, conv_data_size,
                                    self.kernel_size[0]*self.kernel_size[1]))
            for i in range(self.kernel_size[0]):
                for j in range(self.kernel_size[1]):
                    conv_kernel[i,j,:,i*self.kernel_size[0]+j] = 1
            x = K.depthwise_conv2d(x, K.constant(conv_kernel),
                                       strides=self.strides, padding='same')
            x = K.reshape(x, (-1,*x.shape[1:3],conv_data_size,
                                     self.kernel_size[0]*self.kernel_size[1]))
            x = K.permute_dimensions(x, (0,1,2,4,3))
            return x
        poses = conv_expand(poses)
        activations = conv_expand(activations)
        
        poses = K.reshape(poses, (-1,*poses.shape[1:3],
            self.kernel_size[0]*self.kernel_size[1]*input_channel_size,
            1,input_shape[0][-2],input_shape[0][-1]))
        poses = K.tile(poses, (1,1,1,1,self.channels,1,1))
        kernel = K.reshape(self.kernel, (1,1,1,*self.kernel.shape))
        kernel = K.tile(kernel, K.concatenate([
            K.shape(poses)[:3],K.constant(np.ones((4,)),dtype='int32')]))
        votes = kernel @ poses
        votes = K.reshape(votes, (
            -1,*votes.shape[1:5],self.mat_size[0]*input_shape[0][-1]))
        
        activations = K.reshape(activations, (-1,*activations.shape[1:3],
            self.kernel_size[0]*self.kernel_size[1]*input_channel_size,1))
        r = K.constant(1/self.channels, shape=(1,*votes.shape[1:5]))
        r = K.tile(r, K.concatenate([
            K.shape(votes)[:1],K.constant(np.ones((4,)),dtype='int32')]))
        for t in range(self.em_steps):
            inv_temp = self.temp_start*(1-(t/max(1,self.em_steps-1)))
            inv_temp = inv_temp + self.temp_end*(t/max(1,self.em_steps-1))
            inv_temp = 1/inv_temp
            r = r * activations
            r_expanded = K.expand_dims(r)
            r_sum_i = K.expand_dims(K.sum(r, axis=3))
            mu = K.sum(r_expanded*votes,axis=3)/r_sum_i
            mu_diff_square = K.square(votes-K.expand_dims(mu,axis=3))
            mu_diff_square = mu_diff_square + K.epsilon()
            sigma_square = K.sum(r_expanded*mu_diff_square,axis=3)/r_sum_i
            sigma = K.sqrt(sigma_square)
            cost = K.reshape(self.beta_u,(1,1,1,self.channels,1))
            cost = cost + K.log(sigma)
            cost = cost * r_sum_i
            cost_all_h = K.sum(cost,axis=-1)
            # for numerical stability, change sigmoid to hard sigmoid
            a = K.hard_sigmoid(inv_temp*(self.beta_a-cost_all_h))
            if t+1 == self.em_steps: continue
            # for numerical stability, change p from the paper to log p
            log_p = mu_diff_square/K.expand_dims(2*sigma_square,axis=3)
            log_p = -K.sum(log_p,axis=-1)
            log_p = log_p - K.expand_dims(
                K.sum((np.log(2*np.pi)/2)+K.log(sigma),axis=-1),axis=3)
            # for numerical stability, compute r from log p instread of p
            r = K.expand_dims(K.log(a+K.epsilon()),axis=3) + log_p
            r = K.softmax(r,axis=-1)
        
        poses_new = K.reshape(mu,(
            -1,*mu.shape[1:4],self.mat_size[0],input_shape[0][-1]))
        activations_new = a
        return [poses_new, activations_new]
    
    def compute_output_shape(self, input_shape):
        return [(input_shape[0][0],
                input_shape[0][1]//self.strides[0],
                input_shape[0][2]//self.strides[1],self.channels,
                self.mat_size[0],input_shape[0][-1]),(
                input_shape[1][0],
                input_shape[1][1]//self.strides[0],
                input_shape[1][2]//self.strides[1],self.channels)]

In [4]:
X = X_input = keras.layers.Input(X_train.shape[1:])
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Conv2D(32, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(32, (3,3), strides=(1,1), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(64, (3,3), strides=(1,1), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
C = keras.Model(X_input, X)
C.compile('nadam', 'mse', ['acc'])
CF = keras.Model(X_input, X)
CF.trainable = False
CF.compile('nadam', 'mse', ['acc'])
X = X_input = keras.layers.Input(X_train.shape[1:])
X = C(X)
X = keras.layers.Conv2D(4, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(4, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Conv2D(4, (3,3), strides=(2,2), padding='same',
    activation='relu', kernel_initializer='he_normal')(X)
X = keras.layers.Flatten()(X)
X = keras.layers.Dense(Y_train.shape[-1], activation='softmax')(X)
CM = keras.Model(X_input, X)
CM.compile('nadam', 'categorical_crossentropy', ['acc'])

In [5]:
CM.fit(X_train, Y_train,
      validation_data=(X_test, Y_test),
      batch_size=32, epochs=5)

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fea1b9da518>

In [6]:
X = X_input = keras.layers.Input(X_train.shape[1:])
X = CF(X)
X_pose = keras.layers.Conv2D(4*4*2, (1,1), padding='same', activation='tanh')(X)
X_pose = keras.layers.Lambda(
    lambda x: K.reshape(x, (-1,*x.shape[1:3],4,4,2)),
    output_shape=(int(X_pose.shape[1]),int(X_pose.shape[2]),4,4,2))(X_pose)
X = keras.layers.Conv2D(4, (1,1), padding='same', activation='sigmoid')(X)
X_pose,X = CapsuleConv2D(4, (3,3), strides=(2,2))([X_pose,X])
X_pose,X = CapsuleConv2D(4, (3,3), strides=(2,2))([X_pose,X])
X_pose,X = CapsuleConv2D(4, (3,3), strides=(2,2))([X_pose,X])
X_pose,X = CapsuleConv2D(Y_train.shape[-1], (3,3), strides=(2,2))([X_pose,X])
X_pose = keras.layers.Lambda(
    lambda x,d: K.stop_gradient(K.reshape(
        x[0]*K.reshape(x[1],(-1,1,1,d,1,1)),
        (-1,1*1*d*4*2))),
    output_shape=(1*1*Y_train.shape[-1]*4*2,),
    arguments={'d':Y_train.shape[-1]})([X_pose,X])
X_pose = keras.layers.Dense(64, activation='tanh')(X_pose)
X = keras.layers.Flatten()(X)
X = keras.layers.Lambda(lambda x: K.print_tensor(x))(X)
X_pose = keras.layers.Concatenate()([X_pose,keras.layers.Lambda(
    lambda x: K.stop_gradient(x))(X)])
X_pose = keras.layers.Dense(64, activation='tanh')(X_pose)
X_pose = keras.layers.Dense(Y_train.shape[-1], activation='softmax')(X_pose)
X = keras.layers.Softmax()(X)
M = keras.Model(X_input, [X,X_pose])
M_optimizer = keras.optimizers.SGD(momentum=0.9)
M.compile(M_optimizer, 'categorical_crossentropy', ['acc'])
M.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 32, 32, 1)    0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, 16, 16, 64)   28068       input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 16, 16, 32)   2080        model_2[1][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 16, 16, 4, 4, 0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (

In [7]:
M.fit(X_train, [Y_train,Y_train],
      validation_data=(X_test, [Y_test,Y_test]),
      batch_size=8, epochs=10)

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7fe9ba2d75f8>