In [6]:
import torch

In [9]:
class ConvNet(torch.nn.Module):
    class Block(torch.nn.Module):
        def __init__(self, in_channels: int, out_channels: int, stride: int) -> None:
            super().__init__()
            kernel_size = 3
            padding = (kernel_size - 1) // 2
            self.c1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
            self.c2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding)
            self.c3 = torch.nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding)
            self.relu = torch.nn.ReLU()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.relu(self.c1(x))
            x = self.relu(self.c2(x))
            x = self.relu(self.c3(x))
            return x

    def __init__(self, channels_l0 = 64, n_blocks = 4) -> None:
        super().__init__()
        cnn_layers = [
            torch.nn.Conv2d(3, channels_l0, kernel_size=11, stride=2, padding=5),
            torch.nn.ReLU()
        ]
        c1 = channels_l0
        for _ in range(n_blocks):
            c2 = c1 * 2
            cnn_layers.append(self.Block(c1, c2, stride=2))
            c1 = c2
        cnn_layers.append(torch.nn.Conv2d(c1, 1, kernel_size=1)) # 1x1 convolution
        # cnn_layers.append(torch.nn.AdaptiveAvgPool2d(1)) # Pool everything together and average the outputs
        self.network = torch.nn.Sequential(*cnn_layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

net = ConvNet(n_blocks=3)
x = torch.randn(1, 3, 64, 64)
print(net(x).shape)

torch.Size([1, 1, 4, 4])
