In [2]:
import torch.nn as nn


class ModuleList(nn.Module):
    def __init__(self):
        super(ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for _ in range(20)]) # 接受一个由网络层组成的列表

    def change(self):
        self.linears.append(nn.ReLU())
        self.linears.extend([nn.Linear(10, 10), nn.Linear(10, 10)])

    def forward(self, x):
        """前向传播"""
        for i, linear in enumerate(self.linears):
            x = linear[i // 2](x) + linear[x] # 通过网络层顺序进行索引
        return x


net = ModuleList()
net

ModuleList(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): Linear(in_features=10, out_features=10, bias=True)
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): Linear(in_features=10, out_features=10, bias=True)
    (8): Linear(in_features=10, out_features=10, bias=True)
    (9): Linear(in_features=10, out_features=10, bias=True)
    (10): Linear(in_features=10, out_features=10, bias=True)
    (11): Linear(in_features=10, out_features=10, bias=True)
    (12): Linear(in_features=10, out_features=10, bias=True)
    (13): Linear(in_features=10, out_features=10, bias=True)
    (14): Linear(in_features=10, out_features=10, bias=True)
    (15): Linear(in_features=10, out_features=10, bias=Tru

In [3]:
print(net.linears[0]) # 通过网络层顺序进行索引或切片
print(type(net.linears[0]))
print(net.linears[8:15])

Linear(in_features=10, out_features=10, bias=True)
<class 'torch.nn.modules.linear.Linear'>
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Linear(in_features=10, out_features=10, bias=True)
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
  (6): Linear(in_features=10, out_features=10, bias=True)
)


In [4]:
net.linears.append(nn.ReLU()) # 增加一个网络层(类似list中append方法)
net

ModuleList(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): Linear(in_features=10, out_features=10, bias=True)
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): Linear(in_features=10, out_features=10, bias=True)
    (8): Linear(in_features=10, out_features=10, bias=True)
    (9): Linear(in_features=10, out_features=10, bias=True)
    (10): Linear(in_features=10, out_features=10, bias=True)
    (11): Linear(in_features=10, out_features=10, bias=True)
    (12): Linear(in_features=10, out_features=10, bias=True)
    (13): Linear(in_features=10, out_features=10, bias=True)
    (14): Linear(in_features=10, out_features=10, bias=True)
    (15): Linear(in_features=10, out_features=10, bias=Tru

In [5]:
net.linears.extend([nn.Linear(10, 10), nn.Linear(10, 10)]) # 类似list中extend方法
net

ModuleList(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): Linear(in_features=10, out_features=10, bias=True)
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): Linear(in_features=10, out_features=10, bias=True)
    (8): Linear(in_features=10, out_features=10, bias=True)
    (9): Linear(in_features=10, out_features=10, bias=True)
    (10): Linear(in_features=10, out_features=10, bias=True)
    (11): Linear(in_features=10, out_features=10, bias=True)
    (12): Linear(in_features=10, out_features=10, bias=True)
    (13): Linear(in_features=10, out_features=10, bias=True)
    (14): Linear(in_features=10, out_features=10, bias=True)
    (15): Linear(in_features=10, out_features=10, bias=Tru

In [6]:
net.linears.insert(1, nn.ReLU()) # 类似list中insert方法
net

ModuleList(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): Linear(in_features=10, out_features=10, bias=True)
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): Linear(in_features=10, out_features=10, bias=True)
    (8): Linear(in_features=10, out_features=10, bias=True)
    (9): Linear(in_features=10, out_features=10, bias=True)
    (10): Linear(in_features=10, out_features=10, bias=True)
    (11): Linear(in_features=10, out_features=10, bias=True)
    (12): Linear(in_features=10, out_features=10, bias=True)
    (13): Linear(in_features=10, out_features=10, bias=True)
    (14): Linear(in_features=10, out_features=10, bias=True)
    (15): Linear(in_features=10, out_features=10, bias=True)
    (16): Linear(in_features=10, out_feat