In [2]:
import torch
from torch import nn
from d2l import torch as d2l

In [7]:
def corr2d(X:torch.tensor, K:torch.tensor) -> torch.tensor:
    """
    Compute 2D cross-correlation
    Params:
    -------
    X: input tensor
    K: kernel
    
    """
    h, w = K.shape
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    return Y

In [8]:
X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])

corr2d(X, K)

tensor([[19., 25.],
        [37., 43.]])

In [11]:
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

In [13]:
layer = Conv2D((2, 2))
layer.bias

Parameter containing:
tensor([0.], requires_grad=True)

In [15]:
X = torch.ones((6, 8))
X[:, 2:6] = 0
X

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

In [16]:
K = torch.tensor([[1.0, -1.0]])
Y = corr2d(X, K)
Y

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

In [18]:
corr2d(X.t(), K)

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

### LEARNING A KERNEL

In [41]:
conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)

# 2D convolutional layer uses 4D input and output in the format of (example, channel, height, width), wher ethe batch size (number of examples in the batch) and the number of channels are both 1
X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))
lr = 3e-2 # learning rate

for i in range(int(20)):
    Y_hat = conv2d(X)
    l = (Y_hat - Y) ** 2
    conv2d.zero_grad()
    l.sum().backward()
    # Update the kernel (weights)
    conv2d.weight.data[:] -= lr * conv2d.weight.grad
    if (i + 1) % 5 == 0:
        print(f'epoch {i + 1}, loss {l.sum():.3f}')


epoch 5, loss 1.035
epoch 10, loss 0.080
epoch 15, loss 0.008
epoch 20, loss 0.001


In [42]:
conv2d.weight.data

tensor([[[[ 1.0029, -0.9968]]]])

In [37]:
1e3

1000.0