In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

tf.set_random_seed(777)  # reproducibility

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

In [None]:
# hyper parameters
learning_rate = 0.001
batch_size = 128
epsilon = 1e-9
m_plus = 0.9
m_minus = 0.1
lambda_val = 0.5
reg_scale = 0.392   # 0.0005 * 784

class CapsNetModel(object):
    def __init__(self):
        self.sess = tf.Session()

    def _squash(self, vector):
        vec_abs = tf.sqrt(tf.reduce_sum(tf.square(vector)))
        scalar = tf.square(vec_abs) / (1 + tf.square(vec_abs))
        squashed_vector = scalar * tf.divide(vector, vec_abs)
        return squashed_vector

    def _routing(self, input, b):
        W = tf.get_variable('Weight', shape=(1, 1152, 10, 8, 16), dtype=tf.float32,
                            initializer=tf.random_normal_initializer(stddev=0.01))
        input = tf.tile(input, [1, 1, 10, 1, 1])
        W = tf.tile(W, [self.batch_size, 1, 1, 1, 1])
        u_hat = tf.matmul(W, input, transpose_a=True)
        u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')

        iter_routing = 3
        for r_iter in range(iter_routing):
            with tf.variable_scope('iter_' + str(r_iter)):
                # line 4:
                # => [1, 1152, 10, 1, 1]
                c_IJ = tf.nn.softmax(b, dim=2)

                # At last iteration, use `u_hat` in order to receive gradients from the following graph
                if r_iter == iter_routing - 1:
                    # line 5:
                    # weighting u_hat with c_IJ, element-wise in the last two dims
                    # => [batch_size, 1152, 10, 16, 1]
                    s_J = tf.multiply(c_IJ, u_hat)
                    # then sum in the second dim, resulting in [batch_size, 1, 10, 16, 1]
                    s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
                    assert s_J.get_shape() == [self.batch_size, 1, 10, 16, 1]

                    # line 6:
                    # squash using Eq.1,
                    v_J = self._squash(s_J)
                    assert v_J.get_shape() == [self.batch_size, 1, 10, 16, 1]
                elif r_iter < iter_routing - 1:  # Inner iterations, do not apply backpropagation
                    s_J = tf.multiply(c_IJ, u_hat_stopped)
                    s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
                    v_J = self._squash(s_J)

                    # line 7:
                    # reshape & tile v_j from [batch_size ,1, 10, 16, 1] to [batch_size, 1152, 10, 16, 1]
                    # then matmul in the last tow dim: [16, 1].T x [16, 1] => [1, 1], reduce mean in the
                    # batch_size dim, resulting in [1, 1152, 10, 1, 1]
                    v_J_tiled = tf.tile(v_J, [1, 1152, 1, 1, 1])
                    u_produce_v = tf.matmul(u_hat_stopped, v_J_tiled, transpose_a=True)
                    assert u_produce_v.get_shape() == [self.batch_size, 1152, 10, 1, 1]

                    # b_IJ += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True)
                    b += u_produce_v

        return (v_J)

    def build_graph(self, batch_size):
        self.batch_size = batch_size

        # input place holders
        self.X = tf.placeholder(tf.float32, [batch_size, 784], name="INPUT_IMAGE")

        # img 28x28x1 (black/white), Input Layer
        X_img = tf.reshape(self.X, [-1, 28, 28, 1])
        #self.labels = tf.placeholder(tf.float32, shape=[batch_size * 10])
        #self.Y = tf.placeholder(tf.float32, [None, 10])
        #self.Y = tf.reshape(self.labels, shape=(-1, 10, 1))
        self.Y = tf.placeholder(tf.float32, [batch_size, 10])
        self.labels = tf.to_int32(tf.argmax(self.Y, axis=1))

        # Convolutional Layer
        conv1 = tf.layers.conv2d(inputs=X_img, filters=256, kernel_size=[9, 9],
                                 padding="VALID", activation=tf.nn.relu)

        # Capsule 1
        caps = []
        for i in range(8):
            capsule = tf.layers.conv2d(inputs=conv1, filters=32, kernel_size=[9, 9],
                                        padding="VALID", strides=2, activation=None)
            capsule = tf.reshape(capsule, (-1, 1152, 1, 1))
            caps.append(capsule)
        temp = tf.reshape(caps, (-1, 1152, 1, 1, 1))
        caps = tf.concat(caps, axis=2)
        caps = self._squash(caps)

        # Capsule 2
        # b.shape = (-1, 1152, 10, 1, 1)
        #b = tf.zeros_like(tf.tile(temp, [1, 1, 10, 1, 1]))
        b = tf.constant(np.zeros([batch_size, 1, 10, 1, 1], dtype=np.float32))
        recaps = tf.reshape(caps, (-1, 1152, 1, 8, 1))

        caps2 = self._routing(recaps, b)
        caps2 = tf.squeeze(caps2, axis=1)

        self.v_len = tf.sqrt(tf.reduce_sum(tf.square(caps2), axis=2, keep_dims=True) + epsilon)
        sm = tf.nn.softmax(self.v_len, dim=1)


        # result
        argmax_index = tf.to_int32(tf.argmax(sm, axis=1))
        self.argmax_index = tf.reshape(argmax_index, (self.batch_size, ), name="OUTPUT")
        #

        masked = []
        for batch_size in range(self.batch_size):
            v = caps2[batch_size][self.argmax_index[batch_size], :]
            masked.append(tf.reshape(v, shape=(1, 1, 16, 1)))
        self.masked = tf.concat(masked, axis=0)
        #
        #self.masked = tf.multiply(tf.squeeze(caps2), tf.reshape(self.Y, (-1, 10, 1)))
        #

        # reconstruction
        vector_j = tf.reshape(self.masked, shape=(self.batch_size, -1))
        fc1 = tf.contrib.layers.fully_connected(vector_j, num_outputs=512)
        assert fc1.get_shape() == [self.batch_size, 512]
        fc2 = tf.contrib.layers.fully_connected(fc1, num_outputs=1024)
        assert fc2.get_shape() == [self.batch_size, 1024]
        self.decoded = tf.contrib.layers.fully_connected(fc2, num_outputs=784, activation_fn=tf.sigmoid)

        # loss
        max_l = tf.square(tf.maximum(0., m_plus - self.v_len))
        max_r = tf.square(tf.maximum(0., self.v_len - m_minus))

        max_l = tf.reshape(max_l, shape=(self.batch_size, -1))
        max_r = tf.reshape(max_r, shape=(self.batch_size, -1))

        #T_c = tf.reshape(self.Y, (-1, 10))
        T_c = self.Y
        L_c = T_c * max_l + lambda_val * (1 - T_c) * max_r

        self.loss = tf.reduce_mean(tf.reduce_sum(L_c, axis=1))

        # reconstruction loss
        origin = tf.reshape(self.X, shape=(self.batch_size, -1))
        squared = tf.square(self.decoded - origin)
        self.reconstruction_err = tf.reduce_mean(squared)

        # total loss
        self.total_loss = self.loss + reg_scale * self.reconstruction_err

        self.optimizer = tf.train.AdamOptimizer()
        self.train_op = self.optimizer.minimize(self.total_loss)  # var_list=t_vars)

        correct_prediction = tf.equal(tf.to_int32(self.labels), self.argmax_index)
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        self.sess.run(tf.global_variables_initializer())

    def predict(self, x_test):
        return self.sess.run(self.argmax_index,
                             feed_dict={self.X: x_test})

    def get_accuracy(self, x_test, y_test):
        return self.sess.run(self.accuracy,
                             feed_dict={self.X: x_test,
                                        self.Y: y_test})

    def train(self, x_data, y_data):
        return self.sess.run([self.total_loss, self.train_op], feed_dict={
            self.X: x_data, self.Y: y_data})
    
    def save(self, filename):
        tf.train.Saver().save(self.sess, filename)
        
    def restore(self, filename):
        tf.train.Saver().restore(self.sess, filename)
        
    def export_graph(self, output_node_name_list, filename):
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            self.sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_name_list
        ) 
        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(filename, "wb") as f:
            f.write(output_graph_def.SerializeToString())

In [None]:
m1 = CapsNetModel()
m1.build_graph(batch_size)

In [None]:
training_epochs = 15

print('Learning Started!')

# train my model
for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples / batch_size)

    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        print(i, '/', total_batch)
        c, _ = m1.train(batch_xs, batch_ys)
        avg_cost += c / total_batch
    
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))

total_batch = int(mnist.test.num_examples / batch_size)
batch_xs, batch_ys = mnist.test.next_batch(batch_size)
print('Accuracy:', m1.get_accuracy(batch_xs, batch_ys))

print('Learning Finished!')

In [None]:
total_batch = int(mnist.test.num_examples / batch_size)
batch_xs, batch_ys = mnist.test.next_batch(batch_size)
print('Accuracy:', m1.get_accuracy(batch_xs, batch_ys))

print('predict')
results = m1.predict(batch_xs)
for i in range(30):
    print(results[i], batch_ys[i])

In [None]:
m1.save('./caps_mnist.ckpt')

In [None]:
m1.restore('./caps_mnist.ckpt')

In [None]:
m1.export_graph(['OUTPUT'], 'frozen_graph.pb')