In [7]:
import torch as t
import w5d1_tests
from fancy_einsum import einsum
from typing import Union
from einops import rearrange
import torch.nn.functional as F

In [8]:
IntOrPair = Union[int, tuple[int, int]]
Pair = tuple[int, int]
def conv1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    """Like torch's conv1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    """
    
    batch, in_channels, width = x.shape
    out_channels, in_channels_2, kernel_width = weights.shape
    assert in_channels == in_channels_2, "in_channels for x and weights don't match up"
    output_width = width - kernel_width + 1
    
    xsB, xsI, xsWi = x.stride()
    wsO, wsI, wsW = weights.stride()
    
    x_new_shape = (batch, in_channels, output_width, kernel_width)
    x_new_stride = (xsB, xsI, xsWi, xsWi)
    # Common error: xsWi is always 1, so if you put 1 here you won't spot your mistake until you try this with conv2d!
    x_strided = x.as_strided(size=x_new_shape, stride=x_new_stride)
    
    return einsum(
        "batch in_channels output_width kernel_width, out_channels in_channels kernel_width -> batch out_channels output_width", 
        x_strided, weights
    )

def conv2d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    '''Like torch's conv2d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)

    Returns: shape (batch, out_channels, output_height, output_width)
    '''
    xsB, xsI, xsHi, xsWi = x.stride()
    x_new_stride = (xsB, xsI, xsHi, xsWi, xsHi, xsWi)

    batch, in_channels, height, width = x.shape
    out_channels, in_channels_2, kernel_height, kernel_width = weights.shape
    output_width = width - kernel_width + 1
    output_height = height - kernel_height + 1
    
    x_new_shape = (batch, in_channels, output_height, output_width, kernel_height, kernel_width)


    return einsum(
        """batch in_channels output_height output_width kernel_height kernel_width, 
        out_channels in_channels kernel_height kernel_width 
        -> 
        batch out_channels output_height output_width""", 
        x.as_strided(x_new_shape, x_new_stride), weights
    )

def pad1d(x: t.Tensor, left: int, right: int, pad_value: float) -> t.Tensor:
    """Return a new tensor with padding applied to the edges.

    x: shape (batch, in_channels, width), dtype float32

    Return: shape (batch, in_channels, left + right + width)
    """
    B, C, W = x.shape
    output = x.new_full(size=(B, C, left + W + right), fill_value=pad_value)
    output[..., left : left + W] = x
    # Note - you can't use `left:-right`, because `right` could be zero.
    return output
    
def force_pair(v: IntOrPair) -> Pair:
    """Convert v to a pair of int, if it isn't already."""
    if isinstance(v, tuple):
        if len(v) != 2:
            raise ValueError(v)
        return (int(v[0]), int(v[1]))
    elif isinstance(v, int):
        return (v, v)
    raise ValueError(v)

def pad2d(x: t.Tensor, left: int, right: int, top: int, bottom: int, pad_value: float) -> t.Tensor:
    """Return a new tensor with padding applied to the edges.

    x: shape (batch, in_channels, height, width), dtype float32

    Return: shape (batch, in_channels, top + height + bottom, left + width + right)
    """
    B, C, H, W = x.shape
    output = x.new_full(size=(B, C, top + H + bottom, left + W + right), fill_value=pad_value)
    output[..., top : top + H, left : left + W] = x
    return output

def conv2d(x, weights, stride: IntOrPair = 1, padding: IntOrPair = 0) -> t.Tensor:
    """Like torch's conv2d using bias=False

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)


    Returns: shape (batch, out_channels, output_height, output_width)
    """

    stride_h, stride_w = force_pair(stride)
    padding_h, padding_w = force_pair(padding)
    
    x_padded = pad2d(x, left=padding_w, right=padding_w, top=padding_h, bottom=padding_h, pad_value=0)
    
    batch, in_channels, height, width = x_padded.shape
    out_channels, in_channels_2, kernel_height, kernel_width = weights.shape
    assert in_channels == in_channels_2, "in_channels for x and weights don't match up"
    output_width = 1 + (width - kernel_width) // stride_w
    output_height = 1 + (height - kernel_height) // stride_h
    
    xsB, xsIC, xsH, xsW = x_padded.stride() # B for batch, IC for input channels, H for height, W for width
    wsOC, wsIC, wsH, wsW = weights.stride()
    
    x_new_shape = (batch, in_channels, output_height, output_width, kernel_height, kernel_width)
    x_new_stride = (xsB, xsIC, xsH * stride_h, xsW * stride_w, xsH, xsW)
    
    x_strided = x_padded.as_strided(size=x_new_shape, stride=x_new_stride)
    
    return einsum("B IC OH OW wH wW, OC IC wH wW -> B OC OH OW", x_strided, weights)


In [9]:
def test_conv1d_minimal(conv1d_minimal, n_tests=20):
    import numpy as np
    for _ in range(n_tests):
        b = np.random.randint(1, 10)
        h = np.random.randint(10, 30)
        ci = np.random.randint(1, 5)
        co = np.random.randint(1, 5)
        kernel_size = np.random.randint(1, 10)
        x = t.randn((b, ci, h))
        weights = t.randn((co, ci, kernel_size))
        my_output = conv1d_minimal(x, weights)
        torch_output = t.conv1d(x, weights, stride=1, padding=0)
        t.testing.assert_close(my_output, torch_output)
    print("All tests in `test_conv1d_minimal` passed!")

test_conv1d_minimal(conv1d_minimal)

All tests in `test_conv1d_minimal` passed!


In [10]:
def conv_transpose1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    '''Like torch's conv_transpose1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (in_channels, out_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''

    """
    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)
    """
    _, _, width = x.shape
    _, _, kernel_width = weights.shape
    # add padding
    padlen = kernel_width - 1
    new_width = width + 2 * padlen
    # reverse weights
    reversed_weights = weights.flip(-1)

    import torch.nn.functional as F
    # now we expand to size (7, 11) by appending a row of 0s at pos 0 and pos 6, 
    # and a column of 0s at pos 10
    padded_x = F.pad(input=x, pad=(padlen, padlen, 0, 0, 0, 0), mode='constant', value=0)
    reversed_weights = rearrange(reversed_weights, 'in out w -> out in w')
    return conv1d_minimal(padded_x, reversed_weights)


# w5d1_tests.test_conv_transpose1d_minimal(conv_transpose1d_minimal)
testtens = t.tensor([[[1,2,3]]])
testweights = t.tensor([[[4,5]]])
conv_transpose1d_minimal(testtens, testweights)

tensor([[[ 4, 13, 22, 15]]])

In [11]:
def fractional_stride_1d(x, stride: int = 1):
    '''Returns a version of x suitable for transposed convolutions, i.e. "spaced out" with zeros between its values.
    This spacing only happens along the last dimension.

    x: shape (batch, in_channels, width)

    Example: 
        x = [[[1, 2, 3], [4, 5, 6]]]
        stride = 2
        output = [[[1, 0, 2, 0, 3], [4, 0, 5, 0, 6]]]
    '''
    if stride == 1:
        return x 
    shape = list(x.shape)
    print(stride)
    print(shape)
    stride = stride - 1
    width = x.shape[2]
    new_width = width + width * (stride) - stride
    new_shape = [shape[0], shape[1], new_width]
    print(new_shape)
    new_x = t.zeros(new_shape, dtype=x.dtype)
    print(t.arange(0, new_shape[2], stride+1))
    new_x[:,:,t.arange(0, new_shape[2], stride+1)] = x    
    return new_x

x = t.tensor([[[1, 2, 3], [4, 5, 6]]])
# fractional_stride_1d(x, 2)
print(x[:,:,::1])
print(t.arange(0, 5, 2))

w5d1_tests.test_fractional_stride_1d(fractional_stride_1d)

tensor([[[1, 2, 3],
         [4, 5, 6]]])
tensor([0, 2, 4])
2
[1, 2, 3]
[1, 2, 5]
tensor([0, 2, 4])
All tests in `test_fractional_stride_1d` passed!


In [12]:
# def conv_transpose1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
#     '''Like torch's conv_transpose1d using bias=False and all other keyword arguments left at their default values.

#     x: shape (batch, in_channels, width)
#     weights: shape (in_channels, out_channels, kernel_width)

#     Returns: shape (batch, out_channels, output_width)
#     '''

#     """
#     x: shape (batch, in_channels, width)
#     weights: shape (out_channels, in_channels, kernel_width)
#     """
#     _, _, width = x.shape
#     _, _, kernel_width = weights.shape
#     # add padding
#     padlen = kernel_width - 1
#     new_width = width + 2 * padlen
#     # reverse weights
#     reversed_weights = weights.flip(-1)

#     import torch.nn.functional as F
#     # now we expand to size (7, 11) by appending a row of 0s at pos 0 and pos 6, 
#     # and a column of 0s at pos 10
#     padded_x = F.pad(input=x, pad=(padlen, padlen, 0, 0, 0, 0), mode='constant', value=0)
#     reversed_weights = rearrange(reversed_weights, 'in out w -> out in w')
#     return conv1d_minimal(padded_x, reversed_weights)


def conv_transpose1d(x, weights, stride: int = 1, padding: int = 0) -> t.Tensor:
    '''Like torch's conv_transpose1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (in_channels, out_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''
    _, _, width = x.shape
    _, _, kernel_width = weights.shape
    # add padding
    padlen = kernel_width - 1 - padding
    x_frac = fractional_stride_1d(x, stride)
    padded_x = F.pad(input=x_frac, pad=(padlen, padlen, 0, 0, 0, 0), mode='constant', value=0)
    reversed_weights = weights.flip(-1)
    reversed_weights = rearrange(reversed_weights, 'in out w -> out in w')
    
    return conv1d_minimal(padded_x, reversed_weights)
    
    

w5d1_tests.test_conv_transpose1d(conv_transpose1d)

3
[4, 5, 41]
[4, 5, 121]
tensor([  0,   3,   6,   9,  12,  15,  18,  21,  24,  27,  30,  33,  36,  39,
         42,  45,  48,  51,  54,  57,  60,  63,  66,  69,  72,  75,  78,  81,
         84,  87,  90,  93,  96,  99, 102, 105, 108, 111, 114, 117, 120])
2
[6, 6, 41]
[6, 6, 81]
tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
        36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70,
        72, 74, 76, 78, 80])
3
[2, 11, 36]
[2, 11, 106]
tensor([  0,   3,   6,   9,  12,  15,  18,  21,  24,  27,  30,  33,  36,  39,
         42,  45,  48,  51,  54,  57,  60,  63,  66,  69,  72,  75,  78,  81,
         84,  87,  90,  93,  96,  99, 102, 105])
2
[2, 8, 37]
[2, 8, 73]
tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
        36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70,
        72])
2
[3, 4, 43]
[3, 4, 85]
tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,