In [2]:
import math
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.nn.modules.loss import _Loss
import numpy as np
#

#### Spread Loss

In [3]:
### my interprtation off spread loss


def spread_loss(y_pred, y_true, m):

    at = torch.zeros(y_true.shape).to(device)
    zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(device)

    #create at
    for i, cl in enumerate(y_true):
        at[i] = y_pred[i][cl]
    
    at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
    ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

    loss = ((torch.max( m-(at[:,:-1] - ai), zr))**2).sum(dim=1)

    return loss.mean()


In [4]:
class SpreadLoss(_Loss):

    def __init__(self, device, m_min=0.2, m_max=0.9):
        super(SpreadLoss, self).__init__()
        self.m_min = m_min
        self.m_max = m_max
        self.device = device

    def margin(self, reps):
        return self.m_min + (self.m_max - self.m_min)*reps

    def forward(self, y_pred, y_true, reps):
        at = torch.zeros(y_true.shape).to(self.device)
        zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(self.device)
        ma = self.margin(reps)

        #create at
        for i, cl in enumerate(y_true):
            at[i] = y_pred[i][cl]
        
        at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
        ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

        loss = ((torch.max( ma - (at[:,:-1] - ai), zr))**2).sum(dim=1)

        # mean over batch
        return loss.mean()

In [5]:

device = torch.device("cuda")
torch.manual_seed(0)
bs = 8
y_true = torch.randint(0, 9, (bs,), requires_grad=False).to(device)
y_pred = torch.rand(bs,10, requires_grad=True).to(device)
spread_loss(y_pred, y_true, 0.2)

#y_true.unsqueeze(1).repeat(1,4)

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

In [6]:
A = SpreadLoss(device)
A.margin(0)
loss = A.forward(y_pred, y_true, 0)

#loss.backward()
loss

#same result as in gitstuff

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

#### CapsNetEM

In [7]:
ds_train = datasets.MNIST(root='../../data', train=True, download=True, transform=T.ToTensor())

dl_train = torch.utils.data.DataLoader(ds_train, 
                                        batch_size=1, 
                                        shuffle=False,
                                        num_workers=2)              

In [8]:
x, y = next(iter(dl_train))

x.shape,y

(torch.Size([1, 1, 28, 28]), tensor([5]))

In [53]:
class CapsNetEM(nn.Module):
    """
    Genrate CapsNet with EM routing
    Args:
        A: output channels of normal conv
        B: output channels of primary caps
        C: output channels of 1st conv caps
        D: output channels of 2nd conv caps
        E: output channels of class caps (i.e. number of classes)
        K: kernel of conv caps
        P: size of square pose matrix
        iters: number of EM iterations
        ...

        input: (bs, 1, 28, 28)
    """

    def __init__(self, A=32, B=32, C=32, D=32, K=3, P=4, iter=3):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=A, kernel_size=(5, 5), stride=2, padding=2),
            nn.ReLU(inplace=True),
            #nn.BatchNorm2d(num_features=A),
        )
        self.prim_caps0 = PrimaryCaps(ch_in=A, ch_out=B, K=1, P=P, stride=1, padding="valid")
        self.conv_caps1 = ConvCaps(ch_in=B, ch_out=C, K=K, P=P, stride=2, iter=iter, class_caps=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.prim_caps0(x)
        x = self.conv_caps1(x)

        return x

In [52]:
class PrimaryCaps(nn.Module):
    """
    Args:
        A: output of the normal conv layer
        B: number of types of capsules
        K: kernel size of convolution
        P: size of pose matrix is P*P
        stride: stride of convolution
    Shape:
        input:  (*, A, h, w)                (bs, 32, 14, 14)
        output: p -> (*,B, h', w', P, P)    (bs, 32, 14, 14, 4, 4)
                a -> (*,B, h', w')          (bs, 32, 14, 14)
        h', w' is computed the same way as convolution layer
        parameter size is: K*K*A*B*P*P + B*P*P
    """

    def __init__(self, ch_in=32, ch_out=32, K=1, P=4, stride=1, padding="valid"):
        super().__init__()
        self.pose = nn.Conv2d(in_channels=ch_in, out_channels=ch_out*P*P, kernel_size=K, stride=stride, bias=True)
        self.acti = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=K, stride=stride, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        p = self.pose(x)
        a = self.acti(x)
        p = p.view(p.shape[0],-1,p.shape[2],p.shape[3],4,4)

        return p, a

In [54]:
class ConvCaps(nn.Module):
    """Create a convolutional capsule layer
    that transfer capsule layer L to capsule layer L+1
    by EM routing.
    Args:
        B: input number of types of capsules
        C: output number on types of capsules
        K: kernel size of convolution
        P: size of pose matrix is P*P
        stride: stride of convolution
        iters: number of EM iterations
        coor_add: use scaled coordinate addition or not
        w_shared: share transformation matrix across w*h.
    Shape:
        input:  (*,B, h,  w, P, P)      (bs, 32, 14, 14, 4, 4)
                (*,B, h,  w, 1)         (bs, 32, 14, 14)
        output: (*,C, h,  w, P, P)      (bs, 32, 6, 6, 4, 4)
                (*,C, h,  w, 1)         (bs, 32, 6, 6)
        h', w' is computed the same way as convolution layer
        parameter size is: K*K*B*C*P*P + B*P*P
    """   

    def __init__(self, ch_in=32, ch_out=32, K=3, P=4, stride=2, iter=3, class_caps=False):
        super().__init__()
        # init vars
        self.ch_in  = ch_in
        self.ch_out = ch_out
        self.K = K
        self.P = P
        self.stride = stride
        self.iter = iter
        self.class_caps = class_caps

        # constants
            #actualy none
        
        # params
        self.b_u = nn.Parameter(torch.zeros(ch_out))
        self.b_a = nn.Parameter(torch.zeros(ch_out))
        self.M   = nn.Parameter(torch.randn(1, K*K*ch_in, ch_out, P, P))

        # activations
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        #split pose and activation
        p, a = x

        print(self.M.shape)


        return p, a

    def em_routing(self):
        pass

    def e_step(self):
        pass

    def m_step(self):
        pass

In [55]:
B = CapsNetEM()

z, b = B(x)

z.shape, b.shape



torch.Size([1, 288, 32, 4, 4])


(torch.Size([1, 32, 14, 14, 4, 4]), torch.Size([1, 32, 14, 14]))