In [10]:
import torch
import torch.nn.functional as F
from torch import nn

In [11]:
class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        return X - X.mean()

In [12]:
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

tensor([-2., -1.,  0.,  1.,  2.])

In [14]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [15]:
Y = net(torch.rand(4, 8))
Y, Y.mean()

(tensor([[ 1.0985e-01,  4.9189e-02,  1.6070e-01,  2.1082e-01, -7.4241e-02,
          -5.6759e-02,  4.0362e-01,  2.5117e-01, -2.3495e-01,  4.8191e-01,
          -7.6575e-01,  9.9875e-02,  3.3878e-01, -6.4911e-01, -4.1621e-01,
           4.0028e-01, -6.2980e-01,  4.6648e-01,  4.6622e-01, -5.8948e-02,
           4.3449e-01, -5.0905e-02,  4.5123e-03, -1.0727e-03, -3.8042e-01,
          -1.3612e-01,  3.8496e-01, -1.6689e-01, -1.7190e-01, -8.9879e-03,
           3.1670e-01,  3.9238e-02,  2.1080e-01, -1.4880e-02, -5.6392e-01,
          -3.6001e-02,  5.0327e-01, -1.0507e-01,  7.3186e-01, -1.9494e-01,
           2.1365e-01,  3.8387e-01, -1.1999e-01, -1.9439e-01, -4.6684e-01,
          -1.5380e-01, -6.5222e-02, -2.6561e-01, -7.8892e-01, -4.8507e-01,
          -7.4711e-01, -6.7983e-01, -1.1224e-01, -1.3697e-01, -6.7161e-01,
          -4.2997e-02, -5.8012e-01,  2.8306e-01,  4.8408e-01, -1.1037e+00,
          -6.9450e-01,  2.9168e-01,  2.7907e-01,  4.5428e-01,  2.1485e-01,
          -8.4338e-02,  1

In [16]:
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units, ))

    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

In [17]:
linear = MyLinear(5, 3)
linear.weight

Parameter containing:
tensor([[-0.6454, -0.0870, -2.2103],
        [-0.8660, -1.8385,  1.0740],
        [ 1.4739, -0.4810, -1.1145],
        [ 1.0355, -0.8969,  0.9906],
        [-1.1258, -0.1473, -1.0108]], requires_grad=True)

In [18]:
linear(torch.rand(2, 5))

tensor([[0., 0., 0.],
        [0., 0., 0.]])

In [19]:
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

tensor([[9.7259],
        [0.0000]])