In [1]:
import torch
from torch import nn
import torchvision
from torch.utils.data.dataloader import DataLoader
import math

In [2]:
# Used to transform the data to a transor
transform = torchvision.transforms.Compose(
    [
        # Transform to a tensor
        torchvision.transforms.ToTensor(),
    ]
)

In [15]:
# Now let's try this with multiple output channels and a
# subset of inputs channels
class Sparse_Conv(nn.Module):
    def __init__(self, inCh, outCh, kernel_size, sub_size):
        super(Sparse_Conv, self).__init__()
        
        assert sub_size <= inCh
        self.sub_size = sub_size
        self.inCh = inCh
        self.outCh = outCh
        self.kernel_height = kernel_size[0]
        self.kernel_width = kernel_size[1]
        
        self.convs = nn.ParameterList([nn.Conv2d(self.sub_size, 1, kernel_size) for i in range(0, outCh)])
        self.weights = torch.stack([i.weight for i in self.convs])
        self.biases = torch.stack([i.bias for i in self.convs])
        
    def forward(self, X):
        if len(X.shape) == 3:
            X = X.unsqueeze(0)
            
            
            
        # Get the h/W output
        h = X.shape[-2] - self.kernel_height
        if self.kernel_height % 2 != 0:
            h += 1
        w = X.shape[-1] - self.kernel_width
        if self.kernel_width % 2 != 0:
            w += 1
            
            
            
            
        # Number of desired channels
        desired_channels = self.outCh+self.sub_size-1
        # Number of times to repeat the tensor to get to that goal
        num_repeats = math.ceil(desired_channels/self.inCh)
        # Repeat the image num_repeats times along the channels
        X = X.repeat(1, num_repeats, 1, 1)
        # Slice the rest off that we don't need
        X = X[:, :desired_channels]
        
        

            
        # Pad the image by sub_size-1 along the channels to become (N, C+sub_size-1, L, W)
        # X = torch.nn.functional.pad(input=X.unsqueeze(0), pad=(0,0,0,0,0,self.sub_size-1), mode="circular").squeeze(0)
        
        # Unfold image (batch_size, channels+sub_size-1, windows, kernel_height, kernel_width)
        X = X.unfold(2, self.kernel_height, 1).unfold(3, self.kernel_width, 1)
        X = X.contiguous().view(X.shape[0], X.shape[1], -1, self.kernel_height, self.kernel_width)

        # Let's unfold this tensor to be of shape (batch_size, outCh, windows, kernel_height, kernel_width, sub_size)
        X = X.unfold(1, self.sub_size, 1)

        # Make tensor of shape (batch_size, windows, outCh, sub_size, kernel_height, kernel_width)
        X = X.permute(0, 2, 1, 5, 3, 4)

        # Multiply the patches with the weights in order to calculate the conv (batch_size, outCh, HW)
        X = (X * self.weights.transpose(0, 1).unsqueeze(0)).sum([3, 4, 5]).permute(0, 2, 1)
        
        # Add the biases
        X += self.biases.unsqueeze(0)

        # Reshape to output shape (batch_size, outCh, H, W)
        return X.reshape(X.shape[0], -1, h, w)

In [16]:
# Load in MNIST
MNIST_dataset = torchvision.datasets.MNIST("./", train=True, transform=transform, download=True)

In [17]:
# Used to load in the dataset
data_loader = DataLoader(MNIST_dataset, batch_size=64,
        pin_memory=True, num_workers=1, 
        drop_last=False, shuffle=True
    )

In [52]:
# Model with 1x28x28 input and 10 output
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        # convolution layers
        self.convs = nn.Sequential( # 1x28x28
            Sparse_Conv(1, 32, (3, 3), 1), # 32x26x26
            nn.ReLU(),
            
            Sparse_Conv(32, 64, (3, 3), 5), # 64x24x24
            nn.ReLU(),
            nn.MaxPool2d(3), # 64x8x8
            
            nn.Flatten(1, -1), # 4096
            nn.Linear(4096, 250), # 250
            nn.ReLU(),
            nn.Linear(250, 10),
            nn.Softmax(-1)
        )
        
    def forward(self, X):
        return self.convs(X)

In [53]:
# Create the model
model = Model()

In [54]:
# Optimizer
optim = torch.optim.AdamW(model.parameters())

In [55]:
# Loss function
loss_funct = nn.CrossEntropyLoss()

In [None]:
# Training loop
epochs = 10
steps = 0
for epoch in range(0, epochs):
    # Iterate over all data
    for X,labels in data_loader:
        # Send the data through the model
        y_hat = model(X)
        
        # Get the loss
        loss = loss_funct(y_hat, labels)
        
        # Backprop the loss
        loss.backward()
        
        # Update model
        optim.step()
        optim.zero_grad()
        steps += 1
        print(f"step {steps}: {loss.detach().item()}")

step 1: 2.30289888381958
step 2: 2.298799514770508
step 3: 2.3022468090057373
step 4: 2.291330099105835
step 5: 2.27962589263916
step 6: 2.2782809734344482
step 7: 2.261451005935669
step 8: 2.2630674839019775
step 9: 2.253166437149048
step 10: 2.2540359497070312
step 11: 2.262906074523926
step 12: 2.2558646202087402
step 13: 2.2808828353881836
step 14: 2.2111637592315674
step 15: 2.1603310108184814
step 16: 2.207650899887085
step 17: 2.1825063228607178
step 18: 2.16615629196167
step 19: 2.1009445190429688
step 20: 2.1772756576538086
step 21: 2.134047746658325
step 22: 2.1929731369018555
step 23: 2.1163365840911865
step 24: 2.106511354446411
step 25: 2.0999929904937744
step 26: 2.1082231998443604
step 27: 2.0805699825286865
step 28: 2.06715726852417
step 29: 2.067384958267212
step 30: 2.1241672039031982
step 31: 2.0837721824645996
step 32: 2.1100542545318604
step 33: 2.045210838317871
step 34: 2.0776586532592773
step 35: 2.0383312702178955
step 36: 2.012605667114258
step 37: 1.971793174

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x0000025E92BF9BD0>
Traceback (most recent call last):
  File "C:\users\gabri\appdata\local\miniconda3\lib\site-packages\torch\utils\data\dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "C:\users\gabri\appdata\local\miniconda3\lib\site-packages\torch\utils\data\dataloader.py", line 1430, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "C:\users\gabri\appdata\local\miniconda3\lib\multiprocessing\process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "C:\users\gabri\appdata\local\miniconda3\lib\multiprocessing\popen_spawn_win32.py", line 108, in wait
    res = _winapi.WaitForSingleObject(int(self._handle), msecs)
KeyboardInterrupt: 


KeyboardInterrupt: 