In [1]:
import torch
import torch.nn as nn

In [10]:
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X)



tensor([[-0.3195],
        [-0.2268]], grad_fn=<AddmmBackward0>)

In [11]:
# Built-in initialization

def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([ 0.0107, -0.0029, -0.0049,  0.0065]), tensor(0.))

In [12]:
# We can also init all the para to a given constant value
def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)

In [13]:
# Applying different initializers for certain blocks

def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

def init_2(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 2)

net[0].apply(init_xavier)
net[2].apply(init_2)
print(net[0].weight.data[0])
print(net[2].weight.data)

tensor([-0.6802,  0.0539, -0.1071, -0.0554])
tensor([[2., 2., 2., 2., 2., 2., 2., 2.]])


In [17]:
# Custom initialization

def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape) for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]

Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])


tensor([[-7.4097, -5.7399,  8.6192, -6.2231],
        [-0.0000, -0.0000, -0.0000,  5.9861]], grad_fn=<SliceBackward0>)

In [18]:
# Note: We always have the option of setting parameters directly

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]

tensor([42.0000, -4.7399,  9.6192, -5.2231])