In [2]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

In [153]:
class FSNet(torch.autograd.Function):
    def __init__(self, S1, S2, C, x, y, z, K1, K2, K3):
        """
        
        """
        self.C = C
        self.S1 = S1
        self.S2 = S2
        self.x = x
        self.y = y
        self.z = z
        self.K1 = K1
        self.K2 = K2
        self.K3 = K3
        self.K = K1*K2*K3
        
        self.FS = nn.init.kaiming_normal_(torch.zeros((C, K1*x, K2*y)))
        
        self.grad_index = torch.zeros((self.FS.size()[0]*2, self.FS.size()[1]*2, self.FS.size()[2]*2))
        
        self.conv_real = nn.Conv2d(in_channels=C, out_channels=self.K, kernel_size=(S1, S2))
        
    def forward(self, input):
        """
        input: (N, C, H, W)
        """
        conv_weight = torch.zeros((self.K1*self.K2*self.K3, self.C, self.S1, self.S2))
        combinations = [(i, j, k) for i in range(self.K1) 
                        for j in range(self.K2) 
                        for k in range(self.K3)]
        FS_extend = torch.cat([self.FS, self.FS], dim=0)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=1)
        FS_extend = torch.cat([FS_extend, FS_extend], dim=2)
        self.FS_extend = FS_extend
        
        for (k1, k2, k3) in combinations:
            conv_weight[k1 + k2 * self.K1 + k3 * self.K1 * self.K2] = FS_extend[k3*self.z:k3*self.z+self.C, 
                                                            k1*self.x:k1*self.x+self.S1,
                                                           k2*self.y:k2*self.y+self.S2]
            self.grad_index[k3*self.z:k3*self.z+self.C,
                            k1*self.x:k1*self.x+self.S1,
                            k2*self.y:k2*self.y+self.S2] += 1
        c, h, w = self.grad_index.size()
        self.grad_index[:c//2, :, :] += self.grad_index[c//2:, :, :]
        self.grad_index[:, :h//2, :] += self.grad_index[:, h//2:, :]
        self.grad_index[:, :, :w//2] += self.grad_index[:, :, w//2:]
        self.grad_index = self.grad_index[:c//2, :h//2, :w//2]
        print(conv_weight.size())
        
        self.conv_real.weight.data = conv_weight

In [154]:
sample = FSNet(S1=3, S2=3, C=64, x=2, y=2, z=16, K1=4, K2=4, K3=4)

In [155]:
sample.forward(None)

torch.Size([64, 64, 3, 3])


In [159]:
torch.sum(sample.conv_real.weight.data[:, :,:,:] == 0)

tensor(0)

In [151]:
torch.sum(sample.FS_extend == 0)

tensor(0)