In [1]:
import torch
import numpy as np
from abc import ABC, abstractmethod

In [2]:
def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    batch_size, channels_count, input_height, input_width = input_matrix_shape
    output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

    return batch_size, out_channels, output_height, output_width

In [3]:
class ABCConv2d(ABC):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

    def set_kernel(self, kernel):
        self.kernel = kernel

    @abstractmethod
    def __call__(self, input_tensor):
        pass

In [4]:
class Conv2d(ABCConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
                                      stride, padding=0, bias=False)

    def set_kernel(self, kernel):
        self.conv2d.weight.data = kernel

    def __call__(self, input_tensor):
        return self.conv2d(input_tensor)

In [5]:
def create_and_call_conv2d_layer(conv2d_layer_class, stride, kernel, input_matrix):
    out_channels = kernel.shape[0]
    in_channels = kernel.shape[1]
    kernel_size = kernel.shape[2]

    layer = conv2d_layer_class(in_channels, out_channels, kernel_size, stride)
    layer.set_kernel(kernel)

    return layer(input_matrix)

In [6]:
def test_conv2d_layer(conv2d_layer_class, batch_size=2,
                      input_height=4, input_width=4, stride=2):
    kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

    in_channels = kernel.shape[1]

    input_tensor = torch.arange(0, batch_size * in_channels *
                                input_height * input_width,
                                out=torch.FloatTensor()) \
        .reshape(batch_size, in_channels, input_height, input_width)

    custom_conv2d_out = create_and_call_conv2d_layer(
        conv2d_layer_class, stride, kernel, input_tensor)
    conv2d_out = create_and_call_conv2d_layer(
        Conv2d, stride, kernel, input_tensor)

    return torch.allclose(custom_conv2d_out, conv2d_out)

In [7]:
class Conv2dMatrix(ABCConv2d):
    # Функция преобразование кернела в матрицу нужного вида.
    def _unsqueeze_kernel(self, torch_input, output_height, output_width):

        unsqueezed_channels = torch.empty(torch_input.shape[0], self.kernel.shape[3], int((1 + int((torch_input.shape[3] - self.kernel.shape[3])/self.stride))**2), torch_input.shape[3]**2, dtype=torch.float32)
        
        for batch in range(torch_input.shape[0]):
            for ifilter in range(self.kernel.shape[0]):
                for channel in range(self.kernel.shape[1]):

                    x0_v = torch.zeros((self.kernel.shape[2], torch_input.shape[3]-self.kernel.shape[3]))
            
                    stuff = torch.cat((self.kernel[ifilter][channel], x0_v), 1).\
                                                        view(-1)[:-(torch_input.shape[3]-self.kernel.shape[3])]
            
                    first_zeros_list = []

                    for s in range(0, (torch_input.shape[3] - self.kernel.shape[3] + self.stride)*torch_input.shape[3], torch_input.shape[3]):
                        x = [i for i in range(s, s + torch_input.shape[3] - self.kernel.shape[3] + self.stride)]
                        first_zeros_list.append(x)

                    first_zeros = torch.tensor(first_zeros_list).view(-1)

                    for line in range(int((1 + int((torch_input.shape[3] - self.kernel.shape[3])/self.stride)))**2):
                        list1 = torch.zeros(first_zeros[line], dtype=torch.float32)
                        list2 = torch.zeros(torch_input.shape[3]**2-len(stuff)-len(list1), dtype=torch.float32)
            
                        unsqueezed_channels[batch][channel][line] = torch.cat((list1, stuff, list2), 0)

        kernel_unsqueezed = unsqueezed_channels[batch][0]
        for l in range(1, unsqueezed_channels.shape[1]):
            kernel_unsqueezed = torch.cat((kernel_unsqueezed, unsqueezed_channels[batch][l]), 1)

        return kernel_unsqueezed

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width\
            = calc_out_shape(
                input_matrix_shape=torch_input.shape,
                out_channels=self.kernel.shape[0],
                kernel_size=self.kernel.shape[2],
                stride=self.stride,
                padding=0)

        kernel_unsqueezed = self._unsqueeze_kernel(torch_input, output_height, output_width)
        result = kernel_unsqueezed @ torch_input.view((batch_size, -1)).permute(1, 0)
        return result.permute(1, 0).view((batch_size, self.out_channels,
                                          output_height, output_width))

# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(test_conv2d_layer(Conv2dMatrix))

True
