In [12]:
import torch
import torch.nn as nn
import numpy as np

In [4]:
np.random.seed(0)

# Self Attention Routing

In [5]:
n_l = 3
n_h = 2
d_l = 4
d_h = 5
b = 1

In [6]:
W_np = np.random.random((n_l, n_h, d_l, d_h))
B_np = np.random.random((n_l, n_h))
U_l_np = np.random.random((b, n_l, d_l))

In [7]:
print(W_np)
print(B_np)
print(U_l_np)

[[[[0.5488135  0.71518937 0.60276338 0.54488318 0.4236548 ]
   [0.64589411 0.43758721 0.891773   0.96366276 0.38344152]
   [0.79172504 0.52889492 0.56804456 0.92559664 0.07103606]
   [0.0871293  0.0202184  0.83261985 0.77815675 0.87001215]]

  [[0.97861834 0.79915856 0.46147936 0.78052918 0.11827443]
   [0.63992102 0.14335329 0.94466892 0.52184832 0.41466194]
   [0.26455561 0.77423369 0.45615033 0.56843395 0.0187898 ]
   [0.6176355  0.61209572 0.616934   0.94374808 0.6818203 ]]]


 [[[0.3595079  0.43703195 0.6976312  0.06022547 0.66676672]
   [0.67063787 0.21038256 0.1289263  0.31542835 0.36371077]
   [0.57019677 0.43860151 0.98837384 0.10204481 0.20887676]
   [0.16130952 0.65310833 0.2532916  0.46631077 0.24442559]]

  [[0.15896958 0.11037514 0.65632959 0.13818295 0.19658236]
   [0.36872517 0.82099323 0.09710128 0.83794491 0.09609841]
   [0.97645947 0.4686512  0.97676109 0.60484552 0.73926358]
   [0.03918779 0.28280696 0.12019656 0.2961402  0.11872772]]]


 [[[0.31798318 0.41426299 0.

#### Pytorch

In [8]:
W = torch.from_numpy(W_np)
B = torch.from_numpy(B_np)
U_l = torch.from_numpy(U_l_np)

In [9]:
"""
einsum convenventions:
  n_l = i | h
  d_l = j
  n_h = k
  d_h = l
"""
U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, W)

# A (n_l, n_l, n_h)
A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
A = A / torch.sqrt(torch.Tensor([d_l]))
A_sum = torch.einsum("...hij->...hj",A)
C = torch.softmax(A_sum,dim=-1)
CB = C + B
U_h = torch.einsum('...ikl,...ik->...kl', U_hat, CB)

In [43]:
class FCCaps(nn.Module):
    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h
        #
        self.W = torch.nn.Parameter(torch.rand(n_l, n_h, d_l, d_h))
        self.B = torch.nn.Parameter(torch.rand(n_l, n_h))
    def forward(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l
        
        Data tensors:
            U_l (n_l, d_l)
            U_h (n_h, d_h)
            W   (n_l, n_h, d_l, d_h)
            B   (n_l, n_h)
            A   (n_l, n_l, n_h)
            C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, W)
        
        # A (n_l, n_l, n_h)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        
        A = A / torch.sqrt(torch.Tensor([d_l]))
        A_sum = torch.einsum("...hij->...hj",A)
        C = torch.softmax(A_sum,dim=-1)
        CB = C + B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, CB)
        return U_h


In [39]:
model = FCCaps(n_l, n_h, d_l, d_h)

In [40]:
U_h = model(U_l)

In [41]:
U_h.shape

torch.Size([1, 2, 5])

In [None]:
CB

In [None]:
C.shape

In [None]:
B.shape

#### Tensorflow

In [None]:
"""
 code from paper
 should give same results ;)
"""
W = tf.convert_to_tensor(W_np)
B = tf.convert_to_tensor(B_np)
U_l = tf.convert_to_tensor(U_l_np)
#
# (n_l, n_h, d_l, d_h) - > (n_h, n_l, d_l, d_h)
W = tf.transpose(W, (1,0,2,3))

# (n_l, n_h) -> (n_h, n_l)
B = tf.transpose(B, (1, 0))
B = tf.expand_dims(B, axis=-1)
#
u = tf.einsum('...ji,kjiz->...kjz',U_l,W)
c = tf.einsum('...ij,...kj->...i', u, u)[...,None]
c = c/tf.sqrt(tf.cast(d_l, tf.float64))
c = tf.nn.softmax(c, axis=1)
cb = c + B
s = tf.reduce_sum(tf.multiply(u, cb),axis=-2) 

# Squashing