Skip to content

Commit

Permalink
Fixing issue #8, and reimplement the routing algorithm in a matrix way
Browse files Browse the repository at this point in the history
  • Loading branch information
naturomics committed Nov 1, 2017
2 parents 4cba77d + 24ac9fb commit 2c8d0ce
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 127 deletions.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -37,6 +37,7 @@ my weChat: ![my_wechat](/imgs/my_wechat_QR.png)
- NumPy
- [Tensorflow](https://github.com/tensorflow/tensorflow) (I'm using 1.3.0, others should work, too)
- tqdm (for showing training progress info)
- scipy (for saving image)

## Usage

Expand Down Expand Up @@ -83,6 +84,7 @@ Results for the 'wrong' version(Issues #8):
![reconstruction_loss](imgs/reconstruction_loss.png)

- test acc

|Epoch|49|51|
|:----:|:----:|:--:|
|test acc|94.69|94.71|
Expand Down
229 changes: 115 additions & 114 deletions capsLayer.py
Expand Up @@ -4,143 +4,144 @@
from config import cfg


class CapsConv(object):
class CapsLayer(object):
''' Capsule layer.
Args:
input: A 4-D tensor.
num_units: integer, the length of the output vector of a capsule.
num_outputs: the number of capsule in this layer.
vec_len: integer, the length of the output vector of a capsule.
layer_type: string, one of 'FC' or "CONV", the type of this layer,
fully connected or convolution, for the future expansion capability
with_routing: boolean, this capsule is routing with the
lower-level layer capsule.
num_outputs: the number of capsule in this layer.
Returns:
A 4-D tensor.
'''
def __init__(self, num_units, with_routing=True):
self.num_units = num_units
self.with_routing = with_routing

def __call__(self, input, num_outputs, kernel_size=None, stride=None):
def __init__(self, num_outputs, vec_len, with_routing=True, layer_type='FC'):
self.num_outputs = num_outputs
self.kernel_size = kernel_size
self.stride = stride

if not self.with_routing:
# the PrimaryCaps layer
# input: [batch_size, 20, 20, 256]
assert input.get_shape() == [cfg.batch_size, 20, 20, 256]

capsules = []
for i in range(self.num_units):
# each capsule i: [batch_size, 6, 6, 32]
with tf.variable_scope('ConvUnit_' + str(i)):
caps_i = tf.contrib.layers.conv2d(input,
self.num_outputs,
self.kernel_size,
self.stride,
padding="VALID")
caps_i = tf.reshape(caps_i, shape=(cfg.batch_size, -1, 1, 1))
capsules.append(caps_i)

assert capsules[0].get_shape() == [cfg.batch_size, 1152, 1, 1]

# [batch_size, 1152, 8, 1]
capsules = tf.concat(capsules, axis=2)
capsules = squash(capsules)
assert capsules.get_shape() == [cfg.batch_size, 1152, 8, 1]

else:
# the DigitCaps layer
# Reshape the input into shape [batch_size, 1152, 8, 1]
self.input = tf.reshape(input, shape=(cfg.batch_size, 1152, 8, 1))

# b_IJ: [1, num_caps_l, num_caps_l_plus_1, 1]
b_IJ = tf.zeros(shape=[1, 1152, 10, 1], dtype=np.float32)
capsules = []
for j in range(self.num_outputs):
with tf.variable_scope('caps_' + str(j)):
caps_j, b_IJ = capsule(input, b_IJ, j)
capsules.append(caps_j)

# Return a tensor with shape [batch_size, 10, 16, 1]
capsules = tf.concat(capsules, axis=1)
assert capsules.get_shape() == [cfg.batch_size, 10, 16, 1]

return(capsules)


def capsule(input, b_IJ, idx_j):
''' The routing algorithm for one capsule in the layer l+1.
self.vec_len = vec_len
self.with_routing = with_routing
self.layer_type = layer_type

def __call__(self, input, kernel_size=None, stride=None):
'''
The parameters 'kernel_size' and 'stride' will be used only when 'layer_type' equal 'CONV'
'''
if self.layer_type == 'CONV':
self.kernel_size = kernel_size
self.stride = stride

if not self.with_routing:
# the PrimaryCaps layer, a convolutional layer
# input: [batch_size, 20, 20, 256]
assert input.get_shape() == [cfg.batch_size, 20, 20, 256]

capsules = []
for i in range(self.vec_len):
# each capsule i: [batch_size, 6, 6, 32]
with tf.variable_scope('ConvUnit_' + str(i)):
caps_i = tf.contrib.layers.conv2d(input, self.num_outputs,
self.kernel_size, self.stride,
padding="VALID")
caps_i = tf.reshape(caps_i, shape=(cfg.batch_size, -1, 1, 1))
capsules.append(caps_i)

assert capsules[0].get_shape() == [cfg.batch_size, 1152, 1, 1]

# [batch_size, 1152, 8, 1]
capsules = tf.concat(capsules, axis=2)
capsules = squash(capsules)
assert capsules.get_shape() == [cfg.batch_size, 1152, 8, 1]
return(capsules)

if self.layer_type == 'FC':
if self.with_routing:
# the DigitCaps layer, a fully connected layer
# Reshape the input into [batch_size, 1, 1152, 8, 1]
self.input = tf.reshape(input, shape=(cfg.batch_size, 1152, 1, 8, 1))

with tf.variable_scope('routing'):
# b_IJ: [1, 1, num_caps_l, num_caps_l_plus_1, 1]
b_IJ = tf.zeros(shape=[1, 1152, 10, 1, 1], dtype=np.float32)
capsules = routing(self.input, b_IJ)
capsules = tf.squeeze(capsules, axis=1)

return(capsules)


def routing(input, b_IJ):
''' The routing algorithm.
Args:
input: A Tensor with [batch_size, num_caps_l=1152, length(u_i)=8, 1]
input: A Tensor with [batch_size, 1, num_caps_l=1152, length(u_i)=8, 1]
shape, num_caps_l meaning the number of capsule in the layer l.
Returns:
A Tensor of shape [batch_size, 1, length(v_j)=16, 1] representing the
vector output `v_j` of capsule j in the layer l+1
A Tensor of shape [batch_size, num_caps_l_plus_1, length(v_j)=16, 1]
representing the vector output `v_j` in the layer l+1
Notes:
u_i represents the vector output of capsule i in the layer l, and
v_j the vector output of capsule j in the layer l+1.
'''

with tf.variable_scope('routing'):
w_initializer = np.random.normal(size=[1, 1152, 8, 16], scale=0.01)
W_Ij = tf.Variable(w_initializer, dtype=tf.float32)
# repeat W_Ij with batch_size times to shape [batch_size, 1152, 8, 16]
W_Ij = tf.tile(W_Ij, [cfg.batch_size, 1, 1, 1])

# calc u_hat
# [8, 16].T x [8, 1] => [16, 1] => [batch_size, 1152, 16, 1]
u_hat = tf.matmul(W_Ij, input, transpose_a=True)
assert u_hat.get_shape() == [cfg.batch_size, 1152, 16, 1]

shape = b_IJ.get_shape().as_list()
size_splits = [idx_j, 1, shape[2] - idx_j - 1]
for r_iter in range(cfg.iter_routing):
# line 4:
# [1, 1152, 10, 1]
c_IJ = tf.nn.softmax(b_IJ, dim=2)
assert c_IJ.get_shape() == [1, 1152, 10, 1]

# line 5:
# weighting u_hat with c_I in the third dim,
# then sum in the second dim, resulting in [batch_size, 1, 16, 1]
b_Il, b_Ij, b_Ir = tf.split(b_IJ, size_splits, axis=2)
c_Il, c_Ij, b_Ir = tf.split(c_IJ, size_splits, axis=2)
assert c_Ij.get_shape() == [1, 1152, 1, 1]

s_j = tf.multiply(c_Ij, u_hat)
s_j = tf.reduce_sum(tf.multiply(c_Ij, u_hat),
axis=1, keep_dims=True)
assert s_j.get_shape() == [cfg.batch_size, 1, 16, 1]

# line 6:
# squash using Eq.1, resulting in [batch_size, 1, 16, 1]
v_j = squash(s_j)
assert s_j.get_shape() == [cfg.batch_size, 1, 16, 1]

# line 7:
# tile v_j from [batch_size ,1, 16, 1] to [batch_size, 1152, 16, 1]
# [16, 1].T x [16, 1] => [1, 1], then reduce mean in the
# batch_size dim, resulting in [1, 1152, 1, 1]
v_j_tiled = tf.tile(v_j, [1, 1152, 1, 1])
u_produce_v = tf.matmul(u_hat, v_j_tiled, transpose_a=True)
assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 1, 1]
b_Ij += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True)
b_IJ = tf.concat([b_Il, b_Ij, b_Ir], axis=2)

return(v_j, b_IJ)
# W: [num_caps_j, num_caps_i, len_u_i, len_v_j]
W = tf.get_variable('Weight', shape=(1, 1152, 10, 8, 16), dtype=tf.float32,
initializer=tf.random_normal_initializer(stddev=cfg.stddev))

# Eq.2, calc u_hat
# do tiling for input and W before matmul
# input => [batch_size, 1152, 10, 8, 1]
# W => [batch_size, 1152, 10, 8, 16]
input = tf.tile(input, [1, 1, 10, 1, 1])
W = tf.tile(W, [cfg.batch_size, 1, 1, 1, 1])
assert input.get_shape() == [cfg.batch_size, 1152, 10, 8, 1]

# in last 2 dims:
# [8, 16].T x [8, 1] => [16, 1] => [batch_size, 1152, 10, 16, 1]
u_hat = tf.matmul(W, input, transpose_a=True)
assert u_hat.get_shape() == [cfg.batch_size, 1152, 10, 16, 1]

# line 3,for r iterations do
for r_iter in range(cfg.iter_routing):
# line 4:
# => [1, 1, 1152, 10, 1]
c_IJ = tf.nn.softmax(b_IJ, dim=3)
c_IJ = tf.tile(c_IJ, [cfg.batch_size, 1, 1, 1, 1])
assert c_IJ.get_shape() == [cfg.batch_size, 1152, 10, 1, 1]

# line 5:
# weighting u_hat with c_IJ, element-wise in the last tow dim
# => [batch_size, 1152, 10, 16, 1]
s_J = tf.multiply(c_IJ, u_hat)
# then sum in the second dim, resulting in [batch_size, 1, 10, 16, 1]
s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
assert s_J.get_shape() == [cfg.batch_size, 1, 10, 16, 1]

# line 6:
# squash using Eq.1,
v_J = squash(s_J)
assert v_J.get_shape() == [cfg.batch_size, 1, 10, 16, 1]

# line 7:
# reshape & tile v_j from [batch_size ,1, 10, 16, 1] to [batch_size, 10, 1152, 16, 1]
# then matmul in the last tow dim: [16, 1].T x [16, 1] => [1, 1], reduce mean in the
# batch_size dim, resulting in [1, 1152, 10, 1, 1]
v_J_tiled = tf.tile(v_J, [1, 1152, 1, 1, 1])
u_produce_v = tf.matmul(u_hat, v_J_tiled, transpose_a=True)
assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 10, 1, 1]
b_IJ += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True)

return(v_J)


def squash(vector):
'''Squashing function.
'''Squashing function corresponding to Eq. 1
Args:
vector: A 4-D tensor with shape [batch_size, num_caps, vec_len, 1],
vector: A 5-D tensor with shape [batch_size, 1, num_caps, vec_len, 1],
Returns:
A 4-D tensor with the same shape as vector but
squashed in 3rd and 4th dimensions.
A 5-D tensor with the same shape as vector but squashed in 4rd and 5th dimensions.
'''
vec_abs = tf.sqrt(tf.reduce_sum(tf.square(vector))) # a scalar
scalar_factor = tf.square(vec_abs) / (1 + tf.square(vec_abs))
vec_squashed = scalar_factor * tf.divide(vector, vec_abs) # element-wise
vec_squared_norm = tf.reduce_sum(tf.square(vector), -2, keep_dims=True)
scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm)
vec_squashed = scalar_factor * vector # element-wise
return(vec_squashed)
19 changes: 8 additions & 11 deletions capsNet.py
Expand Up @@ -2,7 +2,7 @@

from config import cfg
from utils import get_batch_data
from capsLayer import CapsConv
from capsLayer import CapsLayer


class CapsNet(object):
Expand Down Expand Up @@ -34,19 +34,16 @@ def build_arch(self):
padding='VALID')
assert conv1.get_shape() == [cfg.batch_size, 20, 20, 256]

# TODO: Rewrite the 'CapsConv' class as a function, the capsLay
# function should be encapsulated into tow function, one like conv2d
# and another is fully_connected in Tensorflow.
# Primary Capsules, [batch_size, 1152, 8, 1]
# Primary Capsules layer, return [batch_size, 1152, 8, 1]
with tf.variable_scope('PrimaryCaps_layer'):
primaryCaps = CapsConv(num_units=8, with_routing=False)
caps1 = primaryCaps(conv1, num_outputs=32, kernel_size=9, stride=2)
primaryCaps = CapsLayer(num_outputs=32, vec_len=8, with_routing=False, layer_type='CONV')
caps1 = primaryCaps(conv1, kernel_size=9, stride=2)
assert caps1.get_shape() == [cfg.batch_size, 1152, 8, 1]

# DigitCaps layer, [batch_size, 10, 16, 1]
# DigitCaps layer, return [batch_size, 10, 16, 1]
with tf.variable_scope('DigitCaps_layer'):
digitCaps = CapsConv(num_units=16, with_routing=True)
self.caps2 = digitCaps(caps1, num_outputs=10)
digitCaps = CapsLayer(num_outputs=10, vec_len=16, with_routing=True, layer_type='FC')
self.caps2 = digitCaps(caps1)

# Decoder structure in Fig. 2
# 1. Do masking, how:
Expand All @@ -60,7 +57,7 @@ def build_arch(self):

# b). pick out the index of max softmax val of the 10 caps
# [batch_size, 10, 1, 1] => [batch_size] (index)
argmax_idx = tf.argmax(self.softmax_v, axis=1, output_type=tf.int32)
argmax_idx = tf.to_int32(tf.argmax(self.softmax_v, axis=1))
assert argmax_idx.get_shape() == [cfg.batch_size, 1, 1]

# c). indexing
Expand Down
8 changes: 6 additions & 2 deletions config.py
Expand Up @@ -11,10 +11,14 @@
flags.DEFINE_float('m_plus', 0.9, 'the parameter of m plus')
flags.DEFINE_float('m_minus', 0.1, 'the parameter of m minus')
flags.DEFINE_float('lambda_val', 0.5, 'down weight of the loss for absent digit classes')
flags.DEFINE_integer('batch_size', 128, 'batch size')
flags.DEFINE_integer('epoch', 500, 'epoch')

# for training
flags.DEFINE_integer('batch_size', 32, 'batch size')
flags.DEFINE_integer('epoch', 50, 'epoch')
flags.DEFINE_integer('iter_routing', 3, 'number of iterations in routing algorithm')

flags.DEFINE_float('stddev', 0.01, 'stddev for W initializer')


############################
# environment setting #
Expand Down

0 comments on commit 2c8d0ce

Please sign in to comment.