In [37]:
import torch

def convolve(input, kernel, stride):
    ih, iw = input.shape
    kh, kw = kernel.shape

    ch = (ih - kh) // stride + 1
    cw = (iw - kw) // stride + 1    
    
    convolved = []
    for i in range(0, ch * stride, stride):
        row = []
        for j in range(0, cw * stride, stride):
            region = input[i:i+kh, j:j+kw]
            total = torch.sum(region * kernel)
            row.append(int(total))
        convolved.append(row)
    
    return torch.tensor(convolved)

def print_tensor(tensor, title=""):
    shape_str = "x".join(str(s) for s in tensor.shape)
    print(f"{title} ({shape_str})")
    print(tensor)
    print()

torch.manual_seed(42)

input = torch.randint(low=0, high=10, size=(6, 6))
filter = torch.tensor([[0, 0, 0],
                       [0, 1, 0],
                       [0, 0, 0]])

print_tensor(input, title="input")
print_tensor(filter, title="filter")

stride=2
convolved = convolve(input, filter, stride=stride)
print(f"stride: {stride}")
print_tensor(convolved, title="convolved")




input (6x6)
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1],
        [9, 3, 1, 9, 7, 9],
        [2, 0, 5, 9, 3, 4]])

filter (3x3)
tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])

stride: 2
convolved (2x2)
tensor([[4, 3],
        [6, 6]])



In [76]:
torch.manual_seed(42)

iN, iC, iH, iW = 1, 2, 4, 4
iN, iC, iH, iW = 1, 2, 6, 6
input = torch.randint(low=0, high=10, size=(iN, iC, iH, iW))

groups = 1
# oC - number of kernels (output channels)
# fC - ic // groups
oC, fC, fH, fW = 4, iC // groups, 3, 3
filter = torch.zeros((oC, fC, fW, fH)).long()
k = 0
for i in range(oC):
    for j in range(fC):
        fi = k // fW
        fj = k % fW
        filter[i,j][fi,fj]=1
        k += 1
        
stride = 1
cH = (iH - fH) // stride + 1
cW = (iW - fW) // stride + 1
result = torch.zeros((iN, oC, cH, cW), dtype=torch.long)
for i in range(iN):
    for j in range(oC):
        for k in range(iC):
            channel = input[i, k]
            kernel = filter[j, k]
            total = convolve(channel, kernel, stride)
            result[i, j] += total
            
            
print_tensor(input, title="input")
print_tensor(filter, title="filter")
print_tensor(result, title="convolved")


input (1x2x6x6)
tensor([[[[2, 7, 6, 4, 6, 5],
          [0, 4, 0, 3, 8, 4],
          [0, 4, 1, 2, 5, 5],
          [7, 6, 9, 6, 3, 1],
          [9, 3, 1, 9, 7, 9],
          [2, 0, 5, 9, 3, 4]],

         [[9, 6, 2, 0, 6, 2],
          [7, 9, 7, 3, 3, 4],
          [3, 7, 0, 9, 0, 9],
          [6, 9, 5, 4, 8, 8],
          [6, 0, 0, 0, 0, 1],
          [3, 0, 1, 1, 7, 9]]]])

filter (4x2x3x3)
tensor([[[[1, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],

         [[0, 1, 0],
          [0, 0, 0],
          [0, 0, 0]]],


        [[[0, 0, 1],
          [0, 0, 0],
          [0, 0, 0]],

         [[0, 0, 0],
          [1, 0, 0],
          [0, 0, 0]]],


        [[[0, 0, 0],
          [0, 1, 0],
          [0, 0, 0]],

         [[0, 0, 0],
          [0, 0, 1],
          [0, 0, 0]]],


        [[[0, 0, 0],
          [0, 0, 0],
          [1, 0, 0]],

         [[0, 0, 0],
          [0, 0, 0],
          [0, 1, 0]]]])

convolved (1x4x4x4)
tensor([[[[ 8,  9,  6, 10],
          [ 9, 11,  3,  6

In [77]:
import torch.nn.functional as F

F.conv2d(input=input, 
         weight=filter,
         stride=stride)

tensor([[[[ 8,  9,  6, 10],
          [ 9, 11,  3,  6],
          [ 7,  4, 10,  2],
          [16, 11, 13, 14]],

         [[13, 13, 13,  8],
          [ 3, 10,  8, 13],
          [ 7, 11, 10,  9],
          [15,  6,  3,  1]],

         [[11,  3,  6, 12],
          [ 4, 10,  2, 14],
          [11, 13, 14, 11],
          [ 3,  1,  9,  8]],

         [[ 7,  4, 10,  2],
          [16, 11, 13, 14],
          [ 9,  3,  1,  9],
          [ 2,  1,  6, 16]]]])

In [82]:

F.max_pool2d(result.float(), kernel_size=2)

tensor([[[[11., 10.],
          [16., 14.]],

         [[13., 13.],
          [15., 10.]],

         [[11., 14.],
          [13., 14.]],

         [[16., 14.],
          [ 9., 16.]]]])