[toc]

# Pytorch ModuleList

有的时候，我们希望用一个 list 来存放 module，然后用 for 循环来遍历这些module。

In [43]:
import torch.nn.functional as F


class myModel(nn.Module):
    def __init__(self):
        super(myModel, self).__init__()
        self.linears = [nn.Linear(3, 2) for _ in range(5)]
        self.fc = nn.Linear(10, 1)
    
    def forward(self, x):
        x = [F.relu(m(x)) for m in self.linears]
        x = torch.cat(x, dim=1)
        x = self.fc(x)
        return x

model = myModel()
model(x)

tensor([[0.5875],
        [0.4155],
        [0.6017],
        [0.4432],
        [0.4005],
        [0.3734],
        [0.5208],
        [0.4211],
        [0.4048],
        [0.3235]], grad_fn=<AddmmBackward>)

这样乍一看可以实现功能，但是实际上，使用 list 来存这些 nn.Linear 会导致 nn.Linear 不会被添加到 myModel.modules 中，其参数也不会被添加到 myModel.parameters 中

In [44]:
for m in model.modules():
    print(m)

for p in model.parameters():
    print(p)

myModel(
  (fc): Linear(in_features=10, out_features=1, bias=True)
)
Linear(in_features=10, out_features=1, bias=True)
Parameter containing:
tensor([[ 0.2114,  0.0426,  0.1360,  0.2596,  0.2948, -0.1039,  0.1681, -0.2994,
         -0.1194,  0.3006]], requires_grad=True)
Parameter containing:
tensor([0.1848], requires_grad=True)


可以用 nn.ModuleList 来实现这个功能

In [47]:
import torch.nn.functional as F


class myModel2(nn.Module):
    def __init__(self):
        super(myModel2, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(3, 2) for _ in range(5)])
        self.fc = nn.Linear(10, 1)
    
    def forward(self, x):
        x = [F.relu(m(x)) for m in self.linears]
        x = torch.cat(x, dim=1)
        x = self.fc(x)
        return x

model = myModel2()
model(x)

tensor([[ 0.5319],
        [ 0.3566],
        [ 0.2174],
        [-0.0533],
        [ 0.4019],
        [ 0.4418],
        [-0.1033],
        [ 0.1301],
        [ 0.1030],
        [ 0.3335]], grad_fn=<AddmmBackward>)

可以看到，nn.ModuleList 封装的 nn.Linear 被添加到到 myModel2.modules() 中，其参数也被添加到了 myModel2.parameters() 中

In [48]:
for m in model.modules():
    print(m)

for p in model.parameters():
    print(p)

myModel2(
  (linears): ModuleList(
    (0): Linear(in_features=3, out_features=2, bias=True)
    (1): Linear(in_features=3, out_features=2, bias=True)
    (2): Linear(in_features=3, out_features=2, bias=True)
    (3): Linear(in_features=3, out_features=2, bias=True)
    (4): Linear(in_features=3, out_features=2, bias=True)
  )
  (fc): Linear(in_features=10, out_features=1, bias=True)
)
ModuleList(
  (0): Linear(in_features=3, out_features=2, bias=True)
  (1): Linear(in_features=3, out_features=2, bias=True)
  (2): Linear(in_features=3, out_features=2, bias=True)
  (3): Linear(in_features=3, out_features=2, bias=True)
  (4): Linear(in_features=3, out_features=2, bias=True)
)
Linear(in_features=3, out_features=2, bias=True)
Linear(in_features=3, out_features=2, bias=True)
Linear(in_features=3, out_features=2, bias=True)
Linear(in_features=3, out_features=2, bias=True)
Linear(in_features=3, out_features=2, bias=True)
Linear(in_features=10, out_features=1, bias=True)
Parameter containing:
