In [1]:
import tensorflow as tf
from keras.datasets import mnist
from keras import backend as K
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Lambda, Layer
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import initializers

import matplotlib.pyplot as plt
import sys
import numpy as np

Using TensorFlow backend.


In [130]:
#
# The length of the output vector of a capsule is to represent the probability that the entity represented by the capsule
# is present in the current unit. A nonlinear squashing function ensures that
# - short vectors get shrunk to almost zero length and
# - long vectors get shrunk to a length slightly below 1
# this function is designed as
# v_j = \frac{||s_j||^2}{1 + ||s_j||^2 } \frac{s_j}{||s_j||}
#
def squash(output_vector, axis=-1):
    norm = tf.reduce_sum(tf.square(output_vector), axis, keep_dims=True)
    return output_vector * norm / ((1 + norm) * tf.sqrt(norm + 1.0e-10))

#
# This layer takes to input vectors:
#   - the first one is the output of the CapsuleLayer, 'n_calss' arrays
#   - the ground truth vector, an array with a length of 'n_class', with one of the elements is '1', the rests are '0'
#
class MaskingLayer(Layer):
    def call(self, inputs, **kwargs):
        input, mask = inputs
        return K.batch_dot(input, mask, 1)

    def compute_output_shape(self, input_shape):
        *_, output_shape = input_shape[0]
        return (None, output_shape)


#
# construct a conv layer, then reshape and apply squash operation
#
def PrimaryCapsule(n_vector, n_channel, n_kernel_size, n_stride, padding='valid'):
    def builder(inputs):
        output = Conv2D(filters=n_vector * n_channel, kernel_size=n_kernel_size, strides=n_stride, padding=padding)(inputs)
        output = Reshape( target_shape=[-1, n_vector], name='primary_capsule_reshape')(output)
        return Lambda(squash, name='primary_capsule_squash')(output)
    return builder

#
# Traditional Neural Network          Capsule
# scalar in scalar out       -->>     vector in vector out/matrix in matrix out
# back propagation update    -->>     routing update
#
class CapsuleLayer(Layer):
    def __init__(self, n_capsule, n_vec, n_routing, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.n_capsule = n_capsule
        self.n_vector = n_vec
        self.n_routing = n_routing
        self.kernel_initializer = initializers.get('he_normal')
        self.bias_initializer = initializers.get('zeros')

    def build(self, input_shape): # input_shape is a 4D tensor
        _, self.input_n_capsule, self.input_n_vector, *_ = input_shape
        self.W = self.add_weight(shape=[self.input_n_capsule, self.n_capsule, self.input_n_vector, self.n_vector], initializer=self.kernel_initializer, name='W')
        self.bias = self.add_weight(shape=[1, self.input_n_capsule, self.n_capsule, 1, 1], initializer=self.bias_initializer, name='bias', trainable=False)
        self.built = True

    def call(self, inputs, training=None):
        input_expand = tf.expand_dims(tf.expand_dims(inputs, 2), 2)
        input_tiled = tf.tile(input_expand, [1, 1, self.n_capsule, 1, 1])
        input_hat = tf.scan(lambda ac, x: K.batch_dot(x, self.W, [3, 2]), elems=input_tiled, initializer=K.zeros( [self.input_n_capsule, self.n_capsule, 1, self.n_vector]))
        for i in range(self.n_routing): # routing
            c = tf.nn.softmax(self.bias, dim=2)
            outputs = squash(tf.reduce_sum( c * input_hat, axis=1, keep_dims=True))
            if i != self.n_routing - 1:
                self.bias += tf.reduce_sum(input_hat * outputs, axis=-1, keep_dims=True)
        return tf.reshape(outputs, [-1, self.n_capsule, self.n_vector])

    def compute_output_shape(self, input_shape):
        # output current layer capsules
        return (None, self.n_capsule, self.n_vector)

#
# This layer takes 'n_class' arrays as input, outputs an array of size 'n_class',
# each eleemnt in the output array represent the possibility,
# i.e., the last layer in Figure 2.
#
class LengthLayer(Layer):
    def call(self, inputs, **kwargs):
        return tf.sqrt(tf.reduce_sum(tf.square(inputs), axis=-1, keep_dims=False))

    def compute_output_shape(self, input_shape):
        *output_shape, _ = input_shape
        return tuple(output_shape)
    
    
#
# margin loss is employed to measure the accuracy of the capsule net,
# in the code below, mean absolute error is used to measure the accuracy of the reconstructed image
#
def margin_loss(y_ground_truth, y_prediction):
    _m_plus = 0.9
    _m_minus = 0.1
    _lambda = 0.5
    L = y_ground_truth * tf.square(tf.maximum(0., _m_plus - y_prediction)) + _lambda * ( 1 - y_ground_truth) * tf.square(tf.maximum(0., y_prediction - _m_minus))
    return tf.reduce_mean(tf.reduce_sum(L, axis=1))

In [2]:
img_rows = 28 
img_cols = 28
channels = 1

In [3]:
optimizer = Adam(0.0002, 0.5)

In [4]:
def build_generator():

    noise_shape = (100,)
    x_noise = Input(shape=noise_shape)

    x = Dense(128 * 7 * 7, activation="relu")(x_noise)
    x = Reshape((7, 7, 128))(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = UpSampling2D()(x)
    x = Conv2D(128, kernel_size=3, padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = UpSampling2D()(x)
    x = Conv2D(64, kernel_size=3, padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Conv2D(1, kernel_size=3, padding="same")(x)
    gen_out = Activation("tanh")(x)

    return Model(x_noise, gen_out)

In [5]:
# Build and compile the generator
generator = build_generator()
generator.summary()
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

# The generator takes noise as input and generated imgs
z = Input(shape=(100,))
img = generator(z)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 128)       147584    
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 128)       0         
__________

In [121]:
def build_discriminator():

        img_shape = (img_rows, img_cols, channels)
        x_img = Input(shape=img_shape)

        # first typical convlayer outputs 20x20x256 matrix
        x = Conv2D(filters=32, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x_img)
        
        #
        # capsule architecture starts from here. primarycaps coming first
        #
        # filters 256 (n_vectors=8 * channels=32)
        x = Conv2D(filters=256, kernel_size=9, strides=2, padding='valid', name='conv2_primarycaps')(x)
        # reshape into 8D vector for all 32 maps combined
        # (primary capsule has collections of activations which denote the orientation of digit
        # while intensity of the vector which denotes the presence of the digit)
        x = Reshape(target_shape=[-1, 8], name='conv2_reshape')(x)
        # the purpose is to output a number between 0 and 1 for each capsule where the length of the input decides the amount
        x = Lambda(squash, name='squash_primarycaps')(x)
        
        #
        # digitcaps are here. in this approach i'm writing a simplified version of digitcaps i.e. without tiling the input
        # but using ordinary keras dense layers as weight holders
        #
        # a capsule (i) in a lower-level layer needs to decide how to send its output vector to higher-level capsules (j)
        # it makes this decision by changing scalar weight (c_ij) that will multiply its output vector and then be treated as input to a higher-level capsule
        #
        # uhat = prediction vector, w = weight matrix but will act as a dense layer # ANY CORRECTIONS ARE APPRECIATED HERE, PLEASE SUBMIT PULL REQUESTS
        # uhat (prediction vector) = u (output from a previous layer) * w
        x = Flatten()(x)
        # neurons 160 (num_capsules=10 * num_vectors=16)
        x = Dense(160, kernel_initializer='he_normal', name='weights_digitcaps')(x)
        # coupling coeff = a softmax over uhat * c (coupling coefficient) | "the coupling coefficients between capsule (i) and all the capsules in the layer above sum to 1"
        # we treat the coupling coefficiant as a softmax over bias weights from the previous dense layer
        x = Activation('softmax', name='softmax_digitcap1')(x) # softmax will make sure that each weight c_ij is a non-negative number and their sum equals to one
        x = Dense(160, activation=squash)(x) #apply a final squashing function
        
        #
        # we will repeat the routing part 2 more times (num_routing=3) to roll out the loop
        #
        x = Activation('softmax', name='softmax_digitcap2')(x) # softmax will make sure that each weight c_ij is a non-negative number and their sum equals to one
        x = Dense(160, activation=squash)(x) #apply a final squashing function
        x = Activation('softmax', name='softmax_digitcap')(x) # softmax will make sure that each weight c_ij is a non-negative number and their sum equals to one
        x = Dense(160, activation=squash)(x) #apply a final squashing function
    
        pred = Dense(1, activation='sigmoid')(x)

        return Model(x_img, pred)

In [8]:
# discriminator structure
def build_discriminator():

        img_shape = (img_rows, img_cols, channels)
        
        img = Input(shape=img_shape)
        
        # Layer 1: Just a conventional Conv2D layer
        conv1 = Conv2D(filters=32, kernel_size=3, strides=2, padding='same', activation='relu', name='conv1')(img)
        # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
        primarycaps = PrimaryCap(conv1, dim_vector=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
        x = Dropout(0.5)(primarycaps)
        x = BatchNormalization(momentum=0.8)(x)
        #digitcaps = CapsuleLayer(num_capsule=10, dim_vector=16, num_routing=3, name='digitcaps')(primarycaps)
        x = Flatten()(x)
        x = Dense(500, activation='relu')(x)
        x = Dense(100, activation='relu')(x)
        x = Dense(10, activation='relu')(x)
        validity = Dense(1, activation='sigmoid')(x)


        return Model(img, validity)

In [9]:
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.summary()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv1 (Conv2D)               (None, 14, 14, 32)        320       
_________________________________________________________________
primarycap_conv2d (Conv2D)   (None, 3, 3, 256)         663808    
_________________________________________________________________
primarycap_reshape (Reshape) (None, 288, 8)            0         
_________________________________________________________________
primarycap_squash (Lambda)   (None, 288, 8)            0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 288, 8)            0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 288, 8)            32        
__________

In [10]:
# For the combined model we will only train the generator
discriminator.trainable = True

# The valid takes generated images as input and determines validity
valid = discriminator(img)

In [11]:
# The combined model  (stacked generator and discriminator) takes
# noise as input => generates images => determines validity 
combined = Model(z, valid)
combined.summary()
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 100)               0         
_________________________________________________________________
model_1 (Model)              (None, 28, 28, 1)         856705    
_________________________________________________________________
model_2 (Model)              (None, 1)                 1867781   
Total params: 2,724,486
Trainable params: 2,723,830
Non-trainable params: 656
_________________________________________________________________


In [12]:
def save_imgs(epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        #fig.suptitle("DCGAN: Generated digits", fontsize=12)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

In [13]:
# Losses for further plotting
D_L_REAL = []
D_L_FAKE = []
D_L = []
D_ACC = []
G_L = []

In [14]:
def train(epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            # Sample noise and generate a half batch of new images
            noise = np.random.normal(0, 1, (half_batch, 100))
            gen_imgs = generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))

            # Plot the progress
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            D_L_REAL.append(d_loss_real)
            D_L_FAKE.append(d_loss_fake)
            D_L.append(d_loss)
            D_ACC.append(d_loss[1])
            G_L.append(g_loss)

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                save_imgs(epoch)

In [15]:
train(epochs=30000, batch_size=32, save_interval=50)

0 [D loss: 0.741382, acc.: 28.12%] [G loss: 0.323144]
1 [D loss: 0.378989, acc.: 50.00%] [G loss: 0.040662]
2 [D loss: 0.455916, acc.: 50.00%] [G loss: 0.013651]
3 [D loss: 0.540281, acc.: 50.00%] [G loss: 0.020041]
4 [D loss: 0.585189, acc.: 50.00%] [G loss: 0.072252]
5 [D loss: 0.662552, acc.: 50.00%] [G loss: 0.144692]
6 [D loss: 0.792815, acc.: 50.00%] [G loss: 0.202287]
7 [D loss: 0.896737, acc.: 50.00%] [G loss: 0.233637]
8 [D loss: 0.965858, acc.: 50.00%] [G loss: 0.235365]
9 [D loss: 1.109696, acc.: 50.00%] [G loss: 0.188132]
10 [D loss: 0.951979, acc.: 50.00%] [G loss: 0.263158]
11 [D loss: 0.841134, acc.: 50.00%] [G loss: 0.188074]
12 [D loss: 0.599117, acc.: 59.38%] [G loss: 0.084556]
13 [D loss: 0.709516, acc.: 59.38%] [G loss: 0.027972]
14 [D loss: 0.673140, acc.: 65.62%] [G loss: 0.019935]
15 [D loss: 0.985649, acc.: 53.12%] [G loss: 0.034106]
16 [D loss: 0.815501, acc.: 46.88%] [G loss: 0.162321]
17 [D loss: 1.110515, acc.: 46.88%] [G loss: 0.142084]
18 [D loss: 1.526166

KeyboardInterrupt: 

In [None]:
plt.plot(D_L)
plt.xlabel('Epochs')
plt.ylabel('Loss (blue), Accuracy(orange)')
plt.show()