In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [4]:
class CenteredLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
    def forward(self,X):
        return X - X.mean()

In [5]:
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

tensor([-2., -1.,  0.,  1.,  2.])

In [6]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [7]:
Y = net(torch.rand(4, 8))
Y.mean()

tensor(-3.7253e-09, grad_fn=<MeanBackward0>)

In [8]:
class MyLinear(nn.Module):
    def __init__(self,in_units,units) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units,units))
        self.bias =nn.Parameter(torch.randn(units,))
    
    def forward(self,X):
        linear = torch.matmul(X,self.weight.data)+self.bias.data
        return F.relu(linear)

In [9]:
dense =MyLinear(5,3)
dense.weight

Parameter containing:
tensor([[ 1.8256,  1.9848,  0.8014],
        [ 0.2299, -0.4476, -0.2398],
        [ 0.2305, -0.5860, -0.3505],
        [ 0.3462,  0.5602,  0.7189],
        [-1.3748,  0.0814,  0.1008]], requires_grad=True)

In [11]:
net =nn.Sequential(MyLinear(64,8),MyLinear(8,1))
net.state_dict()

OrderedDict([('0.weight',
              tensor([[ 3.6528e-01, -2.2441e-01, -2.4445e-01, -1.0680e+00,  6.0884e-01,
                       -4.5474e-01, -1.3496e+00,  1.6570e+00],
                      [-1.5456e+00,  8.8172e-01, -9.5126e-01,  7.2025e-01,  6.0231e-01,
                       -6.3199e-01, -2.7070e-01, -4.8182e-02],
                      [ 6.8575e-02, -7.8239e-01, -7.5946e-01,  1.0078e+00,  9.7391e-03,
                       -3.6915e-02,  1.3527e+00, -4.2855e-01],
                      [-1.1152e+00,  9.8074e-01, -2.4605e-01,  3.2954e-01, -1.4208e+00,
                       -9.7146e-01,  6.3970e-01, -3.5377e-01],
                      [ 2.1305e+00,  1.3395e+00,  2.0693e-02, -2.3210e-01, -1.1961e+00,
                        1.0886e+00, -4.9638e-01,  3.0510e-01],
                      [-1.9477e+00, -1.3312e+00, -1.5776e+00,  1.2780e+00,  3.6000e-02,
                       -9.2632e-02,  1.5674e-01, -1.6363e+00],
                      [-6.0585e-01, -3.2625e-01,  5.6412e-01,  3.757