In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("./../..")

In [211]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import tensorflow as tf
import numpy as np


#local
from effcn.layers import Squash, CapsLayer, PrimaryCapsLayer
from effcn.functions import squash_hinton, max_norm_masking, masking
#from effcn.models_mnist import Decoder, CapsNet
#from effcn.models_multimnist import MultiMnistEcnDecoder, CapsNet
from effcn.models_smallnorb import CapsNet

In [53]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=1
                              )

    def forward(self, x):
        return F.relu(self.conv(x))


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
        super(PrimaryCaps, self).__init__()
        self.num_routes = num_routes
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
            for _ in range(num_capsules)])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), self.num_routes, -1)
        return squash_hinton(u)
        #return self.squash(u)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [54]:
C = ConvLayer()
P = PrimaryCaps()

a = torch.rand([1,1,28,28])
a = C(a)
a = P(a)
a.shape

torch.Size([1, 1152, 8])

In [18]:
class EPrimaryCaps(nn.Module):
    """
        Attributes
        ----------
        F: int depthwise conv number of features
        K: int depthwise conv kernel dimension 
        N: int number of primary capsules
        D: int primary capsules dimension (number of properties)
        s: int depthwise conv strides
    """

    def __init__(self, F, K, N, D, s=1):
        super().__init__()
        self.F = F
        self.K = K
        self.N = N
        self.D = D
        self.s = s
        #
        self.dw_conv2d = nn.Conv2d(
            F, F, kernel_size=K, stride=s, groups=F, padding="valid")
        #
        self.squash = Squash(eps=1e-20)

    def forward(self, x):
        """
        IN:  (B,C,H,W)
        OUT: (B, N, D)

        therefore for x, we have the following constraints:
            (B,C,H,W) = (B, F,F,K)
        """
        # (B,C,H,W) -> (B,C,H,W)
        x = self.dw_conv2d(x)

        # (B,C,H,W) -> (B, N, D)
        x = x.view((-1, self.N, self.D))
        x = self.squash(x)
        return x

In [20]:
EPC = EPrimaryCaps(F=128, K=9, N=16, D=8)  # F = n_l * d_l !!!

b = torch.rand([1, 128, 9, 9])
b = EPC(b)

b.shape

torch.Size([1, 16, 8])

In [22]:
32*6*6

1152

In [94]:
class testPrimaryCapsLayer(nn.Module):
    """
    c_in: input channels
    c_out: output channels
    d_l: dimension of prime caps
    """
    def __init__(self, c_in, c_out, d_l, kernel_size, stride):
        super(testPrimaryCapsLayer, self).__init__()
        self.conv = nn.Conv2d(c_in, c_out * d_l, kernel_size=kernel_size, stride=stride)
        self.c_in = c_in
        self.c_out = c_out
        self.d_l = d_l

    def forward(self, input):
        out = self.conv(input)
        N, C, H, W = out.size()
        out = out.view(N, self.c_out, self.d_l, H, W)

        # will output N x OUT_CAPS x OUT_DIM
        out = out.permute(0, 1, 3, 4, 2).contiguous()
        out = out.view(out.size(0), -1, out.size(4))
        out = squash_hinton(out)
        return out

In [207]:
C = nn.Sequential(
        nn.Conv2d(1, 256, kernel_size=9, stride=1),
        nn.Conv2d(256, 256, kernel_size=9, stride=1)
)
P = testPrimaryCapsLayer(256, 32, 8, kernel_size=9, stride=2)
L = CapsLayer(32 * 6 * 6, 8, 10,16,3)

b = torch.rand([1, 1, 36, 36])
#b = torch.rand([1, 2, 48, 48])
b = C(b)
b = P(b)
b = L(b)

b.shape


torch.Size([1, 10, 16])

In [163]:
4608/32

144.0

In [41]:
class refCapsNet(nn.Module):
    def __init__(self, routing_iterations, n_classes=10):
        super(CapsNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
        self.primaryCaps = PrimaryCapsLayer(256, 32, 8, kernel_size=9, stride=2)  # outputs 6*6
        self.num_primaryCaps = 32 * 6 * 6
        routing_module = AgreementRouting(self.num_primaryCaps, n_classes, routing_iterations)
        self.digitCaps = CapsLayer(self.num_primaryCaps, 8, n_classes, 16, routing_module)

    def forward(self, input):
        x = self.conv1(input)
        x = F.relu(x)
        x = self.primaryCaps(x)
        x = self.digitCaps(x)
        probs = x.pow(2).sum(dim=2).sqrt()
        return x, probs

In [90]:
class blaCapsNet(nn.Module):
    """
        CapsNet Implementation for MNIST
        all parameters taken from the paper
    """

    def __init__(self):
        super().__init__()
        # values from paper, are fixed!
        self.n_l = (32 * 6 * 6) # num of primary capsules
        self.d_l = 8            # dim of primary capsules
        self.n_h = 10           # num of output capsules
        self.d_h = 16           # dim of output capsules
        self.n_iter = 3

        self.backbone = nn.Conv2d(1, 256, kernel_size=9, stride=1)
        self.primcaps = PrimaryCapsLayer(c_in=256,c_out=32,d_l=self.d_l, kernel_size=9, stride=2)
        self.digitCaps = CapsLayer(self.n_l, self.d_l, self.n_h, self. d_h, self.n_iter)
        self.decoder = Decoder()

    def forward(self, x, y_true=None):
        """
            IN:
                x (b, 1, 28, 28)
            OUT:
                u_h
                    (b, n_h, d_h)
                    output caps
                x_rec
                    (b, 1, 28, 28)
                    reconstruction of x
        """
        u_l = self.primcaps(self.backbone(x))
        u_h = self.digitCaps(u_l)
        #
        u_h_masked = masking(u_h, y_true)
        x_rec = self.decoder(u_h_masked)
        return u_h, x_rec

In [216]:
CN = CapsNet()

b = torch.rand([1, 2, 48, 48])
#b = torch.rand([1, 1, 36, 36])
b, x = CN(b)

print(b.shape, x.shape)

torch.Size([1, 5, 16]) torch.Size([1, 2, 48, 48])
