# 24、手写并验证向量内积实现PyTorch二维卷积

In [21]:
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F

In [22]:
def my_conv(
    input_mat: torch.Tensor,
    kernel_mat: Tensor,
    bias: Optional[Tensor] = None,
    stride: int = 1,
    padding: int = 0,
) -> Tensor:
    """
    pytorch中 conv2d 输入是4维度卷积，这里是二维

    Args:
        input_mat (torch.Tensor): 2维矩阵
        kernel_mat (Tensor): 2维的卷积核
        bias (Optional[Tensor], optional): Defaults to None.
        stride (int, optional): 步长. Defaults to 1.
        padding (int, optional): 边界周围填充. Defaults to 0.

    """
    if padding != 0:
        input_mat = F.pad(input_mat, [padding, padding, padding, padding])
    w_in, h_in = input_mat.shape
    w_ker, h_ker = kernel_mat.shape

    h_out = (h_in - h_ker) // stride + 1
    w_out = (w_in - w_ker) // stride + 1
    out = torch.zeros(w_out, h_out)
    for i in range(w_out):
        i_s = i * stride
        for j in range(h_out):
            j_s = j * stride
            submat = input_mat[i_s : i_s + w_ker, j_s : j_s + h_ker]
            out[i, j] += (submat * kernel_mat).sum()
    if bias:
        out += bias
    return out

In [23]:
def my_conv2d(
    input: Tensor,
    kernel: Tensor,
    bias: Optional[Tensor] = None,
    stride: int = 1,
    padding: int = 0,
) -> Tensor:
    """

    Args:
        input_mat (torch.Tensor): 4维矩阵 (batch_size, input_channels, weight, height)
        kernel_mat (Tensor): 2维的卷积核 (out_channels, weight, height)
        bias (Optional[Tensor], optional): Defaults to None.
        stride (int, optional): 步长. Defaults to 1.
        padding (int, optional): 边界周围填充. Defaults to 0.

    """
    assert input.dim() == 4
    batch_size, input_channels, w_in, h_in = input.shape
    out_channels, w_ker, h_ker = kernel.shape

    w_out = (w_in - w_ker) // stride + 1
    h_out = (h_in - h_ker) // stride + 1
    out = torch.zeros(batch_size, out_channels, w_out, h_out)
    for i in range(batch_size):
        for oc in range(out_channels):
            kernel_i = kernel[oc, :, :]
            for ic in range(input_channels):
                out[i, oc, :, :] += my_conv(input[i, ic, :, :], kernel_i, bias, stride, padding)

    return out

In [24]:
def test_conv():
    mat = torch.randn((3, 1, 4, 4))
    ker = torch.randn((2, 3, 3))
    r1 = my_conv2d(mat, ker)
    r2 = F.conv2d(mat, ker.unsqueeze(1), stride=1)
    assert r1.allclose(r2)
    print("All Test Passed!")


test_conv()

All Test Passed!
