<a href="https://colab.research.google.com/github/ugomezjr/Capsule-Networks-Pytorch/blob/main/pytorch_capsule_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from torch import nn
import torch

In [None]:
img_digit = torch.randn(1, 28, 28)

## Layer 1: ReLU Conv1

Conv1 has 256 feature maps, 9 x 9 convolutional kernels with a stride of 1 and ReLU activation. 

* Converts pixel intensities to the activities of local feature detectors that are then used as inputs to the *primary* capsules. 

In [3]:
# This layer converts pixel intensities to the activities of local feature
# detectors that are then used as inputs to the "primary" capsules. 
relu_conv1_block = nn.Sequential(
    nn.Conv2d(in_channels=1,
              out_channels=256,
              kernel_size=9,
              stride=1,
              padding=0),
    nn.ReLU()
)

In [4]:
relu_conv1_block(img_digit.unsqueeze(0)).shape

torch.Size([1, 256, 20, 20])

## Layer 2: PrimaryCapsules

A convolutional capsule layer with 32 channels of convolutional 8D capsules (i.e. each primary capsule contains 8 convolutional units with a 9 x 9 kernel and a stride of 2).

*   Lowest level of multi-dimensional entities.
*   Activating the primary capsules corresponds to inverting the rendering process. 



In [110]:
# In total PrimaryCapsules has [32 x 6 x 6] capsule outputs (each output is an 8D vector) and
# each capsule in the [6 x 6] grid is sharing their weights with each other. 

class PrimaryCaps(nn.Module):
  def __init__(self, 
               in_channels: int=256, 
               out_channels: int=32,
               caps_dim: int=8, 
               kernel_size: int=9,
               stride: int=2,
               padding: int=0):
    super().__init__()

    self.caps_dim = caps_dim

    self.conv2d = nn.Sequential(
        nn.Conv2d(in_channels=in_channels,
                  out_channels=(out_channels*caps_dim),
                  kernel_size=kernel_size,
                  stride=stride,
                  padding=padding)
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    print(f"Input Shape [ReLU Conv1]: {x.shape}")
    batch_size = x.size(0)
    x = self.conv2d(x)
    print(f"PrimaryCaps Shape [original]: {x.shape}")
    x = x.view(batch_size, -1, self.caps_dim)
    print(f"PrimaryCaps Shape [reshaped]: {x.shape}")
    return x


In [113]:
primarycaps = PrimaryCaps()
output = primarycaps(torch.randn(1, 256, 20, 20))

Input Shape [ReLU Conv1]: torch.Size([1, 256, 20, 20])
PrimaryCaps Shape [original]: torch.Size([1, 256, 6, 6])
PrimaryCaps Shape [reshaped]: torch.Size([1, 1152, 8])


## Equation 1: Non-Linear "**Squashing**" Function

In [8]:
# The length of the vector output of a capsule cannot exceed 1 by applying a non-linearity that leaves
# the orientation of the vector unchanged but scales down its magnitude.
def squash(sj: torch.Tensor):
  norm = torch.linalg.norm(sj)
  norm_sqr = norm**2

  return (norm_sqr / (1 + norm_sqr)) * (sj / norm)

In [90]:
def norm(x: torch.Tensor):
  sum_of_sqrs = torch.sum(torch.Tensor([ i**2 for i in x]))
  return torch.sqrt(sum_of_sqrs)

x1 = torch.arange(9, dtype=torch.float)
x2 = torch.zeros(9, dtype=torch.float)

print(torch.norm(x1), norm(x1))
print(torch.norm(x2), norm(x2))

tensor(14.2829) tensor(14.2829)
tensor(0.) tensor(0.)


In [98]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

class CapsNet(nn.Module):

  def __init__(self, 
               in_channels: int=1,
               conv1_out_channels: int=256,
               primary_caps_out_channels: int=32, 
               caps_dim: int=8,
               kernel_size: int=9):
    super().__init__()

    self.caps_dim = caps_dim

    self.relu_conv1 = nn.Sequential(
        nn.Conv2d(in_channels=1,
                  out_channels=256,
                  kernel_size=9,
                  stride=1,
                  padding=0),
        nn.ReLU()
    )

    self.primary_caps = nn.Sequential(
        nn.Conv2d(in_channels=256,
                  out_channels=32*8,
                  kernel_size=9,
                  stride=2,
                  padding=0)
    )
  

  def squash(self, sj: torch.Tensor, dim: int=-1):
    """
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
    return scale * inputs
    """
    """
    # Original Implementation:

    norm = torch.linalg.norm(sj)
    norm_sqr = norm**2

    return (norm_sqr / (1 + norm_sqr)) * (sj / norm)
    """

    norm = torch.linalg.norm(sj, ord=2, dim=dim, keepdim=True)
    norm_sqr = norm**2

    scale = norm_sqr / (1 + norm_sqr) / (norm + 1e-8)
    return scale * sj


  def forward(self, x: torch.Tensor):
    batch_size = x.size(0)

    x = self.relu_conv1(x)
    print(f"ReLU Conv1 Shape: {x.shape}")
    x = self.primary_caps(x)
    print(f"PrimaryCaps Shape [original]: {x.shape}")
    x = x.view(batch_size, -1, self.caps_dim) # x.reshape(x_size, -1, 8)
    print(f"PrimaryCaps Shape [reshaped]: {x.shape}")
    x = self.squash(x)
    print(f"PrimaryCaps shape [squashed]: {x.shape}")
    return x

In [99]:
model = CapsNet()
model

CapsNet(
  (relu_conv1): Sequential(
    (0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
    (1): ReLU()
  )
  (primary_caps): Sequential(
    (0): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
)

In [100]:
outputs = model(img_digit.unsqueeze(0))
outputs

ReLU Conv1 Shape: torch.Size([1, 256, 20, 20])
PrimaryCaps Shape [original]: torch.Size([1, 256, 6, 6])
PrimaryCaps Shape [reshaped]: torch.Size([1, 1152, 8])
PrimaryCaps shape [squashed]: torch.Size([1, 1152, 8])


tensor([[[ 0.1852,  0.1201,  0.1423,  ...,  0.3105,  0.1277,  0.3101],
         [ 0.1490,  0.3131,  0.1541,  ...,  0.1951,  0.0989,  0.1846],
         [ 0.3402,  0.2957,  0.1792,  ...,  0.0869,  0.2138,  0.0988],
         ...,
         [-0.0140,  0.0277,  0.0024,  ..., -0.0039, -0.0209, -0.0052],
         [-0.0664, -0.0437, -0.1837,  ...,  0.0165,  0.0360, -0.1984],
         [-0.0148, -0.0383, -0.0140,  ..., -0.0247,  0.0270, -0.0262]]],
       grad_fn=<MulBackward0>)

In [71]:
torch.sum(outputs[0][0])

tensor(1.5287, grad_fn=<SumBackward0>)

In [17]:
32*6*6

1152

In [18]:
32*8

256