In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("./../..")

In [3]:
import math
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.nn.modules.loss import _Loss
import numpy as np
#local
from misc.utils import count_parameters

#### Spread Loss

In [123]:
### my interprtation off spread loss


def spread_loss(y_pred, y_true, m, device):

    at = torch.zeros(y_true.shape).to(device)
    zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(device)

    #create at
    for i, cl in enumerate(y_true):
        at[i] = y_pred[i][cl]
    
    at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
    ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

    loss = ((torch.max( m-(at[:,:-1] - ai), zr))**2).sum(dim=1)

    return loss.mean()


In [124]:
class SpreadLoss(_Loss):

    def __init__(self, device, m_min=0.2, m_max=0.9):
        super(SpreadLoss, self).__init__()
        self.m_min = m_min
        self.m_max = m_max
        self.device = device

    def margin(self, reps):
        return self.m_min + (self.m_max - self.m_min)*reps

    def forward(self, y_pred, y_true, reps):
        at = torch.zeros(y_true.shape).to(self.device)
        zr = torch.zeros((y_pred.shape[0],y_pred.shape[1]-1)).to(self.device)
        ma = self.margin(reps)

        #create at
        for i, cl in enumerate(y_true):
            at[i] = y_pred[i][cl]
        
        at = at.unsqueeze(1).repeat(1,y_pred.shape[1])
        ai = y_pred[y_pred!=at].view(y_pred.shape[0],-1)

        loss = ((torch.max( ma - (at[:,:-1] - ai), zr))**2).sum(dim=1)

        # mean over batch
        return loss.mean()

In [125]:

loss_dev = torch.device("cuda")
torch.manual_seed(0)
bs = 8
y_true = torch.randint(0, 9, (bs,), requires_grad=False).to(device)
y_pred = torch.rand(bs,10, requires_grad=True).to(device)
spread_loss(y_pred, y_true, 0.2, loss_dev)

#y_true.unsqueeze(1).repeat(1,4)

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

In [126]:
A = SpreadLoss(loss_dev)
A.margin(0)
loss = A.forward(y_pred, y_true, 0)

#loss.backward()
loss

#same result as in gitstuff

tensor(2.0286, device='cuda:0', grad_fn=<MeanBackward0>)

#### CapsNetEM

In [18]:
ds_train = datasets.MNIST(root='../../data', train=True, download=True, transform=T.ToTensor())

dl_train = torch.utils.data.DataLoader(ds_train, 
                                        batch_size=3, 
                                        shuffle=False,
                                        num_workers=2)              

In [19]:
x, y = next(iter(dl_train))

x.shape,y

(torch.Size([3, 1, 28, 28]), tensor([5, 0, 4]))

In [10]:
class CapsNetEM(nn.Module):
    """
    Genrate CapsNet with EM routing
    Args:
        A: output channels of normal conv
        B: output channels of primary caps
        C: output channels of 1st conv caps
        D: output channels of 2nd conv caps
        E: output channels of class caps (i.e. number of classes)
        K: kernel of conv caps
        P: size of square pose matrix
        iters: number of EM iterations
        ...

        input: (bs, 1, 28, 28)
    """

    def __init__(self, A=32, B=32, C=32, D=32,E=10, K=3, P=4, iter=3, hw_out=(28,28)):
        super().__init__()
        hw_out = self.hw_cal(hw_out, kernel=5, padding=2, dilatation=1, stride=2)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=A, kernel_size=(5, 5), stride=2, padding=2),
            nn.ReLU(inplace=True),
            #nn.BatchNorm2d(num_features=A),
        )
        hw_out = self.hw_cal(hw_out, kernel=1, padding=0, dilatation=1, stride=1)
        self.prime_caps = PrimaryCaps(ch_in=A, ch_out=B, K=1, P=P, stride=1, padding="valid")
        #
        hw_out = self.hw_cal(hw_out, kernel=K, padding=0, dilatation=1, stride=2)
        self.conv_caps1 = ConvCaps(ch_in=B, ch_out=C, K=K, P=P, stride=2, iter=iter, hw_out=hw_out, class_caps=False)
        #
        hw_out = self.hw_cal(hw_out, kernel=K, padding=0, dilatation=1, stride=1)
        self.conv_caps2 = ConvCaps(ch_in=C, ch_out=B, K=K, P=P, stride=1, iter=iter, hw_out=hw_out, class_caps=False)
        #
        self.class_caps = ConvCaps(ch_in=D, ch_out=E, K=1, P=P, stride=1, iter=iter, hw_out=hw_out, class_caps=True)

    def hw_cal(self, hw_in, kernel, padding=0, dilatation=1, stride=1):
        if type(hw_in) == type(int()):
            hw_out = math.floor((hw_in + 2*padding - dilatation * (kernel - 1) - 1) / stride + 1)
        elif type(hw_in) == type(tuple()):
            h_out = math.floor((hw_in[0] + 2*padding - dilatation * (kernel - 1) - 1) / stride + 1)
            w_out = math.floor((hw_in[1] + 2*padding - dilatation * (kernel - 1) - 1) / stride + 1)
            hw_out = (h_out, w_out)
        return hw_out


    def forward(self, x):
        x = self.conv1(x)
        x = self.prime_caps(x)
        x = self.conv_caps1(x)
        x = self.conv_caps2(x)
        x = self.class_caps(x)
        return x

In [11]:
class PrimaryCaps(nn.Module):
    """
    Args:
        A: output of the normal conv layer
        B: number of types of capsules
        K: kernel size of convolution
        P: size of pose matrix is P*P
        stride: stride of convolution
    Shape:
        input:  (*, A, h, w)                (bs, 32, 14, 14)
        output: p -> (*,B, h', w', P, P)    (bs, 32, 14, 14, 4, 4)
                a -> (*,B, h', w')          (bs, 32, 14, 14)
        h', w' is computed the same way as convolution layer
        parameter size is: K*K*A*B*P*P + B*P*P
    """

    def __init__(self, ch_in=32, ch_out=32, K=1, P=4, stride=1, padding="valid"):
        super().__init__()
        self.pose = nn.Conv2d(in_channels=ch_in, out_channels=ch_out*P*P, kernel_size=K, stride=stride, bias=True)
        self.acti = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=K, stride=stride, bias=True),
            nn.Sigmoid()
        )
        self.P = P

    def forward(self, x):
        p = self.pose(x)
        a = self.acti(x)
        p = p.view(p.shape[0],-1,p.shape[2],p.shape[3],self.P,self.P)

        return p, a

In [119]:
class ConvCaps(nn.Module):
    """Create a convolutional capsule layer
    that transfer capsule layer L to capsule layer L+1
    by EM routing.
    Args:
        B: input number of types of capsules
        C: output number on types of capsules
        K: kernel size of convolution
        P: size of pose matrix is P*P
        stride: stride of convolution
        iters: number of EM iterations
        coor_add: use scaled coordinate addition or not
        w_shared: share transformation matrix across w*h.
    Shape:
        input:  (*,B, h,  w, P, P)      (bs, 32, 14, 14, 4, 4)
                (*,B, h,  w, 1)         (bs, 32, 14, 14)
        output: (*,C, h,  w, P, P)      (bs, 32, 6, 6, 4, 4)
                (*,C, h,  w, 1)         (bs, 32, 6, 6)
        h', w' is computed the same way as convolution layer
        parameter size is: K*K*B*C*P*P + B*P*P
    """   

    def __init__(self, ch_in=32, ch_out=32, K=3, P=4, stride=2, iter=3, hw_out=(1,1), final_lambda=1e-02, class_caps=False):
        super().__init__()
        # init vars
        self.ch_in  = ch_in
        self.ch_out = ch_out
        self.K = K
        self.P = P
        self.psize = P*P
        self.stride = stride
        self.iter = iter
        self.hw_out = hw_out
        self.class_caps = class_caps
        self.final_lambda = final_lambda

        # constants
        self.eps = 1e-8
        
        # params
        self.b_u = nn.Parameter(torch.zeros(ch_out), requires_grad=True)
        self.b_a = nn.Parameter(torch.zeros(ch_out), requires_grad=True)
        self.w = nn.Parameter(torch.rand([1, ch_in, hw_out[0], hw_out[0], P, P, ch_out]), requires_grad=True)

        #conv with static kernel
        self.conv_stat = nn.Conv2d(in_channels=ch_in, out_channels=ch_in, kernel_size=K, stride=stride, bias=False, padding=0)
        self.conv_stat.weight = torch.nn.Parameter((torch.ones_like(self.conv_stat.weight)/K**2),requires_grad=False)

        # activations
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=2)

    def voting(self, x):
        """
                Input:     (bs, ch_in, h, w, p, p)
                Output:    (bs, ch_in, h, w, p, p, ch_out)
        """        
        sh_in = x.shape
        #conv & shaping
        x = x.view(sh_in[0]*self.P*self.P, self.ch_in,sh_in[2], sh_in[3])
        x = self.conv_stat(x)
        x = x.view(-1, x.shape[1], x.shape[2], x.shape[3], self.P,self.P)
        
        #expand x and w to number of out channels
        x = x.unsqueeze(-1).repeat([1, 1, 1, 1, 1, 1, self.ch_out])
        w = self.w.repeat([sh_in[0], 1, 1, 1, 1, 1, 1])
        
        assert x.shape == w.shape

        # compute v
        v = torch.mul(x, w)
        return v
    
    def add_cord(self, v):
        """
            Input:
                v:         (bs, ch_in, h, w, p, p, ch_out)
            Output:
                v:         (bs, ch_in, h, w, p, p, ch_out)
        """        
        #split shapes
        s_bs, s_ch_in, s_h, s_w, s_p1, s_p2, s_ch_out = v.shape
        v = v.view(s_bs, s_ch_in, s_h, s_w, s_ch_out, s_p1* s_p2)
        #coordinate addition
        ar_h = torch.arange(s_h, dtype=torch.float32) / s_h
        ar_w = torch.arange(s_w, dtype=torch.float32) / s_w
        coor_h = torch.FloatTensor(1, 1, s_h, 1, 1, s_p1* s_p2).fill_(0.)      #dev
        coor_w = torch.FloatTensor(1, 1, 1, s_w, 1, s_p1* s_p2).fill_(0.)      #dev
        coor_h[0, 0, :, 0, 0, 0] = ar_h
        coor_w[0, 0, 0, :, 0, 1] = ar_w
        v = v + coor_h + coor_w
        v = v.view(s_bs, s_ch_in, s_h, s_w, s_p1, s_p2, s_ch_out)     
        return v
    
    def _inv_temp(self,it):
        # AG 18/07/2018: modified schedule for inverse_temperature (lambda) based
        # on Hinton's response to questions on OpenReview.net: 
        # https://openreview.net/forum?id=HJWLfGWRb
        return (self.final_lambda * (1. - 0.95**(1 + it)))

    def em_routing(self, v, a):
        """
            Input:
                v:         (bs, ch_in, h, w, p, p, ch_out)
                a_in:      (bs, ch_in, h, w)
            
            For ConvCaps:
            Output:
                mu:        (bs, ch_out, h, w, p, p)
                a_out:     (bs, ch_out, h, w)
            Note that some dimensions are merged
            for computation convenient, that is
                v:         (bs*h*w, ch_in, p*p, ch_out)
                a_in:      (bs*h*w, ch_in, 1)
            
            For ClassCaps:
            Output:
                mu:        (bs, ch_out, p, p)
                a_out:     (bs, ch_out)
            Note that some dimensions are merged
            for computation convenient, that is
                v:         (bs, ch_in*h*w, p*p, ch_out)
                a_in:      (bs, ch_in*h*w, 1)
        """

        # split shapes
        s_bs, s_ch_in, s_h, s_w, s_p1, s_p2, s_ch_out = v.shape

        if self.class_caps == False:
            # reshape for conv caps
            v = v.view(s_bs*s_h*s_w, s_ch_in, s_ch_out, s_p1*s_p2)
            a = a.view(s_bs*s_h*s_w, s_ch_in).unsqueeze(-1)
            #declare r
            r = torch.FloatTensor(s_bs*s_h*s_w, s_ch_in, s_ch_out).fill_(1./s_ch_out)
        else:
            # cood add
            v = self.add_cord(v)
            # reshape for class caps
            v = v.view(s_bs, s_ch_in*s_h*s_w, s_ch_out, s_p1*s_p2)
            a = a.view(s_bs, s_ch_in*s_h*s_w).unsqueeze(-1)            
            # declare r
            r = torch.FloatTensor(s_bs, s_ch_in*s_h*s_w, s_ch_out).fill_(1./s_ch_out)


        #iteration
        for it in range(self.iter):
            # M-Step (with inverse temperatur schedulder lambda)
            lambd=self._inv_temp(it)
            a_out, mu, sig_sq = self.m_step(a, r, v, lambd=lambd)
            
            # E-Step
            if it < self.iter - 1:
                r = self.e_step(mu, sig_sq, a_out, v)

        #reshape from M and a as output
        if self.class_caps == False:
            mu = mu.view(s_bs, s_ch_out, s_h, s_w, s_p1, s_p2)
            a_out = a_out.view(s_bs, s_ch_out, s_h, s_w)
        else:
            mu = mu.view(s_bs, s_ch_out, s_p1, s_p2)
            a_out = a_out.view(s_bs, s_ch_out)

        return mu, a_out

    def e_step(self, mu, sig_sq, a_out, v):
        """
            ln_p_j = sum_h \dfrac{(\V^h_{ij} - \mu^h_j)^2}{2 \sigma^h_j}
                    - sum_h ln(\sigma^h_j) - 0.5*\sum_h ln(2*\pi)
            r = softmax(ln(a_j*p_j))
              = softmax(ln(a_j) + ln(p_j))
            Input:
                mu:        (bs*h*w, 1, ch_out, P*P)
                sig_sq:    (bs*h*w, 1, ch_out, P*P)
                a_out:     (bs*h*w, 1, ch_out, 1)
                v:         (bs*h*w, ch_in, ch_out, p*p)
            Local:
                p_ln:  (bs*h*w, ch_in, ch_out, p*p)
                ap_ln:     (bs*h*w, ch_in, ch_out, 1)
            Output:
                r:         (bs*h*w, ch_in, ch_out)
        """
        p_ln = -1. * (v - mu)**2 / (2 * sig_sq) - torch.log(sig_sq.sqrt()) - 0.5*torch.log(torch.tensor(2*math.pi))
        ap_ln = (p_ln.sum(dim=3, keepdim=True) + torch.log(a_out)).squeeze(-1)
        r = self.softmax(ap_ln)

        return r

    def m_step(self, a, r, v, lambd):
        """
            \mu^h_j = \dfrac{\sum_i r_{ij} V^h_{ij}}{\sum_i r_{ij}}
            (\sigma^h_j)^2 = \dfrac{\sum_i r_{ij} (V^h_{ij} - mu^h_j)^2}{\sum_i r_{ij}}
            cost_h = (\beta_u + log \sigma^h_j) * \sum_i r_{ij}
            a_j = logistic(\lambda * (\beta_a - \sum_h cost_h))
            Input:
                a_in:      (bs*h*w, ch_in, 1)
                r:         (bs*h*w, ch_in, ch_out)
                v:         (bs*h*w, ch_in, ch_out, p*p)
            Local:
                cost_h:    (bs*h*w, 1, ch_out, P*P)
                r_sum:     (bs*h*w, 1, ch_out, 1)
            Output:
                a_out:     (bs*h*w, 1, ch_out, 1)
                mu:        (bs*h*w, 1, ch_out, P*P)
                sig_sq:    (bs*h*w, 1, ch_out, P*P)
        """
        s_st, s_ch_in, s_ch_out, p = v.shape 
        r = (r * a).unsqueeze(-1) + self.eps
        r_sum = r.sum(dim=1, keepdim=True)
        mu = torch.sum(r * v, dim=1, keepdim=True) / r_sum
        sig_sq = (torch.sum(r * (v - mu)**2, dim=1, keepdim=True) / r_sum)  + self.eps
        cost = (self.b_u.view(1,1,s_ch_out,1) + torch.log(sig_sq.sqrt())) * r_sum
        a_out = self.sigmoid(lambd*(self.b_a.view(1,1,s_ch_out,1) - cost.sum(dim=3, keepdim=True)))
        
        return a_out, mu, sig_sq

    def forward(self, x):
        #split pose and activation
        x, a = x

        # conv of poses to get votes v
        x = self.voting(x) 
        
        # conv activations
        a = self.conv_stat(a)

        #routing
        x, a = self.em_routing(x, a)


        return x, a

#### Test Stuff

In [122]:
B = CapsNetEM()
print(count_parameters(B))
z, a = B(x)

z.shape, a.shape, a




952820


(torch.Size([3, 10, 4, 4]),
 torch.Size([3, 10]),
 tensor([[1.0916e-19, 5.0000e-01, 5.0000e-01, 4.9990e-01, 0.0000e+00, 4.9855e-01,
          3.6116e-10, 4.9999e-01, 3.7994e-01, 4.3996e-01],
         [8.2380e-17, 5.0000e-01, 5.0000e-01, 4.9998e-01, 0.0000e+00, 4.9471e-01,
          5.1797e-06, 5.0000e-01, 1.2167e-01, 3.8224e-01],
         [3.1269e-15, 5.0000e-01, 5.0000e-01, 5.0000e-01, 0.0000e+00, 4.9580e-01,
          6.5609e-10, 5.0000e-01, 2.0584e-01, 2.3622e-01]],
        grad_fn=<ViewBackward0>))

In [None]:
    for it in range(FLAGS.iter_routing):  
      # AG 17/09/2018: modified schedule for inverse_temperature (lambda) based
      # on Hinton's response to questions on OpenReview.net: 
      # https://openreview.net/forum?id=HJWLfGWRb
      # "the formula we used for lambda is:
      # lambda = final_lambda * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32)))
      # where 'i' is the routing iteration (range is 0-2). Final_lambda is set 
      # to 0.01."
      # final_lambda = 0.01
      final_lambda = FLAGS.final_lambda
      inverse_temperature = (final_lambda * 
                             (1 - tf.pow(0.95, tf.cast(it + 1, tf.float32))))


                                     lambd = self.final_lambda * (1. - 0.95**(1 + it))

In [104]:
q = torch.randn(1, 14, 14,32*17)
B = 32
K = 3
psize = 16
stride = 2

q, oh, ow = add_pathes(x=q, B=B, K=K, psize=psize, stride=stride)

q.shape, oh, ow

NameError: name 'add_pathes' is not defined

In [None]:
r = z.view(z.shape[0],-1,z.shape[2],z.shape[3])
r.shape



torch.Size([1, 16384, 6, 6])

In [None]:
c_out = 16 * 4 * 4
CV =nn.Conv2d(in_channels= r.shape[1], out_channels=c_out, kernel_size=(3,3), stride=2, padding=0,
 dilation=1, groups=32, bias=True, padding_mode='zeros', device=None, dtype=None)


print(count_parameters(CV))
cv = CV(r).view(1,-1,6,6,4,4).shape



1179904


RuntimeError: shape '[1, -1, 6, 6, 4, 4]' is invalid for input of size 1024

In [None]:
#test view
dd = torch.arange(0,512,1)
dd = dd.reshape(-1,4,4)
dd.view(-1)

cc = dd.clone()
#cc =cc.reshape(4,4,-1)
cc = cc.permute(1,2,0)
cc.reshape(-1)

tensor([  0,  16,  32,  48,  64,  80,  96, 112, 128, 144, 160, 176, 192, 208,
        224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432,
        448, 464, 480, 496,   1,  17,  33,  49,  65,  81,  97, 113, 129, 145,
        161, 177, 193, 209, 225, 241, 257, 273, 289, 305, 321, 337, 353, 369,
        385, 401, 417, 433, 449, 465, 481, 497,   2,  18,  34,  50,  66,  82,
         98, 114, 130, 146, 162, 178, 194, 210, 226, 242, 258, 274, 290, 306,
        322, 338, 354, 370, 386, 402, 418, 434, 450, 466, 482, 498,   3,  19,
         35,  51,  67,  83,  99, 115, 131, 147, 163, 179, 195, 211, 227, 243,
        259, 275, 291, 307, 323, 339, 355, 371, 387, 403, 419, 435, 451, 467,
        483, 499,   4,  20,  36,  52,  68,  84, 100, 116, 132, 148, 164, 180,
        196, 212, 228, 244, 260, 276, 292, 308, 324, 340, 356, 372, 388, 404,
        420, 436, 452, 468, 484, 500,   5,  21,  37,  53,  69,  85, 101, 117,
        133, 149, 165, 181, 197, 213, 229, 245, 261, 277, 293, 3

In [None]:
#test view
dd = torch.rand(1,32,4,4,6,6)
dd = dd.reshape(1,32,4)
#dd.view(-1)
dd.shape
#cc = dd.clone()
#cc =cc.reshape(4,4,-1)
#cc = cc.permute(1,2,0)
#cc.reshape(-1)

RuntimeError: shape '[1, 32, 4]' is invalid for input of size 18432

In [None]:
2359808
147968
74240

4609

print(32*32*17)
print(9*32*32*16)
print(9*32*32*16)
print(32*10*16)
print(32*32*17 + 9*32*32*16 + 9*32*32*16+ 32*10*16)

17408
147456
147456
5120
317440


In [None]:
q = torch.randn(24, 288, 32, 4, 4)
j = torch.randn(24, 288, 32, 4, 4)
torch.matmul(q,j).shape


torch.Size([24, 288, 32, 4, 4])

In [None]:
torch.manual_seed(0)
nb_channels = 10
h, w = 14,14
k = 3
x = torch.randn(5, nb_channels, h, w)
weights = torch.tensor([[0., 0., 0.],
                        [0., 1., 0.],
                        [0., 0., 0.]])
weights = (torch.ones(1,1,k,k)/k**2).repeat(2, 5, 1, 1)
#weights = weights.view(1, 1, 3, 3).repeat(1, nb_channels, 1, 1)

output = F.conv2d(x, weights,groups=2)

output.shape

torch.Size([5, 2, 12, 12])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def fix_conv(x):
    D = nn.Conv2d(in_channels=nb_channels, out_channels=out_channels, kernel_size=k, stride=1, bias=False)
    D.weight = torch.nn.Parameter((torch.ones_like(D.weight)/k**2),requires_grad=False)

    
    x = x.view(bs*p*p, nb_channels,h, w)
    x = D(x)
    x = x.view(-1, x.shape[1], x.shape[2], x.shape[3], p,p)

    return x


In [None]:
torch.manual_seed(0)
nb_channels = 32
out_channels = 32
h, w = 14,14
p = 4
k = 5
bs = 2
x = torch.randn(bs, nb_channels, p, p , h, w)
C = 32

#D = nn.Conv2d(in_channels=nb_channels, out_channels=out_channels, kernel_size=k, stride=1, bias=False)
#D.weight = torch.nn.Parameter((torch.ones_like(D.weight)/k**2),requires_grad=False)


#x = x.view(bs*p*p, nb_channels,h, w)
#x = x.view(bs, nb_channels, p, p , h, w)
#x.shape
#D.weight.shape , D.weight.requires_grad, D.weight.mean(), torch.numel(D.weight)

#x = D(x)
#x = x.view(-1, x.shape[1], x.shape[2], x.shape[3], p,p)

x = fix_conv(x)

x = x.unsqueeze(-1).repeat([1, 1, 1, 1, 1, 1, C])
w = torch.rand([1, 32, x.shape[2], x.shape[3], 4, 4, C])
w = w.repeat([bs, 1, 1, 1, 1, 1, 1])

assert x.shape == w.shape

v = torch.mul(x, w)



v.shape, v.numel(), w.numel()/1e6

(torch.Size([2, 32, 10, 10, 4, 4, 32]), 3276800, 3.2768)

In [None]:
a = torch.rand(10,1).repeat(1,4)
b = torch.rand(10,4)

c = torch.mul(a,b)
c.shape

torch.Size([10, 4])

In [None]:
    def transform_view(x, w, C, P, w_shared=False):
        """
            For conv_caps:
                Input:     (b*H*W, K*K*B, P*P)
                Output:    (b*H*W, K*K*B, C, P*P)
            For class_caps:
                Input:     (b, H*W*B, P*P)
                Output:    (b, H*W*B, C, P*P)
        """
        b, B, psize = x.shape
        assert psize == P*P

        x = x.view(b, B, 1, P, P)
        if w_shared:
            hw = int(B / w.size(1))
            w = w.repeat(1, hw, 1, 1, 1)

        w = w.repeat(b, 1, 1, 1, 1)
        x = x.repeat(1, 1, C, 1, 1)
        v = torch.matmul(x, w)
        v = v.view(b, B, C, P*P)
        print(w.shape, w.numel(), x.shape, x.numel())
        return v

In [None]:
21233664//2359296

9

In [None]:
x = torch.rand(1*12*12,3*3*32,4*4)
w = torch.rand(1,3*3*32,32,4,4)
t = transform_view(x=x, w=w, C=32, P=4)

t.shape, t.numel(), w.numel(), t.view(-1,12,12,32,9,32,16).shape

torch.Size([144, 288, 32, 4, 4]) 21233664 torch.Size([144, 288, 32, 4, 4]) 21233664


(torch.Size([144, 288, 32, 16]),
 21233664,
 147456,
 torch.Size([1, 12, 12, 32, 9, 32, 16]))

In [None]:
#h = torch.arange(0,10,1)
p = torch.arange(0,100,1)
p = p.reshape(-1,10)
h = p.clone()

mu = torch.mul(h,p)
ma = torch.matmul(h,p)
ei = torch.einsum('ij,jk->ik', h, p)

#print(h.shape, h)
#print(p.shape, p)
print(mu.shape, mu)
print(ma.shape, ma)
print(ei.shape, ei)

torch.Size([10, 10]) tensor([[   0,    1,    4,    9,   16,   25,   36,   49,   64,   81],
        [ 100,  121,  144,  169,  196,  225,  256,  289,  324,  361],
        [ 400,  441,  484,  529,  576,  625,  676,  729,  784,  841],
        [ 900,  961, 1024, 1089, 1156, 1225, 1296, 1369, 1444, 1521],
        [1600, 1681, 1764, 1849, 1936, 2025, 2116, 2209, 2304, 2401],
        [2500, 2601, 2704, 2809, 2916, 3025, 3136, 3249, 3364, 3481],
        [3600, 3721, 3844, 3969, 4096, 4225, 4356, 4489, 4624, 4761],
        [4900, 5041, 5184, 5329, 5476, 5625, 5776, 5929, 6084, 6241],
        [6400, 6561, 6724, 6889, 7056, 7225, 7396, 7569, 7744, 7921],
        [8100, 8281, 8464, 8649, 8836, 9025, 9216, 9409, 9604, 9801]])
torch.Size([10, 10]) tensor([[ 2850,  2895,  2940,  2985,  3030,  3075,  3120,  3165,  3210,  3255],
        [ 7350,  7495,  7640,  7785,  7930,  8075,  8220,  8365,  8510,  8655],
        [11850, 12095, 12340, 12585, 12830, 13075, 13320, 13565, 13810, 14055],
        [16350, 1

In [29]:
    def add_coord( v, b, h, w, B, C, psize):
        """
            Shape:
                Input:     (b, H*W*B, C, P*P)
                Output:    (b, H*W*B, C, P*P)
        """
        assert h == w
        v = v.view(b, h, w, B, C, psize)
        coor = torch.arange(h, dtype=torch.float32) / h
        coor_h = torch.FloatTensor(1, h, 1, 1, 1, psize).fill_(0.)      #dev
        coor_w = torch.FloatTensor(1, 1, w, 1, 1, psize).fill_(0.)      #dev
        coor_h[0, :, 0, 0, 0, 0] = coor
        coor_w[0, 0, :, 0, 0, 1] = coor
        v = v + coor_h + coor_w
        v = v.view(b, h*w*B, C, psize)
        return v

In [48]:
a = torch.rand(1,16,1,16)

add = add_coord(v=a, b=1, h=4, w=4, B=1, C=1, psize=16)

delta = add-a
delta.shape

delta[0,:,0,:2]


tensor([[0.0000, 0.0000],
        [0.0000, 0.2500],
        [0.0000, 0.5000],
        [0.0000, 0.7500],
        [0.2500, 0.0000],
        [0.2500, 0.2500],
        [0.2500, 0.5000],
        [0.2500, 0.7500],
        [0.5000, 0.0000],
        [0.5000, 0.2500],
        [0.5000, 0.5000],
        [0.5000, 0.7500],
        [0.7500, 0.0000],
        [0.7500, 0.2500],
        [0.7500, 0.5000],
        [0.7500, 0.7500]])