In [1]:
import torch as t
import torch.nn as nn

from einops import rearrange
from typing import Union, Tuple
import utils

In [2]:
def conv_transpose1d_minimal(x: t.Tensor, weights: t.Tensor) -> t.Tensor:
    '''Like torch's conv_transpose1d using bias=False and all other keyword arguments left
    at their default values.

    x: shape (batch, in_channels, width)
    weights: shape (in_channels, out_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''
    kernel_width = weights.shape[-1]
    pad_size = kernel_width - 1

    x = t.nn.functional.pad(x, (pad_size, pad_size))
    weights = rearrange(weights, "i o w -> o i w")
    weights = t.flip(weights, (-1,))

    return t.nn.functional.conv1d(x, weights)

utils.test_conv_transpose1d_minimal(conv_transpose1d_minimal)

All tests in `test_conv1d_minimal` passed!


In [3]:
def fractional_stride_1d(x, stride: int = 1):
    '''Returns a version of x suitable for transposed convolutions, i.e. "spaced out" with zeros
    between its values.
    This spacing only happens along the last dimension.

    x: shape (batch, in_channels, width)

    Example: 
        x = [[[1, 2, 3], [4, 5, 6]]]
        stride = 2
        output = [[[1, 0, 2, 0, 3], [4, 0, 5, 0, 6]]]
    '''
    new_shape = list(x.shape)
    new_shape[-1] = x.shape[-1] + (x.shape[-1] - 1) * (stride - 1)

    x_fstrided = t.zeros(new_shape, dtype=x.dtype, device=x.device)
    x_fstrided[..., ::stride] = x
    return x_fstrided

utils.test_fractional_stride_1d(fractional_stride_1d)

All tests in `test_fractional_stride_1d` passed!


In [4]:

def conv_transpose1d(x, weights, stride: int = 1, padding: int = 0) -> t.Tensor:
    '''Like torch's conv_transpose1d using bias=False and all other keyword arguments left at their
    default values.

    x: shape (batch, in_channels, width)
    weights: shape (out_channels, in_channels, kernel_width)

    Returns: shape (batch, out_channels, output_width)
    '''
    kernel_width = weights.shape[-1]
    pad_size = kernel_width - 1 - padding

    x = fractional_stride_1d(x, stride)
    x = t.nn.functional.pad(x, (pad_size, pad_size))

    weights = rearrange(weights, "i o w -> o i w")
    weights = t.flip(weights, (-1,))

    return t.nn.functional.conv1d(x, weights)

utils.test_conv_transpose1d(conv_transpose1d)

All tests in `test_conv_transpose1d` passed!


In [5]:
IntOrPair = Union[int, Tuple[int, int]]
Pair = Tuple[int, int]

def force_pair(v: IntOrPair) -> Pair:
    '''Convert v to a pair of int, if it isn't already.'''
    if isinstance(v, tuple):
        if len(v) != 2:
            raise ValueError(v)
        return (int(v[0]), int(v[1]))
    elif isinstance(v, int):
        return (v, v)
    raise ValueError(v)

In [6]:
def fractional_stride_2d(x, stride_h: int, stride_w: int):
    '''
    Same as fractional_stride_1d, except we apply it along the last 2 dims of x (width and height).
    '''
    new_shape = list(x.shape)
    new_shape[-2] = x.shape[-2] + (x.shape[-2] - 1) * (stride_h - 1)
    new_shape[-1] = x.shape[-1] + (x.shape[-1] - 1) * (stride_w - 1)

    x_fstrided = t.zeros(new_shape, dtype=x.dtype, device=x.device)
    x_fstrided[..., ::stride_h, ::stride_w] = x
    return x_fstrided

In [7]:
def conv_transpose2d(x, weights, stride: IntOrPair = 1, padding: IntOrPair = 0) -> t.Tensor:
    '''Like torch's conv_transpose2d using bias=False

    x: shape (batch, in_channels, height, width)
    weights: shape (out_channels, in_channels, kernel_height, kernel_width)


    Returns: shape (batch, out_channels, output_height, output_width)
    '''
    stride_h, stride_w = force_pair(stride)
    padding_h, padding_w = force_pair(padding)

    kernel_height, kernel_width = weights.shape[-2], weights.shape[-1]
    pad_size_h = kernel_height - 1 - padding_h
    pad_size_w = kernel_width - 1 - padding_w

    x = fractional_stride_2d(x, stride_h, stride_w)
    x = t.nn.functional.pad(x, (pad_size_w, pad_size_w, pad_size_h, pad_size_h))

    weights = rearrange(weights, "i o h w -> o i h w")
    weights = t.flip(weights, (-2, -1))

    return t.nn.functional.conv2d(x, weights)

utils.test_conv_transpose2d(conv_transpose2d)

All tests in `test_conv_transpose2d` passed!


In [8]:
class ConvTranspose2d(nn.Module):
    def __init__(self,
        in_channels: int,
        out_channels: int,
        kernel_size: IntOrPair,
        stride: IntOrPair = 1,
        padding: IntOrPair = 0
    ):
        '''
        Same as torch.nn.ConvTranspose2d with bias=False.

        Name your weight field `self.weight` for compatibility with the tests.
        '''
        super().__init__()

        kernel_h, kernel_w = force_pair(kernel_size)
        self.stride = force_pair(stride)
        self.padding = force_pair(padding)

        sqrt_k = t.tensor(1 / (out_channels * kernel_w * kernel_h)).sqrt()
        weight_shape = (in_channels, out_channels, kernel_h, kernel_w)
        self.weight = nn.Parameter(t.FloatTensor(*weight_shape).uniform_(-sqrt_k, sqrt_k))

    def forward(self, x: t.Tensor) -> t.Tensor:
        return conv_transpose2d(x, self.weight, stride=self.stride, padding=self.padding)

utils.test_ConvTranspose2d(ConvTranspose2d)

All tests in `test_ConvTranspose2d` passed!


In [9]:
class Tanh(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        return (t.exp(2 * x) - 1) / (t.exp(2 * x) + 1)

utils.test_Tanh(Tanh)

All tests in `test_Tanh` passed.


In [10]:
class LeakyReLU(nn.Module):
    def __init__(self, negative_slope: float = 0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x: t.Tensor) -> t.Tensor:
        return t.maximum(x, t.tensor(0)) + self.negative_slope * t.minimum(x, t.tensor(0))

    def extra_repr(self) -> str:
        return f'negative_slope={self.negative_slope}'

utils.test_LeakyReLU(LeakyReLU)

All tests in `test_LeakyReLU` passed.


In [11]:
class Sigmoid(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        return 1 / (1 + t.exp(-x))

utils.test_Sigmoid(Sigmoid)

All tests in `test_Sigmoid` passed.
