diff --git a/README.md b/README.md index 5e69fcf..6422232 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ my weChat: ![my_wechat](/imgs/my_wechat_QR.png) - NumPy - [Tensorflow](https://github.com/tensorflow/tensorflow) (I'm using 1.3.0, others should work, too) - tqdm (for showing training progress info) +- scipy (for saving image) ## Usage @@ -83,6 +84,7 @@ Results for the 'wrong' version(Issues #8): ![reconstruction_loss](imgs/reconstruction_loss.png) - test acc + |Epoch|49|51| |:----:|:----:|:--:| |test acc|94.69|94.71| diff --git a/capsLayer.py b/capsLayer.py index 3599431..6b4925f 100644 --- a/capsLayer.py +++ b/capsLayer.py @@ -4,143 +4,144 @@ from config import cfg -class CapsConv(object): +class CapsLayer(object): ''' Capsule layer. Args: input: A 4-D tensor. - num_units: integer, the length of the output vector of a capsule. + num_outputs: the number of capsule in this layer. + vec_len: integer, the length of the output vector of a capsule. + layer_type: string, one of 'FC' or "CONV", the type of this layer, + fully connected or convolution, for the future expansion capability with_routing: boolean, this capsule is routing with the lower-level layer capsule. - num_outputs: the number of capsule in this layer. Returns: A 4-D tensor. ''' - def __init__(self, num_units, with_routing=True): - self.num_units = num_units - self.with_routing = with_routing - - def __call__(self, input, num_outputs, kernel_size=None, stride=None): + def __init__(self, num_outputs, vec_len, with_routing=True, layer_type='FC'): self.num_outputs = num_outputs - self.kernel_size = kernel_size - self.stride = stride - - if not self.with_routing: - # the PrimaryCaps layer - # input: [batch_size, 20, 20, 256] - assert input.get_shape() == [cfg.batch_size, 20, 20, 256] - - capsules = [] - for i in range(self.num_units): - # each capsule i: [batch_size, 6, 6, 32] - with tf.variable_scope('ConvUnit_' + str(i)): - caps_i = tf.contrib.layers.conv2d(input, - self.num_outputs, - self.kernel_size, - self.stride, - padding="VALID") - caps_i = tf.reshape(caps_i, shape=(cfg.batch_size, -1, 1, 1)) - capsules.append(caps_i) - - assert capsules[0].get_shape() == [cfg.batch_size, 1152, 1, 1] - - # [batch_size, 1152, 8, 1] - capsules = tf.concat(capsules, axis=2) - capsules = squash(capsules) - assert capsules.get_shape() == [cfg.batch_size, 1152, 8, 1] - - else: - # the DigitCaps layer - # Reshape the input into shape [batch_size, 1152, 8, 1] - self.input = tf.reshape(input, shape=(cfg.batch_size, 1152, 8, 1)) - - # b_IJ: [1, num_caps_l, num_caps_l_plus_1, 1] - b_IJ = tf.zeros(shape=[1, 1152, 10, 1], dtype=np.float32) - capsules = [] - for j in range(self.num_outputs): - with tf.variable_scope('caps_' + str(j)): - caps_j, b_IJ = capsule(input, b_IJ, j) - capsules.append(caps_j) - - # Return a tensor with shape [batch_size, 10, 16, 1] - capsules = tf.concat(capsules, axis=1) - assert capsules.get_shape() == [cfg.batch_size, 10, 16, 1] - - return(capsules) - - -def capsule(input, b_IJ, idx_j): - ''' The routing algorithm for one capsule in the layer l+1. + self.vec_len = vec_len + self.with_routing = with_routing + self.layer_type = layer_type + + def __call__(self, input, kernel_size=None, stride=None): + ''' + The parameters 'kernel_size' and 'stride' will be used only when 'layer_type' equal 'CONV' + ''' + if self.layer_type == 'CONV': + self.kernel_size = kernel_size + self.stride = stride + + if not self.with_routing: + # the PrimaryCaps layer, a convolutional layer + # input: [batch_size, 20, 20, 256] + assert input.get_shape() == [cfg.batch_size, 20, 20, 256] + + capsules = [] + for i in range(self.vec_len): + # each capsule i: [batch_size, 6, 6, 32] + with tf.variable_scope('ConvUnit_' + str(i)): + caps_i = tf.contrib.layers.conv2d(input, self.num_outputs, + self.kernel_size, self.stride, + padding="VALID") + caps_i = tf.reshape(caps_i, shape=(cfg.batch_size, -1, 1, 1)) + capsules.append(caps_i) + + assert capsules[0].get_shape() == [cfg.batch_size, 1152, 1, 1] + + # [batch_size, 1152, 8, 1] + capsules = tf.concat(capsules, axis=2) + capsules = squash(capsules) + assert capsules.get_shape() == [cfg.batch_size, 1152, 8, 1] + return(capsules) + + if self.layer_type == 'FC': + if self.with_routing: + # the DigitCaps layer, a fully connected layer + # Reshape the input into [batch_size, 1, 1152, 8, 1] + self.input = tf.reshape(input, shape=(cfg.batch_size, 1152, 1, 8, 1)) + + with tf.variable_scope('routing'): + # b_IJ: [1, 1, num_caps_l, num_caps_l_plus_1, 1] + b_IJ = tf.zeros(shape=[1, 1152, 10, 1, 1], dtype=np.float32) + capsules = routing(self.input, b_IJ) + capsules = tf.squeeze(capsules, axis=1) + + return(capsules) + + +def routing(input, b_IJ): + ''' The routing algorithm. Args: - input: A Tensor with [batch_size, num_caps_l=1152, length(u_i)=8, 1] + input: A Tensor with [batch_size, 1, num_caps_l=1152, length(u_i)=8, 1] shape, num_caps_l meaning the number of capsule in the layer l. Returns: - A Tensor of shape [batch_size, 1, length(v_j)=16, 1] representing the - vector output `v_j` of capsule j in the layer l+1 + A Tensor of shape [batch_size, num_caps_l_plus_1, length(v_j)=16, 1] + representing the vector output `v_j` in the layer l+1 Notes: u_i represents the vector output of capsule i in the layer l, and v_j the vector output of capsule j in the layer l+1. ''' - with tf.variable_scope('routing'): - w_initializer = np.random.normal(size=[1, 1152, 8, 16], scale=0.01) - W_Ij = tf.Variable(w_initializer, dtype=tf.float32) - # repeat W_Ij with batch_size times to shape [batch_size, 1152, 8, 16] - W_Ij = tf.tile(W_Ij, [cfg.batch_size, 1, 1, 1]) - - # calc u_hat - # [8, 16].T x [8, 1] => [16, 1] => [batch_size, 1152, 16, 1] - u_hat = tf.matmul(W_Ij, input, transpose_a=True) - assert u_hat.get_shape() == [cfg.batch_size, 1152, 16, 1] - - shape = b_IJ.get_shape().as_list() - size_splits = [idx_j, 1, shape[2] - idx_j - 1] - for r_iter in range(cfg.iter_routing): - # line 4: - # [1, 1152, 10, 1] - c_IJ = tf.nn.softmax(b_IJ, dim=2) - assert c_IJ.get_shape() == [1, 1152, 10, 1] - - # line 5: - # weighting u_hat with c_I in the third dim, - # then sum in the second dim, resulting in [batch_size, 1, 16, 1] - b_Il, b_Ij, b_Ir = tf.split(b_IJ, size_splits, axis=2) - c_Il, c_Ij, b_Ir = tf.split(c_IJ, size_splits, axis=2) - assert c_Ij.get_shape() == [1, 1152, 1, 1] - - s_j = tf.multiply(c_Ij, u_hat) - s_j = tf.reduce_sum(tf.multiply(c_Ij, u_hat), - axis=1, keep_dims=True) - assert s_j.get_shape() == [cfg.batch_size, 1, 16, 1] - - # line 6: - # squash using Eq.1, resulting in [batch_size, 1, 16, 1] - v_j = squash(s_j) - assert s_j.get_shape() == [cfg.batch_size, 1, 16, 1] - - # line 7: - # tile v_j from [batch_size ,1, 16, 1] to [batch_size, 1152, 16, 1] - # [16, 1].T x [16, 1] => [1, 1], then reduce mean in the - # batch_size dim, resulting in [1, 1152, 1, 1] - v_j_tiled = tf.tile(v_j, [1, 1152, 1, 1]) - u_produce_v = tf.matmul(u_hat, v_j_tiled, transpose_a=True) - assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 1, 1] - b_Ij += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True) - b_IJ = tf.concat([b_Il, b_Ij, b_Ir], axis=2) - - return(v_j, b_IJ) + # W: [num_caps_j, num_caps_i, len_u_i, len_v_j] + W = tf.get_variable('Weight', shape=(1, 1152, 10, 8, 16), dtype=tf.float32, + initializer=tf.random_normal_initializer(stddev=cfg.stddev)) + + # Eq.2, calc u_hat + # do tiling for input and W before matmul + # input => [batch_size, 1152, 10, 8, 1] + # W => [batch_size, 1152, 10, 8, 16] + input = tf.tile(input, [1, 1, 10, 1, 1]) + W = tf.tile(W, [cfg.batch_size, 1, 1, 1, 1]) + assert input.get_shape() == [cfg.batch_size, 1152, 10, 8, 1] + + # in last 2 dims: + # [8, 16].T x [8, 1] => [16, 1] => [batch_size, 1152, 10, 16, 1] + u_hat = tf.matmul(W, input, transpose_a=True) + assert u_hat.get_shape() == [cfg.batch_size, 1152, 10, 16, 1] + + # line 3,for r iterations do + for r_iter in range(cfg.iter_routing): + # line 4: + # => [1, 1, 1152, 10, 1] + c_IJ = tf.nn.softmax(b_IJ, dim=3) + c_IJ = tf.tile(c_IJ, [cfg.batch_size, 1, 1, 1, 1]) + assert c_IJ.get_shape() == [cfg.batch_size, 1152, 10, 1, 1] + + # line 5: + # weighting u_hat with c_IJ, element-wise in the last tow dim + # => [batch_size, 1152, 10, 16, 1] + s_J = tf.multiply(c_IJ, u_hat) + # then sum in the second dim, resulting in [batch_size, 1, 10, 16, 1] + s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True) + assert s_J.get_shape() == [cfg.batch_size, 1, 10, 16, 1] + + # line 6: + # squash using Eq.1, + v_J = squash(s_J) + assert v_J.get_shape() == [cfg.batch_size, 1, 10, 16, 1] + + # line 7: + # reshape & tile v_j from [batch_size ,1, 10, 16, 1] to [batch_size, 10, 1152, 16, 1] + # then matmul in the last tow dim: [16, 1].T x [16, 1] => [1, 1], reduce mean in the + # batch_size dim, resulting in [1, 1152, 10, 1, 1] + v_J_tiled = tf.tile(v_J, [1, 1152, 1, 1, 1]) + u_produce_v = tf.matmul(u_hat, v_J_tiled, transpose_a=True) + assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 10, 1, 1] + b_IJ += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True) + + return(v_J) def squash(vector): - '''Squashing function. + '''Squashing function corresponding to Eq. 1 Args: - vector: A 4-D tensor with shape [batch_size, num_caps, vec_len, 1], + vector: A 5-D tensor with shape [batch_size, 1, num_caps, vec_len, 1], Returns: - A 4-D tensor with the same shape as vector but - squashed in 3rd and 4th dimensions. + A 5-D tensor with the same shape as vector but squashed in 4rd and 5th dimensions. ''' - vec_abs = tf.sqrt(tf.reduce_sum(tf.square(vector))) # a scalar - scalar_factor = tf.square(vec_abs) / (1 + tf.square(vec_abs)) - vec_squashed = scalar_factor * tf.divide(vector, vec_abs) # element-wise + vec_squared_norm = tf.reduce_sum(tf.square(vector), -2, keep_dims=True) + scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm) + vec_squashed = scalar_factor * vector # element-wise return(vec_squashed) diff --git a/capsNet.py b/capsNet.py index 1cfae25..6f74d22 100644 --- a/capsNet.py +++ b/capsNet.py @@ -2,7 +2,7 @@ from config import cfg from utils import get_batch_data -from capsLayer import CapsConv +from capsLayer import CapsLayer class CapsNet(object): @@ -34,19 +34,16 @@ def build_arch(self): padding='VALID') assert conv1.get_shape() == [cfg.batch_size, 20, 20, 256] - # TODO: Rewrite the 'CapsConv' class as a function, the capsLay - # function should be encapsulated into tow function, one like conv2d - # and another is fully_connected in Tensorflow. - # Primary Capsules, [batch_size, 1152, 8, 1] + # Primary Capsules layer, return [batch_size, 1152, 8, 1] with tf.variable_scope('PrimaryCaps_layer'): - primaryCaps = CapsConv(num_units=8, with_routing=False) - caps1 = primaryCaps(conv1, num_outputs=32, kernel_size=9, stride=2) + primaryCaps = CapsLayer(num_outputs=32, vec_len=8, with_routing=False, layer_type='CONV') + caps1 = primaryCaps(conv1, kernel_size=9, stride=2) assert caps1.get_shape() == [cfg.batch_size, 1152, 8, 1] - # DigitCaps layer, [batch_size, 10, 16, 1] + # DigitCaps layer, return [batch_size, 10, 16, 1] with tf.variable_scope('DigitCaps_layer'): - digitCaps = CapsConv(num_units=16, with_routing=True) - self.caps2 = digitCaps(caps1, num_outputs=10) + digitCaps = CapsLayer(num_outputs=10, vec_len=16, with_routing=True, layer_type='FC') + self.caps2 = digitCaps(caps1) # Decoder structure in Fig. 2 # 1. Do masking, how: @@ -60,7 +57,7 @@ def build_arch(self): # b). pick out the index of max softmax val of the 10 caps # [batch_size, 10, 1, 1] => [batch_size] (index) - argmax_idx = tf.argmax(self.softmax_v, axis=1, output_type=tf.int32) + argmax_idx = tf.to_int32(tf.argmax(self.softmax_v, axis=1)) assert argmax_idx.get_shape() == [cfg.batch_size, 1, 1] # c). indexing diff --git a/config.py b/config.py index f1d3d40..00ff226 100644 --- a/config.py +++ b/config.py @@ -11,10 +11,14 @@ flags.DEFINE_float('m_plus', 0.9, 'the parameter of m plus') flags.DEFINE_float('m_minus', 0.1, 'the parameter of m minus') flags.DEFINE_float('lambda_val', 0.5, 'down weight of the loss for absent digit classes') -flags.DEFINE_integer('batch_size', 128, 'batch size') -flags.DEFINE_integer('epoch', 500, 'epoch') + +# for training +flags.DEFINE_integer('batch_size', 32, 'batch size') +flags.DEFINE_integer('epoch', 50, 'epoch') flags.DEFINE_integer('iter_routing', 3, 'number of iterations in routing algorithm') +flags.DEFINE_float('stddev', 0.01, 'stddev for W initializer') + ############################ # environment setting #