The performance gap is huge: https://twitter.com/bwasti/status/1449922665360338945

In [1]:
import torch.nn.functional as f
import torch as th

In [393]:
help(f.pad)

Help on function _pad in module torch.nn.functional:

_pad(input: torch.Tensor, pad: List[int], mode: str = 'constant', value: float = 0) -> torch.Tensor
    Pads tensor.
    
    Padding size:
        The padding size by which to pad some dimensions of :attr:`input`
        are described starting from the last dimension and moving forward.
        :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
        of ``input`` will be padded.
        For example, to pad only the last dimension of the input tensor, then
        :attr:`pad` has the form
        :math:`(\text{padding\_left}, \text{padding\_right})`;
        to pad the last 2 dimensions of the input tensor, then use
        :math:`(\text{padding\_left}, \text{padding\_right},`
        :math:`\text{padding\_top}, \text{padding\_bottom})`;
        to pad the last 3 dimensions, use
        :math:`(\text{padding\_left}, \text{padding\_right},`
        :math:`\text{padding\_top}, \text{padding\_bottom}`
        :math

In [401]:
def buck_pad(inp, pad: int):
    inp_height, inp_width = inp.shape
    return th.cat([th.zeros([pad, inp_width + pad * 2]),
                  th.cat([th.zeros([inp_height, pad]), inp, th.zeros([inp_height, pad])], dim=1),
                  th.zeros([pad, inp_width + pad * 2])], dim=0)

In [402]:
buck_pad(th.ones(4, 4), 2)

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [396]:
th.cat

<function _VariableFunctionsClass.cat>

In [2]:
f.conv1d

<function _VariableFunctionsClass.conv1d>

In [360]:
B = 5
C_in = 4
C_out = 4
LEN = 10
WINDOW_SIZE = 3

In [361]:
x = th.randn(B, C_in, LEN)

In [362]:
w = th.randn(C_out, C_in, WINDOW_SIZE)

In [363]:
x.stride()

(40, 10, 1)

In [364]:
def buck_conv1d(inp, weight, stride=1, dilation=1, groups=1):
    B, C_in, LEN = inp.shape
    C_out, C_in2, WINDOW_SIZE = weight.shape
    assert C_in == C_in2 * groups, (C_in, C_in2 * groups)
    size = [B, groups, C_in2, (LEN - ((WINDOW_SIZE - 1) * dilation)) // stride, WINDOW_SIZE]
    stride = [LEN * C_in, LEN * C_in2, LEN, stride, dilation]
    strided_x_dilated = inp.as_strided(size=size, stride=stride)
    
    return th.einsum('bgcnw,gdcw->bgdn', strided_x_dilated, weight.reshape([groups, C_out//groups, C_in2, WINDOW_SIZE])
                    ).reshape([B, C_out, size[3]])

In [365]:
groups_W.shape

torch.Size([4, 2, 5])

In [366]:
assert th.allclose(f.conv1d(x, w), buck_conv1d(x, w))
assert th.allclose(f.conv1d(x, w, stride=2), buck_conv1d(x, w, stride=2))
assert th.allclose(f.conv1d(x, w, dilation=2), buck_conv1d(x, w, dilation=2))
assert th.allclose(f.conv1d(x, w, stride=2, dilation=2), buck_conv1d(x, w, stride=2, dilation=2))
groups_W = th.randn(C_out, C_in // 2, WINDOW_SIZE)
assert th.allclose(f.conv1d(x, groups_W, groups=2),
                  buck_conv1d(x, groups_W, groups=2))


## `conv2d`

In [367]:
th.arange(12).reshape(3, 4)

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [368]:
th.arange(12).reshape(3, 4).as_strided([2, 3, 2, 2], [4, 1, 4, 1])

tensor([[[[ 0,  1],
          [ 4,  5]],

         [[ 1,  2],
          [ 5,  6]],

         [[ 2,  3],
          [ 6,  7]]],


        [[[ 4,  5],
          [ 8,  9]],

         [[ 5,  6],
          [ 9, 10]],

         [[ 6,  7],
          [10, 11]]]])

In [369]:
B = 2
C_in = 4
C_out = 4
H = 10
W = 1
kH = 10
kW = 1

x_2d = th.randn(B, C_in, H, W)
w_2d = th.randn(C_out, C_in, kH, kW)

In [370]:
f.conv2d(x_2d, w_2d).shape

torch.Size([2, 4, 1, 1])

In [371]:
initial_stride = x_2d.stride()
initial_stride

(40, 10, 1, 1)

In [372]:
x_2d.shape

torch.Size([2, 4, 10, 1])

In [373]:
strided_x = x_2d.as_strided([B, C_in, H - kH + 1, W - kW + 1, kH, kW], list(initial_stride) + [W, 1])

In [374]:
buck_conv_result = th.einsum('bcxyij,dcij->bdxy', strided_x, w_2d)

In [375]:
buck_conv_result.shape

torch.Size([2, 4, 1, 1])

In [390]:
def buck_conv2d(inp, weight, stride=1):
    B, C_in, H, W = inp.shape
    C_out, C_in2, kH, kW = weight.shape
    dilation = 1
    size = [B, C_in, (H - ((kH - 1) * dilation)) // stride, (W - ((kW - 1) * dilation)) // stride, kH, kW]

    stride = list(inp.stride()) + [W * stride, 1 * stride]
    strided_x = inp.as_strided(size=size, stride=stride)
    
    return th.einsum('bcxyij,dcij->bdxy', strided_x, weight)

In [391]:
assert th.allclose(buck_conv2d(x_2d, w_2d), f.conv2d(x_2d, w_2d))
assert th.allclose(buck_conv2d(x_2d, w_2d, stride=2), f.conv2d(x_2d, w_2d, stride=2))
# TODO: dilation, groups

Debugging tips:

- Are there any numbers you can set to one?
- Are there any special cases of conv2d that are closely related to something you've already done?