In [31]:
import torch as t
from fancy_einsum import einsum

import utils

In [51]:
def conv1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    '''Like torch's conv1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''
    b, ic, iw = x.shape         # batch, in_channels, input_width
    oc, ic, kw = weights.shape  # out_channels, in_channels, kernel_width)
    ow = iw - kw + 1            # output_width

    bs, ics, iws = x.stride()    # batch_stride, input_channel_stride, input_width_stride
    x_strided = x.as_strided(size=(b, ic, ow, kw), stride=(bs, ics, iws, iws))

    return einsum('b ic ow kw, oc ic kw -> b oc ow', x_strided, weights)

utils.test_conv1d_minimal(conv1d_minimal)