In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("./../..")

In [3]:
import math
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.nn.modules.loss import _Loss
import numpy as np
#local
from misc.utils import count_parameters

#### Spread Loss

In [4]:
### my interprtation off spread loss


def spread_loss(y_pred, y_true, m):

    at = torch.zeros(y_true.shape).to(device)
    zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(device)

    #create at
    for i, cl in enumerate(y_true):
        at[i] = y_pred[i][cl]
    
    at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
    ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

    loss = ((torch.max( m-(at[:,:-1] - ai), zr))**2).sum(dim=1)

    return loss.mean()


In [5]:
class SpreadLoss(_Loss):

    def __init__(self, device, m_min=0.2, m_max=0.9):
        super(SpreadLoss, self).__init__()
        self.m_min = m_min
        self.m_max = m_max
        self.device = device

    def margin(self, reps):
        return self.m_min + (self.m_max - self.m_min)*reps

    def forward(self, y_pred, y_true, reps):
        at = torch.zeros(y_true.shape).to(self.device)
        zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(self.device)
        ma = self.margin(reps)

        #create at
        for i, cl in enumerate(y_true):
            at[i] = y_pred[i][cl]
        
        at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
        ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

        loss = ((torch.max( ma - (at[:,:-1] - ai), zr))**2).sum(dim=1)

        # mean over batch
        return loss.mean()

In [6]:

device = torch.device("cuda")
torch.manual_seed(0)
bs = 8
y_true = torch.randint(0, 9, (bs,), requires_grad=False).to(device)
y_pred = torch.rand(bs,10, requires_grad=True).to(device)
spread_loss(y_pred, y_true, 0.2)

#y_true.unsqueeze(1).repeat(1,4)

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

In [7]:
A = SpreadLoss(device)
A.margin(0)
loss = A.forward(y_pred, y_true, 0)

#loss.backward()
loss

#same result as in gitstuff

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

#### CapsNetEM

In [8]:
ds_train = datasets.MNIST(root='../../data', train=True, download=True, transform=T.ToTensor())

dl_train = torch.utils.data.DataLoader(ds_train, 
                                        batch_size=1, 
                                        shuffle=False,
                                        num_workers=2)              

In [9]:
x, y = next(iter(dl_train))

x.shape,y

(torch.Size([1, 1, 28, 28]), tensor([5]))

In [10]:
class CapsNetEM(nn.Module):
    """
    Genrate CapsNet with EM routing
    Args:
        A: output channels of normal conv
        B: output channels of primary caps
        C: output channels of 1st conv caps
        D: output channels of 2nd conv caps
        E: output channels of class caps (i.e. number of classes)
        K: kernel of conv caps
        P: size of square pose matrix
        iters: number of EM iterations
        ...

        input: (bs, 1, 28, 28)
    """

    def __init__(self, A=32, B=32, C=32, D=32,E=10, K=3, P=4, iter=3):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=A, kernel_size=(5, 5), stride=2, padding=2),
            nn.ReLU(inplace=True),
            #nn.BatchNorm2d(num_features=A),
        )
        self.prim_caps0 = PrimaryCaps(ch_in=A, ch_out=B, K=1, P=P, stride=1, padding="valid")
        self.conv_caps1 = ConvCaps(ch_in=B, ch_out=C, K=K, P=P, stride=2, iter=iter, class_caps=False)
        self.conv_caps2 = ConvCaps(ch_in=C, ch_out=B, K=K, P=P, stride=1, iter=iter, class_caps=False)
        self.class_caps = ConvCaps(ch_in=D, ch_out=E, K=1, P=P, stride=1, iter=iter, class_caps=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.prim_caps0(x)
        x = self.conv_caps1(x)
        #x = self.conv_caps2(x)
        #x = self.class_caps(x)
        return x

In [11]:
class PrimaryCaps(nn.Module):
    """
    Args:
        A: output of the normal conv layer
        B: number of types of capsules
        K: kernel size of convolution
        P: size of pose matrix is P*P
        stride: stride of convolution
    Shape:
        input:  (*, A, h, w)                (bs, 32, 14, 14)
        output: p -> (*,B, h', w', P, P)    (bs, 32, 14, 14, 4, 4)
                a -> (*,B, h', w')          (bs, 32, 14, 14)
        h', w' is computed the same way as convolution layer
        parameter size is: K*K*A*B*P*P + B*P*P
    """

    def __init__(self, ch_in=32, ch_out=32, K=1, P=4, stride=1, padding="valid"):
        super().__init__()
        self.pose = nn.Conv2d(in_channels=ch_in, out_channels=ch_out*P*P, kernel_size=K, stride=stride, bias=True)
        self.acti = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=K, stride=stride, bias=True),
            nn.Sigmoid()
        )
        self.P = P

    def forward(self, x):
        p = self.pose(x)
        a = self.acti(x)
        p = p.view(p.shape[0],-1,p.shape[2],p.shape[3],self.P,self.P)

        return p, a

In [12]:
class ConvCaps(nn.Module):
    """Create a convolutional capsule layer
    that transfer capsule layer L to capsule layer L+1
    by EM routing.
    Args:
        B: input number of types of capsules
        C: output number on types of capsules
        K: kernel size of convolution
        P: size of pose matrix is P*P
        stride: stride of convolution
        iters: number of EM iterations
        coor_add: use scaled coordinate addition or not
        w_shared: share transformation matrix across w*h.
    Shape:
        input:  (*,B, h,  w, P, P)      (bs, 32, 14, 14, 4, 4)
                (*,B, h,  w, 1)         (bs, 32, 14, 14)
        output: (*,C, h,  w, P, P)      (bs, 32, 6, 6, 4, 4)
                (*,C, h,  w, 1)         (bs, 32, 6, 6)
        h', w' is computed the same way as convolution layer
        parameter size is: K*K*B*C*P*P + B*P*P
    """   

    def __init__(self, ch_in=32, ch_out=32, K=3, P=4, stride=2, iter=3, class_caps=False):
        super().__init__()
        # init vars
        self.ch_in  = ch_in
        self.ch_out = ch_out
        self.K = K
        self.P = P
        self.psize = P*P
        self.stride = stride
        self.iter = iter
        self.class_caps = class_caps

        # constants
            #actualy none
        
        # params
        self.b_u = nn.Parameter(torch.zeros(ch_out))
        self.b_a = nn.Parameter(torch.zeros(ch_out))

        #deepwiese conv for single caps poses M to generate trained W
        self.conv_deep = nn.Conv2d(in_channels=(ch_in*P*P), out_channels=(ch_in*ch_out*P*P), kernel_size=K, stride=stride,
                        padding=0, groups=ch_in*ch_out)

        # activations
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        #split pose and activation
        p, a = x

        #deepwise conv of poses to get votes v
        v = p.view(p.shape[0],-1,p.shape[2],p.shape[3])
        v = self.conv_deep(v)

        v = v.view(v.shape[0],-1,v.shape[2],v.shape[3],self.P,self.P)
        return v, a

    def em_routing(self):
        pass

    def e_step(self):
        pass

    def m_step(self):
        pass

#### Test Stuff

In [13]:
B = CapsNetEM()
print(count_parameters(B))
z, b = B(x)

z.shape, b.shape




ValueError: in_channels must be divisible by groups

In [14]:
def add_pathes(x, B, K, psize, stride):
    """
        Shape:
            Input:     (b, H, W, B*(P*P+1))
            Output:    (b, H', W', K, K, B*(P*P+1))
    """
    b, h, w, c = x.shape
    assert h == w
    assert c == B*(psize+1)
    oh = ow = int(((h - K )/stride)+ 1) # moein - changed from: oh = ow = int((h - K + 1) / stride)
    idxs = [[(h_idx + k_idx) \
            for k_idx in range(0, K)] \
            for h_idx in range(0, h - K + 1, stride)]
    print(idxs)
    print(x.shape)
    x = x[:, idxs, :, :]
    print(x.shape)
    x = x[:, :, :, idxs, :]
    print(x.shape)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return x, oh, ow

In [15]:
q = torch.randn(1, 14, 14,32*17)
B = 32
K = 3
psize = 16
stride = 2

q, oh, ow = add_pathes(x=q, B=B, K=K, psize=psize, stride=stride)

q.shape, oh, ow

[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 8], [8, 9, 10], [10, 11, 12]]
torch.Size([1, 14, 14, 544])
torch.Size([1, 6, 3, 14, 544])
torch.Size([1, 6, 3, 6, 3, 544])


(torch.Size([1, 6, 6, 3, 3, 544]), 6, 6)

In [16]:
r = z.view(z.shape[0],-1,z.shape[2],z.shape[3])
r.shape



NameError: name 'z' is not defined

In [18]:
c_out = 16 * 4 * 4
CV =nn.Conv2d(in_channels= r.shape[1], out_channels=c_out, kernel_size=(3,3), stride=2, padding=0,
 dilation=1, groups=32, bias=True, padding_mode='zeros', device=None, dtype=None)


print(count_parameters(CV))
cv = CV(r).view(1,-1,6,6,4,4).shape



NameError: name 'r' is not defined

In [17]:
#test view
dd = torch.arange(0,512,1)
dd = dd.reshape(-1,4,4)
dd.view(-1)

cc = dd.clone()
#cc =cc.reshape(4,4,-1)
cc = cc.permute(1,2,0)
cc.reshape(-1)

tensor([  0,  16,  32,  48,  64,  80,  96, 112, 128, 144, 160, 176, 192, 208,
        224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432,
        448, 464, 480, 496,   1,  17,  33,  49,  65,  81,  97, 113, 129, 145,
        161, 177, 193, 209, 225, 241, 257, 273, 289, 305, 321, 337, 353, 369,
        385, 401, 417, 433, 449, 465, 481, 497,   2,  18,  34,  50,  66,  82,
         98, 114, 130, 146, 162, 178, 194, 210, 226, 242, 258, 274, 290, 306,
        322, 338, 354, 370, 386, 402, 418, 434, 450, 466, 482, 498,   3,  19,
         35,  51,  67,  83,  99, 115, 131, 147, 163, 179, 195, 211, 227, 243,
        259, 275, 291, 307, 323, 339, 355, 371, 387, 403, 419, 435, 451, 467,
        483, 499,   4,  20,  36,  52,  68,  84, 100, 116, 132, 148, 164, 180,
        196, 212, 228, 244, 260, 276, 292, 308, 324, 340, 356, 372, 388, 404,
        420, 436, 452, 468, 484, 500,   5,  21,  37,  53,  69,  85, 101, 117,
        133, 149, 165, 181, 197, 213, 229, 245, 261, 277, 293, 3

In [43]:
#test view
dd = torch.rand(1,32,4,4,6,6)
dd = dd.reshape(1,32,4)
#dd.view(-1)
dd.shape
#cc = dd.clone()
#cc =cc.reshape(4,4,-1)
#cc = cc.permute(1,2,0)
#cc.reshape(-1)

torch.Size([1, 32, 4, 4, 6, 6])

In [None]:
2359808
147968
74240

4609

print(32*32*17)
print(9*32*32*16)
print(9*32*32*16)
print(32*10*16)
print(32*32*17 + 9*32*32*16 + 9*32*32*16+ 32*10*16)

17408
147456
147456
5120
317440


In [None]:
q = torch.randn(24, 288, 32, 4, 4)
j = torch.randn(24, 288, 32, 4, 4)
torch.matmul(q,j).shape


torch.Size([24, 288, 32, 4, 4])

In [65]:
torch.manual_seed(0)
nb_channels = 10
h, w = 14,14
k = 3
x = torch.randn(5, nb_channels, h, w)
weights = torch.tensor([[0., 0., 0.],
                        [0., 1., 0.],
                        [0., 0., 0.]])
weights = (torch.ones(1,1,k,k)/k**2).repeat(2, 5, 1, 1)
#weights = weights.view(1, 1, 3, 3).repeat(1, nb_channels, 1, 1)

output = F.conv2d(x, weights,groups=2)

output.shape

torch.Size([5, 2, 12, 12])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [104]:
def fix_conv(x):
    D = nn.Conv2d(in_channels=nb_channels, out_channels=out_channels, kernel_size=k, stride=1, bias=False)
    D.weight = torch.nn.Parameter((torch.ones_like(D.weight)/k**2),requires_grad=False)
    
    x = x.view(bs*p*p, nb_channels,h, w)
    x = D(x)
    x = x.view(-1, x.shape[1], x.shape[2], x.shape[3], p,p)

    return x


In [156]:
torch.manual_seed(0)
nb_channels = 32
out_channels = 32
h, w = 14,14
p = 4
k = 3
bs = 1
x = torch.randn(bs, nb_channels, p, p , h, w)
C = 32

#D = nn.Conv2d(in_channels=nb_channels, out_channels=out_channels, kernel_size=k, stride=1, bias=False)
#D.weight = torch.nn.Parameter((torch.ones_like(D.weight)/k**2),requires_grad=False)


#x = x.view(bs*p*p, nb_channels,h, w)
#x = x.view(bs, nb_channels, p, p , h, w)
#x.shape
#D.weight.shape , D.weight.requires_grad, D.weight.mean(), torch.numel(D.weight)

#x = D(x)
#x = x.view(-1, x.shape[1], x.shape[2], x.shape[3], p,p)

x = fix_conv(x)


x.shape


w = torch.rand([1, 32, 12, 12, 4, 4])
#w.shape

v = torch.mul(x, w)

v.shape, v.numel(), w.numel()#*3+ count_parameters(PrimaryCaps())

(torch.Size([1, 32, 12, 12, 4, 4]), 73728, 73728)

In [128]:
a = torch.rand(10,1).repeat(1,4)
b = torch.rand(10,4)

c = torch.mul(a,b)
c.shape

torch.Size([10, 4])

In [139]:
    def transform_view(x, w, C, P, w_shared=False):
        """
            For conv_caps:
                Input:     (b*H*W, K*K*B, P*P)
                Output:    (b*H*W, K*K*B, C, P*P)
            For class_caps:
                Input:     (b, H*W*B, P*P)
                Output:    (b, H*W*B, C, P*P)
        """
        b, B, psize = x.shape
        assert psize == P*P

        x = x.view(b, B, 1, P, P)
        if w_shared:
            hw = int(B / w.size(1))
            w = w.repeat(1, hw, 1, 1, 1)

        w = w.repeat(b, 1, 1, 1, 1)
        x = x.repeat(1, 1, C, 1, 1)
        v = torch.matmul(x, w)
        v = v.view(b, B, C, P*P)
        return v

In [145]:
x = torch.rand(1*12*12,3*3*32,4*4)
w = torch.rand(1,3*3*32,32,4,4)
t = transform_view(x=x, w=w, C=32, P=4)

t.shape, t.numel(), w.numel()

(torch.Size([144, 288, 32, 16]), 21233664, 147456)