In [1]:
import torch
from torch.nn import functional as F
import torchvision.transforms as transforms
from torchvision.io import read_image 
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
sns.set_theme()

In [568]:
inp = 2
kern = 3
stride = 1
dilation = 1
padding = 0
bias = bias=torch.Tensor([1, 2, 3])

inputs = torch.randint(1,11,(2, 1, inp, inp)).float()
filters = torch.randint(1, 11, (1, 3, kern, kern)).float()

conv = F.conv_transpose2d(inputs, filters, padding=padding, dilation=dilation, stride=stride, bias=bias)
print(conv.shape)
filters

torch.Size([2, 3, 4, 4])


tensor([[[[ 4.,  2., 10.],
          [ 8.,  2., 10.],
          [10.,  1.,  1.]],

         [[ 3.,  6.,  4.],
          [ 2.,  3.,  1.],
          [ 9.,  6., 10.]],

         [[ 6.,  2., 10.],
          [ 4.,  4.,  5.],
          [ 9.,  4.,  2.]]]])

In [569]:
from torch import Tensor
from math import floor

def dim_window_count(input, weights, stride, padding, dilation, dim):
    return floor(
        (
            input.shape[2 + dim]
            + 2 * padding[dim]
            - dilation[dim] * (weights.shape[2 + dim] - 1)
            - 1
        )
        / stride[dim]
        + 1
    )

def custom_conv2d(
    input: Tensor,
    weights: Tensor,
    bias: Tensor | None = None,
    stride: tuple | int = 1,
    padding: int | str | tuple[int, int] = 0,
    dilation: tuple | int = 1,
    groups: int = 1,
):
    if len(input.shape) < 3 or len(input.shape) > 4:
        raise ValueError(
            f"Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: {input.shape}"
        )
    if len(input.shape) == 3:
        input = input.unsqueeze(0)
        print(f"unsqueezed input: {input.shape}")

    if len(weights.shape) != 4:
        raise ValueError(
            f"Expected 4D weights with shape [out_channels, in_channels/groups, kH, kW], but got weights with shape: {weights.shape}"
        )

    if bias is not None and bias.shape[0] != weights.shape[0]:
        raise ValueError(
            f"Expected bias shape to be [out_channels], but got bias with shape: {bias.shape}"
        )

    batch_size, input_channels, input_h, input_w = input.shape
    output_channels, kernel_channels, kernel_h, kernel_w = weights.shape

    if input_channels % groups != 0 or output_channels % groups != 0:
        raise ValueError(
            f"Groups should be divisible by both input_channels and output_channels. Got:\ninput channels: {input_channels}\noutput channels: {output_channels}\ngroups: {groups}"
        )
    if kernel_channels != input_channels // groups:
        raise ValueError(
            f"Expected kernel channels to be [input_channels/groups], got:\n\
                input channels: {input_channels}\n\
                groups: {groups}\n\
                output channels: {output_channels}"
        )
    if type(dilation) is int:
        dilation = (dilation, dilation)
    if type(padding) is int:
        padding = (padding, padding)
    if type(stride) is int:
        stride = (stride, stride)
    if type(padding) is int:
        padding = (padding, padding)

    print(padding)

    # input_paded = F.pad(
    #     input, (padding[0], padding[0], padding[1], padding[1]), "constant", 0
    # )

    output_h = dim_window_count(input, weights, stride, padding, dilation, 0)
    output_w = dim_window_count(input, weights, stride, padding, dilation, 1)
    output = torch.zeros(size=(batch_size, output_channels, output_h, output_w))

    windows_per_input_channel = output_h*output_w

    unfolded = F.unfold(input, (kernel_h, kernel_w), dilation, padding, stride)

    for img_num, img in enumerate(input):
        for kernel_num, kernel in enumerate(weights):
            group_num = kernel_num // (output_channels // groups)
            input_channels_start = group_num * (input_channels // groups)
            for input_ch_num in range(
                input_channels_start, input_channels_start + input_channels // groups
            ):
                # for kernel_ch_num, kernel_ch in enumerate(kernel):
                kernel_ch_num = input_ch_num % kernel_channels
                # print(
                #     f"kernel channel {kernel_ch_num} applied to input channel {input_ch_num} to output {kernel_num}"
                # )


                kernel_area = kernel_h * kernel_w
                channel_windows_start = input_ch_num * kernel_area

                folds = unfolded[
                    img_num,
                    channel_windows_start : channel_windows_start + kernel_area,
                ]
                current_input_channel_windows = folds.T.reshape(windows_per_input_channel, kernel_h, kernel_w)
                # print(img[input_ch_num])
                # print(windows)
                product = current_input_channel_windows * kernel[kernel_ch_num]
                # print(f'p: {product[:3]}')
                weighted_sum = torch.sum(product, dim=(1, 2), keepdim=True)
                # print(f'ws: {weighted_sum.shape}')

                weighted_sum = weighted_sum.reshape(output_h, output_w)

                output[img_num, kernel_num] += weighted_sum

    if bias is not None:
        output += bias

    return output

In [587]:
def custom_conv_transpose2d(    
    input: Tensor,
    weights: Tensor,
    bias: Tensor | None = None,
    stride: tuple | int = 1,
    padding: int | str | tuple[int, int] = 0,
) :
    if len(input.shape) < 3 or len(input.shape) > 4:
        raise ValueError(
            f"Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: {input.shape}"
        )
    if len(input.shape) == 3:
        input = input.unsqueeze(0)
        print(f"unsqueezed input: {input.shape}")

    if len(weights.shape) != 4:
        raise ValueError(
            f"Expected 4D weights with shape [out_channels, in_channels/groups, kH, kW], but got weights with shape: {weights.shape}"
        )

    if bias is not None and bias.shape[0] != weights.shape[1]:
        raise ValueError(
            f"Expected bias shape to be [out_channels], but got bias with shape: {bias.shape}"
        )

    batch_size, input_channels, input_h, input_w = input.shape
    input_channels, out_channels, kernel_h, kernel_w = weights.shape

    if type(stride) is int:
        stride = (stride, stride)
    if type(padding) is int:
        padding = (padding, padding)

    z = tuple(map(lambda x: x-1, stride))
    p_ = tuple(map(lambda t: t[0] - t[1] - 1, zip(weights.shape[2:], padding)))

    expanded_h = input_h + (input_h-1) * (z[0])
    expanded_w = input_w + (input_w-1) * (z[1])
    expanded = torch.zeros((batch_size, input_channels, expanded_h, expanded_w))
    expanded[::, ::, ::z[0]+1, ::z[1]+1] = input[::, ::, ::, ::]

    expanded_padded = Tensor(np.pad(expanded, ((0,0), (0,0), (p_[0], p_[0]), (p_[1], p_[1]))))

    print(f'expanded inp: {expanded_padded}')

    result = torch.zeros((batch_size, out_channels, expanded_padded.shape[2]-kernel_h+1, expanded_padded.shape[3]-kernel_w+1))

    for img_num, img in enumerate(result):
        for ch_num, _ in enumerate(img):
            for k_num, _ in enumerate(expanded_padded[img_num:img_num+1]):
                result[img_num:img_num+1, ch_num:ch_num+1] += custom_conv2d(
                    expanded_padded[img_num:img_num+1, k_num:k_num+1],
                    weights[k_num:k_num+1, ch_num:ch_num+1, ::, ::].flip(3,2),
                    stride=1,
                    bias=bias[ch_num:ch_num+1],
                )

    return result

In [567]:
import ipytest
ipytest.autoconfig()

In [590]:
%%ipytest

def test_1():
    stride = 1
    padding = 0
    bias = torch.Tensor([1])
    inputs = torch.rand((2, 1, 2, 2)).float()
    filters = torch.rand((1, 1, 3, 3)).float()

    conv = F.conv_transpose2d(inputs, filters, stride=stride, padding=padding, bias=bias)
    my_conv = custom_conv_transpose2d(inputs, filters, stride=stride, padding=padding, bias=bias)
    assert torch.allclose(conv, my_conv)

def test_2():
    stride = 2
    padding = 0
    bias = torch.Tensor([2])
    inputs = torch.rand((2, 1, 3, 3)).float()
    filters = torch.rand((1, 1, 5, 5)).float()

    conv = F.conv_transpose2d(inputs, filters, stride=stride, padding=padding, bias=bias)
    my_conv = custom_conv_transpose2d(inputs, filters, stride=stride, padding=padding, bias=bias)
    assert torch.allclose(conv, my_conv)

def test_3():
    stride = 3
    padding = 1
    bias = torch.Tensor([4])
    inputs = torch.rand((2, 1, 8, 8)).float()
    filters = torch.rand((1, 1, 4, 4)).float()

    conv = F.conv_transpose2d(inputs, filters, stride=stride, padding=padding, bias=bias)
    my_conv = custom_conv_transpose2d(inputs, filters, stride=stride, padding=padding, bias=bias)
    assert torch.allclose(conv, my_conv)

[32m.[0m[32m.[0m[32m.[0m[32m                                                                                          [100%][0m
[32m[32m[1m3 passed[0m[32m in 0.06s[0m[0m
