In [3]:
import torch
from torch import nn

In [105]:
batch_size = 32
height = 10
length = 5
x = torch.randn(batch_size, height, length)

hidden = [4, 2]
kernel_size = 2
stride = 2


def conv1d_size(length: int, kernel_size: int, stride: int) -> int:
    """Calculate output dimensions for 1d convolution.

    Args:
        length (int): Input size.
        kernel_size (int): Kernel_size.
        stride (int): Stride.

    Returns:
        int: Output size.
    """
    return (length - kernel_size) // stride + 1


shapes = []
encoder = []
channels = height

for h in hidden:
    shapes.append((channels, length))
    conv = nn.Conv1d(
        in_channels=channels,
        out_channels=h,
        kernel_size=kernel_size,
        stride=stride,
        padding=0,
    )
    encoder.append(conv)
    length = conv1d_size(length, kernel_size, stride)
    channels = h

shapes.append((channels, length))
encoder = nn.Sequential(*encoder)
print(shapes)
shapes.reverse()

decoder = []
for i in range(len(hidden)):
    channels_in, length_in = shapes[i]
    channels_out, length_out = shapes[i + 1]
    if length_out % 2 == 0:
        padding = 0
    else:
        padding = 1
    conv = nn.ConvTranspose1d(
        in_channels=channels_in,
        out_channels=channels_out,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=padding,
    )
    decoder.append(conv)
decoder = nn.Sequential(*decoder)

print(x.shape)
x = encoder(x)
print(x.shape)
x = decoder(x)
print(x.shape)

[(10, 5), (4, 2), (2, 1)]
torch.Size([32, 10, 5])
torch.Size([32, 2, 1])
torch.Size([32, 10, 5])


In [94]:
y = torch.randn_like(x)
print(y.shape)

dconv = nn.ConvTranspose1d(
    in_channels=height,
    out_channels=height,
    kernel_size=2,
    stride=2,
    output_padding=0,
)
dconv_padded = nn.ConvTranspose1d(
    in_channels=height,
    out_channels=height,
    kernel_size=2,
    stride=2,
    output_padding=1,
)

for pad in pads[::-1]:
    if pad:
        print("not padded")
        y = dconv_padded(y)
    else:
        print("padded")
        y = dconv(y)
    print(y.shape)

torch.Size([32, 10, 2])
padded
torch.Size([32, 10, 4])
padded
torch.Size([32, 10, 8])
not padded
torch.Size([32, 10, 17])


In [58]:
batch_size = 32
channels = 3
height = 10
length = 24
e = torch.randn(batch_size, channels, length, height)

conv = nn.Conv2d(
    in_channels=channels,
    out_channels=height,
    kernel_size=(2, height),
    stride=(2, 1),
    padding=(1, 0),
)

e = conv(e)
print(e.shape)

torch.Size([32, 10, 13, 1])
