In [2]:
import torch
from torch import nn

In [50]:
batch_size = 32
height = 10
length = 16
x = torch.randn(batch_size, height, length)

hidden = [10, 10]
kernel_size = 3
stride = 2
padding = 0


def conv1d_size(
    length: int, kernel_size: int, stride: int, padding: int
) -> int:
    """Calculate output dimensions for 1d convolution.

    Args:
        length (int): Input size.
        kernel_size (int): Kernel_size.
        stride (int): Stride.

    Returns:
        int: Output size.
    """
    return int((length + (2 * padding) - (kernel_size - 1) - 1) / stride) + 1


shapes = []
encoder = []
channels = height

for h in hidden:
    shapes.append((channels, length))
    conv = nn.Conv1d(
        in_channels=channels,
        out_channels=h,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
    )
    encoder.append(conv)
    length = conv1d_size(length, kernel_size, stride, padding)
    channels = h

shapes.append((channels, length))
encoder = nn.Sequential(*encoder)
print(shapes)
shapes.reverse()

decoder = []
for i in range(len(hidden)):
    channels_in, length_in = shapes[i]
    channels_out, length_out = shapes[i + 1]
    if length_out % 2 == 1:
        pad = 1
    else:
        pad = 0
    conv = nn.ConvTranspose1d(
        in_channels=channels_in,
        out_channels=channels_out,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        output_padding=pad,
    )
    decoder.append(conv)
decoder = nn.Sequential(*decoder)

print(x.shape)
x = encoder(x)
print(x.shape)
x = decoder(x)
print(x.shape)

[(10, 16), (10, 7), (10, 3)]
torch.Size([32, 10, 16])
torch.Size([32, 10, 3])
torch.Size([32, 10, 17])


In [46]:
def conv1d_size(
    length: int, kernel_size: int, stride: int, padding: int
) -> int:
    """Calculate output dimensions for 1d convolution.

    Args:
        length (int): Input size.
        kernel_size (int): Kernel_size.
        stride (int): Stride.

    Returns:
        int: Output size.
    """
    return int((length + (2 * padding) - (kernel_size - 1) - 1) / stride) + 1


batch_size = 32
height = 10
length = 19
y = torch.randn(batch_size, height, length)
print(y.shape)

conv = nn.Conv1d(
    in_channels=height, out_channels=height, kernel_size=2, stride=2, padding=0
)

lengths = [y.shape[-1]]
for i in range(2):
    length = conv1d_size(length, 2, 2, 0)
    lengths.append(length)
    y = conv(y)
    print(y.shape)

print(lengths)

torch.Size([32, 10, 19])
torch.Size([32, 10, 9])
torch.Size([32, 10, 4])
[19, 9, 4]


In [58]:
batch_size = 32
channels = 3
height = 10
length = 24
e = torch.randn(batch_size, channels, length, height)

conv = nn.Conv2d(
    in_channels=channels,
    out_channels=height,
    kernel_size=(2, height),
    stride=(2, 1),
    padding=(1, 0),
)

e = conv(e)
print(e.shape)

torch.Size([32, 10, 13, 1])
