## Containers

In [1]:
import torch
import torch.nn as nn

### [Module](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html)

Base class for all neural network modules.


In [2]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        return self.relu(self.conv2(x))


Model()(torch.randn(1, 1, 32, 32)).shape

torch.Size([1, 20, 24, 24])

### [Sequential](https://docs.pytorch.org/docs/stable/generated/torch.nn.Sequential.html)

A sequential container.


In [3]:
from collections import OrderedDict

model = nn.Sequential(nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU())
model(torch.randn(1, 1, 32, 32))

model = nn.Sequential(
    OrderedDict(
        conv1=nn.Conv2d(1, 20, 5),
        relu1=nn.ReLU(),
        conv2=nn.Conv2d(20, 64, 5),
        relu2=nn.ReLU(),
    )
)
model(torch.randn(1, 1, 32, 32)).shape

torch.Size([1, 64, 24, 24])

### [ModuleList](https://docs.pytorch.org/docs/stable/generated/torch.nn.ModuleList.html)

Holds submodules in a list.


In [4]:
class Model(nn.ModuleList):
    def __init__(self):
        super().__init__(nn.Linear(10, 10) for _ in range(10))

    def forward(self, x):
        for i, l in enumerate(self):
            x = self[i // 2](x) + l(x)
        return x


Model()(torch.randn(1, 10)).shape

torch.Size([1, 10])

### [ModuleDict](https://docs.pytorch.org/docs/stable/generated/torch.nn.ModuleDict.html)

Holds submodules in a dictionary.


In [5]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.choices = nn.ModuleDict(
            {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)}
        )
        self.activations = nn.ModuleDict(
            [("lrelu", nn.LeakyReLU()), ("prelu", nn.PReLU())]
        )

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x


Model()(torch.randn(1, 10, 10, 10), "conv", "lrelu").shape

torch.Size([1, 10, 8, 8])