In [1]:
import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    def forward(self, X):
        return X - X.mean()

In [3]:
layer = CenteredLayer()
X = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)

In [4]:
layer(X)

tensor([-2., -1.,  0.,  1.,  2.])

In [5]:
net =nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [6]:
X = torch.rand(4, 8)
y = net(X)
y.mean().item()

5.587935447692871e-09

In [9]:
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))
    def forward(self, X):
        for i in range(len(self.params)):
            X = torch.mm(X, self.params[i])
        return X

net = MyDense()
print(net)

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [10]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
            'linear1': nn.Parameter(torch.randn(4, 4)),
            'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({
            'linear3': nn.Parameter(torch.randn(4, 2))
        })
    def forward(self, X, choice='linear1'):
        return torch.mm(X, self.params[choice])

net = MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)


In [11]:
X = torch.ones(1, 4)
print(net(X, 'linear1'))
print(net(X, 'linear2'))
print(net(X, 'linear3'))

tensor([[-0.3173,  0.2780, -0.0839, -3.5765]], grad_fn=<MmBackward>)
tensor([[4.2340]], grad_fn=<MmBackward>)
tensor([[ 2.4874, -0.0837]], grad_fn=<MmBackward>)


In [13]:
net = nn.Sequential(MyDictDense(),
                    MyDense())
print(net)
print(net(X))

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[-39.0587]], grad_fn=<MmBackward>)
