In [37]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
def all_close_flat(a, b):
    return torch.allclose(a.flatten(), b.flatten(), atol=1e-6, rtol=1e-6), (a - b)

def check(a, b, string):
    print(string, a.shape)
    print(len(string) * ' ', b.shape)
    # all_close_flat(a, b)

In [53]:
class MyConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias = True, padding_mode='zeros'):
        super().__init__()
        assert padding_mode in ['zeros', 'circular']
        if padding_mode == 'zeros':
            padding_mode = 'constant'
        assert padding == (kernel_size - 1) // 2
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride=stride
        self.padding = padding
        self.bias = bias
        self.padding_mode = padding_mode

        self.weights = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
        self.biases = nn.Parameter(torch.randn(out_channels))

    def forward(self, x):
        batch_size, in_channels2, width = x.shape
        assert in_channels2 == self.in_channels

        
        x_pad = F.pad(x, (self.padding, self.padding), mode=self.padding_mode)

        patches = x_pad.unsqueeze(2).unfold(3, self.kernel_size, 1)

        patches = patches.contiguous().view(batch_size, self.in_channels, width, self.kernel_size)

        # Shift the windows into the batch dimension using permute
        patches = patches.permute(0, 2, 1, 3) # # batch_size, width, channels, kernel_size 

        # Multiply the patches with the weights in order to calculate the conv
        # out = (patches.unsqueeze(2) * self.weights.unsqueeze(0)).sum([3, 4])
        # patches_unsqueezed = patches.unsqueeze(2) # batch_size, 1, width, channels, kernel_size
        # weights_unsqueezed = self.weights.unsqueeze(0) # 1, out_channels, in_channels, kernel_size
        # out = patches_unsqueezed * weights_unsqueezed
        # print(patches_unsqueezed.shape, weights_unsqueezed.shape, out.shape)
        # torch.Size([2, 10, 1, 4, 3]) torch.Size([1, 6, 4, 3]) torch.Size([2, 10, 6, 4, 3])
        # print(patches.shape, self.weights.shape, out.shape)
        # torch.Size([2, 10, 4, 3]) torch.Size([6, 4, 3]) torch.Size([2, 10, 6, 4, 3])
        # out = out.sum([3, 4]) # batch_size, out_channels, width, channels
        # torch.Size([2, 10, 6])
        # print(out.shape)
        # out = out.permute(0, 2, 1) # batch_size, out_channels, output_pixels
        out = torch.einsum('bwik,oik->bow', patches, self.weights) # (bwik) -> (batch_size, with, in_channels, kernel)

        # Add the bias
        if self.bias:
            out += self.biases.unsqueeze(0).unsqueeze(2)

        out = out.view(batch_size, self.out_channels, width)
        return out
    


batch_size, width = 2, 10
in_channels, out_channels = 4, 6
kernel_size, stride = 3, 1

x = torch.randn(batch_size, in_channels, width)
padding = (kernel_size - 1) // 2

padding_mode = 'circular'
bias = True

# Create conv
conv_torch = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, padding=padding, padding_mode=padding_mode)
out_true = conv_torch(x)

conv_my = MyConv1d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, padding=padding, padding_mode=padding_mode)
conv_my.weights = nn.Parameter(conv_torch.weight)
conv_my.biases = nn.Parameter(conv_torch.bias)
out_my = conv_my(x)

assert torch.allclose(out_true, out_my, atol=1e-6, rtol=1e-6)
# assert torch.allclose(out_true, out_my) # For some reason, this fails. TODO: Find out why
