In [1]:
import torch
from torch import nn

In [2]:
class CenteredLayer(nn.Module):
    def __init__(self,**kwargs):
        super(CenteredLayer,self).__init__(**kwargs)
    def forward(self,x):
        return x-x.mean()

In [3]:
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))


tensor([-2., -1.,  0.,  1.,  2.])

In [4]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())


In [5]:
y=net(torch.rand(4,8))
y.mean().item()

9.313225746154785e-10

In [8]:
y.shape


torch.Size([4, 128])

## 含模型参数的自定义层

In [9]:
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense，self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,1)) for i in range(3)])

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [12]:
[nn.Parameter(torch.randn(4, 4)) for i in range(3)]

[Parameter containing:
 tensor([[-1.0593, -0.4763, -0.0411, -0.7268],
         [-0.7693, -2.0863, -0.1402, -1.6674],
         [-0.4251,  0.3844,  0.4256, -0.0722],
         [ 1.3343, -1.1075, -0.1018, -0.6938]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.0387, -0.3003,  0.7171,  0.6594],
         [-1.5237,  1.4976, -0.7864, -0.3164],
         [-0.4153,  0.0448, -0.7776, -0.1961],
         [-0.2953,  0.3014, -0.6460,  1.8578]], requires_grad=True),
 Parameter containing:
 tensor([[-0.2585,  0.4178, -1.0071, -1.9221],
         [-1.4201, -0.0967, -1.0294, -1.1975],
         [-1.4365,  0.6457, -2.2471, -0.0823],
         [ 0.7351,  1.0136, -1.1364,  1.6566]], requires_grad=True)]

In [16]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
                'linear1': nn.Parameter(torch.randn(4, 4)),
                'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)


MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)


In [17]:
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))


tensor([[-0.0744,  0.3404,  4.2077,  0.0982]], grad_fn=<MmBackward>)
tensor([[-1.7985]], grad_fn=<MmBackward>)
tensor([[ 3.5068, -1.4044]], grad_fn=<MmBackward>)


In [19]:
net = nn.Sequential(
    MyDictDense(),
)
print(net)
print(net(x))


Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
)
tensor([[-3.0449, -1.8065, -2.0970, -1.3760]], grad_fn=<MmBackward>)
