In [1]:
import torch
from torch import nn
from d2l import torch as d2l

def corr2d(X, K):
    h, w = K.shape
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i,j] = (X[i:i+h,j:j+w] * K).sum()
    return Y

In [2]:
X = torch.tensor([[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]])
K = torch.tensor([[0.0,1.0], [2.0,3.0]])
corr2d(X,K)

tensor([[19., 25.],
        [37., 43.]])

In [11]:
# 卷积层
class Conv2D(nn.Module):
    def __init__(self,kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        # self.weight = torch.tensor([[0.0,1.0], [2.0,3.0]])
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self,X):
        return corr2d(X,self.weight) + self.bias

In [12]:
net = Conv2D((2,2))
net(X)

tensor([[2.7653, 4.2916],
        [7.3444, 8.8707]], grad_fn=<AddBackward0>)

In [23]:
# 边缘检测
X = torch.ones((6,8))
X[:, 2:6] = 0
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [24]:
K = torch.tensor([[1.0,-1.0]])

In [25]:
Y = corr2d(X, K)
Y

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

In [18]:
Y = corr2d(X.t(),K)

In [19]:
Y

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [32]:
# 学习卷积核
conv2d = nn.Conv2d(1,1, kernel_size=(1,2),bias=False)
X = X.reshape((1,1,6,8))
Y = Y.reshape((1,1,6,7))
lr = 3e-2
for i in range(10):
    conv2d.zero_grad()
    Y_hat = conv2d(X)
    l = (Y_hat - Y) ** 2
    l.sum().backward()
    conv2d.weight.data[:] -= lr * conv2d.weight.grad
    if (i+1) % 2 == 0:
        print(f'epoch {i + 1}, loss {l.sum():.3f}')

epoch 2, loss 6.447
epoch 4, loss 2.131
epoch 6, loss 0.787
epoch 8, loss 0.308
epoch 10, loss 0.124


In [33]:
conv2d.weight.data.reshape((1,2))

tensor([[ 1.0285, -0.9564]])