In [10]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
def all_close_flat(a, b):
    return torch.allclose(a.flatten(), b.flatten(), atol=1e-6, rtol=1e-6), (a - b)

def check(a, b, string):
    print(string, a.shape)
    print(len(string) * ' ', b.shape)
    # all_close_flat(a, b)

In [165]:
class MyConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias = True, padding_mode='zeros'):
        super().__init__()
        assert padding_mode in ['zeros', 'circular']
        if padding_mode == 'zeros':
            padding_mode = 'constant'
        assert padding == (kernel_size - 1) // 2
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride=stride
        self.padding = padding
        self.bias = bias
        self.padding_mode = padding_mode

        self.weights = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
        self.biases = nn.Parameter(torch.randn(out_channels))

    def einsum_alt(self, patches, weights):
        patches = patches.permute(0, 2, 1, 3) # # batch_size, width, channels, kernel_size 
        patches_unsqueezed = patches.unsqueeze(2) # batch_size, 1, width, channels, kernel_size
        weights_unsqueezed = weights.unsqueeze(0) # 1, out_channels, in_channels, kernel_size
        out = patches_unsqueezed * weights_unsqueezed
        out = out.sum([3, 4]) # batch_size, out_channels, width, channels
        out = out.permute(0, 2, 1) # batch_size, out_channels, output_pixels
        return out

    def forward(self, x):
        batch_size, in_channels2, width = x.shape
        assert in_channels2 == self.in_channels

        
        x_pad = F.pad(x, (self.padding, self.padding), mode=self.padding_mode)
        patches = x_pad.unsqueeze(2).unfold(3, self.kernel_size, 1)

        patches = patches.contiguous().view(batch_size, self.in_channels, width, self.kernel_size)

        self.patches = patches

        

        out = torch.einsum('biwk,oik->bow', patches, self.weights) # (biwk) -> (batch_size, in_channels, width, kernel)
        out2 = self.einsum_alt(patches, self.weights)
        assert torch.allclose(out, out2)#, atol=1e-6, rtol=1e-6)

        # Add the bias
        if self.bias:
            out += self.biases.unsqueeze(0).unsqueeze(2)

        return out
    


batch_size, width = 1, 5
in_channels, out_channels = 1, 1
kernel_size, stride = 3, 1

x = torch.randn(batch_size, in_channels, width)
padding = (kernel_size - 1) // 2

padding_mode = 'circular'
bias = False

# Create conv
conv_torch = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, padding=padding, padding_mode=padding_mode)
out_true = conv_torch(x)

conv = MyConv1d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, padding=padding, padding_mode=padding_mode)
conv.weights = nn.Parameter(conv_torch.weight)
conv.biases = nn.Parameter(conv_torch.bias)
out = conv(x)

assert torch.allclose(out_true, out, atol=1e-6, rtol=1e-6)
assert torch.allclose(out_true, out) # For some reason, this fails. TODO: Find out why


In [166]:
p = conv.patches
w = conv.weights

patches = p.permute(0, 2, 1, 3) # # batch_size, width, channels, kernel_size 
patches = patches.unsqueeze(2) # batch_size, 1, width, channels, kernel_size
weigths = w.unsqueeze(0) # 1, out_channels, in_channels, kernel_size
out = patches * weigths
out = out.sum([3, 4]) # batch_size, out_channels, width, channels
out = out.permute(0, 2, 1) # batch_size, out_channels, output_pixels
out2 = out

p2 = patches[0,0,0]
w2 = weigths[0,0]
o2 = (p2 * w2).sum()

o = torch.einsum('biwk,oik->bow', p, w) # (biwk) -> (batch_size, in_channels, width, kernel)
assert o.shape == out_true.shape == out2.shape
assert torch.allclose(out_true[0,0,0], o[0,0,0])
assert torch.allclose(out_true[0,0,0], out2[0,0,0])
assert torch.allclose(out_true[0,0,0], o2)
o.shape, o[0,0,0], o2

(torch.Size([1, 1, 5]),
 tensor(0.8512, grad_fn=<SelectBackward0>),
 tensor(0.8512, grad_fn=<SumBackward0>))

In [167]:
p.shape, w.shape, o.shape

(torch.Size([1, 1, 5, 3]), torch.Size([1, 1, 3]), torch.Size([1, 1, 5]))

In [171]:
def cycle(x, i):
    return torch.cat([x[i:], x[:i]])

eye = torch.eye(5)

torch.sum(torch.stack([cycle(eye, i) for i in range(3)]), dim=0)

tensor([[3., 0., 0., 0., 0.],
        [0., 3., 0., 0., 0.],
        [0., 0., 3., 0., 0.],
        [0., 0., 0., 3., 0.],
        [0., 0., 0., 0., 3.]])

In [202]:
def cycle(x, i):
    return torch.cat([x[i:], x[:i]])

kernel_weights = torch.zeros(width)
kernel_weights[:3] = w.data.squeeze().detach().clone() 
w3 = torch.stack([cycle(kernel_weights, -i) for i in range(-1, width-1)])
w3

tensor([[-0.1305,  0.4368,  0.0000,  0.0000,  0.0532],
        [ 0.0532, -0.1305,  0.4368,  0.0000,  0.0000],
        [ 0.0000,  0.0532, -0.1305,  0.4368,  0.0000],
        [ 0.0000,  0.0000,  0.0532, -0.1305,  0.4368],
        [ 0.4368,  0.0000,  0.0000,  0.0532, -0.1305]])

In [214]:
torch.einsum('w,kw->k', cycle(x.squeeze(), -2), w3)

tensor([-0.1688,  0.3409,  0.8512, -0.7342, -0.1783])

In [207]:
o3 = torch.einsum('w,kw->k', x.squeeze(), w3)
o3

tensor([ 0.8512, -0.7342, -0.1783, -0.1688,  0.3409])

In [209]:
kernel_weights

tensor([ 0.0532, -0.1305,  0.4368,  0.0000,  0.0000])

In [183]:
x.squeeze().shape, w3.shape

(torch.Size([5]), torch.Size([5, 5]))

In [137]:
torch.eye(3, k=1)

TypeError: eye() received an invalid combination of arguments - got (int, k=int), but expected one of:
 * (int n, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (int n, int m, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [128]:
torch.diag(torch.ones(10), diagonal=1) ++ torch.diag(torch.ones(10), diagonal=0)


RuntimeError: The size of tensor a (11) must match the size of tensor b (10) at non-singleton dimension 1

In [133]:
[torch.diag(torch.ones(3), diagonal=diagonal) for diagonal in range(-1, 1)]

[tensor([[0., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.]]),
 tensor([[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]])]

In [131]:
torch.sum([torch.diag(torch.ones(10), diagonal=diagonal) for diagonal in range(-10, 10)])

TypeError: sum(): argument 'input' (position 1) must be Tensor, not list

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [120]:
p

tensor([[[[ 0.4151, -0.4179,  0.0902],
          [-0.4179,  0.0902, -0.7310],
          [ 0.0902, -0.7310,  0.2207],
          [-0.7310,  0.2207, -2.5049],
          [ 0.2207, -2.5049, -0.7893],
          [-2.5049, -0.7893,  2.1092],
          [-0.7893,  2.1092, -0.6504],
          [ 2.1092, -0.6504, -1.3639],
          [-0.6504, -1.3639,  0.4151],
          [-1.3639,  0.4151, -0.4179]]]])

In [122]:
torch.eye(3)

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

In [114]:
w[0] == w2, w2.shape

(tensor([[True, True, True],
         [True, True, True],
         [True, True, True],
         [True, True, True]]),
 torch.Size([4, 3]))

In [116]:
w2.shape

torch.Size([4, 3])

In [110]:
p2.shape, w2.shape, o2.shape

(torch.Size([4, 3]), torch.Size([4, 3]), torch.Size([]))

In [111]:
p2

tensor([[ 0.5580,  1.7661, -0.3027],
        [ 0.6031,  0.4213,  0.6577],
        [-0.6039,  0.0876, -1.5754],
        [ 0.0647, -1.4128, -0.7551]])

In [104]:
x

tensor([[[ 1.7661, -0.3027,  0.2392, -0.1246, -0.2540,  0.8262, -0.1507,
          -1.8762,  0.1936,  0.5580],
         [ 0.4213,  0.6577, -0.6326, -2.1045, -0.7548,  1.4206, -0.1227,
           0.7463,  0.5374,  0.6031],
         [ 0.0876, -1.5754,  0.7371,  0.5767,  1.1310, -0.9098, -0.0612,
          -0.9385, -0.2748, -0.6039],
         [-1.4128, -0.7551, -1.0314, -0.0925,  0.7355, -0.5328, -0.0863,
          -0.5916,  0.9724,  0.0647]],

        [[ 0.2188, -1.0770,  0.0059,  1.1088, -0.2786, -1.2330, -0.8237,
           2.1114, -0.8140, -0.5056],
         [ 0.1743,  1.8318,  1.3409, -0.9931,  0.1322, -0.2115,  0.1997,
           0.2212,  0.3497, -0.2611],
         [ 0.1338, -0.0476,  0.1128,  0.8461, -0.1283, -1.5674, -0.3284,
          -0.3179, -1.0814, -1.2627],
         [-0.4925, -1.1542, -0.4834, -0.7194, -0.1469, -0.4030, -0.5062,
           0.7469, -1.1123,  0.2723]]])

In [84]:
p2.shape, w2.shape

(torch.Size([2, 10, 1, 4, 3]), torch.Size([1, 6, 4, 3]))

(tensor(0.8715, grad_fn=<SelectBackward0>),
 tensor(0.8715, grad_fn=<SelectBackward0>),
 tensor(0.8715, grad_fn=<SelectBackward0>))

In [59]:
p.shape

torch.Size([2, 4, 10, 3])