# Convolution operation

In [1]:
import torch

def convolve2d(input, kernel, stride, padding=0, dilation=1):
    ih, iw = input.shape
    kh, kw = kernel.shape

    if padding:
        input = torch.nn.functional.pad(input, pad=[padding] * 4)

    # calculating output size
    oh = (ih - kh + 2 * padding - (kh - 1) * (dilation - 1)) // stride + 1
    ow = (iw - kw + 2 * padding - (kw - 1) * (dilation - 1)) // stride + 1    
    
    output = []
    for i in range(0, oh * stride, stride):
        row = []
        for j in range(0, ow * stride, stride):
            region = input[i:i+kh + (kh-1)*(dilation-1):dilation, j:j+kw + (kw-1)*(dilation-1):dilation]
            total = torch.sum(region * kernel)
            row.append(int(total))
        output.append(row)
    
    output = torch.tensor(output)
    assert output.shape == (oh, ow)

    return output

def print_tensor(tensor, title=""):
    shape_str = "x".join(str(s) for s in tensor.shape)
    print(f"{title.upper()} ({shape_str})")
    print(tensor)
    print()




## Input and Kernel

In [2]:
torch.manual_seed(42)
input = torch.randint(low=0, high=10, size=(6, 6))
kernel = torch.tensor([[0, 0, 0],
                       [0, 1, 0],
                       [0, 0, 0]])

print_tensor(input, title="input")
print_tensor(kernel, title="kernel")

INPUT (6x6)
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1],
        [9, 3, 1, 9, 7, 9],
        [2, 0, 5, 9, 3, 4]])

KERNEL (3x3)
tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])



## Convolution stride 1 padding 0 dilation 1

In [3]:
print_tensor(input, title="input")
print_tensor(kernel, title="kernel")

stride = 1
padding = 0
dilation = 1
print(f"stride: {stride}  padding: {padding}  dilation: {dilation}")

output = convolve2d(input, kernel, stride, padding, dilation)
print_tensor(output, title="output")

torch_output = torch.nn.functional.conv2d(input=input.unsqueeze(0).unsqueeze(0), 
                                          weight=kernel.unsqueeze(0).unsqueeze(0),
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation).squeeze()

assert torch.equal(output, torch_output)

INPUT (6x6)
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1],
        [9, 3, 1, 9, 7, 9],
        [2, 0, 5, 9, 3, 4]])

KERNEL (3x3)
tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])

stride: 1  padding: 0  dilation: 1
OUTPUT (4x4)
tensor([[4, 0, 3, 8],
        [4, 1, 2, 5],
        [6, 9, 6, 3],
        [3, 1, 9, 7]])



## Convolution stride 2 padding 0 dilation 1

In [4]:
print_tensor(input, title="input")
print_tensor(kernel, title="kernel")

stride = 2
padding = 0
dilation = 1
print(f"stride: {stride}  padding: {padding}  dilation: {dilation}")

output = convolve2d(input, kernel, stride, padding, dilation)
print_tensor(output, title="output")

torch_output = torch.nn.functional.conv2d(input=input.unsqueeze(0).unsqueeze(0), 
                                          weight=kernel.unsqueeze(0).unsqueeze(0),
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation).squeeze()

assert torch.equal(output, torch_output)

INPUT (6x6)
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1],
        [9, 3, 1, 9, 7, 9],
        [2, 0, 5, 9, 3, 4]])

KERNEL (3x3)
tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])

stride: 2  padding: 0  dilation: 1
OUTPUT (2x2)
tensor([[4, 3],
        [6, 6]])



## Convolution stride 1 padding 1 dilation 1

In [5]:
print_tensor(input, title="input")
print_tensor(kernel, title="kernel")

stride = 1
padding = 1
dilation = 1
print(f"stride: {stride}  padding: {padding}  dilation: {dilation}")

output = convolve2d(input, kernel, stride, padding, dilation)
print_tensor(output, title="output")

torch_output = torch.nn.functional.conv2d(input=input.unsqueeze(0).unsqueeze(0), 
                                          weight=kernel.unsqueeze(0).unsqueeze(0),
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation).squeeze()

assert torch.equal(output, torch_output)

INPUT (6x6)
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1],
        [9, 3, 1, 9, 7, 9],
        [2, 0, 5, 9, 3, 4]])

KERNEL (3x3)
tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])

stride: 1  padding: 1  dilation: 1
OUTPUT (6x6)
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1],
        [9, 3, 1, 9, 7, 9],
        [2, 0, 5, 9, 3, 4]])



## Convolution stride 1 padding 0 dilation 2

In [6]:
print_tensor(input, title="input")
print_tensor(kernel, title="kernel")

stride = 1
padding = 0
dilation = 2
print(f"stride: {stride}  padding: {padding}  dilation: {dilation}")

output = convolve2d(input, kernel, stride, padding, dilation)
print_tensor(output, title="output")

torch_output = torch.nn.functional.conv2d(input=input.unsqueeze(0).unsqueeze(0), 
                                          weight=kernel.unsqueeze(0).unsqueeze(0),
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation).squeeze()

assert torch.equal(output, torch_output)

INPUT (6x6)
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1],
        [9, 3, 1, 9, 7, 9],
        [2, 0, 5, 9, 3, 4]])

KERNEL (3x3)
tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])

stride: 1  padding: 0  dilation: 2
OUTPUT (2x2)
tensor([[1, 2],
        [9, 6]])



# Batches and Channels

In [7]:
torch.manual_seed(42)

iN, iC, iH, iW = 1, 2, 4, 4
iN, iC, iH, iW = 1, 2, 6, 6
input = torch.randint(low=0, high=10, size=(iN, iC, iH, iW))

groups = 1
# oC - number of kernels (output channels)
# fC - ic // groups
oC, fC, kH, kW = 4, iC // groups, 3, 3
filter = torch.zeros((oC, fC, kW, kH)).long()
k = 0
for i in range(oC):
    for j in range(fC):
        fi = k // kW
        fj = k % kW
        filter[i,j][fi,fj]=1
        k += 1
        
stride = 1
oH = (iH - kH) // stride + 1
oW = (iW - kW) // stride + 1
result = torch.zeros((iN, oC, oH, oW), dtype=torch.long)
for i in range(iN):
    for j in range(oC):
        for k in range(iC):
            channel = input[i, k]
            kernel = filter[j, k]
            total = convolve2d(channel, kernel, stride=stride)
            result[i, j] += total
            
            
print_tensor(input, title="input")
print_tensor(filter, title="filter")
print_tensor(result, title="output")


INPUT (1x2x6x6)
tensor([[[[2, 7, 6, 4, 6, 5],
          [0, 4, 0, 3, 8, 4],
          [0, 4, 1, 2, 5, 5],
          [7, 6, 9, 6, 3, 1],
          [9, 3, 1, 9, 7, 9],
          [2, 0, 5, 9, 3, 4]],

         [[9, 6, 2, 0, 6, 2],
          [7, 9, 7, 3, 3, 4],
          [3, 7, 0, 9, 0, 9],
          [6, 9, 5, 4, 8, 8],
          [6, 0, 0, 0, 0, 1],
          [3, 0, 1, 1, 7, 9]]]])

FILTER (4x2x3x3)
tensor([[[[1, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],

         [[0, 1, 0],
          [0, 0, 0],
          [0, 0, 0]]],


        [[[0, 0, 1],
          [0, 0, 0],
          [0, 0, 0]],

         [[0, 0, 0],
          [1, 0, 0],
          [0, 0, 0]]],


        [[[0, 0, 0],
          [0, 1, 0],
          [0, 0, 0]],

         [[0, 0, 0],
          [0, 0, 1],
          [0, 0, 0]]],


        [[[0, 0, 0],
          [0, 0, 0],
          [1, 0, 0]],

         [[0, 0, 0],
          [0, 0, 0],
          [0, 1, 0]]]])

OUTPUT (1x4x4x4)
tensor([[[[ 8,  9,  6, 10],
          [ 9, 11,  3,  6],
