# 3.5 模型初始化

In [2]:
import torch
import torch.nn as nn

conv = nn.Conv2d(1,3,3)
linear = nn.Linear(10,1)

isinstance(conv,nn.Conv2d) # 判断conv是否是nn.Conv2d类型
isinstance(linear,nn.Conv2d) # 判断linear是否是nn.Conv2d类型

False

In [4]:
# 查看随机初始化的conv参数
print(conv.weight.data)
# 查看linear的参数
print(linear.weight.data)

tensor([[[[-0.1658,  0.1045,  0.1799],
          [ 0.2633,  0.0619, -0.2959],
          [ 0.0295,  0.2326, -0.0225]]],


        [[[ 0.0827, -0.3130, -0.2353],
          [-0.1460,  0.2899,  0.2258],
          [-0.2041, -0.0764,  0.0616]]],


        [[[-0.2411, -0.1778,  0.0363],
          [ 0.2287, -0.0314, -0.0919],
          [-0.3135,  0.1814, -0.0925]]]])
tensor([[-0.0383, -0.3064,  0.3131,  0.3029,  0.1184, -0.1115,  0.1538,  0.1375,
         -0.1329,  0.2193]])


In [6]:
# 对conv进行kaiming初始化
torch.nn.init.kaiming_normal_(conv.weight.data)
print(conv.weight.data)
# 对linear进行常数初始化
torch.nn.init.constant_(linear.weight.data,0.3)
print(linear.weight.data)

tensor([[[[ 0.7561, -0.6994, -0.5043],
          [-0.3809, -1.0173, -0.0489],
          [-0.0896,  0.3203,  0.4248]]],


        [[[ 0.8974, -0.0754, -0.3074],
          [-0.0452,  0.5232,  0.2168],
          [ 0.3735, -0.1852,  0.2766]]],


        [[[-0.9125,  0.3434, -0.1310],
          [-0.0467,  0.2664,  0.8275],
          [-0.3877,  0.4532, -0.1888]]]])
tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000,
         0.3000]])


In [7]:
def initialize_weights(model):
	for m in model.modules():
		# 判断是否属于Conv2d
		if isinstance(m, nn.Conv2d):
			torch.nn.init.zeros_(m.weight.data)
			# 判断是否有偏置
			if m.bias is not None:
				torch.nn.init.constant_(m.bias.data,0.3)
		elif isinstance(m, nn.Linear):
			torch.nn.init.normal_(m.weight.data, 0.1)
			if m.bias is not None:
				torch.nn.init.zeros_(m.bias.data)
		elif isinstance(m, nn.BatchNorm2d):
			m.weight.data.fill_(1) 		 
			m.bias.data.zeros_()	
   
class MLP(nn.Module):
  def __init__(self, **kwargs):
    super(MLP, self).__init__(**kwargs)
    self.hidden = nn.Conv2d(1,1,3)
    self.act = nn.ReLU()
    self.output = nn.Linear(10, 1)
    
  def forward(self, x):
    o = self.act(self.hidden(x))
    return self.output(o)

mlp = MLP()
print(mlp.hidden.weight.data)
print("-------初始化-------")
mlp.apply(initialize_weights)
print(mlp.hidden.weight.data)

tensor([[[[ 0.0968,  0.3289,  0.2632],
          [-0.0845,  0.2719, -0.1006],
          [-0.0358, -0.0613,  0.2231]]]])
-------初始化-------
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])
