In [1]:
import torch
from torch import nn
import math

In [2]:
# Create a random image of shape (1, C, N, W)
inCh = 10
outCh = 13
L = 12
W = 13
image = torch.rand(1, inCh, L, W)

In [3]:
# Random convolution weight of shape (outCh, inCh, kernel_height, kernel_width)
kernel_height = 3
kernel_width = 4
conv_weight = torch.rand(outCh, inCh, kernel_height, kernel_width)

In [4]:
# Functional convolution to get the real output for checking
conv_out = torch.nn.functional.conv2d(image, conv_weight)

# Manual convolution

In [5]:
image.shape

torch.Size([1, 10, 12, 13])

In [6]:
# Unfold image
patches = image.unfold(2, kernel_height, 1).unfold(3, kernel_width, 1)

In [7]:
# batch_size, channels, h_windows, w_windows, kernel_height, kernel_width
patches.shape

torch.Size([1, 10, 10, 10, 3, 4])

In [8]:
patches = patches.contiguous().view(patches.shape[0], patches.shape[1], -1, kernel_height, kernel_width)
# batch_size, channels, windows, kernel_height, kernel_width
patches.shape

torch.Size([1, 10, 100, 3, 4])

In [9]:
# Shift the windows into the batch dimension using permute
patches = patches.permute(0, 2, 1, 3, 4)
# batch_size, windows, channels, kernel_height, kernel_width
patches.shape

torch.Size([1, 100, 10, 3, 4])

In [10]:
conv_weight.shape

torch.Size([13, 10, 3, 4])

In [11]:
# Multiply the patches with the weights in order to calculate the conv
result = (patches.unsqueeze(2) * conv_weight.unsqueeze(0).unsqueeze(1)).sum([3, 4, 5])
result = result.permute(0, 2, 1)

In [12]:
result.shape

torch.Size([1, 13, 100])

In [13]:
# Reshape to height and width
h = w = int(result.size(2)**0.5)
result = result.view(result.shape[0], -1, h, w)

In [14]:
torch.all(result.round(decimals=2) == conv_out.round(decimals=2))

tensor(True)

In [15]:
error = (((result-conv_out)**2)**0.5).sum()
error

tensor(0.0048)

### What I want is a tensor that's unrolled in two ways to be of shape (N, C//sub_size, sub_size

In [16]:
sub_size = 4

In [17]:
image.shape

torch.Size([1, 10, 12, 13])

In [18]:
# I think ContinuousPad1d will work.
# If I have a sequence 1,2,3,4,5,6 and I pad by 2, I want 1,2,3,4,5,6,1,2
samp = torch.tensor([1,2,3,4,5,6], dtype=torch.float32)

In [19]:
# This looks like it works
torch.nn.functional.pad(input=samp.unsqueeze(0).unsqueeze(0), pad=(0,2), mode="circular")

tensor([[[1., 2., 3., 4., 5., 6., 1., 2.]]])

In [20]:
image.shape

torch.Size([1, 10, 12, 13])

In [21]:
# I need to pad the image by sub_size-1 along the channels to become (N, C+sub_size-1, L, W)
image_pad = torch.nn.functional.pad(input=image.unsqueeze(0), pad=(0,0,0,0,0,sub_size-1), mode="circular").squeeze(0)

In [22]:
# batch_size, outCh, L, W
image_pad.shape

torch.Size([1, 13, 12, 13])

In [23]:
# Let's make sure the last `sub_size-1` channels are equal to the first `sub_size-1` channels
assert torch.all(image_pad[:, :sub_size-1] == image_pad[:, image_pad.shape[1]-sub_size+1:])
# and the the first channels are equal to the first channels in the original image
assert torch.all(image_pad[:, :sub_size-1] == image[:, :sub_size-1])

In [63]:
# Let's find out what the weight should look like.
# We want outCh number of kernels which take in sub_size and output 1
# which will be stacked to get a shape of outCh

In [64]:
# Convolutions
convs = [nn.Conv2d(sub_size, 1, (kernel_height, kernel_width)) for i in range(0, outCh)]

In [65]:
# Let's get all the weights
weight = torch.stack([i.weight for i in convs])

In [66]:
# What's the shape?
weight.shape

torch.Size([13, 1, 4, 3, 4])

In [67]:
new_conv_weight = weight.clone()

In [68]:
#### Shape is (outCh, 1, sub_size, kernel_height, kernel_width)

In [69]:
# This total size should be a lot less than the original
print(f"Old: {torch.prod(torch.tensor(conv_weight.shape)).item()}")
print(f"New: {torch.prod(torch.tensor(weight.shape)).item()}")
print(f"New/Old: {torch.prod(torch.tensor(conv_weight.shape)).item()/torch.prod(torch.tensor(weight.shape)).item()}")

Old: 1560
New: 624
New/Old: 2.5


In [70]:
# This is the expected value as that's the number of times sub_size goes into inCh
print(f"outCh/Sub size = {inCh/sub_size}")

outCh/Sub size = 2.5


In [71]:
# First, we need to pad the original image to have (outCh-sub_size-1) channels (N, outC+sub_size-1, L, W)
# When it unrolls, there will be outCh number of images with sub_size channels each

# Number of desired channels
desired_channels = outCh+sub_size-1
# Number of times to repeat the tensor to get to that goal
num_repeats = math.ceil(desired_channels/inCh)
# Repeat the image num_repeats times along the channels
image_pad = image.repeat(1, num_repeats, 1, 1)
# Slice the rest off that we don't need
image_pad = image_pad[:, :desired_channels]
# image_pad = torch.nn.functional.pad(input=image.unsqueeze(0), pad=(0,0,0,0,0,outCh+sub_size-1), mode="circular").squeeze(0)

In [72]:
image_pad.shape

torch.Size([1, 16, 12, 13])

In [73]:
# Unfold image
patches = image_pad.unfold(2, kernel_height, 1).unfold(3, kernel_width, 1)
# batch_size, channels+sub_size-1, h_windows, w_windows, kernel_height, kernel_width
patches.shape

torch.Size([1, 16, 10, 10, 3, 4])

In [74]:
patches = patches.contiguous().view(patches.shape[0], patches.shape[1], -1, kernel_height, kernel_width)
# batch_size, channels+sub_size-1, windows, kernel_height, kernel_width
patches.shape

torch.Size([1, 16, 100, 3, 4])

In [75]:
# Let's unfold this tensor to be of shape (batch_size, windows, outCh, sub_size, kernel_height, kernel_width)
patches = patches.unfold(1, sub_size, 1)
# We now have a tensor of shape (batch_size, outCh, kernel_height, kernel_width, sub_size)
patches = patches.view(patches.shape[0], patches.shape[1], sub_size, -1, kernel_height, kernel_width)
# Now we have one of shape (batch_size, outCh, sub_size, windows, kernel_height, kernel_width)
patches.shape

torch.Size([1, 13, 4, 100, 3, 4])

In [76]:
# Shift the windows into the batch dimension using permute
patches = patches.permute(0, 3, 1, 2, 4, 5)
# batch_size, windows, outCh, sub_size, kernel_height, kernel_width
patches.shape

torch.Size([1, 100, 13, 4, 3, 4])

In [77]:
# Our weight shape (outCh, 1, sub_size, kernel_height, kernel_width)
new_conv_weight.shape

torch.Size([13, 1, 4, 3, 4])

In [78]:
(patches.unsqueeze(2) * new_conv_weight.unsqueeze(0).unsqueeze(1)).shape

torch.Size([1, 100, 13, 13, 4, 3, 4])

In [79]:
patches.shape  # batch_size, windows, inCh, sub_size, kernel_height, kernel_width

torch.Size([1, 100, 13, 4, 3, 4])

In [80]:
new_conv_weight.shape # outCh, 1, sub_size, kernel_height, kernel_width

torch.Size([13, 1, 4, 3, 4])

In [81]:
patches.shape

torch.Size([1, 100, 13, 4, 3, 4])

In [82]:
new_conv_weight.transpose(0, 1).unsqueeze(0).shape

torch.Size([1, 1, 13, 4, 3, 4])

In [83]:
# Multiply the patches with the weights in order to calculate the conv
result = (patches * new_conv_weight.transpose(0, 1).unsqueeze(0)).sum([3, 4, 5])
result = result.permute(0, 2, 1)
# batch_size, outCh, HW
result.shape

torch.Size([1, 13, 100])

In [84]:
# Reshape to height and width
h = w = int(result.size(2)**0.5)
result = result.view(result.shape[0], -1, h, w)
result.shape

torch.Size([1, 13, 10, 10])

# Custom convolution class

In [104]:
# Now let's try this with multiple output channels and a
# subset of inputs channels
class Sparse_Conv_old(nn.Module):
    def __init__(self, inCh, outCh, kernel, sub):
        super(Sparse_Conv_old, self).__init__()
        
        assert sub <= inCh
        self.sub = sub
        self.inCh = inCh
        self.outCh = outCh
        
        # self.convs = nn.ParameterList([nn.Conv2d(sub, 1, kernel) for i in range(0, outCh)])
        
    def forward(self, X):
        if len(X.shape) == 3:
            X = X.unsqueeze(0)
        
        # Output tensor
        out = []
        
        # Iterate over all convolutions
        for i, c in enumerate(self.convs):
            # Current subset
            lower = math.floor((self.inCh/self.outCh)*i)
            upper = min(self.inCh, lower+self.sub)
            extra = max(0, lower+self.sub-self.inCh)
            print(lower, upper, extra)
            
            # Convolution
            if extra == 0:
                sub = X[:, lower:upper]
            else:
                sub = torch.cat((X[:, lower:upper], X[:, 0:extra]), dim=1)
            out.append(c(sub).squeeze(1))
        
        # Stack the output and return it
        return torch.stack(out).permute(1, 0, 2, 3)

In [135]:
# Now let's try this with multiple output channels and a
# subset of inputs channels
class Sparse_Conv_new(nn.Module):
    def __init__(self, inCh, outCh, kernel_size, sub_size):
        super(Sparse_Conv_new, self).__init__()
        
        assert sub_size <= inCh
        self.sub_size = sub_size
        self.inCh = inCh
        self.outCh = outCh
        self.kernel_height = kernel_size[0]
        self.kernel_width = kernel_size[1]
        
        self.convs = nn.ParameterList([nn.Conv2d(self.sub_size, 1, kernel_size) for i in range(0, outCh)])
        self.weights = torch.stack([i.weight for i in self.convs])
        self.biases = torch.stack([i.bias for i in self.convs])
        
    def forward(self, X):
        if len(X.shape) == 3:
            X = X.unsqueeze(0)
            
            
            
        # Get the h/W output
        h = X.shape[-2] - self.kernel_height
        if self.kernel_height % 2 != 0:
            h += 1
        w = X.shape[-1] - self.kernel_width
        if self.kernel_width % 2 != 0:
            w += 1
            
            
            
            
            
        # Number of desired channels
        desired_channels = self.outCh+self.sub_size-1
        # Number of times to repeat the tensor to get to that goal
        num_repeats = math.ceil(desired_channels/self.inCh)
        # Repeat the image num_repeats times along the channels
        X = X.repeat(1, num_repeats, 1, 1)
        # Slice the rest off that we don't need
        X = X[:, :desired_channels]
        
        

            
        # Pad the image by sub_size-1 along the channels to become (N, C+sub_size-1, L, W)
        # X = torch.nn.functional.pad(input=X.unsqueeze(0), pad=(0,0,0,0,0,self.sub_size-1), mode="circular").squeeze(0)
        
        # Unfold image (batch_size, channels+sub_size-1, windows, kernel_height, kernel_width)
        X = X.unfold(2, self.kernel_height, 1).unfold(3, self.kernel_width, 1)
        X = X.contiguous().view(X.shape[0], X.shape[1], -1, self.kernel_height, self.kernel_width)

        # Let's unfold this tensor to be of shape (batch_size, outCh, windows, kernel_height, kernel_width, sub_size)
        X = X.unfold(1, self.sub_size, 1)

        # Make tensor of shape (batch_size, windows, outCh, sub_size, kernel_height, kernel_width)
        X = X.permute(0, 2, 1, 5, 3, 4)

        # Multiply the patches with the weights in order to calculate the conv (batch_size, outCh, HW)
        X = (X * self.weights.transpose(0, 1).unsqueeze(0)).sum([3, 4, 5]).permute(0, 2, 1)
        
        # Add the biases
        X += self.biases.unsqueeze(0)

        # Reshape to output shape (batch_size, outCh, H, W)
        return X.reshape(X.shape[0], -1, h, w)

In [106]:
# New conv - should be fast
new_conv = Sparse_Conv_new(inCh, 1, (3, 3), sub_size) 

In [107]:
# Old conv - should be slow
old_conv = Sparse_Conv_old(inCh, 1, (3, 3), sub_size)
old_conv.convs = new_conv.convs

In [108]:
# Let's get the output for both
new_out = new_conv(image)
old_out = old_conv(image)

0 4 0


In [109]:
new_out.shape

torch.Size([1, 1, 10, 11])

In [110]:
old_out.shape

torch.Size([1, 1, 10, 11])

In [111]:
new_out[0,0]

tensor([[-0.1031,  0.0841, -0.2583, -0.3028, -0.5479, -0.5996, -0.2871, -0.4466,
         -0.1664, -0.4724, -0.1575],
        [-0.5450, -0.7332, -0.3128, -0.2300, -0.3033, -0.2435, -0.5466, -0.0620,
         -0.1484, -0.2541, -0.2863],
        [-0.3936, -0.2299, -0.0468, -0.0956, -0.2990, -0.3370, -0.3357, -0.2347,
         -0.1475, -0.0404, -0.4263],
        [-0.5244, -0.4245, -0.3017, -0.6112, -0.1812, -0.4746, -0.2088, -0.2968,
         -0.6251, -0.1962, -0.3898],
        [-0.3357, -0.1350, -0.2825, -0.3889, -0.3026, -0.3162, -0.3972, -0.1892,
         -0.3293,  0.0150, -0.2099],
        [ 0.0513, -0.5709, -0.3394, -0.4296, -0.2635, -0.0886, -0.1979, -0.3455,
         -0.2843, -0.2960, -0.1861],
        [-0.3702, -0.5569, -0.2921, -0.0489, -0.5124, -0.4308, -0.3535, -0.3309,
         -0.2642, -0.1549, -0.3898],
        [-0.1592, -0.1268, -0.2140, -0.5149, -0.6209, -0.5350, -0.2412, -0.3363,
         -0.3722, -0.5017, -0.3613],
        [-0.0909, -0.2174, -0.5810, -0.6015, -0.3517, -0

In [112]:
old_out[0,0]

tensor([[-0.1031,  0.0841, -0.2583, -0.3028, -0.5479, -0.5996, -0.2871, -0.4466,
         -0.1664, -0.4724, -0.1575],
        [-0.5450, -0.7332, -0.3128, -0.2300, -0.3033, -0.2435, -0.5466, -0.0620,
         -0.1484, -0.2541, -0.2863],
        [-0.3936, -0.2299, -0.0468, -0.0956, -0.2990, -0.3370, -0.3357, -0.2347,
         -0.1475, -0.0404, -0.4263],
        [-0.5244, -0.4245, -0.3017, -0.6112, -0.1812, -0.4746, -0.2088, -0.2968,
         -0.6251, -0.1962, -0.3898],
        [-0.3357, -0.1350, -0.2825, -0.3889, -0.3026, -0.3162, -0.3972, -0.1892,
         -0.3293,  0.0150, -0.2099],
        [ 0.0513, -0.5709, -0.3394, -0.4296, -0.2635, -0.0886, -0.1979, -0.3455,
         -0.2843, -0.2960, -0.1861],
        [-0.3702, -0.5569, -0.2921, -0.0489, -0.5124, -0.4308, -0.3535, -0.3309,
         -0.2642, -0.1549, -0.3898],
        [-0.1592, -0.1268, -0.2140, -0.5149, -0.6209, -0.5350, -0.2412, -0.3363,
         -0.3722, -0.5017, -0.3613],
        [-0.0909, -0.2174, -0.5810, -0.6015, -0.3517, -0

In [113]:
# Nice! Looks like it's working for less channels. What out more?

In [114]:
# Trying more than 1 channel

In [139]:
# New conv - should be fast
new_conv = Sparse_Conv_new(inCh, 13, (3, 3), sub_size) 

In [140]:
# Old conv - should be slow
old_conv = Sparse_Conv_old(inCh, 13, (3, 3), sub_size)
old_conv.convs = new_conv.convs

In [141]:
# Let's get the output for both
new_out = new_conv(image)
old_out = old_conv(image)

0 4 0
0 4 0
1 5 0
2 6 0
3 7 0
3 7 0
4 8 0
5 9 0
6 10 0
6 10 0
7 10 1
8 10 2
9 10 3


In [142]:
new_out.shape

torch.Size([1, 13, 10, 11])

In [143]:
old_out.shape

torch.Size([1, 13, 10, 11])

In [144]:
new_out[0,0,0]

tensor([-0.5026, -0.5090, -0.7384, -0.7571, -0.4243, -0.4055, -0.2884, -0.6291,
        -0.3119, -0.2803, -0.6103], grad_fn=<SelectBackward0>)

In [145]:
old_out[0,0,0]

tensor([-0.5026, -0.5090, -0.7384, -0.7571, -0.4243, -0.4055, -0.2884, -0.6291,
        -0.3119, -0.2803, -0.6103], grad_fn=<SelectBackward0>)

In [146]:
# Note that the 1st convolution in old_out is the inCh+1th one in new_out

In [147]:
new_out[0,1,0]

tensor([ 0.3534, -0.0025,  0.2519,  0.1554,  0.1382, -0.2893, -0.1240, -0.0676,
        -0.0118,  0.4416, -0.0767], grad_fn=<SelectBackward0>)

In [148]:
old_out[0,1,0]

tensor([ 0.1063, -0.0421,  0.0628,  0.0503, -0.2039, -0.1631,  0.1163,  0.2855,
         0.0290, -0.1114,  0.0829], grad_fn=<SelectBackward0>)

In [570]:
# EH whatever good enough :/

In [149]:
new_conv_ = Sparse_Conv_new(inCh, 1, (3, 3), sub_size) 

In [153]:
new_conv_(image[: :1])

tensor([[[[ 0.1063, -0.0421,  0.0628,  0.0503, -0.2039, -0.1631,  0.1163,
            0.2855,  0.0290, -0.1114,  0.0829],
          [ 0.0875,  0.0832, -0.0584,  0.0506,  0.0688,  0.1033,  0.2276,
           -0.1613,  0.1657,  0.2156, -0.3551],
          [-0.1080, -0.1465, -0.2105, -0.4055,  0.3021, -0.0491, -0.0357,
           -0.0458,  0.0398,  0.1756,  0.1060],
          [-0.0041,  0.1989,  0.2022,  0.1812, -0.0950, -0.0478,  0.0589,
            0.1771, -0.2062, -0.3233, -0.2548],
          [-0.0601, -0.1583,  0.0608,  0.2234, -0.0060, -0.0835,  0.2972,
           -0.0902, -0.1343,  0.0225,  0.0680],
          [-0.1732, -0.0916,  0.0513, -0.0832,  0.1527,  0.2315,  0.3204,
           -0.3288, -0.2228,  0.0451, -0.0672],
          [ 0.0706,  0.1927,  0.2232, -0.2133,  0.0987,  0.0525, -0.2881,
           -0.3337,  0.0343,  0.3164,  0.3772],
          [-0.1197, -0.0913, -0.0454, -0.0999, -0.0718, -0.0934, -0.0471,
            0.2693,  0.0440, -0.0263, -0.0289],
          [-0.0039,  0.1