In [22]:
import torch
import torch.nn as nn

class ConvWithoutSharing(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super(ConvWithoutSharing, self).__init__()

        # Calculate the size of the output and the number of weights needed
        self.kernel_size = kernel_size
        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.output_size = (kernel_size * kernel_size * in_channels, out_channels)
        self.fc = nn.Linear(*self.output_size)

    def forward(self, x: torch.Tensor):
        # Use the unfold method to create a view of the input where each receptive field becomes a row in a 2D matrix
        unfolded_x = x.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
        unfolded_x = unfolded_x.contiguous().view(-1, self.kernel_size * self.kernel_size * self.in_channels)

        # Apply the fully connected layer
        output = self.fc(unfolded_x)
        output = output.view(-1, self.out_channels, x.shape[2] - self.kernel_size + 1, x.shape[3] - self.kernel_size + 1)
        
        return output

In [29]:
kernel_size = 2
in_channels = 1
out_channels = 2
stride = 1
output_size = (kernel_size *kernel_size * in_channels, out_channels)
print("weights of dims: ", output_size)
weights = nn.Linear(*output_size)

batch_size = 1
input_dim = 5
example_tensor = torch.randn(batch_size, in_channels, 4, 4)
unfolded = example_tensor.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
# [batch_size, in_channels, kernel_size, kernel_size, output_size, output_size]
print("unfolded dims: ", unfolded.shape)
unfolded = unfolded.contiguous().view(-1, kernel_size * kernel_size * in_channels)
# [flattened other dims, flattened output]
print("unfolded dims: ", unfolded.shape)

weights of dims:  (4, 2)
unfolded dims:  torch.Size([1, 1, 3, 3, 2, 2])
unfolded dims:  torch.Size([9, 4])


In [26]:
in_channels = 2
out_channels = 3
kernel_size = 2
stride = 1
c = ConvWithoutSharing(in_channels, out_channels, kernel_size, stride)

# [batch, in_channels, height, width]
example_tensor = torch.randn(1, 2, 5, 5)
out = c.forward(example_tensor)
# [batch, out channels, fittable, fittable]
print(out.shape)
print(out)

torch.Size([1, 3, 4, 4])
tensor([[[[-0.9946, -0.4071, -0.9955, -0.4921],
          [ 0.3420, -0.9385, -0.3527,  0.8769],
          [-0.1706,  0.0570,  0.4960,  0.0404],
          [-0.3765,  0.3728,  0.3542,  0.2269]],

         [[-0.1993, -0.5132,  0.2181,  0.6789],
          [ 0.6060, -0.7805, -0.4876, -1.1001],
          [ 0.8882,  0.7053, -0.4613,  0.7275],
          [ 0.6015,  0.5583,  0.2537, -1.0038]],

         [[-0.4446,  0.1035, -0.4651, -0.0029],
          [-0.1907, -1.0302, -1.1328,  0.1075],
          [-0.0401,  0.4906, -0.9713, -0.8414],
          [-0.4314,  0.5927,  0.4548,  0.4047]]]], grad_fn=<ViewBackward0>)


In [None]:
# TODO what is peer normalization?