In [1]:
import torch as t
from fancy_einsum import einsum
from typing import Union, Tuple, Optional

import utils

### Exercise 1

In [75]:
def conv1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    '''Like torch's conv1d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''
    b, ic, iw = x.shape         # batch, in_channels, input_width
    oc, ic, kw = weights.shape  # out_channels, in_channels, kernel_width
    ow = iw - kw + 1            # output_width

    bs, ics, iws = x.stride()   # batch_stride, input_channel_stride, input_width_stride
    x_strided = x.as_strided(size=(b, ic, ow, kw), stride=(bs, ics, iws, iws))

    return einsum('b ic ow kw, oc ic kw -> b oc ow', x_strided, weights)

utils.test_conv1d_minimal(conv1d_minimal)

### Exercise 2

In [20]:
def conv2d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    '''Like torch's conv2d using bias=False and all other keyword arguments left at their default values.

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)

    Returns: shape (batch, out_channels, output_height, output_width)
    '''
    b, ic, ih, iw = x.shape         # batch, in_channels, input_height, input_width
    oc, ic, kh, kw = weights.shape  # out_channels, in_channels, kernel_height, kernel_width
    oh = ih - kh + 1                # output_height
    ow = iw - kw + 1                # output_width

    bs, ics, ihs, iws = x.stride()  # batch_stride, input_channel_stride, input_height_stride, input_width_stride
    x_strided = x.as_strided(size=(b, ic, oh, ow, kh, kw), stride=(bs, ics, ihs, iws, ihs, iws))

    return einsum('b ic oh ow kh kw, oc ic kh kw -> b oc oh ow', x_strided, weights)

utils.test_conv2d_minimal(conv2d_minimal)

### Exercise 3

In [43]:
def pad1d(x: t.Tensor, left: int, right: int, pad_value: float) -> t.Tensor:
    '''Return a new tensor with padding applied to the edges.

    x: shape (batch, in_channels, width), dtype float32

    Return: shape (batch, in_channels, left + right + width)
    '''
    b, c, w = x.shape
    x_padded = x.new_full((b, c, left + w + right), pad_value)
    x_padded[:, :, left:left+w] = x
    return x_padded


utils.test_pad1d(pad1d)
utils.test_pad1d_multi_channel(pad1d)

In [51]:
def pad2d(x: t.Tensor, left: int, right: int, top: int, bottom: int, pad_value: float) -> t.Tensor:
    '''Return a new tensor with padding applied to the edges.

    x: shape (batch, in_channels, height, width), dtype float32

    Return: shape (batch, in_channels, top + height + bottom, left + width + right)
    '''
    b, c, h, w = x.shape
    x_padded = x.new_full((b, c, top + h + bottom, left + w + right), pad_value)
    x_padded[:, :, top:top+h, left:left+w] = x
    return x_padded

utils.test_pad2d(pad2d)
utils.test_pad2d_multi_channel(pad2d)

### Exercise 4

In [92]:
def conv1d(x, weights, stride: int = 1, padding: int = 0) -> t.Tensor:
    '''Like torch's conv1d using bias=False.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''
    x = pad1d(x, padding, padding, 0)
    
    b, ic, iw = x.shape           # batch, in_channels, input_width
    oc, ic, kw = weights.shape    # out_channels, in_channels, kernel_width
    ow = (iw - kw) // stride + 1  # output_width
    

    bs, ics, iws = x.stride()     # batch_stride, input_channel_stride, input_width_stride
    x_strided = x.as_strided(size=(b, ic, ow, kw), stride=(bs, ics, stride, iws))
    return einsum('b ic ow kw, oc ic kw -> b oc ow', x_strided, weights)

utils.test_conv1d(conv1d)

In [97]:
IntOrPair = Union[int, Tuple[int, int]]
Pair = Tuple[int, int]

def force_pair(v: IntOrPair) -> Pair:
    '''Convert v to a pair of int, if it isn't already.'''
    if isinstance(v, tuple):
        if len(v) != 2:
            raise ValueError(v)
        return (int(v[0]), int(v[1]))
    elif isinstance(v, int):
        return (v, v)
    raise ValueError(v)

# Examples of how this function can be used:
#       force_pair((1, 2))     ->  (1, 2)
#       force_pair(2)          ->  (2, 2)
#       force_pair((1, 2, 3))  ->  ValueError

In [102]:
def conv2d(x, weights, stride: IntOrPair = 1, padding: IntOrPair = 0) -> t.Tensor:
    '''Like torch's conv2d using bias=False

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)

    Returns: shape (batch, out_channels, output_height, output_width)
    '''
    padding_h, padding_w = force_pair(padding)
    stride_h, stride_w = force_pair(stride)
    
    x = pad2d(x, padding_w, padding_w, padding_h, padding_h, 0)
    
    b, ic, ih, iw = x.shape         # batch, in_channels, input_height, input_width
    oc, ic, kh, kw = weights.shape  # out_channels, in_channels, kernel_height, kernel_width
    oh = (ih - kh) // stride_h + 1  # output_height
    ow = (iw - kw) // stride_w + 1  # output_width

    bs, ics, ihs, iws = x.stride()  # batch_stride, input_channel_stride, input_height_stride, input_width_stride
    x_strided = x.as_strided(
        size=(b, ic, oh, ow, kh, kw),
        stride=(bs, ics, ihs * stride_h, iws * stride_w, ihs, iws)
    )

    return einsum('b ic oh ow kh kw, oc ic kh kw -> b oc oh ow', x_strided, weights)

utils.test_conv2d(conv2d)

### Exercise 5

In [159]:
def maxpool2d(x: t.Tensor, kernel_size: IntOrPair, stride: Optional[IntOrPair] = None, padding: IntOrPair = 0
) -> t.Tensor:
    '''Like PyTorch's maxpool2d.

    x: shape (batch, channels, height, width)
    stride: if None, should be equal to the kernel size

    Return: (batch, channels, out_height, output_width)
    '''
    if stride is None:
        stride = kernel_size

    kh, kw = force_pair(kernel_size)
    padding_h, padding_w = force_pair(padding)
    stride_h, stride_w = force_pair(stride)
    
    x = pad2d(x, padding_w, padding_w, padding_h, padding_h, -t.inf)
    
    b, ic, ih, iw = x.shape         # batch, in_channels, input_height, input_width
    oh = (ih - kh) // stride_h + 1  # output_height
    ow = (iw - kw) // stride_w + 1  # output_width

    bs, ics, ihs, iws = x.stride()  # batch_stride, input_channel_stride, input_height_stride, input_width_stride
    x_strided = x.as_strided(
        size=(b, ic, oh, ow, kh, kw),
        stride=(bs, ics, ihs * stride_h, iws * stride_w, ihs, iws)
    )

    return x_strided.amax((-1, -2))

utils.test_maxpool2d(maxpool2d)