In [1]:
import torch
from torch import nn

In [2]:
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape



torch.Size([2, 1])

# Built-in Initialization

In [4]:
def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight,mean=0,std=0.01)
        nn.init.zeros_(module.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([ 0.0050,  0.0142, -0.0123, -0.0108]), tensor(0.))

In [5]:
def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([1., 1., 1., 1.]), tensor(0.))

In [8]:
def init_xavier(module):
    if type(module)== nn.Linear:
        nn.init.xavier_uniform_(module.weight)
def init_42(module):
    if type(module)== nn.Linear:
        nn.init.constant(module.weight,42)
net[0].apply(init_xavier)
net[2].apply(init_42)
net[0].weight.data, net[2].weight.data

  nn.init.constant(module.weight,42)


(tensor([[-0.6268,  0.2511,  0.2685,  0.2984],
         [-0.2180, -0.1583, -0.3445,  0.4060],
         [-0.3499,  0.5430,  0.5767,  0.2828],
         [ 0.1821,  0.5244,  0.1052, -0.3641],
         [-0.2921, -0.4281,  0.6907, -0.5320],
         [ 0.0124, -0.1961, -0.5019,  0.6298],
         [ 0.2724, -0.2867,  0.4791,  0.0070],
         [-0.2186, -0.1635, -0.1525, -0.2373]]),
 tensor([[42., 42., 42., 42., 42., 42., 42., 42.]]))

# Custom Initialization

In [9]:
def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]

Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])


tensor([[ 8.4648,  0.0000, -7.2185,  5.0860],
        [ 5.2136,  0.0000,  6.3120,  0.0000]], grad_fn=<SliceBackward0>)

In [10]:
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]

tensor([42.0000,  1.0000, -6.2185,  6.0860])