# Import Modules

In [1]:
import torch
from torch import nn
from torch.nn import init

  from .autonotebook import tqdm as notebook_tqdm


# Custom Weight and Bias

In [2]:
class Model(nn.Module):
    def __init__(self, init_weight_bias=False):
        super(Model, self).__init__()
        
        self.features = nn.Sequential(
            nn.Linear(2, 4),
            nn.Linear(4, 8),
        )
        
        self.regressor = nn.Sequential(
            nn.Linear(8, 1)
        )
        
        if init_weight_bias:
            self.init_weight_bias()

    def init_weight_bias(self):
        for layer in self.get_submodule("features"):
            layer.weight = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.weight.shape),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                )
            )
            
            layer.bias = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.bias.shape[0], 1),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                ).flatten()
            )
            
        for layer in self.get_submodule("regressor"):
            layer.weight = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.weight.shape),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                )
            )
            
            layer.bias = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.bias.shape[0], 1),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                ).flatten()
            )
        
    def forward(self, X):
        X = self.features(X)
        X = self.regressor(X)
        
        return X

Model()

Model(
  (features): Sequential(
    (0): Linear(in_features=2, out_features=4, bias=True)
    (1): Linear(in_features=4, out_features=8, bias=True)
  )
  (regressor): Sequential(
    (0): Linear(in_features=8, out_features=1, bias=True)
  )
)

# Without Init Weight Bias

In [3]:
model_without_init_weight_bias = Model(init_weight_bias=False)
model_without_init_weight_bias.state_dict()

OrderedDict([('features.0.weight',
              tensor([[-0.1543, -0.3934],
                      [ 0.2480, -0.1861],
                      [ 0.6934, -0.4746],
                      [ 0.1439,  0.3509]])),
             ('features.0.bias', tensor([-0.3414,  0.5008, -0.0406, -0.4307])),
             ('features.1.weight',
              tensor([[ 0.0737, -0.3186,  0.3577, -0.0697],
                      [ 0.3686, -0.3904, -0.3349, -0.2428],
                      [-0.0981,  0.2275, -0.3821, -0.1596],
                      [-0.1840,  0.1613,  0.3970, -0.2375],
                      [-0.4849,  0.3978, -0.2229,  0.0538],
                      [-0.3657,  0.1577, -0.4224, -0.4594],
                      [ 0.2261, -0.2111,  0.1444, -0.1431],
                      [ 0.0707, -0.3296, -0.1776,  0.2904]])),
             ('features.1.bias',
              tensor([ 0.0435,  0.4030,  0.1115,  0.3327,  0.1933,  0.0435, -0.0812,  0.4018])),
             ('regressor.0.weight',
              tensor([[-0.1586

In [4]:
X = torch.randint(low=1, high=100, size=(2, ), dtype=torch.float32)
model_without_init_weight_bias(X).item()

-5.4122209548950195

# With Init Weight Bias

In [5]:
model_with_init_weight_bias = Model(init_weight_bias=True)
model_with_init_weight_bias.state_dict()

OrderedDict([('features.0.weight',
              tensor([[-0.8373, -0.6031],
                      [-0.5011, -0.0018],
                      [-0.0609, -0.7094],
                      [ 0.9570,  0.3440]])),
             ('features.0.bias', tensor([-2.6650, -0.5723,  0.1846, -1.7616])),
             ('features.1.weight',
              tensor([[ 0.1436, -0.2205, -0.8100, -1.1376],
                      [ 0.7415, -0.7468, -0.0129,  0.6369],
                      [ 0.9104, -0.0813, -0.1333, -0.2930],
                      [ 0.9118,  0.8248,  0.1434, -0.6823],
                      [ 0.1461, -0.6782, -0.4863,  0.6377],
                      [-0.5486, -1.0457,  0.5480, -0.6921],
                      [-0.3371,  0.3533,  0.3993,  1.5323],
                      [ 0.6777,  0.1470,  0.8740, -1.0514]])),
             ('features.1.bias',
              tensor([-1.0013, -0.5680,  0.2845,  0.9347,  1.5941, -0.7580, -1.0245, -1.1026])),
             ('regressor.0.weight',
              tensor([[-0.7181

In [6]:
X = torch.randint(low=1, high=100, size=(2, ), dtype=torch.float32)
model_with_init_weight_bias(X).item()

-8.561196327209473