In [1]:
import numpy as np
import torch
from torch.nn.functional import conv2d as libConv2d

In [11]:
class Conv2D:
    def __init__(self, 
                 input_data, 
                 kernel_size: tuple | int, 
                 bias: float | None = None,
                 stride: int = 1, 
                 padding: tuple[int, int] | int | str = (0, 0), 
                 dilation: int = 1
                 ):
        self.input_data_numpy = input_data[0, 0].numpy()
        self.input_data_torch = input_data
        self.bias = bias
        if type(kernel_size) == int:
            self.kernel_size = (kernel_size, kernel_size) 
        else: 
            self.kernel_size = kernel_size
        self.stride  = stride
        self.dilation = dilation
        if isinstance(padding, tuple):
            self.padding = padding[0]
        elif padding == "same" and stride == 1:
                self.padding = self.kernel_size[0] - 1
        else: 
            self.padding = 0
        self.weight_tensor_torch = torch.randn(1, 1, *self.kernel_size)
        self.weight_tensor_numpy = self.weight_tensor_torch[0, 0].numpy()

    def conv2d(self):
        image_height, image_width = self.input_data_numpy.shape
        weight_height, weight_width = self.weight_tensor_numpy.shape
        H_out = int((image_height - self.dilation * (weight_height - 1) - 1 + 2 * self.padding) / self.stride) + 1
        W_out = int((image_width - self.dilation * (weight_width - 1) - 1 + 2 * self.padding) / self.stride) + 1
        if self.bias:
            result += self.bias
        if self.padding > 0:
            self.input_data_numpy = np.pad(self.input_data_numpy, self.padding, mode='constant') 
        else:
            self.input_data_numpy = self.input_data_numpy
        result = np.array([[np.sum(self.input_data_numpy[y * self.stride:y * self.stride + weight_height,
                x * self.stride:x * self.stride + weight_width] * self.weight_tensor_numpy) for x in range(W_out)] for y in range(H_out)])

        return result

    def torch_conv2d(self):
        return libConv2d(self.input_data_torch, self.weight_tensor_torch, self.bias, self.stride, self.padding,
                         self.dilation)

    def test(self, print_flag=False):
        my_conv2d = self.conv2d()
        torch_out = np.array(self.torch_conv2d())
        if print_flag:
            print(my_conv2d)
            print(torch_out[0, 0])
        print(np.allclose(my_conv2d, torch_out[0, 0]))

### Проверяем работоспособность функции:

#### ТЕСТЫ

In [13]:
# Tests
image = torch.randn(1, 1, 5, 5)
test1 = Conv2D(image, kernel_size=2)
test1.test()

image = torch.randn(1, 1, 5, 5)
test2 = Conv2D(image, kernel_size=2, padding='valid')
test2.test()

image = torch.randn(1, 1, 5, 5)
test3 = Conv2D(image, kernel_size=2, padding='same')
test3.test()

image = torch.randn(1, 1, 5, 5)
test4 = Conv2D(image, kernel_size=1, padding='same')
test4.test()

image = torch.randn(1, 1, 5, 5)
test5 = Conv2D(image, kernel_size=1, dilation=3)
test5.test()

image = torch.randn(1, 1, 5, 5)
test6 = Conv2D(image, kernel_size=2, stride=4)
test6.test()


True
True
True
True
True
True


### Результаты совпадают 