### forward

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

#### funcs

In [None]:
def squash_func(x,eps=10e-21):
    """
    Input:
        x(b,n,d)
    Output:
        y = squash(x(b,n,d))
    """
    
    x_norm = torch.norm(x, dim=-1, keepdim=True)
    y = (1 - 1/torch.exp(x_norm) + eps) * (x/x_norm + eps)
    
    return y


def margin_loss(u, y_true, lbd=0.5, m_plus=0.9, m_minus=0.1):
    """
    IN:
        u      (b,n,d)  ... capsules with n equals the numbe of classes
        y_true (b,n)    .... labels vector, categorical representation
    OUT:
        loss, scalar  
    """
    u_norm = torch.norm(u, dim=2)
    term_left  = F.relu(m_plus - u_norm)
    term_right = F.relu(u_norm - m_minus)
    #
    loss = y_true * term_left + lbd * (1.0 - y_true) * term_right
    loss = loss.sum(dim=1).mean()
    return loss


def margin_loss2(u, y_true, lbd=0.5, m_plus=0.9, m_minus=0.1):
    """
    Input:  u      (b,n,d)  ... capsules with n equals the numbe of classes
            y_true (b,n)    ... labels vector, categorical representation
    Output:
        loss, scalar  
    """
    
    u_norm = torch.norm(u, dim=-1)
    p_true = torch.square(F.relu(m_plus - u_norm))     #square is the difference to margin_loss!
    p_false = torch.square(F.relu(u_norm - m_minus))

    loss = y_true * p_true + lbd * (1-y_true) * p_false
    loss = loss.sum(dim=1).mean()
    
    return loss


def max_norm_masking(u):
    """
    IN:
        u (b, n d) ... capsules
    OUT:
        masked(u)  (b, n, d) where:
        - normalise over dimension d of u
        - keep largest vector in dimension n
        - mask out everything else
    """
    _, n_classes, _ = u.shape
    u_norm = torch.norm(u,dim=2)
    #mask = F.one_hot(torch.argmax(u_norm,1), num_classes=n_classes)
    mask = torch.nn.functional.one_hot(torch.argmax(u_norm,1), num_classes=n_classes)
    return torch.einsum('bnd,bn->bnd',u, mask)



In [None]:
class Squash(nn.Module):
    def __init__(self, eps=10e-21):
        super().__init__()
        self.eps = eps
    def forward(self, x):
        """
        Input:  x(b,n,d)
        Output: squash(x(b,n,d))
        """       
        return squash_func(x, self.eps)
    

class View(nn.Module):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)


class PrimaryCaps(nn.Module):
    """
        Attributes
        ----------
        F: int depthwise conv number of features
        K: int depthwise conv kernel dimension 
        N: int number of primary capsules
        D: int primary capsules dimension (number of properties)
        s: int depthwise conv strides
    """
    def __init__(self, F, K, N, D, s=1):
        super().__init__()
        self.F = F
        self.K = K
        self.N = N
        self.D = D
        self.s = s
        #
        self.dw_conv2d = nn.Conv2d(F, F, kernel_size=K, stride=s, groups=F, padding="valid")
        #
    def forward(self, x):
        """
        IN:  (B,C,H,W)
        OUT: (B, N, D)
        
        therefore for x, we have the following constraints:
            (B,C,H,W) = (B, F,F,K)
        """
        # (B,C,H,W) -> (B,C,H,W)
        x = self.dw_conv2d(x)

        # (B,C,H,W) -> (B, N, D)
        x = x.view((-1, self.N, self.D))
        return x

    
class PrimaryCaps2(nn.Module):
    """
        Create a primary capsule layer with the methodology described in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'. 
        Properties of each capsule s_n are exatracted using a 2D depthwise convolution.

        ...

        Attributes
        ----------
        F: int depthwise conv number of features
        K: int depthwise conv kernel dimension 
        N: int number of primary capsules
        D: int primary capsules dimension (number of properties)
        s: int depthwise conv strides
    """
    def __init__(self, F, K, N, D, s=1):
        super().__init__()
        self.F = F
        self.K = K
        self.N = N
        self.D = D
        self.s = s
        #
        self.dw_conv2d = nn.Conv2d(F, F, kernel_size=K, stride=s, groups=F, padding="valid")
        self.squash = Squash()
        #
    def forward(self, x):
        """
         X in (B,C,H,W) = (B,F,K,K)
         -> (B, N, D)
        """
        # (B,C,H,W) -> (B,C,H,W)
        x = self.dw_conv2d(x)
        # (B,C,H,W) -> (B, N, D)
        x = x.view((-1, self.N, self.D))
        x = self.squash(x)                                     #Difference to PrimaryCaps
        #
        return x    
    

class FCCaps(nn.Module):
    """
        Attributes
        ----------
        n_l ... number of lower layer capsules
        d_l ... dimension of lower layer capsules
        n_h ... number of higher layer capsules
        d_h ... dimension of higher layer capsules

        W   (n_l, n_h, d_l, d_h) ... weight tensor
        B   (n_l, n_h)           ... bias tensor
    """
    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h

        self.W = torch.nn.Parameter(torch.rand(n_l, n_h, d_l, d_h), requires_grad=True)
        self.B = torch.nn.Parameter(torch.rand(n_l, n_h), requires_grad=True)
        self.squash = Squash(eps=1e-20)

        # init custom weights
        # i'm relly unsure about this initialization scheme
        # i don't think it makes sense in our case, but the paper says so ...
        torch.nn.init.kaiming_normal_(self.W, a=0, mode='fan_in', nonlinearity='leaky_relu')
        torch.nn.init.kaiming_normal_(self.B, a=0, mode="fan_in", nonlinearity="leaky_relu")


        self.attention_scaling = np.sqrt(self.d_l)

    def forward(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l
        
        Data tensors:
            IN:  U_l ... lower layer capsules
            OUT: U_h ... higher layer capsules
            DIMS: 
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        #A = A / torch.sqrt(torch.Tensor([self.d_l]))
        A = A / self.attention_scaling
        A_sum = torch.einsum("...hij->...hj",A)
        C = torch.softmax(A_sum,dim=-1)
        CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, CB)
        return self.squash(U_h)
    
    
class FCCaps2(nn.Module):
    """
    Fully-connected caps layer. It exploites the routing mechanism, explained in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing', 
    to create a parent layer of capsules. 
    
    nl: number of input capsuls          (nl...i)(nl...h)
    dl: dimension of input capsuls       (dl...j)
    nh: number of output capsuls         (nh...k)
    dh: dimension of output capsuls      (dh...l)
    b: batch size                        (...)

    W: weigth tensor                     (W->ikjl)
    B: bias matrix                       (B->1ik)
    U_l: input capsuls matrix            (U->...ij)    
    U_hat: weigthed input capsuls matrix (U_hat->...ikl)
    A: covariance tensor                 (A->...hik)
    C: couplimg coefficients             (C->...ik)
    
    input: nl, dl, nh, dh
    
    """
    
    def __init__(self, nl, nh,dl, dh):
        super().__init__()
        self.nl = nl
        self.dl = dl
        self.nh = nh
        self.dh = dh

        self.W = torch.nn.Parameter(torch.rand([self.nl,self.nh,self.dl,self.dh]), requires_grad=True)
        self.B = torch.nn.Parameter(torch.rand([1,self.nl,self.nh]), requires_grad=True)                         #Difference in Dimension definition, but shouldnot be a problem
        self.squash = Squash()                                                                                   #eps in function predefind

        
            # init custom weights -> not implemented
        
        
    def forward(self, U_l):
        """
        Data tensors:
            Input:  U_l ... lower layer capsules
            Ouput: U_h ... higher layer capsules
        """
        U_hat = torch.einsum("...ij,ikjl->...ikl",U_l,self.W)
        A = torch.einsum("...hkl,...ikl->...hik",U_hat, U_hat)
        A = A / torch.sqrt(torch.Tensor([self.dl]))
        A_hat = torch.einsum("...hik->...ik",A)
        C = torch.softmax(A_hat,dim=-1)
        CB = C+self.B
        U_h = torch.einsum("...ikl,...ik->...kl",U_hat,CB)
        U_h = self.squash(U_h)
        return U_h

In [None]:

x = torch.rand((1,128,9,9))
print(x.size())
S = PrimaryCaps(F=128, K=9, N=16, D=8)(x)
print(S.size())

F = FCCaps2(16,10,8,16)(S)

print(F.size())

#### mnist

In [None]:
class MnistEcnBackbone(nn.Module):
    """
        Backbone model from Efficient-CapsNet for MNIST
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=2, padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
        )
    def forward(self, x):
        """
            IN:
                x (b, 1, 28, 28)
            OUT:
                x (b, 128, 9, 9)
        """
        return self.layers(x)

In [None]:
x = torch.rand((1,1,28,28))
print(x.size())
A = MnistEcnBackbone()(x)
print(A.size())

In [None]:
class MnistEcnDecoder(nn.Module):
    """
        Decoder model from Efficient-CapsNet for MNIST
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(16*10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 28*28),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """
            IN:
                x (b, n, d) with n=10 and d=16
            OUT:
                x_rec (b, 1, 28, 28)
            Notes:
                input must be masked!
        """
        x = self.layers(x)
        x = x.view(-1, 1, 28, 28)
        return x

In [None]:
x = torch.rand((1,160))
print(x.size())
A = MnistEcnDecoder()(x)
print(A.size())

In [None]:
class MnistEffCapsNet(nn.Module):
    """
        EffCaps Implementation for MNIST
        all parameters taken from the paper
    """
    def __init__(self):
        super().__init__()
        # values from paper, are fixed!
        self.n_l = 16  # num of primary capsules
        self.d_l = 8   # dim of primary capsules
        self.n_h = 10  # num of output capsules
        self.d_h = 16  # dim of output capsules
        
        self.backbone = MnistEcnBackbone()
        self.primcaps = PrimaryCaps(F=128, K=9, N=self.n_l, D=self.d_l) # F = n_l * d_l !!!
        self.fcncaps = FCCaps(self.n_l, self.n_h, self.d_l, self.d_h)
        self.decoder = MnistEcnDecoder()

    def forward(self, x):
        """
            IN:
                x (b, 1, 28, 28)                       
            OUT:
                u_h    
                    (b, n_h, d_h)
                    output caps
                x_rec  
                    (b, 1, 28, 28)
                    reconstruction of x
        """
        u_l = self.primcaps(self.backbone(x))
        u_h = self.fcncaps(u_l)
        #
        u_h_masked = max_norm_masking(u_h)
        u_h_masked = torch.flatten(u_h_masked, start_dim=1)
        x_rec = self.decoder(u_h_masked)
        return u_h, x_rec

In [None]:
x = torch.rand((1,1,28,28))
print(x.size())
u_h, x_rec = MnistEffCapsNet()(x)
print(u_h.size())
print(x_rec.size())

#### multimnist

In [None]:
class MultiMnistEcnBackbone(nn.Module):
    """
        Backbone model from Efficient-CapsNet for MNIST
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=2, padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=2, padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
        )
    def forward(self, x):
        """
            IN:
                x (b, 1, 36, 36)
            OUT:
                x (b, 128, 6, 6)
        """
        return self.layers(x)

In [None]:
x = torch.rand((10,1,36,36))
print(x.size())
A = MultiMnistEcnBackbone()(x)
print(A.size())

In [None]:
class MultiMnistEcnDecoder(nn.Module):
    """
        Decoder model from Efficient-CapsNet for MNIST
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(16*10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 36*36),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """
            IN:
                x (b, n, d) with n=10 and d=16
            OUT:
                x_rec (b, 1, 36, 36)
            Notes:
                input must be masked!
        """
        x = self.layers(x)
        x = x.view(-1, 1, 36, 36)
        return x


In [None]:
x = torch.rand((1,160))
print(x.size())
A = MultiMnistEcnDecoder()(x)
print(A.size())

In [None]:
class MultiMnistEffCapsNet(nn.Module):
    """
        EffCaps Implementation for MNIST
        all parameters taken from the paper
    """
    def __init__(self):
        super().__init__()
        # values from paper, are fixed!
        self.n_l = 16  # num of primary capsules
        self.d_l = 8   # dim of primary capsules
        self.n_h = 10  # num of output capsules
        self.d_h = 16  # dim of output capsules
        
        self.backbone = MultiMnistEcnBackbone()
        self.primcaps = PrimaryCaps(F=128, K=5, N=self.n_l, D=self.d_l, s=2) # F = n_l * d_l !!!
        self.fcncaps = FCCaps2(self.n_l,  self.n_h, self.d_l, self.d_h)
        self.decoder = MultiMnistEcnDecoder()

    def forward(self, x):
        """
            IN:
                x (b, 1, 36, 36)
            OUT:
                u_h    
                    (b, n_h, d_h)
                    output caps
                x_rec  
                    (b, 1, 36, 36)
                    reconstruction of x
        """
        u_l = self.backbone(x)
        u_l = self.primcaps(u_l)
        u_h = self.fcncaps(u_l)
        #
        u_h_masked = max_norm_masking(u_h)
        u_h_masked = torch.flatten(u_h_masked, start_dim=1)
        x_rec = self.decoder(u_h_masked)
        return u_h, x_rec

In [None]:

x = torch.nn.Parameter(torch.rand((1,1,36,36)), requires_grad=True)
print(x.size())
u_h, x_rec = MultiMnistEffCapsNet()(x)
print(x_rec.size())
loss_rec = torch.nn.functional.mse_loss(x, x_rec)
print(loss_rec)
#print(A.size())

print(x_rec)

In [None]:
# x = F.one_hot(torch.arange(0, 5) % 3)
import torch.nn.functional as F    
x = torch.nn.functional.one_hot(torch.arange(0, 5) % 3)
y = F.one_hot(torch.arange(0, 5) % 3)
print(x,"\n", y)

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

In [None]:
y_true1 = tf.keras.layers.Input(shape=(10,))

print(y_true1)

In [None]:
class Mask(tf.keras.layers.Layer):
    """
    Mask operation described in 'Dynamic routinig between capsules'.
    
    ...
    
    Methods
    -------
    call(inputs, double_mask)
        mask a capsule layer
        set double_mask for multimnist dataset
    """
    def call(self, inputs, double_mask=None, **kwargs):
        if type(inputs) is list:
            if double_mask:
                inputs, mask1, mask2 = inputs
            else:
                inputs, mask = inputs
        else:  
            x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
            if double_mask:
                mask1 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,0],num_classes=x.get_shape().as_list()[1])
                mask2 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,1],num_classes=x.get_shape().as_list()[1])
            else:
                mask = tf.keras.backend.one_hot(indices=tf.argmax(x, 1), num_classes=x.get_shape().as_list()[1])

        if double_mask:
            masked1 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask1, -1))
            masked2 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask2, -1))
            return masked1, masked2
        else:
            masked = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask, -1))
            return masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # generation step
            return tuple([None, input_shape[1] * input_shape[2]])

    def get_config(self):
        config = super(Mask, self).get_config()
        return config
    

In [None]:
inputs = tf.keras.Input((1,10,16))
y_true1 = tf.keras.layers.Input(shape=(10,))
y_true2 = tf.keras.layers.Input(shape=(10,))

masked_by_y1,masked_by_y2 = Mask()([inputs, y_true1, y_true2],double_mask=True)  
masked1,masked2 = Mask()(inputs,double_mask=True)

print(masked_by_y1)

### SmalNorb

In [None]:
class SmalNorbEcnBackbone(nn.Module):
    """
        Backbone model from Efficient-CapsNet for SmalNorb
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=(7, 7),stride=2, padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.InstanceNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.InstanceNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.InstanceNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=2, padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.InstanceNorm2d(128),
        )
    def forward(self, x):
        """
            IN:
                x (b, 2, 48, 48)
            OUT:
                x (b, 128, 8, 8)
        """
        return self.layers(x)

In [None]:
#x = torch.rand((1,2,48,48))
x = torch.nn.Parameter(torch.rand((1,2,48,48)), requires_grad=True)
print(x.size())
A = SmalNorbEcnBackbone()(x)
print(A.size())

In [None]:
    
    x = np.random.rand(1,48,48,2)

    print(x.shape)

    x = tf.keras.layers.Conv2D(32,7,2,activation=None, padding='valid', kernel_initializer='he_normal')(x)

    x = tf.keras.layers.Conv2D(64,3, activation=None, padding='valid', kernel_initializer='he_normal')(x)

    x = tf.keras.layers.Conv2D(64,3, activation=None, padding='valid', kernel_initializer='he_normal')(x) 

    x = tf.keras.layers.Conv2D(128,3,2, activation=None, padding='valid', kernel_initializer='he_normal')(x)   

    print(x.shape)


In [None]:
class SmalNorbEcnDecoder(nn.Module):
    """
        Decoder model from Efficient-CapsNet for MNIST
    """
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(16*5, 64),
            #nn.ReLU(inplace=True),
            #nn.Linear(512, 1024),
            #nn.ReLU(inplace=True),
            #nn.Linear(1024, 36*36),
            #nn.Sigmoid()
        )
        self.layer2 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(1, 64, kernel_size=(3, 3), padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding="valid"),
            nn.LeakyReLU(0.3,inplace=True),
            nn.Conv2d(128, 2, kernel_size=(3, 3), padding="valid"),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """
            IN:
                x (b, n, d) with n=10 and d=16
            OUT:
                x_rec (b, 1, 36, 36)
            Notes:
                input must be masked!
        """
        x = self.layer1(x)
        x = x.view(-1, 1, 8, 8)
        x = self.layer2(x)
        return x


In [None]:
x = torch.rand((1,80))
print(x.size())
A = SmalNorbEcnDecoder()(x)
print(A.size())

In [None]:
input_shape = (2, 2, 2, 3)
x = np.arange(np.prod(input_shape)).reshape(input_shape)
print(x)
y = tf.keras.layers.UpSampling2D(size=(1, 2))(x)
print(y)


In [None]:
    inputs = tf.keras.Input(16*5)

    x = tf.keras.layers.Dense(64)(inputs)
    x = tf.keras.layers.Reshape(target_shape=(8,8,1))(x)
    x = tf.keras.layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), padding="valid", activation=tf.nn.leaky_relu)(x)
    x = tf.keras.layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
    x = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), padding="valid", activation=tf.nn.leaky_relu)(x)
    x = tf.keras.layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
    x = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), padding="valid", activation=tf.nn.leaky_relu)(x)
    x = tf.keras.layers.Conv2D(filters=2, kernel_size=(3,3), padding="valid", activation=tf.nn.sigmoid)(x)  
    
    print(x)

In [None]:
class SmalNorbEffCapsNet(nn.Module):
    """
        EffCaps Implementation for SmalNorb
        all parameters taken from the paper
    """
    def __init__(self):
        super().__init__()
        # values from paper, are fixed!
        self.n_l = 16  # num of primary capsules
        self.d_l = 8   # dim of primary capsules
        self.n_h = 5   # num of output capsules
        self.d_h = 16  # dim of output capsules
        
        self.backbone = SmalNorbEcnBackbone()
        self.primcaps = PrimaryCaps(F=128, K=8, N=self.n_l, D=self.d_l, s=2) # F = n_l * d_l !!!
        self.fcncaps = FCCaps2(self.n_l, self.n_h ,self.d_l, self.d_h)
        self.decoder = SmalNorbEcnDecoder()

    def forward(self, x):
        """
            IN:
                x (b, 2, 48, 48)
            OUT:
                u_h    
                    (b, n_h, d_h)
                    output caps
                x_rec  
                    (b, 2, 48, 48)
                    reconstruction of x
        """
        u_l = self.backbone(x)
        u_l = self.primcaps(u_l)
        u_h = self.fcncaps(u_l)
        #
        u_h_masked = max_norm_masking(u_h)
        u_h_masked = torch.flatten(u_h_masked, start_dim=1)
        x_rec = self.decoder(u_h_masked)
        return u_h, x_rec

In [None]:
x = torch.nn.Parameter(torch.rand((1,2,48,48)), requires_grad=True)
print(x.size())
#SmalNorbEffCapsNet()(x)
u_h, x_rec = SmalNorbEffCapsNet()(x)
print(x_rec.size())
print(x_rec[0,1,:,:])
loss_rec = torch.nn.functional.mse_loss(x, x_rec)
print(loss_rec)