In [1]:
import torch
from torch import nn
from d2l import torch as d2l

In [12]:
def corr2d(X, K):
    h, w = K.shape
    Y = torch.zeros(size=(X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i+h, j:j+w] * K).sum()
    return Y

In [3]:
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

In [4]:
X = torch.ones((6,8))
X[:,2:6] = 0
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [5]:
K = torch.tensor([[1.,-1.]])

In [7]:
K

tensor([[ 1., -1.]])

In [13]:
Y = corr2d(X, K)
Y

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

In [145]:
conv2d = nn.Conv2d(1,1,(1, 2),bias=False)
X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))
lr = 3e-2

for i in range(1000):
    Y_hat = conv2d(X)
    l = (Y_hat - Y) ** 2
    conv2d.zero_grad()
    l.mean().backward()
    conv2d.weight.data[:] -= lr * conv2d.weight.grad
    print(f'epoch: {i}, lr: {l.sum():.3f}')

epoch: 0, lr: 31.704
epoch: 1, lr: 31.139
epoch: 2, lr: 30.585
epoch: 3, lr: 30.043
epoch: 4, lr: 29.512
epoch: 5, lr: 28.992
epoch: 6, lr: 28.481
epoch: 7, lr: 27.981
epoch: 8, lr: 27.490
epoch: 9, lr: 27.009
epoch: 10, lr: 26.537
epoch: 11, lr: 26.074
epoch: 12, lr: 25.620
epoch: 13, lr: 25.175
epoch: 14, lr: 24.737
epoch: 15, lr: 24.308
epoch: 16, lr: 23.887
epoch: 17, lr: 23.473
epoch: 18, lr: 23.067
epoch: 19, lr: 22.668
epoch: 20, lr: 22.277
epoch: 21, lr: 21.892
epoch: 22, lr: 21.515
epoch: 23, lr: 21.144
epoch: 24, lr: 20.780
epoch: 25, lr: 20.422
epoch: 26, lr: 20.071
epoch: 27, lr: 19.726
epoch: 28, lr: 19.387
epoch: 29, lr: 19.054
epoch: 30, lr: 18.727
epoch: 31, lr: 18.406
epoch: 32, lr: 18.090
epoch: 33, lr: 17.780
epoch: 34, lr: 17.475
epoch: 35, lr: 17.175
epoch: 36, lr: 16.881
epoch: 37, lr: 16.592
epoch: 38, lr: 16.308
epoch: 39, lr: 16.029
epoch: 40, lr: 15.754
epoch: 41, lr: 15.485
epoch: 42, lr: 15.220
epoch: 43, lr: 14.959
epoch: 44, lr: 14.703
epoch: 45, lr: 14.45

In [142]:
conv2d.weight.data

tensor([[[[ 0.7077, -0.7080]]]])