In [2]:
import numpy as np
from torch import nn
import torch
import pytest


In [29]:
%%file conv_transpose.py

def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
    batch_size, in_channels, in_height, in_width = input.shape
    out_channels, in_channels, kernel_height, kernel_width = weight.shape
    out_height = (in_height - 1) * stride - 2 * padding + kernel_height + output_padding
    out_width = (in_width - 1) * stride - 2 * padding + kernel_width + output_padding
    output = np.zeros((batch_size, out_channels, out_height, out_width))
    for b in range(batch_size):
        for c in range(out_channels):
            for i in range(out_height):
                for j in range(out_width):
                    for k in range(in_channels):
                        for s in range(kernel_height):
                            for t in range(kernel_width):
                                ii = i + padding - s * dilation
                                jj = j + padding - t * dilation
                                if ii >= 0 and jj >= 0 and ii < in_height * stride and jj < in_width * stride and (ii % stride == 0) and (jj % stride == 0):
                                    ii //= stride
                                    jj //= stride
                                    output[b, c, i, j] += input[b, k, ii, jj] * weight[c, k, s, t]
            if bias is not None:
                output[b, c, :, :] += bias[c]
    return output


Writing conv_transpose.py


In [30]:
input_data = torch.randn(1, 1, 3, 3)
weight = torch.randn(1, 1, 3, 3)

your_result = conv_transpose(input_data.numpy(), weight.numpy())
torch_result = torch.nn.functional.conv_transpose2d(input_data, weight).numpy()

print(your_result)
print(torch_result)

[[[[ 0.68636662  0.0527702  -2.08839497  2.40579349 -3.70227385]
   [ 0.34910472  0.14816615 -3.01318112  7.85112751 -2.62486166]
   [ 0.19530126  2.19017723 -1.55486235 -1.82652484 -0.07011274]
   [ 0.51614574 -1.48924741  0.18732784  2.83440238 -1.18785825]
   [-0.19654755  0.56382792  0.83627185 -1.3440052   0.2573778 ]]]]
[[[[ 0.6863666   0.0527702  -2.0883949   2.4057934  -3.7022738 ]
   [ 0.3491047   0.14816618 -3.0131812   7.8511276  -2.6248617 ]
   [ 0.19530126  2.190177   -1.5548624  -1.8265249  -0.07011276]
   [ 0.51614577 -1.4892474   0.18732795  2.8344026  -1.1878582 ]
   [-0.19654755  0.56382793  0.8362719  -1.3440052   0.2573778 ]]]]
