In [1]:
import torch
from torch import nn

In [2]:
num_samples = 2
num_inputs = 4
num_hiddens = 8
num_outputs = 1

In [3]:
X = torch.rand((2, 4))

In [4]:
net = nn.Sequential(
    nn.Linear(num_inputs, num_hiddens),
    nn.ReLU(),
    nn.Linear(num_hiddens, num_outputs)
)

In [5]:
net(X)

tensor([[-0.3134],
        [-0.2397]], grad_fn=<AddmmBackward0>)

In [6]:
def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)  # net 안에 들어 있는 모든 레이어에 대해 적용됨

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)

In [7]:
net[0].weight

Parameter containing:
tensor([[-0.0015, -0.0047, -0.0116, -0.0020],
        [-0.0117,  0.0098,  0.0010, -0.0054],
        [ 0.0067, -0.0059, -0.0034, -0.0135],
        [-0.0092, -0.0180,  0.0133,  0.0102],
        [-0.0099,  0.0144,  0.0093,  0.0003],
        [ 0.0095, -0.0012, -0.0149, -0.0016],
        [ 0.0049, -0.0030, -0.0025,  0.0005],
        [-0.0005,  0.0046, -0.0198, -0.0098]], requires_grad=True)

In [8]:
net[0].bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

In [9]:
def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)

In [10]:
net[0].weight

Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)

In [11]:
net[0].bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

In [12]:
def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

net.apply(init_xavier)

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)

In [13]:
net[0].weight

Parameter containing:
tensor([[ 0.2073,  0.0423,  0.3270,  0.4785],
        [-0.0771,  0.1240,  0.2713,  0.6624],
        [ 0.5972,  0.5691,  0.3513,  0.2824],
        [ 0.5032, -0.4697,  0.2592,  0.1832],
        [-0.0554, -0.0767,  0.6015,  0.5767],
        [-0.6897,  0.3948, -0.3485, -0.1676],
        [ 0.3974, -0.3618,  0.2473,  0.5100],
        [-0.5367,  0.0260,  0.3431, -0.2516]], requires_grad=True)

In [14]:
def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)

net.apply(init_42)

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)

In [15]:
net[0].weight

Parameter containing:
tensor([[42., 42., 42., 42.],
        [42., 42., 42., 42.],
        [42., 42., 42., 42.],
        [42., 42., 42., 42.],
        [42., 42., 42., 42.],
        [42., 42., 42., 42.],
        [42., 42., 42., 42.],
        [42., 42., 42., 42.]], requires_grad=True)

In [16]:
def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)

Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])


Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)

In [17]:
net[0].weight

Parameter containing:
tensor([[-5.9905,  0.0000, -5.5672,  8.8737],
        [-6.2581, -6.5907, -8.8889,  0.0000],
        [-8.3845,  8.2419, -9.6014, -9.3240],
        [ 0.0000,  6.4026,  5.8749,  8.0636],
        [ 9.6230, -8.4694,  0.0000,  6.3945],
        [-0.0000,  0.0000, -0.0000,  0.0000],
        [ 5.4643,  0.0000,  0.0000, -5.4845],
        [-8.1369,  0.0000,  0.0000, -0.0000]], requires_grad=True)