In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T
import matplotlib.pyplot as plt
#
import torch.nn as nn
from torch import optim
#
import numpy as np

In [None]:
ds_train = datasets.MNIST(root='./data', train=True, download=True, transform=T.ToTensor())
ds_valid = datasets.MNIST(root="./data", train=False, download=True, transform=T.ToTensor())

In [None]:
plt.imshow(ds_train.data[0], cmap='gray')
plt.title('%i' % ds_train.targets[0])
plt.show()

In [None]:
dl_train = torch.utils.data.DataLoader(ds_train, 
                                          batch_size=256, 
                                          shuffle=True, 
                                          num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, 
                                          batch_size=16, 
                                          shuffle=True, 
                                          num_workers=4)

# Baseline CNN for MNIST

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output    # return x for visualization

In [None]:
model = CNN()

In [None]:
loss_func = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr = 0.01) 

In [None]:
num_epochs = 1
model.train()
for epoch in range(num_epochs):
    for idx, (x, y_true) in enumerate(dl_train):
        y_pred = model(x)
        loss = loss_func(y_pred, y_true)

        # clear gradients for this training step   
        optimizer.zero_grad()           
            
        # backpropagation, compute gradients 
        loss.backward()    
        # apply gradients             
        optimizer.step()
        
        if idx % 1000 == 0:
            print("Epoch[{}/{}] - step {} loss: {:.4f}".format(epoch, num_epochs, idx, loss.item()))

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for x, y_true in dl_valid:
        y_pred = model(x)
        y_pred = torch.max(y_pred, 1)[1]
        correct += (y_pred == y_true).sum().item()
        total += y_true.shape[0]
    acc = correct / total

In [None]:
print(acc)
print(total - correct)

# Eff-Caps

In [None]:
class PrimaryCaps(nn.Module):
    """
        Create a primary capsule layer with the methodology described in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'. 
        Properties of each capsule s_n are exatracted using a 2D depthwise convolution.

        ...

        Attributes
        ----------
        F: int depthwise conv number of features
        K: int depthwise conv kernel dimension 
        N: int number of primary capsules
        D: int primary capsules dimension (number of properties)
        s: int depthwise conv strides
    """
    def __init__(self, F, K, N, D, s=1):
        super().__init__()
        self.F = F
        self.K = K
        self.N = N
        self.D = D
        self.s = s
        #
        self.dw_conv2d = nn.Conv2d(F, F, kernel_size=K, stride=s, groups=F, padding="valid")
        #
    def forward(self, x):
        """
         X in (B,C,H,W) = (B,F,K,K)
         -> (B, N, D)
        """
        # (B,C,H,W) -> (B,C,H,W)
        x = self.dw_conv2d(x)

        # (B,C,H,W) -> (B, N, D)
        x = x.view((-1, self.N, self.D))
        
        #
        return x

class Squash(nn.Module):
    def __init__(self, eps=10e-21):
        super().__init__()
        self.eps = eps
    def forward(self, x):
        """
         in:  (b, n, d)
         out: (b, n, d)
        """
        xn = torch.norm(x, dim=2, keepdim=True)
        return (1 - 1/(torch.exp(xn) + self.eps)) * (x / (xn + self.eps))

class FCCaps(nn.Module):
    def __init__(self, n_l, n_h, d_l, d_h):
        super().__init__()
        self.n_l = n_l
        self.d_l = d_l
        self.n_h = n_h
        self.d_h = d_h
        #
        self.W = torch.nn.Parameter(torch.rand(n_l, n_h, d_l, d_h), requires_grad=True)
        self.B = torch.nn.Parameter(torch.rand(n_l, n_h), requires_grad=True)
        self.squash = Squash()

    def forward(self, U_l):
        """
        einsum convenventions:
          n_l = i | h
          d_l = j
          n_h = k
          d_h = l
        
        Data tensors:
            IN:  U_l
            OUT: U_h
            DIMS: 
                U_l (n_l, d_l)
                U_h (n_h, d_h)
                W   (n_l, n_h, d_l, d_h)
                B   (n_l, n_h)
                A   (n_l, n_l, n_h)
                C   (n_l, n_h)
        """
        U_hat = torch.einsum('...ij,ikjl->...ikl', U_l, self.W)
        A = torch.einsum("...ikl, ...hkl -> ...hik", U_hat, U_hat)
        A = A / torch.sqrt(torch.Tensor([d_l]))
        A_sum = torch.einsum("...hij->...hj",A)
        C = torch.softmax(A_sum,dim=-1)
        CB = C + self.B
        U_h = torch.einsum('...ikl,...ik->...kl', U_hat, CB)
        return self.squash(U_h)


class EfficientCapsNets(nn.Module):
    def __init__(self):
        super().__init__()
    def call(self, x):
        pass

class View(nn.Module):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [None]:
# CNN Part
# add he normal initializer
cn = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid"),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(32),
    nn.Conv2d(32, 64, kernel_size=(3, 3), padding="valid"),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(64),
    nn.Conv2d(64, 64, kernel_size=(3, 3), padding="valid"),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(64),
    nn.Conv2d(64, 128, kernel_size=(3, 3), stride=2, padding="valid"),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(128),
)
x_h = cn(x)
print(x_h.shape)

In [None]:
n_l = 16
d_l = 8
#
n_h = 10
d_h = 16
#
pc = PrimaryCaps(F=128, K=9, N=n_l, D=d_l)
#
U_0 = pc(x_h)
print(U_0.shape)

In [None]:
fcn = FCCaps(n_l, n_h, d_l, d_h)
#
U_1 = fcn(U_0)

In [None]:
U_flat = torch.flatten(U_1, start_dim=1)

In [None]:
generator = nn.Sequential(
    nn.Linear(16*10, 512),
    nn.ReLU(inplace=True),
    nn.Linear(512, 1024),
    nn.ReLU(inplace=True),
    nn.Linear(1024, 28*28),
    nn.Sigmoid(),
)

In [None]:
x_rec = generator(U_flat)
x_rec = x_rec.view((-1, 1, 28, 28))

In [None]:
x_view = x_rec.reshape((-1, 28, 28))

In [None]:
x_view.shape

In [None]:
plt.imshow(x_view[5].detach().numpy())

In [None]:
plt.imshow(x[0][0])

In [None]:
x.shape, x_rec.shape

In [None]:
torch.nn.functional.mse_loss(x, x_rec)

# TensorFlow Stuff

In [None]:
class Length(tf.keras.layers.Layer):
    """
    Compute the length of each capsule n of a layer l.
    ...
    
    Methods
    -------
    call(inputs)
        compute the length of each capsule
    """

    def call(self, inputs, **kwargs):
        """
        Compute the length of each capsule
        
        Parameters
        ----------
        inputs: tensor
           tensor with shape [None, num_capsules (N), dim_capsules (D)]
        """
        return tf.sqrt(tf.reduce_sum(tf.square(inputs), - 1) + tf.keras.backend.epsilon())

    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

    def get_config(self):
        config = super(Length, self).get_config()
        return config

class Mask(tf.keras.layers.Layer):
    """
    Mask operation described in 'Dynamic routinig between capsules'.
    
    ...
    
    Methods
    -------
    call(inputs, double_mask)
        mask a capsule layer
        set double_mask for multimnist dataset
    """
    def call(self, inputs, double_mask=None, **kwargs):
        if type(inputs) is list:
            if double_mask:
                inputs, mask1, mask2 = inputs
            else:
                inputs, mask = inputs
        else:  
            x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
            if double_mask:
                mask1 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,0],num_classes=x.get_shape().as_list()[1])
                mask2 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,1],num_classes=x.get_shape().as_list()[1])
            else:
                mask = tf.keras.backend.one_hot(indices=tf.argmax(x, 1), num_classes=x.get_shape().as_list()[1])

        if double_mask:
            masked1 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask1, -1))
            masked2 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask2, -1))
            return masked1, masked2
        else:
            masked = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask, -1))
            return masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # generation step
            return tuple([None, input_shape[1] * input_shape[2]])

    def get_config(self):
        config = super(Mask, self).get_config()
        return config