# CapsNet by chainer

* original code
https://github.com/Sarasra/models/tree/master/research/capsules

* paper
https://arxiv.org/abs/1710.09829



In [26]:
import chainer
from chainer import initializers
from chainer import functions as F, links as L

In [27]:
import cupy
import numpy as np
xp = cupy

In [28]:
class CapsNet(chainer.Chain):
    def __init__(self, image_channel=1, n_out=10):
        super(CapsNet, self).__init__()
        with self.init_scope():
            # inputs are conv features
            self.conv1 = L.Convolution2D(image_channel, out_channels=256, ksize=9, stride=1)
            # to create primaryCaps
            self.capsconv = CapsConv(out_channels=32, dim_vec=8, ksize=9, stride=1)
            # to throw to CapsNet
            self.caps = CapsuleLayer(in_channels=32, input_dim_vec=8, out_channels=10, dim_vec=16)
            
    def __call__(self, inputs):
        h = self.conv1(inputs)
        h = self.capsconv(h)
        h = self.caps(h)
        return xp.linalg.norm(h, axis=1)
    

In [29]:
 def _squash(inputs):
        norm = xp.linalg.norm(inputs, axis=2, keepdims=True)
        norm_squred = norm * norm
        return (inputs / norm) * (norm_squred / (1 + norm_squred))

In [30]:
def _update_routing(votes, bias, logit_shape, num_dims, input_dim, num_routing=3, leaky=False):
    votes_t_shape = [3, 0, 1, 2]
    for i in range(num_dims -4):
        votes_t_shape += [i + 4]
    r_t_shape = [1, 2, 3, 0]
    for i in range(num_dims -4):
        r_t_shape += [i + 4]
    votes_trans = F.transpose(votes, votes_t_shape)
    acts = []
    def _body(logits):
        if leaky:
            # TODO impriments leaky method
            route = F.softmax(logits, axis=2)
            pass
        else:
            route = F.softmax(logits, axis=2)
        pre_act_t = F.transpose(route * votes_trans, r_t_shape)
        pre_act = F.sum(pre_act_t, axis=1) + bias
        act = _squash(pre_act)
        
        acts.append(act)
        
        act_3d = F.expand_dims(act, 1)
        tile_shape = np.ones(num_dims, dtype=np.int32).tolist()
        tile_shape[1] = input_dim
        act_rep = F.tile(act_3d, tile_shape)
        distances = F.sum(votes * act_rep, axis=3)
        logits += distances
        return logits
    logits = xp.zeros(logit_shape, dtype=np.float32)
    
    for i in range(num_routing):
        logits = _body(logits)
    
    return acts[num_routing -1]

In [31]:
class CapsConv(chainer.Link):
    def __init__(self, in_channels=1, input_dim_vec=256, out_channels=32, dim_vec=8, ksize=9, stride=1):
        super(CapsConv, self).__init__()
        self.in_channels = in_channels
        self.input_dim_vec = input_dim_vec
        self.out_channels = out_channels
        self.dim_vec = dim_vec
        self.bias = chainer.Parameter(initializer=initializers.Zero())

        with self.init_scope():
            # TODO: separately convolution to each channel.
            self.conv = L.Convolution2D(in_channels=in_channels*input_dim_vec, out_channels=out_channels*dim_vec, ksize=ksize, stride=stride)
            
    def __call__(self, inputs):
        self.bias.initialize((self.out_channels, self.dim_vec, 1, 1))
        h = self.conv(inputs)
        h_shape = h.data.shape
        # output shape are (batch_size, number of capsule, dimension of vector, width, height)
        reshaped_h = h.reshape((-1, self.in_channels, self.out_channel, self.dim_vec, h_shape[2], h_shape[3]))
        
        logit_shape = (inputs.shape[0], self.in_channels, self.out_channels, self.input_dim_vec, self.dim_vec)
        bias_rep = F.tile(self.bias, (1, 1, reshaped_h.shape[2], reshaped_h[3]))
        
        activation = _update_routing(votes=reshaped_h,
                                     bias=bias_rep,
                                     logit_shape=logit_shape,
                                     num_dims=6,
                                     input_dim=self.in_channels,
                                     output_dim=self.out_channels,
                                     num_routing=3
                                    )
        return  activation

In [32]:
class CapuleLayer(chainer.Link):
    def __init__(self, in_channels, input_dim_vec, out_channels , dim_vec, num_routing=3):
        super(CapuleLayer, self).__init__()
        self.in_channels = in_channels
        self.input_dim_vec = input_dim_vec
        self.out_channels = out_channels
        self.dim_vec = dim_vec
        self.num_routing = num_routing

        with self.init_scope():
            self.W = chainer.Parameter(initializer=initializers.GlorotNorma(),
                                       shape=None
                                      )
            self.bias = chainer.Parameter(initializer=initializers.Zero(),
                                          shape=None
                                         )
    
    def __call__(self, inputs):
        capsule_atom_last = F.transpose(inputs, (0, 1, 3, 4, 2))
        capsule_3d = F.reshape(capsule_atom_last, (self.input_shape[0], -1, self.input_dim_vec))

        self.input_shape = inputs.shape
        self.W.initialize(shape=(self.in_channels*self.input_shape[-2]*self.input_shape[-1],
                                 self.input_dim_vec,
                                 self.out_channels*self.dim_vec)
                         )
        self.bias.initialize(shape=(self.dim_vec, self.out_channels))
        
        input_tiled = F.tile(F.expand_dims(capsule_3d, -1), (1, 1, 1, self.out_channels*self.dim_vec))
        votes = F.sum(input_tiled * self.W, axis=2)
        votes_reshaped = F.reshape(votes,
                                   (-1, self.in_channels*self.input_shape[-2]*self.input_shape[-1], self.dim_vec, self.out_channels))
        # routing algorithm
        logit_shape = (self.capsule_3d.shape[0], self.in_channels*self.input_shape[-2]*self.input_shape[-1], self.out_channels)
        activations = _update_routing(votes=votes_reshaped,
                                      bias=self.bias,
                                      logit_shape=logit_shape,
                                      ndim_dims=4,
                                      input_dim=self.in_channels*self.input_shape[-2]*self.input_shape[-1],
                                      output_dim=output_dim,
                                      num_routing=self.num_routing
                                     )
        return activations
