In [1]:
import torch
from typing import Any

In [9]:
# Residual connections
class MyModelLN(torch.nn.Module):
  class Block(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
      super().__init__()
      self.model = torch.nn.Sequential(
        torch.nn.Linear(in_channels, out_channels),
        torch.nn.LayerNorm(out_channels),
        torch.nn.ReLU(),
        torch.nn.Linear(out_channels, out_channels),
        torch.nn.LayerNorm(out_channels),
        torch.nn.ReLU()
      ) # We can add two blocks before the residual connection

      # Validate the number of input channels matches the number of output channels for the residual connections
      if in_channels != out_channels:
        self.skip = torch.nn.Linear(in_channels, out_channels) # Add a linear layer to change the shape and match the output
      else:
        self.skip = torch.nn.Identity()

    def forward(self, x):
      return self.model(x) + self.skip(x) # By adding `x`, we have added a residual network

  def __init__(self, layer_size = [512, 512, 512]):
    super(MyModelLN, self).__init__()
    layers = []
    layers.append(torch.nn.Flatten())
    c = 128*128*3
    layers.append(torch.nn.Linear(c, 512, bias=False)) # Always start with a linear layer, then blocks of residual connections
    c = 512 # Update `c` after the first layer
    for s in layer_size:
      layers.append(self.Block(c, s))
      c = s
    layers.append(torch.nn.Linear(c, 102, bias=False))
    self.model = torch.nn.Sequential(*layers)

  def forward(self, x) -> Any:
    return self.model(x)

x = torch.rand(10, 3, 128, 128)
net = MyModelLN([512]*4)
print(net(x))

tensor([[ 0.4891, -1.5137, -0.2610,  ..., -0.3551, -2.3938, -0.8180],
        [ 1.2984, -1.4633,  0.2115,  ..., -0.9101, -2.0068,  0.1915],
        [ 0.8688, -0.8597, -0.3191,  ..., -0.9664, -2.4525, -0.5118],
        ...,
        [ 1.3604, -1.1332, -0.0411,  ...,  0.0526, -2.3494, -0.6312],
        [ 0.3147, -1.7935,  0.5190,  ..., -0.4628, -2.2549,  0.0798],
        [ 0.4137, -1.5348,  0.3141,  ..., -0.4908, -2.5118,  0.4319]],
       grad_fn=<MmBackward0>)
