In [1]:
import torch
from torch import nn

In [2]:
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape



torch.Size([2, 1])

# 6.3.1. Built-in Initialization

In [3]:
def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([-0.0120,  0.0071,  0.0049, -0.0040]), tensor(0.))

In [4]:
def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([1., 1., 1., 1.]), tensor(0.))

In [5]:
def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data[0])

tensor([ 0.1828,  0.1586, -0.4040, -0.4338])
tensor([42., 42., 42., 42., 42., 42., 42., 42.])


# 6.3.1.1. Custom Initialization

In [6]:
def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param) for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5


net.apply(my_init)
net[0].weight[:2]

Init weight Parameter containing:
tensor([[ 0.1828,  0.1586, -0.4040, -0.4338],
        [ 0.6289, -0.2841,  0.0749,  0.1344],
        [-0.2147, -0.1357, -0.4925,  0.3307],
        [ 0.3571, -0.4535,  0.0080,  0.2013],
        [ 0.0601, -0.5887,  0.0063,  0.1932],
        [ 0.2966,  0.3536, -0.1678, -0.5390],
        [ 0.2883, -0.6859, -0.5471, -0.3631],
        [ 0.6301,  0.0031, -0.4234,  0.2842]], requires_grad=True)
Init weight Parameter containing:
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]], requires_grad=True)


tensor([[ 5.6342,  9.4741, -0.0000,  0.0000],
        [-5.8484, -0.0000,  8.1051, -0.0000]], grad_fn=<SliceBackward0>)

In [7]:
print(net[0].weight.data)
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]

tensor([[ 5.6342,  9.4741, -0.0000,  0.0000],
        [-5.8484, -0.0000,  8.1051, -0.0000],
        [-0.0000, -5.9879,  6.5829,  7.1877],
        [ 0.0000, -6.9106,  0.0000, -0.0000],
        [ 9.1797,  0.0000, -9.5172,  7.6742],
        [-0.0000, -0.0000,  6.2690,  0.0000],
        [ 5.9203,  5.7909, -0.0000, -0.0000],
        [-6.3098, -0.0000,  9.5610,  9.9933]])


tensor([42.0000, 10.4741,  1.0000,  1.0000])