# CapsNet on MNIST
- Author : *Jinhong Du*

- Reference: 
    
    1. https://github.com/naturomics/CapsNet-Tensorflow/blob/master
    2. https://www.jiqizhixin.com/articles/2017-11-05

# Content

1. [Import Related Modules and Packages](#Sec1)
2. [Hyperparameters](#Sec2)
3. [Prepare for data](#Sec3)
    - [Load Data](#Sec3.1)
    - [Data Preprocessing](#Sec3.1)
    - [Generate Dataset](#Sec3.1) 
4. [Build RNN from scratch](#Sec4)
    - [Primary Capsules](#Sec4.1)
    - [Digit Capsules](#Sec4.2)
    - [Build CapsNet](#Sec4.3)
    - [Loss Function](#Sec4.4) 
    - [Accuracy Evaluator](#Sec4.5)
    - [Optimizer](#Sec4.6)
    - [Training](#Sec4.8)    

## 1. Import Related Modules and Packages<a id='Sec1'></a>

In [1]:
import tensorflow as tf
tfe = tf.contrib.eager
# Enable eager execution mode
tf.enable_eager_execution()

import numpy as np
import matplotlib.pyplot as plt
import time

## 2. Hyperparameters<a id='Sec2'></a>

In [None]:
NUM_CLASS = 10

## 3. Prepare for Data<a id='Sec3'></a>

### 3.1. Load Data<a id='Sec3.1'></a>

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='mnist.npz')

print('Training set:')
print('       Data :\t shape:', np.shape(x_train), '\t type:', x_train.dtype)
print('       Label:\t shape:', np.shape(y_train), '\t\t type:', y_train.dtype)
print('Testing set :')
print('       Data :\t shape:', np.shape(x_test), '\t type:', x_test.dtype)
print('       Label:\t shape:', np.shape(y_test), '\t\t type:', y_test.dtype)

### 3.2. Data Preprocessing<a id='Sec3.2'></a>

In [None]:
# cast to float and standardize to [0,1]
x_train = x_train.astype(np.float32)/255
x_test = x_test.astype(np.float32)/255
print(np.shape(x_train), np.shape(x_test))

# transform lables to ont-hot vectors
y_train = tf.one_hot(y_train, NUM_CLASS, dtype=tf.float32)
y_test = tf.one_hot(y_test, NUM_CLASS, dtype=tf.float32)
print(y_train, y_test)

### 3.3. Generate Dataset<a id='Sec3.3'></a>

In [None]:
# Generate training Dataset
TrainDataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Shuffle per buffer_size
TrainDataset = TrainDataset.shuffle(buffer_size=5000)
# Batch size
TrainDataset = TrainDataset.batch(BATCH_SIZE, drop_remainder=True)

# Generate testing Dataset
TestDataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE, drop_remainder=True)

## 4. Build Model<a id='Sec4'></a>

### 4.1 Squash Function<a id='Sec4.1'></a>

In [None]:
def squash(vector):
    '''Squashing function corresponding to Eq. 1
    Args:
        vector: A tensor with shape [batch_size, 1, num_caps, vec_len, 1] or [batch_size, num_caps, vec_len, 1].
    Returns:
        A tensor with the same shape as vector but squashed in 'vec_len' dimension.
    '''
    vec_squared_norm = reduce_sum(tf.square(vector), -2, keepdims=True)
    scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + epsilon)
    vec_squashed = scalar_factor * vector  # element-wise
    return(vec_squashed)

### 4.2 Routing Function<a id='Sec4.2'></a>

In [None]:
def routing(input_tensor, b, W, biases, num_outputs=10, num_dims=16):
    ''' The routing algorithm.
    Args:
        input_tensor: A Tensor with [batch_size, num_caps_l=1152, 1, length(u_i)=8, 1]
               shape, num_caps_l meaning the number of capsule in the layer l.
        num_outputs: the number of output capsules.
        num_dims: the number of dimensions for output capsule.
    Returns:
        A Tensor of shape [batch_size, num_caps_l_plus_1, length(v_j)=16, 1]
        representing the vector output `v_j` in the layer l+1
    Notes:
        u_i represents the vector output of capsule i in the layer l, and
        v_j the vector output of capsule j in the layer l+1.
     '''

    
    # W: [1, num_caps_i, num_caps_j * len_v_j, len_u_j, 1]
    input_shape = get_shape(input_tensor)

    # Eq.2, calc u_hat
    # Since tf.matmul is a time-consuming op,
    # A better solution is using element-wise multiply, reduce_sum and reshape
    # ops instead. Matmul [a, b] x [b, c] is equal to a series ops as
    # element-wise multiply [a*c, b] * [a*c, b], reduce_sum at axis=1 and
    # reshape to [a, c]
    input_tensor = tf.tile(input_tensor, [1, 1, num_dims * num_outputs, 1, 1])
    # assert input.get_shape() == [cfg.batch_size, 1152, 160, 8, 1]

    u_hat = reduce_sum(W * input_tensor, axis=3, keepdims=True)
    u_hat = tf.reshape(u_hat, shape=[-1, input_shape[1], num_outputs, num_dims, 1])
    # assert u_hat.get_shape() == [cfg.batch_size, 1152, 10, 16, 1]

    # In forward, u_hat_stopped = u_hat; in backward, no gradient passed back from u_hat_stopped to u_hat
    u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')

    # line 3,for r iterations do
    for r_iter in range(cfg.iter_routing):
        # line 4:
        # => [batch_size, 1152, 10, 1, 1]
        c_IJ = softmax(b, axis=2)

        # At last iteration, use `u_hat` in order to receive gradients from the following graph
        if r_iter == cfg.iter_routing - 1:
            # line 5:
            # weighting u_hat with c_IJ, element-wise in the last two dims
            # => [batch_size, 1152, 10, 16, 1]
            s_J = tf.multiply(c_IJ, u_hat)
            # then sum in the second dim, resulting in [batch_size, 1, 10, 16, 1]
            s_J = reduce_sum(s_J, axis=1, keepdims=True) + biases
            # assert s_J.get_shape() == [cfg.batch_size, 1, num_outputs, num_dims, 1]

            # line 6:
            # squash using Eq.1,
            v_J = squash(s_J)
            # assert v_J.get_shape() == [cfg.batch_size, 1, 10, 16, 1]
        elif r_iter < cfg.iter_routing - 1:  # Inner iterations, do not apply backpropagation
            s_J = tf.multiply(c_IJ, u_hat_stopped)
            s_J = reduce_sum(s_J, axis=1, keepdims=True) + biases
            v_J = squash(s_J)

            # line 7:
            # reshape & tile v_j from [batch_size ,1, 10, 16, 1] to [batch_size, 1152, 10, 16, 1]
            # then matmul in the last tow dim: [16, 1].T x [16, 1] => [1, 1], reduce mean in the
            # batch_size dim, resulting in [1, 1152, 10, 1, 1]
            v_J_tiled = tf.tile(v_J, [1, input_shape[1], 1, 1, 1])
            u_produce_v = reduce_sum(u_hat_stopped * v_J_tiled, axis=3, keepdims=True)
            # assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 10, 1, 1]

            # b += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True)
            b += u_produce_v

    return(v_J)

### 4.3. Primary Capsules<a id='Sec4.3'></a>

In [None]:
class PrimaryCaps(object):
    def __init(self):
        self.num_outputs
        self.caps = [
            tf.layers.Conv2D(filters = self.num_outputs, kernel_size = 9, strides = 2,
                             padding = 'valid', activation = None, name = 'PrimaryCapsules%d'%i) for i in range(6)]
    
    def call(self, input_tensor): 
        # TO DO    
        for i in range(self.vec_len):
            # 将一般卷积的结果张量拉平，并为添加到列表中
            caps_i = tf.reshape(self.caps[i], shape=(batch_size, -1, 1, 1))
            capsules.append(caps_i)

        # 合并为PrimaryCaps的输出张量，即6×6×32个长度为8的向量，合并后的维度为 [batch_size, 1152, 8, 1]
        capsules = tf.concat(capsules, axis=2)
        # 将每个Capsule 向量投入非线性函数squash进行缩放与激活
        capsules = squash(capsules)

### 4.4. Digit Capsules<a id='Sec4.4'></a>

In [None]:
class DigitCaps(object):
    def __init(self):
        self.input_shape = 
        self.num_outputs = 
        self.num_dims = 
        
        self.W = tfe.Variable(
            tf.random.normal(shape=[1, self.input_shape[1], self.num_dims * self.num_outputs] + self.input_shape[-2:],
            dtype=tf.float32, stddev=0.01))
        self.biases = tfe.Variable(
            tf.zeros(shape=(1, 1, self.num_outputs, self.num_dims, 1),
            dtype=tf.float32))
    
    def call(self, input_tensor): 
        # TO DO   
        b = tf.constant(tf.zeros([1, self.input_shape[1], self.num_outputs, 1, 1], dtype=np.float32))
        capsules = routing(input_tensor, self.b, self.W, self.biases)
        #将s_j投入 squeeze 函数以得出 DigitCaps 层的输出向量
        capsules = tf.squeeze(capsules, axis=1)
        return capsules

### 4.5. Build CapsNet<a id='Sec4.5'></a>

In [None]:
class CapsNet(tf.keras.Model):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.device = device
        self._input_shape = [-1, 28, 28, 1]
        
        # 1st layer
        self.Conv1 = tf.layers.Conv2D(filters = 256, kernel_size = 9, strides = 1,
                                      padding = 'valid', activation = tf.nn.relu, name = 'Conv1')
        
        # 2nd layer - PrimaryCaps
        self.PrimaryCaps = PrimaryCaps()
  
        # 3rd layer - DigitCaps
        self.DigitCaps = DigitCaps()
    
    def call(self, input_tensor):
        input_tensor = self.Conv1(input_tensor)
        input_tensor = self.PrimaryCaps(input_tensor)
        input_tensor = self.DigitCaps(input_tensor)
        return input_tensor

### 4.6. Loss Function<a id='Sec4.6'></a>

In [None]:
def MarginLoss(v, Y, m_plus = 0.9, m_minus = 0.1, lambda_val = 0.5):
    '''
    Calculate the sum of separate digit margin loss for every samples
    and average it over batched samples.
    Input:
        Y         - One-hot labels.
        v         - The output tensor of the DigitCaps layer.   
    '''    
    batch_size = tf.shape(v)[0]
    # [batch_size, 10, 1, 1]
    # max_l = max(0, m_plus-||v_c||)^2
    max_l = tf.square(tf.maximum(0., m_plus - v_length))
    # max_r = max(0, ||v_c||-m_minus)^2
    max_r = tf.square(tf.maximum(0., v_length - m_minus))    
    
    # reshape: [batch_size, 10, 1, 1] => [batch_size, 10]
    max_l = tf.reshape(max_l, shape=(batch_size, -1))
    max_r = tf.reshape(max_r, shape=(batch_size, -1))
    
    # calc T_c: [batch_size, 10]
    T_c = Y
    # [batch_size, 10], element-wise multiply
    L_c = T_c * max_l + lambda_val * (1 - T_c) * max_r

    return tf.reduce_mean(tf.reduce_sum(L_c, axis=1))

def ReconstructionLoss(X, X_decoded):
    '''
    Calculate the sum of squared construction error for every samples
    and average it over batched samples.
    Input:
        X         - Input tensor.
        X_decoded - The output tensor of the reconstruction layer.    
    '''
    batch_size = tf.shape(X)[0]
    orgin = tf.reshape(X, shape=(batch_size, -1))
    squared = tf.square(X_decoded - orgin)
    return tf.reduce_mean(tf.reduce_sum(squared, axis=1))
    
def loss(X, Y, v, X_decoded, regularization_scale = 0.0005):
    '''
    Input:
        X         - Input tensor.
        Y         - One-hot labels.
        v         - The output tensor of the DigitCaps layer.
        X_decoded - The output tensor of the reconstruction layer.
    '''
    return MarginLoss(v, Y) + regularization_scale * ReconstructionLoss(X, X_decoded)

### 4.7. Accuracy Evaluator<a id='Sec4.7'></a>

### 4.8. Optimizer<a id='Sec4.8'></a>

### 4.9. Training<a id='Sec4.9'></a>