In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, channel, kernel_size, stride=1, padding=None):
        super(ResidualBlock, self).__init__()
        if padding is None:
            padding = 1
        self.conv1 = nn.Conv2d(channel, channel, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv2 = nn.Conv2d(channel, channel, kernel_size=kernel_size, stride=stride, padding=padding)
        self.ln1 = nn.InstanceNorm2d(channel, affine=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logits = self.conv1(x)
        logits = F.relu(logits)
        logits = self.conv2(logits)
        logits = self.ln1(logits)
        logits = logits + x
        logits = F.relu(logits)
        return logits

In [3]:
class Encoder(nn.Module):
    def __init__(self, latent_channels=10, latent_dim=100, input_size=(2, 128, 400)):
        super(Encoder, self).__init__()
        self.latent_dim = latent_channels
        self.input_size = input_size
        self.conv0 = nn.Conv2d(2, latent_channels, kernel_size=3)
        self.conv1 = nn.ParameterList([nn.Conv2d(latent_channels, latent_channels, 5) for _ in range(20)])
        self.resd1 = nn.ParameterList([ResidualBlock(latent_channels, 3) for _ in range(10)])
        self.conv2 = nn.ParameterList([nn.Conv2d(latent_channels, latent_channels, 3) for _ in range(20)])
        # to latent dim (20, 11, 79)
        self.conv3 = nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1)
        self.fc_mu = nn.Linear(10 * 6 * 278, latent_dim)
        self.fc_logvar = nn.Linear(10 * 6 * 278, latent_dim)
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        logit = self.conv0(x)
        logit = F.relu(logit)

        for layer in self.conv1:
            logit = layer(logit)
            logit = F.relu(logit)

        for resblock in self.resd1:
            logit = resblock(logit)
            logit = F.relu(logit)

        for layer in self.conv2:
            logit = layer(logit)
            logit = F.relu(logit)

        logit = self.conv3(logit)
        logit = F.relu(logit)

        logit = logit.reshape(logit.size()[0], -1)
        return self.fc_mu(logit), self.fc_logvar(logit)


class Decoder(nn.Module):
    def __init__(self, latent_channels=10, latent_dim=100, input_size=(2, 128, 400)):
        super(Decoder, self).__init__()
        self.input_size = input_size
        self.latent_channels = latent_channels
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(latent_dim, 16680)
        self.conv1 = nn.Conv2d(latent_channels, latent_channels, 3, 1, padding=1)
        self.conv2 = nn.ParameterList([nn.ConvTranspose2d(latent_channels, latent_channels, 3, 1) for _ in range(20)])
        self.resd1 = nn.ParameterList([ResidualBlock(latent_channels, 3) for _ in range(10)])
        self.conv3 = nn.ParameterList([nn.ConvTranspose2d(latent_channels, latent_channels, 5) for _ in range(20)])
        self.conv0 = nn.ConvTranspose2d(latent_channels, 2, kernel_size=3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logit = self.fc1(x)
        logit = F.relu(logit)
        logit = logit.reshape(logit.size()[0], self.latent_channels, 6, 278)
        logit = self.conv1(logit)
        logit = F.relu(logit)
        for resblock in self.conv2:
            logit = resblock(logit)
            logit = F.relu(logit)
        for resblock in self.resd1:
            logit = resblock(logit)
            logit = F.relu(logit)
        for resblock in self.conv3:
            logit = resblock(logit)
            logit = F.relu(logit)
        logit = self.conv0(logit)
        return logit

encoder = Encoder()
decoder = Decoder()

In [4]:
out: torch.Tensor = encoder(torch.randn(1, 2, 128, 400))

In [6]:
out

(tensor([[-0.0765,  0.0094,  0.0191,  0.0062,  0.0088, -0.0091, -0.0235,  0.0162,
          -0.0107,  0.0102, -0.0036,  0.0186, -0.0145,  0.0260, -0.0088, -0.0189,
           0.0090,  0.0109,  0.0013, -0.0242,  0.0143, -0.0303,  0.0351,  0.0099,
          -0.0115, -0.0188, -0.0197, -0.0086, -0.0129, -0.0062, -0.0023,  0.0133,
           0.0266,  0.0086,  0.0099,  0.0322,  0.0107, -0.0196, -0.0466, -0.0203,
           0.0008,  0.0459, -0.0278, -0.0044, -0.0037, -0.0586, -0.0058, -0.0192,
           0.0210,  0.0105, -0.0433,  0.0386,  0.0534, -0.0230, -0.0193, -0.0197,
           0.0102, -0.0418, -0.0044,  0.0488, -0.0149,  0.0078,  0.0243,  0.0200,
           0.0082,  0.0331,  0.0299,  0.0489,  0.0143, -0.0111, -0.0146,  0.0182,
          -0.0264,  0.0303,  0.0326, -0.0269, -0.0333, -0.0115, -0.0208,  0.0522,
          -0.0436,  0.0272, -0.0214, -0.0158,  0.0414, -0.0503,  0.0275,  0.0041,
           0.0238, -0.0295, -0.0236,  0.0496, -0.0139, -0.0379, -0.0419, -0.0204,
           0.017

In [129]:
out: torch.Tensor = encoder(torch.randn(1, 2, 128, 400))

In [7]:
decoder(torch.randn(1, 100)).size()

torch.Size([1, 2, 128, 400])