In [31]:
import torch
from torch import nn

In [16]:
# Let's say I have an image with C channels of size LW
C = 10
L = W = 15
image = torch.rand(C, L, W)

In [17]:
# Now let's say I want to do a convolution over the image
# with a single output channel and a kernel of size 3.
# The weight will be of shape (1, C, 3, 3). This is a single
# kernel which performs a convolution over all of C
conv = nn.Conv2d(C, 1, 3)
conv.weight.shape

torch.Size([1, 10, 3, 3])

In [19]:
# The output will be of shape (1, L-2, W-2)
conv(image).shape

torch.Size([1, 13, 13])

In [23]:
# What if I now have a kernel that takes in C/2 and we split the
# inputs up into a batch size of 2 by slicing the number of channels
# into 2? So two images of shape (C/2, L, W)
conv2 = nn.Conv2d(C//2, 1, 3)
conv2.weight.shape

torch.Size([1, 5, 3, 3])

In [30]:
# The weight has half the size now reducing the parameters by 1/2. Now,
# the output will be of shape (2, L-2, W-2) where the convolution
# convolvs the first half and second half
image2 = torch.stack((image[:5], image[5:]), dim=0)
conv2(image2).squeeze().shape

torch.Size([2, 13, 13])

In [60]:
# What if we make a convoltuion that only looks at a subset of the filters?
# For now, the conv size if just 1
class Sub_Conv(nn.Module):
    def __init__(self, inCh, kernel, sub):
        super(Sub_Conv, self).__init__()
        
        assert sub <= inCh
        self.sub = sub
        
        self.conv = nn.Conv2d(sub, 1, kernel)
        
    def forward(self, X):
        # Take the subset
        if len(X.shape) == 4:
            X = X[:, :self.sub]
        elif len(X.shape) == 3:
            X = X[:self.sub]
        else:
            raise NotImplementedError
        return self.conv(X)

In [45]:
sub_conv = Sub_Conv(C, 3, 4)

In [46]:
image.shape

torch.Size([10, 15, 15])

In [47]:
sub_conv(image).shape

torch.Size([1, 13, 13])

In [82]:
# Now let's try this with multiple output channels and a randomized
# subset of inputs channels
class Rand_Conv(nn.Module):
    def __init__(self, inCh, outCh, kernel, sub):
        super(Rand_Conv, self).__init__()
        
        assert sub <= inCh
        self.sub = sub
        
        self.convs = nn.ParameterList([nn.Conv2d(sub, 1, kernel) for i in range(0, outCh)])
        
    def forward(self, X):
        if len(X.shape) == 3:
            X = X.unsqueeze(0)
        
        # Output tensor
        out = []
        
        # Iterate over all convolutions
        for c in self.convs:
            # Get a random subset
            idx = torch.randperm(X.shape[1])
            
            # Convolution
            out.append(c(X[:, idx][:, :self.sub]).squeeze())
        
        # Stack the output and return it
        return torch.stack(out).permute(1, 0, 2, 3)

In [83]:
sub_conv2 = Sub_Conv2(10, 4, 3, 4)

In [84]:
image.shape

torch.Size([10, 15, 15])

In [85]:
sub_conv2(image).shape

torch.Size([1, 4, 13, 13])