In [1]:
import torch
import torch.nn as nn
import torch.nn._functions as tnnf
import torchvision
import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
class FSFunc(torch.autograd.Function):
    def __init__(self, S1, S2, C, x, y, z, K1, K2, K3):
        self.C = C
        self.S1 = S1
        self.S2 = S2
        self.x = x
        self.y = y
        self.z = z
        self.K1 = K1
        self.K2 = K2
        self.K3 = K3
        self.K = K1*K2*K3
        
        self.FS = nn.init.kaiming_normal_(torch.zeros((C, K1*x, K2*y)))
        
        self.grad_index = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2))
        
        self.conv_real = nn.Conv2d(in_channels=C, out_channels=self.K, kernel_size=(S1, S2))
        
    def forward(self, input):
        """
        input: (N, C, H, W)
        """
        conv_weight = torch.zeros((self.K1*self.K2*self.K3, self.C, self.S1, self.S2))
        combinations = [(i, j, k) for i in range(self.K1) 
                        for j in range(self.K2) 
                        for k in range(self.K3)]
        FS_extend = torch.cat([self.FS, self.FS], dim=0)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=1)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=2)
        
        for (k1, k2, k3) in combinations:
            conv_weight[k1 + k2 * self.K1 + k3 * self.K1 * self.K2] = FS_extend[k3*self.z:k3*self.z+self.C, 
                                                            k1*self.x:k1*self.x+self.S1,
                                                           k2*self.y:k2*self.y+self.S2]
            self.grad_index[k3*self.z:k3*self.z+self.C,
                            k1*self.x:k1*self.x+self.S1,
                            k2*self.y:k2*self.y+self.S2] += 1
        c, h, w = self.grad_index.size()
        self.grad_index[:c//2, :, :] += self.grad_index[c//2:, :, :]
        self.grad_index[:, :h//2, :] += self.grad_index[:, h//2:, :]
        self.grad_index[:, :, :w//2] += self.grad_index[:, :, w//2:]
        self.grad_index = self.grad_index[:c//2, :h//2, :w//2]
        conv_weight.requires_grad_()
        
        self.conv_real.weight.data = conv_weight
        return self.conv_real(input)
    
    def backward(self, grad_output):
        print('aaaaa')
        print(self.conv_real.weight.grad)
        grad_input = grad_output.clone()
        print(grad_input.size())
        print(grad_input)
        #print(self.conv_real.backward(grad_input))
        return grad_input
        
class FSMod(nn.Module):
    def __init__(self, S1, S2, C, x, y, z, K1, K2, K3):
        self.fs_func = FSFunc(S1, S2, C, x, y, z, K1, K2, K3)
    def forward(self, input):
        return self.fs_func.apply(input)

In [6]:
sample = FSFunc(S1=3, S2=3, C=64, x=2, y=2, z=16, K1=4, K2=4, K3=4)

In [7]:
input_sample = torch.randn((2,64, 20, 20))
input_sample.requires_grad_()
out = sample(input_sample)

In [8]:
out.sum().backward()

aaaaa
None
torch.Size([2, 64, 18, 18])
tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         ...,

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...

RuntimeError: Function FSFuncLegacyBackward returned an invalid gradient at index 0 - got [2, 64, 18, 18] but expected shape compatible with [2, 64, 20, 20]

In [270]:
nn.init.kaiming_normal_(torch.zeros((1,2)))

torch.Tensor

In [289]:
conv = nn.Conv2d(in_channels=64, out_channels=5, kernel_size=(2,2))
conv.weight = nn.Parameter(torch.randn((5, 64, 2, 2)))
o = conv(input_sample)
o.sum().backward()

In [290]:
conv.weight.grad

tensor([[[[ 32.5298,  30.3631],
          [ 41.9629,  36.2019]],

         [[-18.9460,  -8.0175],
          [-11.9282,  -1.7782]],

         [[ 19.3847,  15.4611],
          [ 15.2742,  14.3128]],

         ...,

         [[ 28.6468,  19.1297],
          [ 23.7953,  12.0854]],

         [[ 42.5016,  51.4581],
          [ 32.3103,  41.4168]],

         [[ 13.0511,  -3.8644],
          [ 19.2382,  -1.6136]]],


        [[[ 32.5298,  30.3631],
          [ 41.9629,  36.2019]],

         [[-18.9460,  -8.0175],
          [-11.9282,  -1.7782]],

         [[ 19.3847,  15.4611],
          [ 15.2742,  14.3128]],

         ...,

         [[ 28.6468,  19.1297],
          [ 23.7953,  12.0854]],

         [[ 42.5016,  51.4581],
          [ 32.3103,  41.4168]],

         [[ 13.0511,  -3.8644],
          [ 19.2382,  -1.6136]]],


        [[[ 32.5298,  30.3631],
          [ 41.9629,  36.2019]],

         [[-18.9460,  -8.0175],
          [-11.9282,  -1.7782]],

         [[ 19.3847,  15.4611],
          

In [292]:
class FSFunc2(torch.autograd.Function):
    def __init__(self, inStride=1, inPad=0, inDil=1, inGroups=1, imgDim=2):
        super(FSFunc2, self).__init__()
        self.conStride = ()
        self.conPad = ()
        self.conDil = ()
        self.conGroups = inGroups
        for k in range(imgDim):
            self.conStride = self.conStride + (inStride,)
            self.conPad = self.conPad + (inPad,)
            self.conDil = self.conDil + (inDil,)

        self.convFct = tnnf.conv.ConvNd(self.conStride, self.conPad, self.conDil,
                                        False, (0, 0), self.conGroups)
    
    def forward(self, inImg, inKernel, inBias=None):
        self.save_for_backward(inImg, inKernel, inBias)
        self.convFct.requires_grad = True
        return self.convFct.forward(inImg, inKernel, inBias)
    
    def backward(self, grad_output):
        inImg, inKernel, inBias = self.saved_tensors
        if inBias != None:
            self.convFct.needs_input_grad = (True, True, True)
        else:
            self.convFct.needs_input_grad = (True, True, False)

        # for surveillance purpose -> time wait
        time.sleep(0)
        gradIn = self.convFct._grad_input(inImg, inKernel, grad_output)
        gradWeight, gradBias = self.convFct._grad_params(inImg, inKernel, inBias, grad_output)
        
        print(gradWeight.shape)
        print(gradWeight)
        return gradIn, gradWeight, gradBias

In [293]:
class FSMod2(nn.Module):
    def __init__(self, S1, S2, C, x, y, z, K1, K2, K3):
        self.fs_func = FSFunc2()
        
        self.C = C
        self.S1 = S1
        self.S2 = S2
        self.x = x
        self.y = y
        self.z = z
        self.K1 = K1
        self.K2 = K2
        self.K3 = K3
        self.K = K1*K2*K3
        
        self.FS = nn.init.kaiming_normal_(torch.zeros((C, K1*x, K2*y)))
        
        self.grad_index = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2))
        
    def forward(self, input):
        """
        input: (N, C, H, W)
        """
        conv_weight = torch.zeros((self.K1*self.K2*self.K3, self.C, self.S1, self.S2))
        combinations = [(i, j, k) for i in range(self.K1) 
                        for j in range(self.K2) 
                        for k in range(self.K3)]
        FS_extend = torch.cat([self.FS, self.FS], dim=0)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=1)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=2)
        
        for (k1, k2, k3) in combinations:
            conv_weight[k1 + k2 * self.K1 + k3 * self.K1 * self.K2] = FS_extend[k3*self.z:k3*self.z+self.C, 
                                                            k1*self.x:k1*self.x+self.S1,
                                                           k2*self.y:k2*self.y+self.S2]
            self.grad_index[k3*self.z:k3*self.z+self.C,
                            k1*self.x:k1*self.x+self.S1,
                            k2*self.y:k2*self.y+self.S2] += 1
        c, h, w = self.grad_index.size()
        self.grad_index[:c//2, :, :] += self.grad_index[c//2:, :, :]
        self.grad_index[:, :h//2, :] += self.grad_index[:, h//2:, :]
        self.grad_index[:, :, :w//2] += self.grad_index[:, :, w//2:]
        self.grad_index = self.grad_index[:c//2, :h//2, :w//2]
        
        out = self.fs_func(input, nn.Parameter(conv_weight, requires_grad=True))
        return out

In [294]:
sample = FSMod2(S1=3, S2=3, C=64, x=2, y=2, z=16, K1=4, K2=4, K3=4)

AttributeError: module 'torch.nn._functions' has no attribute 'conv'

In [299]:
nn.functional.conv2d.

<function _VariableFunctions.conv2d>

In [63]:
class FSMod3(nn.Module):
    def __init__(self, S1, S2, C, x, y, z, K1, K2, K3):
        super(FSMod3, self).__init__()
        self.C = C
        self.S1 = S1
        self.S2 = S2
        self.x = x
        self.y = y
        self.z = z
        self.K1 = K1
        self.K2 = K2
        self.K3 = K3
        self.K = K1*K2*K3
        
        self.FS = nn.Parameter(nn.init.kaiming_normal_(torch.zeros((C, K1*x, K2*y))))
        
        self.grad_index = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2))
        
        self.conv_real = nn.Conv2d(in_channels=C, out_channels=self.K, kernel_size=(S1, S2))
        
        self.backward_hook_handle = self.conv_real.register_backward_hook(self.backward_hook)

        
    def __call__(self, input):
        """
        input: (N, C, H, W)
        """
        conv_weight = torch.zeros((self.K1*self.K2*self.K3, self.C, self.S1, self.S2))
        combinations = [(i, j, k) for i in range(self.K1) 
                        for j in range(self.K2) 
                        for k in range(self.K3)]
        FS_extend = torch.cat([self.FS, self.FS], dim=0)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=1)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=2)
        
        for (k1, k2, k3) in combinations:
            conv_weight[k1 + k2 * self.K1 + k3 * self.K1 * self.K2] = FS_extend[k3*self.z:k3*self.z+self.C, 
                                                            k1*self.x:k1*self.x+self.S1,
                                                           k2*self.y:k2*self.y+self.S2]
            self.grad_index[k3*self.z:k3*self.z+self.C,
                            k1*self.x:k1*self.x+self.S1,
                            k2*self.y:k2*self.y+self.S2] += 1
        c, h, w = self.grad_index.size()
        self.grad_index[:c//2, :, :] += self.grad_index[c//2:, :, :]
        self.grad_index[:, :h//2, :] += self.grad_index[:, h//2:, :]
        self.grad_index[:, :, :w//2] += self.grad_index[:, :, w//2:]
        self.grad_index = self.grad_index[:c//2, :h//2, :w//2]
        
        self.conv_real.weight = nn.Parameter(conv_weight, requires_grad=True)
        
        return self.conv_real(input)
    
    def backward_hook(self, module, grad_input, grad_output):
        '''
        grad_input[1] is the grad of weight of conv_real
        '''
        print(grad_input[0].shape)
        print(grad_input[1])

        print(grad_output[0].shape)
        grad_extend = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2))
        for i, grad in enumerate(grad_input[1]):
            k1 = i%self.K1
            k2 = ((i - k1) // self.K1) % self.K2
            k3 = i // (self.K1 * self.K2)
            grad_extend[k3*self.z:k3*self.z+self.C,
                        k1*self.x:k1*self.x+self.S1,
                        k2*self.y:k2*self.y+self.S2] = grad
        c, h, w = grad_extend.size()
        grad_extend[:c//2, :, :] += grad_extend[c//2:, :, :]
        grad_extend[:, :h//2, :] += grad_extend[:, h//2:, :]
        grad_extend[:, :, :w//2] += grad_extend[:, :, w//2:]
        grad_extend = grad_extend[:c//2, :h//2, :w//2]
        grad_extend = grad_extend/self.grad_index
        self.FS.grad = grad_extend
        

In [64]:
sample_mod = FSMod3(S1=3, S2=3, C=64, x=2, y=2, z=16, K1=4, K2=4, K3=4)

In [69]:
input_sample = torch.randn((2,64, 20, 20))
input_sample.requires_grad_()
out = sample_mod(input_sample)

In [70]:
out.sum().backward()

torch.Size([2, 64, 20, 20])
tensor([[[[  5.4274,  12.8745,  28.2792],
          [  3.3317,   7.0415,  24.3272],
          [  5.8558,   7.4714,  22.1657]],

         [[-36.3913, -31.3278, -22.8151],
          [-28.0632, -19.8239, -12.1555],
          [-45.7838, -37.1138, -26.5622]],

         [[ 19.8825,  24.3742,   7.0627],
          [ 28.0776,  34.5208,  18.4658],
          [ 43.5583,  45.2899,  27.8722]],

         ...,

         [[-26.8156, -23.2839, -11.9284],
          [-26.5238, -19.2960,  -6.4947],
          [-31.0968, -22.3179, -11.5498]],

         [[ 22.1402,  22.3068,  27.6189],
          [  2.9099,   5.8348,  14.9321],
          [ -5.4319,  -0.1535,   4.9789]],

         [[ 15.0959,  21.5191,  16.0882],
          [  4.7585,   9.7723,   6.9737],
          [  0.9108,   8.4358,   6.8235]]],


        [[[  5.4274,  12.8745,  28.2792],
          [  3.3317,   7.0415,  24.3272],
          [  5.8558,   7.4714,  22.1657]],

         [[-36.3913, -31.3278, -22.8151],
          [-28.06

RuntimeError: The size of tensor a (8) must match the size of tensor b (4) at non-singleton dimension 2

In [71]:
sample_mod.conv_real.weight.grad

In [72]:
sample_mod.FS.grad

tensor([[[ 11.2652,  11.5585,   6.7491,  ...,  11.5585,   6.7491,  11.5585],
         [ 16.8650,  16.6484,   9.6257,  ...,  16.6484,   9.6257,  16.6484],
         [  6.5543,   6.5852,   3.6873,  ...,   6.5852,   3.6873,   6.5852],
         ...,
         [ 16.8650,  16.6484,   9.6257,  ...,  16.6484,   9.6257,  16.6484],
         [  6.5543,   6.5852,   3.6873,  ...,   6.5852,   3.6873,   6.5852],
         [ 16.8650,  16.6484,   9.6257,  ...,  16.6484,   9.6257,  16.6484]],

        [[ -6.0082,  -5.6608,  -2.9118,  ...,  -5.6608,  -2.9118,  -5.6608],
         [ -3.6398,  -3.4555,  -1.4560,  ...,  -3.4555,  -1.4560,  -3.4555],
         [ -2.2247,  -2.6110,  -1.0002,  ...,  -2.6110,  -1.0002,  -2.6110],
         ...,
         [ -3.6398,  -3.4555,  -1.4560,  ...,  -3.4555,  -1.4560,  -3.4555],
         [ -2.2247,  -2.6110,  -1.0002,  ...,  -2.6110,  -1.0002,  -2.6110],
         [ -3.6398,  -3.4555,  -1.4560,  ...,  -3.4555,  -1.4560,  -3.4555]],

        [[ -5.7768,  -5.1579,  -2.5486,  ...

In [60]:
[(i, j, k) for i in range(3) 
                        for j in range(4) 
                        for k in range(5)]

[(0, 0, 0),
 (0, 0, 1),
 (0, 0, 2),
 (0, 0, 3),
 (0, 0, 4),
 (0, 1, 0),
 (0, 1, 1),
 (0, 1, 2),
 (0, 1, 3),
 (0, 1, 4),
 (0, 2, 0),
 (0, 2, 1),
 (0, 2, 2),
 (0, 2, 3),
 (0, 2, 4),
 (0, 3, 0),
 (0, 3, 1),
 (0, 3, 2),
 (0, 3, 3),
 (0, 3, 4),
 (1, 0, 0),
 (1, 0, 1),
 (1, 0, 2),
 (1, 0, 3),
 (1, 0, 4),
 (1, 1, 0),
 (1, 1, 1),
 (1, 1, 2),
 (1, 1, 3),
 (1, 1, 4),
 (1, 2, 0),
 (1, 2, 1),
 (1, 2, 2),
 (1, 2, 3),
 (1, 2, 4),
 (1, 3, 0),
 (1, 3, 1),
 (1, 3, 2),
 (1, 3, 3),
 (1, 3, 4),
 (2, 0, 0),
 (2, 0, 1),
 (2, 0, 2),
 (2, 0, 3),
 (2, 0, 4),
 (2, 1, 0),
 (2, 1, 1),
 (2, 1, 2),
 (2, 1, 3),
 (2, 1, 4),
 (2, 2, 0),
 (2, 2, 1),
 (2, 2, 2),
 (2, 2, 3),
 (2, 2, 4),
 (2, 3, 0),
 (2, 3, 1),
 (2, 3, 2),
 (2, 3, 3),
 (2, 3, 4)]