权重衰退

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)  # 一个简单的全连接层

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4)
optimizer_2 = optim.SGD([
    {'params': model.fc.weight, 'weight_decay': 0.01},  # 权重参数
    {'params': model.fc.bias, 'weight_decay': 0}        # 偏置参数
], lr=0.01)



权重初始化

In [None]:
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
model.apply(init_weights)

#常见的参数初始化方法
param =0  #随便定义的，跑不了的
# 全零初始化
nn.init.constant_(param, 0)

# 全一初始化
nn.init.constant_(param, 1)

# 均匀分布初始化
nn.init.uniform_(param, a=-0.1, b=0.1)

# 正态分布初始化
nn.init.normal_(param, mean=0, std=0.01)

# Xavier 均匀分布，适用于sigmoid函数或者tanh函数作为激活函数的网络
nn.init.xavier_uniform_(param)

# Xavier 正态分布，适用于sigmoid函数或者tanh函数作为激活函数的网络
nn.init.xavier_normal_(param)

# Kaiming 均匀分布，适用于ReLu作为激活函数的网络以及深层网络
nn.init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu')

# Kaiming 正态分布，适用于ReLu作为激活函数的网络以及深层网络
nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='relu')

#正交初始化，适用于自注意力机制
nn.init.orthogonal_(param, gain=1.0)

#查看层中有什么参数的方法
for name, param in model.named_parameters():
    print(f"name: {name}, param: {param}")

# 查看所有子层
for name, layer in model.named_modules():
    print(f"Layer name: {name}, Layer type: {type(layer)}")

# 查看直接子层
for name, layer in model.named_children():
    print(f"Layer name: {name}, Layer type: {type(layer)}")

丢弃法

In [None]:
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.fc1 = nn.Linear(10,5)
        self.fc2 = nn.Linear(5,3)
        self.fc3 = nn.Linear(3,1)
        self.dropout_1 = nn.Dropout(0.2)
        self.dropout_2 = nn.Dropout(0.2)
    def forward(self, x):
        x = self.fc1(x)
        x = self.dropout_1(x)
        x = self.fc2(x)
        x = self.dropout_2(x)
        x = self.fc3(x)
        return x

