- A module could describe a single layer, a component consisting of multiples layers or the entire model itself. it is a resusable abstraction which can be combined into larger artefacts.

In [9]:
import torch
from torch import nn
from torch.nn import functional as F

In [12]:
class MLP(nn.Module):
    """
    A custom MLP with 256 input units wit ReLU and 10 output units
    """

    def __init__(self):
        super().__init__()
        self.hidden = nn.LazyLinear(256)
        self.out = nn.LazyLinear(10)

    def forward(self, X):
        """As simple as it can get"""
        return self.out(F.relu(self.hidden(X)))

In [None]:
net = MLP()
X = torch.rand(2, 20)
print(X)
net(X).shape

## Sequential Module

In [15]:
class MySequential(nn.Module):
    """
    Chain modules together in a sequential manner. It takes a list of modules as input and then chains them.
    It implements a foward method which given an input passes through all modules
    """

    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

    def forward(self, X):
        for module in self.children():
            X = module(X)
        return X

In [16]:
net = MySequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
net(X).shape



torch.Size([2, 10])

In [19]:
class FixedHiddenMLP(nn.Module):
    """This is a custom with non-sequential nn with fixed weights and with a foward method which contains control flow"""

    def __init__(self):
        # The weights are fixed and thus never update by backprop
        self.rand_weight = torch.rand((20, 20))
        self.linear = nn.LazyLinear(20)

    def forward(self, X):
        X = self.linear(X)
        X = F.relu(X @ self.rand_weight + 1)  # X.W + b
        X = self.linear(X)
        while X.abs().sum() > 1:
            X /= 2
        return X.sum()

In [26]:
torch.rand(3, 3, 3)

tensor([[[0.8432, 0.5344, 0.7995],
         [0.8956, 0.1778, 0.8018],
         [0.9378, 0.6929, 0.1207]],

        [[0.8299, 0.9529, 0.3596],
         [0.1234, 0.6471, 0.6257],
         [0.0470, 0.6792, 0.6505]],

        [[0.2190, 0.3297, 0.8481],
         [0.3235, 0.0333, 0.8065],
         [0.9824, 0.0164, 0.4167]]])