# Import Modules

In [1]:
import torch
from torch import nn
from torch.nn import init

  from .autonotebook import tqdm as notebook_tqdm


# Custom Weight and Bias

In [2]:
class Model(nn.Module):
    def __init__(self, init_weight_bias=False):
        super(Model, self).__init__()
        
        self.features = nn.Sequential(
            nn.Linear(2, 4),
            nn.Linear(4, 8),
        )
        
        self.regressor = nn.Sequential(
            nn.Linear(8, 1)
        )
        
        if init_weight_bias:
            self.init_weight_bias()

    def init_weight_bias(self):
        for layer in self.get_submodule("features"):
            layer.weight = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.weight.shape),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                )
            )
            
            layer.bias = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.bias.shape[0], 1),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                ).flatten()
            )
            
        for layer in self.get_submodule("regressor"):
            layer.weight = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.weight.shape),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                )
            )
            
            layer.bias = nn.parameter.Parameter(
                init.kaiming_normal_(
                    tensor=torch.empty(layer.bias.shape[0], 1),
                    a=0,
                    mode="fan_in",
                    nonlinearity="leaky_relu"
                ).flatten()
            )
        
    def forward(self, X):
        X = self.features(X)
        X = self.regressor(X)
        
        return X

Model()

Model(
  (features): Sequential(
    (0): Linear(in_features=2, out_features=4, bias=True)
    (1): Linear(in_features=4, out_features=8, bias=True)
  )
  (regressor): Sequential(
    (0): Linear(in_features=8, out_features=1, bias=True)
  )
)

# Without Init Weight Bias

In [3]:
model_without_init_weight_bias = Model(init_weight_bias=False)
model_without_init_weight_bias.state_dict()

OrderedDict([('features.0.weight',
              tensor([[-0.1690, -0.5772],
                      [ 0.2068, -0.4578],
                      [-0.6095,  0.2628],
                      [ 0.5465,  0.2796]])),
             ('features.0.bias', tensor([-0.5703, -0.2293, -0.1294,  0.4650])),
             ('features.1.weight',
              tensor([[ 0.1644, -0.2435, -0.3900,  0.0465],
                      [-0.3416, -0.1947,  0.0654,  0.4741],
                      [-0.3091,  0.1285, -0.2488,  0.3686],
                      [ 0.3393, -0.0681,  0.0978,  0.2410],
                      [ 0.0265, -0.2275, -0.0651,  0.3685],
                      [-0.2107, -0.4869, -0.1816,  0.3392],
                      [ 0.3699,  0.2367, -0.0845, -0.3352],
                      [ 0.1770,  0.4642,  0.2331, -0.3586]])),
             ('features.1.bias',
              tensor([-0.3951,  0.3400,  0.4565, -0.3223,  0.4544,  0.1985,  0.3268, -0.3842])),
             ('regressor.0.weight',
              tensor([[-0.3157

In [4]:
X = torch.randint(low=1, high=100, size=(2, ), dtype=torch.float32)
model_without_init_weight_bias(X).item()

2.8935952186584473

# With Init Weight Bias

In [5]:
model_with_init_weight_bias = Model(init_weight_bias=True)
model_with_init_weight_bias.state_dict()

TypeError: empty() received an invalid combination of arguments - got (torch.Size, int), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [None]:
X = torch.randint(low=1, high=100, size=(2, ), dtype=torch.float32)
model_with_init_weight_bias(X).item()