# Import Modules

In [1]:
import torch
from torch import nn
from torch.nn import init

# Custom Weight and Bias

In [2]:
class Model(nn.Module):
    def __init__(self, init_weight_bias=False):
        super(Model, self).__init__()
        
        self.features = nn.Sequential(
            nn.Linear(2, 4),
            nn.Linear(4, 8),
        )
        
        self.regressor = nn.Sequential(
            nn.Linear(8, 1)
        )
        
        if init_weight_bias:
            self.init_weight_bias("features")
            self.init_weight_bias("regressor")

    def init_weight_bias(self, layer_name):
        for layer in self.get_submodule(layer_name):
            layer.weight = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.weight.shape),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                )
            )
            
            layer.bias = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.bias.shape[0], 1),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                ).flatten()
            )
        
    def forward(self, X):
        X = self.features(X)
        X = self.regressor(X)
        
        return X

Model()

Model(
  (features): Sequential(
    (0): Linear(in_features=2, out_features=4, bias=True)
    (1): Linear(in_features=4, out_features=8, bias=True)
  )
  (regressor): Sequential(
    (0): Linear(in_features=8, out_features=1, bias=True)
  )
)

# Without Init Weight Bias

In [3]:
torch.manual_seed(42)

model_without_init_weight_bias = Model(init_weight_bias=False)
model_without_init_weight_bias.state_dict()

OrderedDict([('features.0.weight',
              tensor([[ 0.5406,  0.5869],
                      [-0.1657,  0.6496],
                      [-0.1549,  0.1427],
                      [-0.3443,  0.4153]])),
             ('features.0.bias', tensor([ 0.6233, -0.5188,  0.6146,  0.1323])),
             ('features.1.weight',
              tensor([[ 0.3694,  0.0677,  0.2411, -0.0706],
                      [ 0.3854,  0.0739, -0.2334,  0.1274],
                      [-0.2304, -0.0586, -0.2031,  0.3317],
                      [-0.3947, -0.2305, -0.1412, -0.3006],
                      [ 0.0472, -0.4938,  0.4516, -0.4247],
                      [ 0.3860,  0.0832, -0.1624,  0.3090],
                      [ 0.0779,  0.4040,  0.0547, -0.1577],
                      [ 0.1343, -0.1356,  0.2104,  0.4464]])),
             ('features.1.bias',
              tensor([ 0.2890, -0.2186,  0.2886,  0.0895,  0.2539, -0.3048, -0.4950, -0.1932])),
             ('regressor.0.weight',
              tensor([[-0.2712

In [4]:
torch.manual_seed(42)

X = torch.randint(low=1, high=100, size=(2, ), dtype=torch.float32)
model_without_init_weight_bias(X).item()

-9.79812240600586

# With Init Weight Bias

In [5]:
torch.manual_seed(42)

model_with_init_weight_bias = Model(init_weight_bias=True)
model_with_init_weight_bias.state_dict()

OrderedDict([('features.0.weight',
              tensor([[-0.1735,  1.3850],
                      [ 0.7045,  1.2197],
                      [-0.6778, -0.5920],
                      [-0.6382, -1.9187]])),
             ('features.0.bias', tensor([-0.9109, -0.8571, -0.2015,  1.3755])),
             ('features.1.weight',
              tensor([[ 3.7096e-01,  8.0695e-01,  3.6518e-02,  5.1485e-01],
                      [-5.0250e-01, -4.2573e-01,  6.7914e-01,  2.8625e-01],
                      [-5.8532e-01,  9.4378e-01,  3.4191e-01, -1.3970e-01],
                      [ 8.9683e-01,  8.6568e-01,  6.9379e-02,  1.2320e+00],
                      [ 4.1242e-01,  7.5443e-01, -3.1831e-01, -4.7995e-01],
                      [ 4.0610e-01,  1.3276e-01, -2.5288e-01, -2.2381e-01],
                      [ 9.0027e-01, -9.2692e-04, -2.1468e-01, -6.9752e-01],
                      [ 8.7186e-02,  2.4739e-01,  4.3648e-01, -1.1974e-01]])),
             ('features.1.bias',
              tensor([-1.9137,  0.2

In [6]:
torch.manual_seed(42)

X = torch.randint(low=1, high=100, size=(2, ), dtype=torch.float32)
model_with_init_weight_bias(X).item()

137.23822021484375