## 1 自定义网络层

### 1.1 自定义不含模型参数的层
下面的CenteredLayer类通过继承Module类自定义了一个将输入减掉均值后输出的层，并将层的计算定义在了forward函数里。这个层里不含模型参数。

In [11]:
import torch
import torch.nn as nn

class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    def forward(self, x):
        return x - x.mean()

测试：

In [12]:
layer = CenteredLayer()
# 自动执行前向计算
print(layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)))

tensor([-2., -1.,  0.,  1.,  2.])


作为复杂模型的一部分：

In [16]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
print(y.mean().item())

7.916241884231567e-09


### 1.2 自定义含模型参数的层
在自定义含模型参数的层时，我们应该将参数定义成torch.nn.Parameter，此外还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。

使用ParameterList：

In [22]:
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))
    
    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x

net = MyDense()
print(net)
y = net(torch.rand(1, 4))
print(y)

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)
tensor([[8.4681]], grad_fn=<MmBackward>)


使用ParameterDict：

In [25]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
            'linear1': nn.Parameter(torch.randn(4, 4)),
            'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})
    
    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)


In [26]:
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

tensor([[-0.2474,  0.2693, -0.1179, -0.5974]], grad_fn=<MmBackward>)
tensor([[2.6846]], grad_fn=<MmBackward>)
tensor([[-1.2708,  1.7712]], grad_fn=<MmBackward>)


作为复杂模型的一部分：

In [29]:
net = nn.Sequential(
    MyDictDense(),
    MyDense(),
)
print(net)
print(net(x))

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[7.3381]], grad_fn=<MmBackward>)
