In [None]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
np.random.seed(0)

# Self Attention Routing

In [None]:
n_l = 3
n_h = 2
d_l = 4
d_h = 5
b = 1

In [None]:
W_np = np.random.random((n_l, n_h, d_l, d_h))
B_np = np.random.random((n_l, n_h))
U_l_np = np.random.random((b, n_l, d_l))

In [None]:
print(W_np)
print(B_np)
print(U_l_np)

#### Pytorch

In [None]:
W = torch.from_numpy(W_np)
B = torch.from_numpy(B_np)
U_l = torch.from_numpy(U_l_np)

In [None]:
"""
einsum convenventions:
  n_l = i | h
  d_l = j
  n_h = k
  d_h = l
"""
U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, W)

# A (n_l, n_l, n_h)
A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
A = A / torch.sqrt(torch.Tensor([d_l]))
A_sum = torch.einsum("...hij->...hj",A)
C = torch.softmax(A_sum,dim=-1)
CB = C + B
U_h = torch.einsum('...ikl,...ik->...kl', U_hat, CB)

In [None]:
class FCCaps(nn.Module):
    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h
        #
        self.W = torch.nn.Parameter(torch.rand(n_l, n_h, d_l, d_h))
        self.B = torch.nn.Parameter(torch.rand(n_l, n_h))
    def forward(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l
        
        Data tensors:
            U_l (n_l, d_l)
            U_h (n_h, d_h)
            W   (n_l, n_h, d_l, d_h)
            B   (n_l, n_h)
            A   (n_l, n_l, n_h)
            C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, W)
        
        # A (n_l, n_l, n_h)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        
        A = A / torch.sqrt(torch.Tensor([d_l]))
        A_sum = torch.einsum("...hij->...hj",A)
        C = torch.softmax(A_sum,dim=-1)
        CB = C + B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, CB)
        return U_h


In [None]:
model = FCCaps(n_l, n_h, d_l, d_h)

In [None]:
U_h = model(U_l)

In [None]:
U_h.shape

In [None]:
CB

In [None]:
C.shape

In [None]:
B.shape

#### Tensorflow

In [None]:
"""
 code from paper
 should give same results ;)
"""
W = tf.convert_to_tensor(W_np)
B = tf.convert_to_tensor(B_np)
U_l = tf.convert_to_tensor(U_l_np)
#
# (n_l, n_h, d_l, d_h) - > (n_h, n_l, d_l, d_h)
W = tf.transpose(W, (1,0,2,3))

# (n_l, n_h) -> (n_h, n_l)
B = tf.transpose(B, (1, 0))
B = tf.expand_dims(B, axis=-1)
#
u = tf.einsum('...ji,kjiz->...kjz',U_l,W)
c = tf.einsum('...ij,...kj->...i', u, u)[...,None]
c = c/tf.sqrt(tf.cast(d_l, tf.float64))
c = tf.nn.softmax(c, axis=1)
cb = c + B
s = tf.reduce_sum(tf.multiply(u, cb),axis=-2) 

# Squashing