In [41]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
def all_close_flat(a, b):
    return torch.allclose(a.flatten(), b.flatten(), atol=1e-6, rtol=1e-6), (a - b)

def check(a, b, string):
    print(string, a.shape)
    print(len(string) * ' ', b.shape)
    # all_close_flat(a, b)

In [82]:
batch_size, width = 2, 32
channels_in, channels_out = 1, 2
kernel_size, stride = 3, 1

x = torch.randn(batch_size, channels_in, width)

# setup the parameters for Conv2d

width_unpadded = width - kernel_size + 1

# Create conv
conv = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True, padding = 0, padding_mode='circular')
conv_pad = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True, padding = 1, padding_mode='circular')

weights, biases = conv.weight, conv.bias
weights_pad, biases_pad = conv_pad.weight, conv_pad.bias
true_conv = conv(x)
true_conv_pad = conv_pad(x)

x_pad = F.pad(x, (1, 1), mode='circular')

patches = x.unsqueeze(2).unfold(3, kernel_size, 1)
patches_pad = x_pad.unsqueeze(2).unfold(3, kernel_size, 1)
check(patches, patches_pad, 'patches')

patches = patches.contiguous().view(batch_size, channels_in, width_unpadded, kernel_size)
patches_pad = patches_pad.contiguous().view(batch_size, channels_in, width, kernel_size)
check(patches, patches_pad, 'patches')

# Shift the windows into the batch dimension using permute
# nb_windows = patches.size(2)
patches = patches.permute(0, 2, 1, 3) # # batch_size, nb_windows, channels, kernel_size 
patches_pad = patches_pad.permute(0, 2, 1, 3) # # batch_size, nb_windows, channels, kernel_size 
check(patches, patches_pad, 'patches')

# Multiply the patches with the weights in order to calculate the conv
result = (patches.unsqueeze(2) * weights.unsqueeze(0)).sum([3, 4])
patches_pad = (patches_pad.unsqueeze(2) * weights_pad.unsqueeze(0)).sum([3, 4])
result = result.permute(0, 2, 1) # batch_size, out_channels, output_pixels
result_pad = patches_pad.permute(0, 2, 1) # batch_size, out_channels, output_pixels
check(result, result_pad, 'result')

# Add the bias
result += biases.unsqueeze(0).unsqueeze(2)
result_pad += biases.unsqueeze(0).unsqueeze(2)
check(result, result_pad, 'result')

result = result.view(batch_size, channels_out, width_unpadded)
result_pad = result_pad.view(batch_size, channels_out, width)
check(result, result_pad, 'result')

# print(torch.allclose(result, true_conv))#, atol=1e-6, rtol=1e-6)
print(torch.allclose(result_pad, true_conv_pad))#, atol=1e-6, rtol=1e-6)


patches torch.Size([2, 1, 1, 30, 3])
        torch.Size([2, 1, 1, 32, 3])
patches torch.Size([2, 1, 30, 3])
        torch.Size([2, 1, 32, 3])
patches torch.Size([2, 30, 1, 3])
        torch.Size([2, 32, 1, 3])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
True


In [88]:
batch_size, width = 2, 32
channels_in, channels_out = 1, 2
kernel_size, stride = 3, 1

x = torch.randn(batch_size, channels_in, width)

# setup the parameters for Conv2d

width_unpadded = width - kernel_size + 1

# Create conv
conv = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True, padding = 0, padding_mode='circular')
conv_pad = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True, padding = 1, padding_mode='circular')

weights, biases = conv.weight, conv.bias
weights_pad, biases_pad = conv_pad.weight, conv_pad.bias
true_conv = conv(x)
true_conv_pad = conv_pad(x)

x_pad = F.pad(x, (1, 1), mode='circular')

patches = x.unsqueeze(2).unfold(3, kernel_size, 1)
patches_pad = x_pad.unsqueeze(2).unfold(3, kernel_size, 1)
check(patches, patches_pad, 'patches')

patches = patches.contiguous().view(batch_size, channels_in, width_unpadded, kernel_size)
patches_pad = patches_pad.contiguous().view(batch_size, channels_in, width, kernel_size)
check(patches, patches_pad, 'patches')

# Shift the windows into the batch dimension using permute
# nb_windows = patches.size(2)
patches = patches.permute(0, 2, 1, 3) # # batch_size, nb_windows, channels, kernel_size 
patches_pad = patches_pad.permute(0, 2, 1, 3) # # batch_size, nb_windows, channels, kernel_size 
check(patches, patches_pad, 'patches')

# Multiply the patches with the weights in order to calculate the conv
result = (patches.unsqueeze(2) * weights.unsqueeze(0)).sum([3, 4])
patches_pad = (patches_pad.unsqueeze(2) * weights_pad.unsqueeze(0)).sum([3, 4])
result = result.permute(0, 2, 1) # batch_size, out_channels, output_pixels
result_pad = patches_pad.permute(0, 2, 1) # batch_size, out_channels, output_pixels
check(result, result_pad, 'result')

# Add the bias
result += biases.unsqueeze(0).unsqueeze(2)
result_pad += biases.unsqueeze(0).unsqueeze(2)
check(result, result_pad, 'result')

result = result.view(batch_size, channels_out, width_unpadded)
result_pad = result_pad.view(batch_size, channels_out, width)
check(result, result_pad, 'result')

print(torch.allclose(result, true_conv))#, atol=1e-6, rtol=1e-6)
print(torch.allclose(result_pad, true_conv_pad))#, atol=1e-6, rtol=1e-6)


patches torch.Size([2, 1, 1, 30, 3])
        torch.Size([2, 1, 1, 32, 3])
patches torch.Size([2, 1, 30, 3])
        torch.Size([2, 1, 32, 3])
patches torch.Size([2, 30, 1, 3])
        torch.Size([2, 32, 1, 3])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
True
False


In [97]:
batch_size, width = 2, 32
channels_in, channels_out = 1, 2
kernel_size, stride = 3, 1

x = torch.randn(batch_size, channels_in, width)

# setup the parameters for Conv2d

width_unpadded = width - kernel_size + 1

# Create conv
conv = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True, padding = 0, padding_mode='circular')
conv_pad = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True, padding = 1, padding_mode='circular')

weights, biases = conv.weight, conv.bias
weights_pad, biases_pad = conv_pad.weight, conv_pad.bias
true_conv = conv(x)
true_conv_pad = conv_pad(x)

x_pad = F.pad(x, (1, 1), mode='circular')

patches = x.unsqueeze(2).unfold(3, kernel_size, 1)
patches_pad = x_pad.unsqueeze(2).unfold(3, kernel_size, 1)
check(patches, patches_pad, 'patches')

patches = patches.contiguous().view(batch_size, channels_in, width_unpadded, kernel_size)
patches_pad = patches_pad.contiguous().view(batch_size, channels_in, width, kernel_size)
check(patches, patches_pad, 'patches')

# Shift the windows into the batch dimension using permute
# nb_windows = patches.size(2)
patches = patches.permute(0, 2, 1, 3) # # batch_size, nb_windows, channels, kernel_size 
patches_pad = patches_pad.permute(0, 2, 1, 3) # # batch_size, nb_windows, channels, kernel_size 
check(patches, patches_pad, 'patches')

# Multiply the patches with the weights in order to calculate the conv
result = (patches.unsqueeze(2) * weights.unsqueeze(0)).sum([3, 4])
patches_pad = (patches_pad.unsqueeze(2) * weights_pad.unsqueeze(0)).sum([3, 4])
result = result.permute(0, 2, 1) # batch_size, out_channels, output_pixels
result_pad = patches_pad.permute(0, 2, 1) # batch_size, out_channels, output_pixels
check(result, result_pad, 'result')

# Add the bias
result += biases.unsqueeze(0).unsqueeze(2)
result_pad += biases_pad.unsqueeze(0).unsqueeze(2)
check(result, result_pad, 'result')

result = result.view(batch_size, channels_out, width_unpadded)
result_pad = result_pad.view(batch_size, channels_out, width)
check(result, result_pad, 'result')

print(torch.allclose(result, true_conv))#, atol=1e-6, rtol=1e-6)
print(torch.allclose(result_pad, true_conv_pad))#, atol=1e-6, rtol=1e-6)


patches torch.Size([2, 1, 1, 30, 3])
        torch.Size([2, 1, 1, 32, 3])
patches torch.Size([2, 1, 30, 3])
        torch.Size([2, 1, 32, 3])
patches torch.Size([2, 30, 1, 3])
        torch.Size([2, 32, 1, 3])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
True
True


patches torch.Size([2, 1, 1, 30, 3])
        torch.Size([2, 1, 1, 32, 3])
patches torch.Size([2, 1, 30, 3])
        torch.Size([2, 1, 32, 3])
patches torch.Size([2, 30, 1, 3])
        torch.Size([2, 32, 1, 3])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 32])
True


In [85]:
batch_size, width = 2, 32
channels_in, channels_out = 1, 2
kernel_size, stride = 3, 1

x = torch.randn(batch_size, channels_in, width)

# setup the parameters for Conv2d

width_out = width - kernel_size + 1
width_out_pad = width

# Create conv
conv = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True,     padding = 0, padding_mode='zeros')
conv_pad = nn.Conv1d(channels_in, channels_out, kernel_size, stride=stride, bias=True, padding = 1, padding_mode='circular')

weights, biases = conv_pad.weight, conv_pad.bias
true_conv = conv(x)

weights_pad, biases_pad = conv.weight, conv.bias
true_conv_pad = conv(x)

x_pad = F.pad(x, (1, 1), mode='circular')

patches = x.unsqueeze(2).unfold(3, kernel_size, 1)
patches_pad = x_pad.unsqueeze(2).unfold(3, kernel_size, 1)
check(patches, patches_pad, 'patches')

patches = patches.contiguous().view(batch_size, channels_in, width_out, kernel_size)
patches_pad = patches_pad.contiguous().view(batch_size, channels_in, width, kernel_size)
check(patches, patches_pad, 'patches')

# Shift the windows into the batch dimension using permute
# nb_windows = patches.size(2)
patches = patches.permute(0, 2, 1, 3) # # batch_size, nb_windows, channels, kernel_size 

# Multiply the patches with the weights in order to calculate the conv
result = (patches.unsqueeze(2) * weights.unsqueeze(0)).sum([3, 4])
patches_pad = (patches_pad.unsqueeze(2) * weights.unsqueeze(0)).sum([3, 4])
result = result.permute(0, 2, 1) # batch_size, out_channels, output_pixels
result_pad = patches_pad.permute(0, 2, 1) # batch_size, out_channels, output_pixels
check(result, result_pad, 'result')

# Add the bias
result += biases.unsqueeze(0).unsqueeze(2)
result_pad += biases.unsqueeze(0).unsqueeze(2)
check(result, result_pad, 'result')

result = result.view(batch_size, channels_out, width_out)
result_pad = result_pad.view(batch_size, channels_out, width_out_pad)
check(result, result_pad, 'result')

torch.allclose(result_pad, true_conv_pad, atol=1e-8, rtol=1e-6)
print(torch.allclose(result, true_conv))#, atol=1e-6, rtol=1e-6)


patches torch.Size([2, 1, 1, 30, 3])
        torch.Size([2, 1, 1, 32, 3])
patches torch.Size([2, 1, 30, 3])
        torch.Size([2, 1, 32, 3])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 1])
result torch.Size([2, 2, 30])
       torch.Size([2, 2, 1])


RuntimeError: shape '[2, 2, 32]' is invalid for input of size 4