In [None]:
import sys
sys.path.insert(0, "/home/forest/Matrix-Capsules-EM-Tensorflow/")

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.layers.python.layers import initializers

In [None]:
tf.reset_default_graph()
np.random.seed(42)
tf.set_random_seed(42)

In [None]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/")

In [None]:
n_samples = 5

plt.figure(figsize=(n_samples * 2, 3))
for index in range(n_samples):
    plt.subplot(1, n_samples, index + 1)
    sample_image = mnist.train.images[index].reshape(28, 28)
    plt.imshow(sample_image, cmap="binary")
    plt.axis("off")

plt.show()

In [None]:
def build_arch(input, coord_add, is_train: bool, num_classes: int):
    batch_size = 50
    test1 = []
    data_size = int(input.get_shape()[1])
    
    pose_dim = 9
    mat_dim = 9
    

    tf.logging.info('input shape: {}'.format(input.get_shape()))
    
    with slim.arg_scope([slim.conv2d], trainable=is_train):
        with tf.variable_scope('relu_conv1') as scope:
                output = slim.conv2d(input, num_outputs=32, kernel_size=[
                                     9, 9], stride=2, padding='VALID', scope=scope, activation_fn=tf.nn.relu)
                data_size = int(np.floor((data_size-8)/ 2))
                                
                assert output.get_shape() == [batch_size, data_size, data_size, 32]
                tf.logging.info('conv1 output shape: {}'.format(output.get_shape()))
                
        
        with tf.variable_scope('primary_caps') as scope:

                pose = slim.conv2d(output, num_outputs=8 * pose_dim,
                                   kernel_size=[4, 4], stride=1, padding='SAME', scope=scope, activation_fn=None)
                activation = slim.conv2d(output, num_outputs=8, kernel_size=[
                                         2, 2], stride=1, padding='SAME', scope='primary_caps/activation', activation_fn=tf.nn.sigmoid)
                pose = tf.reshape(pose, shape=[batch_size, data_size, data_size, 8, pose_dim])
                activation = tf.reshape(
                    activation, shape=[batch_size, data_size, data_size, 8, 1])
                output = tf.concat([pose, activation], axis=4)
                output = tf.reshape(output, shape=[batch_size, data_size, data_size, -1])
                assert output.get_shape() == [batch_size, data_size, data_size, 8 * (pose_dim+1)]
                tf.logging.info('primary capsule output shape: {}'.format(output.get_shape()))
                
        with tf.variable_scope('conv_caps1') as scope:
                
            output = kernel_tile(output, 3, 2)
            data_size = int(np.floor((data_size - 2) / 2))
            output = tf.reshape(output, shape=[batch_size *
                                               data_size * data_size, 3 * 3 * 8, (pose_dim+1)])
            activation = tf.reshape(output[:, :, pose_dim], shape=[
                                    batch_size * data_size * data_size, 3 * 3 * 8, 1])
            
            with tf.variable_scope('v') as scope:
                votes = mat_transform(output[:, :, :pose_dim], mat_dim, tag=True)
                tf.logging.info('conv cap 1 votes shape: {}'.format(votes.get_shape()))

            with tf.variable_scope('routing') as scope:
                miu, activation, _ = em_routing(votes, activation, pose_dim)
                tf.logging.info('conv cap 1 miu shape: {}'.format(miu.get_shape()))
                tf.logging.info('conv cap 1 activation before reshape: {}'.format(
                    activation.get_shape()))

            pose = tf.reshape(miu, shape=[batch_size, data_size, data_size, mat_dim, pose_dim])
            tf.logging.info('conv cap 1 pose shape: {}'.format(pose.get_shape()))
            activation = tf.reshape(
                activation, shape=[batch_size, data_size, data_size, mat_dim, 1])
            tf.logging.info('conv cap 1 activation after reshape: {}'.format(
                activation.get_shape()))
            output = tf.reshape(tf.concat([pose, activation], axis=4), [
                                batch_size, data_size, data_size, -1])
            tf.logging.info('conv cap 1 output shape: {}'.format(output.get_shape()))
            
        
        with tf.variable_scope('conv_caps2') as scope:
            output = kernel_tile(output, 3, 1)
            data_size = int(np.floor((data_size - 2) / 1))
            output = tf.reshape(output, shape=[batch_size *
                                               data_size * data_size, 3 * 3 * mat_dim, (pose_dim+1)])
            activation = tf.reshape(output[:, :, pose_dim], shape=[
                                    batch_size * data_size * data_size, 3 * 3 * mat_dim, 1])

            with tf.variable_scope('v') as scope:
                votes = mat_transform(output[:, :, :pose_dim], mat_dim)
                tf.logging.info('conv cap 2 votes shape: {}'.format(votes.get_shape()))

            with tf.variable_scope('routing') as scope:
                miu, activation, _ = em_routing(votes, activation, pose_dim)

            pose = tf.reshape(miu, shape=[batch_size * data_size * data_size, mat_dim, pose_dim])
            tf.logging.info('conv cap 2 pose shape: {}'.format(votes.get_shape()))
            activation = tf.reshape(
                activation, shape=[batch_size * data_size * data_size, mat_dim, 1])
            tf.logging.info('conv cap 2 activation shape: {}'.format(activation.get_shape()))

        # It is not clear from the paper that ConvCaps2 is full connected to Class Capsules, or is conv connected with kernel size of 1*1 and a global average pooling.
        # From the description in Figure 1 of the paper and the amount of parameters (310k in the paper and 316,853 in fact), I assume a conv cap plus a golbal average pooling is the design.
        
        with tf.variable_scope('class_caps') as scope:
            with tf.variable_scope('v') as scope:
                votes = mat_transform(pose, num_classes)

                assert votes.get_shape() == [batch_size * data_size *
                                             data_size, mat_dim, num_classes, pose_dim]
                tf.logging.info('class cap votes original shape: {}'.format(votes.get_shape()))

                coord_add = np.reshape(coord_add, newshape=[data_size * data_size, 1, 1, 2])
                coord_add = np.tile(coord_add, [batch_size, mat_dim, num_classes, 1])
                coord_add_op = tf.constant(coord_add, dtype=tf.float32)

                votes = tf.concat([coord_add_op, votes], axis=3)
                tf.logging.info('class cap votes coord add shape: {}'.format(votes.get_shape()))
                
            with tf.variable_scope('routing') as scope:
                miu, activation, test2 = em_routing(
                    votes, activation, num_classes)
                tf.logging.info(
                    'class cap activation shape: {}'.format(activation.get_shape()))
                tf.summary.histogram(name="class_cap_routing_hist",
                                     values=test2)

            output = tf.reshape(activation, shape=[
                                batch_size, data_size, data_size, num_classes])
            
        tf.logging.info('output shape: {}'.format(output.get_shape()))
            
        output = tf.reshape(tf.nn.avg_pool(output, ksize=[1, data_size, data_size, 1], strides=[
                            1, 1, 1, 1], padding='VALID'), shape=[batch_size, num_classes])
        
        tf.logging.info('output shape: {}'.format(output.get_shape()))
        
        pose = tf.nn.avg_pool(tf.reshape(miu, shape=[batch_size, data_size, data_size, -1]), ksize=[
                              1, data_size, data_size, 1], strides=[1, 1, 1, 1], padding='VALID')
        
        tf.logging.info('pose shape: {}'.format(pose.get_shape()))
        
        pose_out = tf.reshape(pose, shape=[batch_size, num_classes, (mat_dim+2)])
        
        tf.logging.info('pose_out shape: {}'.format(pose_out.get_shape()))
    return output, pose_out

In [None]:
def kernel_tile(input, kernel, stride):
    # output = tf.extract_image_patches(input, ksizes=[1, kernel, kernel, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID')

    input_shape = input.get_shape()
    tile_filter = np.zeros(shape=[kernel, kernel, input_shape[3],
                                  kernel * kernel], dtype=np.float32)
    for i in range(kernel):
        for j in range(kernel):
            tile_filter[i, j, :, i * kernel + j] = 1.0

    tile_filter_op = tf.constant(tile_filter, dtype=tf.float32)
    output = tf.nn.depthwise_conv2d(input, tile_filter_op, strides=[
                                    1, stride, stride, 1], padding='VALID')
    output_shape = output.get_shape()
    output = tf.reshape(output, shape=[int(output_shape[0]), int(
        output_shape[1]), int(output_shape[2]), int(input_shape[3]), kernel * kernel])
    output = tf.transpose(output, perm=[0, 1, 2, 4, 3])

    return output

In [None]:
def mat_transform(input, caps_num_c,  tag=False):
    batch_size = int(input.get_shape()[0])
    caps_num_i = int(input.get_shape()[1])
    output = tf.reshape(input, shape=[batch_size, caps_num_i, 1, 3, 3])#############
    # the output of capsule is miu, the mean of a Gaussian, and activation, the sum of probabilities
    # it has no relationship with the absolute values of w and votes
    # using weights with bigger stddev helps numerical stability
    w = slim.variable('w', shape=[1, caps_num_i, caps_num_c, 3, 3], dtype=tf.float32,############
                      initializer=tf.truncated_normal_initializer(mean=0.0, stddev=1.0))

    w = tf.tile(w, [batch_size, 1, 1, 1, 1])
    output = tf.tile(output, [1, 1, caps_num_c, 1, 1])
    votes = tf.reshape(tf.matmul(output, w), [batch_size, caps_num_i, caps_num_c, 9])#############最后一维从 改

    return votes

In [None]:
def em_routing(votes, activation, caps_num_c, tag=False):
    iter_routing = 2
    epsilon = 1e-9
    test = []

    batch_size = int(votes.get_shape()[0])
    caps_num_i = int(activation.get_shape()[1])
    n_channels = int(votes.get_shape()[-1])

    sigma_square = []
    miu = []
    activation_out = []
    beta_v = slim.variable('beta_v', shape=[caps_num_c, n_channels], dtype=tf.float32,
                           initializer=tf.constant_initializer(0.0))
    beta_a = slim.variable('beta_a', shape=[caps_num_c], dtype=tf.float32,
                           initializer=tf.constant_initializer(0.0))

    # votes_in = tf.stop_gradient(votes, name='stop_gradient_votes')
    # activation_in = tf.stop_gradient(activation, name='stop_gradient_activation')
    votes_in = votes
    activation_in = activation
    
    for iters in range(iter_routing):
        # if iters == cfg.iter_routing-1:

        # e-step
        if iters == 0:
            r = tf.constant(np.ones([batch_size, caps_num_i, caps_num_c], dtype=np.float32) / caps_num_c)
        else:
            # Contributor: Yunzhi Shi
            # log and exp here provide higher numerical stability especially for bigger number of iterations
            log_p_c_h = -tf.log(tf.sqrt(sigma_square)) - \
                        (tf.square(votes_in - miu) / (2 * sigma_square))
            log_p_c_h = log_p_c_h - \
                        (tf.reduce_max(log_p_c_h, axis=[2, 3], keep_dims=True) - tf.log(10.0))
            p_c = tf.exp(tf.reduce_sum(log_p_c_h, axis=3))

            ap = p_c * tf.reshape(activation_out, shape=[batch_size, 1, caps_num_c])

            # ap = tf.reshape(activation_out, shape=[batch_size, 1, caps_num_c])

            r = ap / (tf.reduce_sum(ap, axis=2, keep_dims=True) + epsilon)

        # m-step
        r = r * activation_in
        r = r / (tf.reduce_sum(r, axis=2, keep_dims=True)+epsilon)

        r_sum = tf.reduce_sum(r, axis=1, keep_dims=True)
        r1 = tf.reshape(r / (r_sum + epsilon),
                        shape=[batch_size, caps_num_i, caps_num_c, 1])

        miu = tf.reduce_sum(votes_in * r1, axis=1, keep_dims=True)
        sigma_square = tf.reduce_sum(tf.square(votes_in - miu) * r1,
                                     axis=1, keep_dims=True) + epsilon
        
        if iters == iter_routing-1:
            r_sum = tf.reshape(r_sum, [batch_size, caps_num_c, 1])
            cost_h = (beta_v + tf.log(tf.sqrt(tf.reshape(sigma_square,
                                                         shape=[batch_size, caps_num_c, n_channels])))) * r_sum

            activation_out = tf.nn.softmax(0.01 * (beta_a - tf.reduce_sum(cost_h, axis=2)))
        else:
            activation_out = tf.nn.softmax(r_sum)
        # if iters <= cfg.iter_routing-1:
        #     activation_out = tf.stop_gradient(activation_out, name='stop_gradient_activation')

    return miu, activation_out, test

In [None]:
def get_coord_add(dataset_name: str):
    import numpy as np
    # TODO: get coord add for cifar10/100 datasets (32x32x3)
    options = {'mnist': ([[[8., 8.], [12., 8.]],
                          [[8., 12.], [12., 12.]]], 28.),
               'smallNORB': ([[[8., 8.], [12., 8.], [16., 8.], [24., 8.]],
                              [[8., 12.], [12., 12.], [16., 12.], [24., 12.]],
                              [[8., 16.], [12., 16.], [16., 16.], [24., 16.]],
                              [[8., 24.], [12., 24.], [16., 24.], [24., 24.]]], 32.)
               }
    coord_add, scale = options[dataset_name]

    coord_add = np.array(coord_add, dtype=np.float32) / scale

    return coord_add

In [None]:
coord_add = get_coord_add('mnist')

In [None]:
num_classes = 10

In [None]:
X = tf.placeholder(shape=[50, 28, 28, 1], dtype=tf.float32, name="X")

In [None]:
output, pose_out = build_arch(X, coord_add, is_train=True, num_classes=num_classes)

In [None]:
def spread_loss(output, pose_out, x, y, m):
    
    batch_size = 50

    num_class = int(output.get_shape()[-1])
    data_size = int(x.get_shape()[1])

    y = tf.one_hot(y, num_class, dtype=tf.float32)

    # spread loss
    output1 = tf.reshape(output, shape=[batch_size, 1, num_class])
    y = tf.expand_dims(y, axis=2)
    at = tf.matmul(output1, y)
    """Paper eq(5)."""
    mb = at - output1
    loss = tf.matmul(tf.square(tf.maximum(0., m - mb)), 1. - y)
    loss = tf.reduce_mean(loss)

    pose_out = tf.reshape(tf.multiply(pose_out, y), shape=[batch_size, -1])
    tf.logging.info("decoder input value dimension:{}".format(pose_out.get_shape()))
    
    with tf.variable_scope('decoder'):
        pose_out = slim.fully_connected(pose_out, 512, trainable=True, weights_regularizer=tf.contrib.layers.l2_regularizer(5e-04))
        pose_out = slim.fully_connected(pose_out, 1024, trainable=True, weights_regularizer=tf.contrib.layers.l2_regularizer(5e-04))
        pose_out = slim.fully_connected(pose_out, data_size * data_size,
                                        trainable=True, activation_fn=tf.sigmoid, weights_regularizer=tf.contrib.layers.l2_regularizer(5e-04))

        x = tf.reshape(x, shape=[batch_size, -1])
        reconstruction_loss = tf.reduce_mean(tf.square(pose_out - x))

    if False:
        # regularization loss
        regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        # loss+0.0005*reconstruction_loss+regularization#
        loss_all = tf.add_n([loss] + [0.0005 * data_size* data_size * reconstruction_loss] + regularization)
    else:
        loss_all = tf.add_n([loss] + [0.0005 * data_size* data_size * reconstruction_loss])

    return loss_all, loss, reconstruction_loss, pose_out

In [None]:
def test_accuracy(logits, labels):
    batch_size = 50
    logits_idx = tf.to_int32(tf.argmax(logits, axis=1))
    logits_idx = tf.reshape(logits_idx, shape=(batch_size,))
    correct_preds = tf.equal(tf.to_int32(labels), logits_idx)
    accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32)) / batch_size

    return accuracy

In [None]:
m_op = tf.placeholder(dtype=tf.float32, shape=())

In [None]:
y = tf.placeholder(shape=[None], dtype=tf.int32, name="y")

In [None]:
loss, spread_loss, mse, _ = spread_loss(
                    output, pose_out, X, y, m_op)

In [None]:
accuracy = test_accuracy(output, y)

In [None]:
optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss, name="training_op")

In [None]:
init = tf.group(
                   tf.global_variables_initializer(),
                   tf.local_variables_initializer ()
          )
saver = tf.train.Saver()

In [None]:
n_epochs = 50
batch_size = 50
       
restore_checkpoint = True

n_iterations_per_epoch = mnist.train.num_examples // batch_size
n_iterations_validation = mnist.validation.num_examples // batch_size
best_loss_val = np.infty
checkpoint_path = "./my_capsule_network4"

with tf.Session() as sess:
    init.run()
    
    
    loss1_vals = []
    m_vals = np.zeros((1))
    loss1_vals = np.zeros((1, 50, 10))
    
    m_min = 0.2
    m_max = 0.9
    m = m_max
    
    for epoch in range(n_epochs):
        
        for iteration in range(1, n_iterations_per_epoch + 1):
            try:
                X_batch, y_batch = mnist.train.next_batch(batch_size)
                _, loss_train= sess.run(
                    [training_op, loss],
                    feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
                               y: y_batch,
                               m_op: m})
                print("\rIteration: {}/{} ({:.1f}%)  Loss: {:.5f}  m : {:.5f}".format(
                          iteration, n_iterations_per_epoch,
                          iteration * 100 / n_iterations_per_epoch,
                          loss_train,  m),
                      end="")
            # 运行训练操作并且评估损失:
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning('%d iteration contains NaN gradients. Discard.' % step)
                continue

            
        loss_vals = []
        acc_vals = []
        
        
        for iteration in range(1, n_iterations_validation + 1):
            X_batch, y_batch = mnist.validation.next_batch(batch_size)
            loss_val, acc_val = sess.run(
                    [loss, accuracy],
                    feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
                               y: y_batch,
                               m_op: m})
            loss_vals.append(loss_val)
            acc_vals.append(acc_val)
            print("\rEvaluating the model: {}/{} ({:.1f}%)".format(
                      iteration, n_iterations_validation,
                      iteration * 100 / n_iterations_validation),
                  end=" " * 10)
            

        loss_val = np.mean(loss_vals)
        acc_val = np.mean(acc_vals)
        print("\rEpoch: {}  Val accuracy: {:.4f}%  Loss: {:.5f}{}".format(
            epoch + 1, acc_val * 100, loss_val,
            " (improved)" if loss_val < best_loss_val else ""))
        
            

        # And save the model if it improved:
        if loss_val < best_loss_val:
            save_path = saver.save(sess, checkpoint_path)
            best_loss_val = loss_val