In [2]:
## Import common packages

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F

In [99]:
N = 64
C_i, C_o = 3, 6
H, W = 28, 28
K = 5
dtype=torch.float64

In [100]:
# N = 1
# C_i, C_o = 1, 1
# H, W = 2, 2
# K = 2

In [101]:
input_feats = torch.rand((N, C_i, H, W), dtype=dtype)
weight = torch.rand((C_o, C_i, K, K))
bias = torch.rand(C_o)
stride = 1
padding = 0
kernel_size = weight.size(2)

input_feats.shape, weight.shape, bias.shape

(torch.Size([64, 3, 28, 28]), torch.Size([6, 3, 5, 5]), torch.Size([6]))

In [102]:
from torch.nn import Conv2d

conv2d = Conv2d(C_i, C_o, kernel_size, stride=stride, padding=padding, dtype=dtype)

conv2d_output = conv2d(input_feats)

weight = torch.clone(conv2d.weight)
bias = torch.clone(conv2d.bias)

conv2d_output.shape, weight.shape, bias.shape

(torch.Size([64, 6, 24, 24]), torch.Size([6, 3, 5, 5]), torch.Size([6]))

In [103]:
# from torch.nn.functional import fold, unfold

# input_feats[0, :, :5, :5][0, 1, 0], unfold(input_feats[0, :, :5, :5], kernel_size)[5, 0]

In [104]:
from torch.nn.functional import fold, unfold

input_unfolded = unfold(input_feats, kernel_size, padding=padding, stride=stride)
weight_unfolded = weight.view(C_o, -1)

output_unfolded = weight_unfolded @ input_unfolded
if bias is not None:
    output_unfolded += bias.view(-1, 1)
H_o = (H + 2 * padding - kernel_size) // stride + 1
output = output_unfolded.view(N, C_o, H_o, -1)

input_unfolded.shape, weight_unfolded.shape, output_unfolded.shape, output.shape

(torch.Size([64, 75, 576]),
 torch.Size([6, 75]),
 torch.Size([64, 6, 576]),
 torch.Size([64, 6, 24, 24]))

In [108]:
import math
for i, (x, y) in enumerate(zip(output.flatten(), conv2d_output.flatten())):
    if not math.isclose(x.item(), y.item()):
        print(i, x.item(), y.item())
        break

In [110]:
from torch.nn.functional import fold, unfold

input_feats = torch.tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]], dtype=torch.float64)
unfold(input_feats, 2)

tensor([[[1.],
         [2.],
         [3.],
         [4.],
         [5.],
         [6.],
         [7.],
         [8.]]], dtype=torch.float64)

In [111]:
H_o, W_o = output.size(2), output.size(3)

H_o, W_o

(24, 24)

In [190]:
N = 64
C_i, C_o = 3, 6
H, W = 28, 28
K = 5
dtype=torch.float64
padding = 0
stride = 1

In [191]:
# N = 1
# C_i, C_o = 1, 1
# H, W = 1, 1
# K = 1
# dtype=torch.float64
# padding = 0
# stride = 1

In [192]:
torch.manual_seed(42)
input_feats = torch.rand((N, C_i, H, W), dtype=dtype)
kernel_size = K

In [193]:
from torch.nn import Conv2d

conv2d = Conv2d(C_i, C_o, kernel_size, padding=padding, stride=stride, dtype=dtype)
weight = torch.clone(conv2d.weight)
bias = torch.clone(conv2d.bias)

In [194]:
conv2d.weight.grad, conv2d.bias.grad = None, None
conv2d_output = conv2d(input_feats)
conv2d_loss = torch.sum(conv2d_output)
conv2d_loss.backward()

conv2d.weight.grad, conv2d.bias.grad

(tensor([[[[18440.4801, 18469.7358, 18466.8721, 18451.5637, 18451.9761],
           [18453.1814, 18475.4890, 18472.2463, 18459.1552, 18455.8155],
           [18452.8338, 18479.6673, 18474.2342, 18471.7449, 18469.8156],
           [18469.9589, 18493.1058, 18492.5274, 18485.5500, 18482.8736],
           [18483.2240, 18503.7942, 18505.1529, 18503.4853, 18501.9564]],
 
          [[18459.0968, 18460.5930, 18454.0254, 18456.0020, 18462.5861],
           [18449.8066, 18451.7519, 18449.0025, 18453.2529, 18455.9310],
           [18442.6837, 18441.5760, 18429.1081, 18446.9385, 18447.6622],
           [18454.9040, 18459.6661, 18451.6285, 18462.9197, 18460.1363],
           [18451.6130, 18456.9143, 18454.4624, 18475.3149, 18472.0842]],
 
          [[18411.8377, 18397.1826, 18388.4731, 18391.7583, 18378.3028],
           [18403.8543, 18395.1834, 18394.3676, 18391.7302, 18369.1994],
           [18380.8593, 18375.9233, 18375.8145, 18370.1944, 18351.2029],
           [18364.2336, 18358.7732, 18349.605

In [199]:
## Forward
input_unfolded = unfold(input_feats, kernel_size, padding=padding, stride=stride)
weight_unfolded = weight.view(C_o, -1)

output_unfolded = weight_unfolded @ input_unfolded
if bias is not None:
    output_unfolded += bias.view(-1, 1)
H_o = (H + 2 * padding - kernel_size) // stride + 1
output = output_unfolded.view(N, C_o, H_o, -1)

input_unfolded.shape, weight_unfolded.shape, output_unfolded.shape, output.shape

(torch.Size([64, 75, 576]),
 torch.Size([6, 75]),
 torch.Size([64, 6, 576]),
 torch.Size([64, 6, 24, 24]))

In [200]:
## Check output

import math
for i, (x, y) in enumerate(zip(output.flatten(), conv2d_output.flatten())):
    if not math.isclose(x.item(), y.item()):
        print(i, x.item(), y.item())
        break

In [204]:
## Backward

from torch.nn.functional import fold, unfold

H_o, W_o = output.size(2), output.size(3)
grad_output = torch.ones(N, C_o, H_o, W_o, dtype=dtype)
grad_output_unfolded = grad_output.view(N, C_o, -1)
input_transpose = torch.transpose(input_unfolded, 1, 2)

grad_weight_unfolded = grad_output_unfolded @ input_transpose
grad_weight = grad_weight_unfolded.view(C_o, -1, kernel_size, kernel_size)

if bias is not None:
    grad_bias = grad_output.sum((0, 2, 3))
    
grad_input_unfolded = weight_unfolded.T @ grad_output_unfolded
print(f'{grad_input_unfolded.shape=}')
grad_input = fold(grad_input_unfolded, (H, W), kernel_size, padding=padding, stride=stride)

grad_output_unfolded.shape, input_transpose.shape, grad_weight_unfolded.shape, grad_weight.shape, grad_bias.shape, grad_input_unfolded.shape, grad_input.shape

grad_input_unfolded.shape=torch.Size([64, 75, 576])


(torch.Size([64, 6, 576]),
 torch.Size([64, 576, 75]),
 torch.Size([64, 6, 75]),
 torch.Size([6, 192, 5, 5]),
 torch.Size([6]),
 torch.Size([64, 75, 576]),
 torch.Size([64, 3, 28, 28]))

In [None]:
grad_weight, grad_bias