In [1]:
import math
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.nn.modules.loss import _Loss
import numpy as np
#

#### Spread Loss

In [78]:
### my interprtation off spread loss


def spread_loss(y_pred, y_true, m):

    at = torch.zeros(y_true.shape).to(device)
    zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(device)

    #create at
    for i, cl in enumerate(y_true):
        at[i] = y_pred[i][cl]
    
    at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
    ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

    loss = ((torch.max( m-(at[:,:-1] - ai), zr))**2).sum(dim=1)

    return loss.mean()


In [80]:
class SpreadLoss(_Loss):

    def __init__(self, device, m_min=0.2, m_max=0.9):
        super(SpreadLoss, self).__init__()
        self.m_min = m_min
        self.m_max = m_max
        self.device = device

    def margin(self, reps):
        return self.m_min + (self.m_max - self.m_min)*reps

    def forward(self, y_pred, y_true, reps):
        at = torch.zeros(y_true.shape).to(self.device)
        zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(self.device)
        ma = self.margin(reps)

        #create at
        for i, cl in enumerate(y_true):
            at[i] = y_pred[i][cl]
        
        at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
        ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

        loss = ((torch.max( ma - (at[:,:-1] - ai), zr))**2).sum(dim=1)

        # mean over batch
        return loss.mean()

In [101]:

device = torch.device("cuda")
torch.manual_seed(0)
bs = 8
y_true = torch.randint(0, 9, (bs,), requires_grad=False).to(device)
y_pred = torch.rand(bs,10, requires_grad=True).to(device)
spread_loss(y_pred, y_true, 0.2)

#y_true.unsqueeze(1).repeat(1,4)

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

In [102]:
A = SpreadLoss(device)
A.margin(0)
loss = A.forward(y_pred, y_true, 0)

#loss.backward()
loss

#same result as in gitstuff

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

#### CapsNetEM

In [104]:
class CapsNetEM(nn.Module):
    """
    Genrate CapsNet with EM routing
    Args:
        A: output channels of normal conv
        B: output channels of primary caps
        C: output channels of 1st conv caps
        D: output channels of 2nd conv caps
        E: output channels of class caps (i.e. number of classes)
        K: kernel of conv caps
        P: size of square pose matrix
        iters: number of EM iterations
        ...
    """

    def __init__(self, A=32, B=32, C=32, D=32, K=3, P=4, iter=3):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=A, kernel_size=(5, 5), stride=2, padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=A, eps=0.001, momentum=0.1, affine=True),
        )



In [107]:
CapsNetEM()

CapsNetEM()