In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
class Conv2dBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation=nn.ReLU()):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.act = activation
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size)

    def forward(self, x):
        return self.act(self.conv2(self.act(self.conv1(x))))

In [24]:
class Encoder(nn.Module):
    def __init__(self, layers: list, kernels: list, pool_size: list):
        super().__init__()
        if len(layers)-1 != len(kernels):
            raise ValueError("Wrong number of kernels")
        if len(layers)-1 != len(pool_size):
            raise ValueError("Wrong number of pool sizes")

        self.enc_blocks = nn.ModuleList()
        self.pools = nn.ModuleList()
        for i in range(len(layers)-1):
            self.enc_blocks.append(Conv2dBlock(layers[i], layers[i+1], kernels[i]))
            self.pools.append(nn.MaxPool2d(pool_size[i]))

    def forward(self, x):
        ftrs = []
        for block, pool in zip(self.enc_blocks, self.pools):
            x = block(x)
            ftrs.append(x)
            x = pool(x)
        return ftrs

In [25]:
# generate fake input
x = torch.randn(torch.Size((1, 1, 40, 25)))
m = torch.full_like(x, 1)
input = torch.concat((x, m), dim=1)

In [26]:
encoder = Encoder(
    layers=[2, 32, 64, 128],
    kernels=[
        (3, 3),
        (3, 3),
        (3, 1),
    ],
    pool_size=[
        (2, 2),
        (2, 2),
        (3, 3),
    ],
)

In [28]:
out = encoder.forward(input)
for el in out:
    print(el.shape)

torch.Size([1, 32, 36, 21])
torch.Size([1, 64, 14, 6])
torch.Size([1, 128, 3, 3])


In [29]:
encoder

Encoder(
  (enc_blocks): ModuleList(
    (0): Conv2dBlock(
      (conv1): Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1))
      (act): ReLU()
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): Conv2dBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (act): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): Conv2dBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 1), stride=(1, 1))
      (act): ReLU()
      (conv2): Conv2d(128, 128, kernel_size=(3, 1), stride=(1, 1))
    )
  )
  (pools): ModuleList(
    (0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (2): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=0, dilation=1, ceil_mode=False)
  )
)

In [30]:
encoder = nn.Sequential(
    Conv2dBlock(in_channels=2, out_channels=32, kernel_size=(3, 3)),
    nn.MaxPool2d(2, 2),
    Conv2dBlock(in_channels=32, out_channels=64, kernel_size=(3, 3)),
    nn.MaxPool2d(2, 2),
    Conv2dBlock(in_channels=64, out_channels=128, kernel_size=(3, 1)),
    nn.MaxPool2d(3, 3),
)

x = encoder.forward(input)
x.shape

torch.Size([1, 128, 1, 1])

In [31]:
decoder = nn.Sequential(

)

In [17]:
encoder = Encoder(layers=[2, 32, 64, 128], kernels=[3, 3, 3, 3], pool_size=2)
encoder.forward(input).shape

RuntimeError: Calculated padded input size per channel: (5 x 1). Kernel size: (3 x 3). Kernel size can't be greater than actual input size

In [6]:
class Conv1dBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size)

    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

In [7]:
class Decoder(nn.Module):
    def __init__(self, layers, kernels, upscale_size, stride):
        super().__init__()
        self.layers = layers
        self.upconvs = nn.ModuleList([
            nn.ConvTranspose1d(layers[i], layers[i+1], upscale_size, stride)
            for i in range(len(layers)-1)
        ])
        self.dec_blocks = nn.ModuleList([
            Conv1dBlock(layers[i], layers[i+1], kernels[i])
            for i in range(len(layers)-1)
        ])

    # def forward(self, x):