In [13]:
import torch.nn as nn
import torch
import torch.optim as optim

In [14]:
# 定义一个简单的网络
class net(nn.Module):
    def __init__(self, num_class=10):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)

    def forward(self, x):
        return self.fc2(self.fc1(x))

In [15]:
model = net()
loss_fn = nn.CrossEntropyLoss()

In [16]:
# 训练前的模型参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print("model.fc1.weight.requires_grad:", model.fc1.weight.requires_grad)
print("model.fc2.weight.requires_grad:", model.fc2.weight.requires_grad)

model.fc1.weight Parameter containing:
tensor([[-0.0624, -0.1627, -0.1996,  0.1495,  0.2983, -0.2941, -0.2399, -0.2755],
        [ 0.2841,  0.2877, -0.0122,  0.2792, -0.0042,  0.1647,  0.0747, -0.3440],
        [-0.2076,  0.1101, -0.0151,  0.1523,  0.1737,  0.2304, -0.1436, -0.0883],
        [ 0.1509, -0.0522, -0.3128, -0.2118,  0.0528,  0.1268, -0.2851,  0.0909]],
       requires_grad=True)
model.fc2.weight Parameter containing:
tensor([[ 0.4246, -0.1300, -0.1669, -0.1591],
        [ 0.0152,  0.2271,  0.1063,  0.2823],
        [ 0.3106,  0.1452,  0.1611, -0.1234],
        [-0.2521, -0.1838, -0.4629,  0.2355],
        [ 0.0928, -0.4808,  0.4481,  0.0825],
        [ 0.4910, -0.0614,  0.1732,  0.0329],
        [-0.3894,  0.4884,  0.1835,  0.4894],
        [ 0.2525, -0.4871,  0.3600,  0.2108],
        [-0.3905, -0.1755,  0.2169,  0.4896],
        [-0.3713,  0.0057, -0.2189, -0.0039]], requires_grad=True)
model.fc1.weight.requires_grad: True
model.fc2.weight.requires_grad: True


In [17]:
# 冻结fc1层的参数
for name, param in model.named_parameters():
    if "fc1" in name:
        param.requires_grad = False

# 定义一个filter, 只传入requires_grad=True的模型参数
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-2)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0, 3, [3]).long()
    output = model(x)

    loss = loss_fn(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# fc1层参数不变
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

print("model.fc1.weight.requires_grad:", model.fc1.weight.requires_grad)
print("model.fc2.weight.requires_grad:", model.fc2.weight.requires_grad)

model.fc1.weight Parameter containing:
tensor([[-0.0624, -0.1627, -0.1996,  0.1495,  0.2983, -0.2941, -0.2399, -0.2755],
        [ 0.2841,  0.2877, -0.0122,  0.2792, -0.0042,  0.1647,  0.0747, -0.3440],
        [-0.2076,  0.1101, -0.0151,  0.1523,  0.1737,  0.2304, -0.1436, -0.0883],
        [ 0.1509, -0.0522, -0.3128, -0.2118,  0.0528,  0.1268, -0.2851,  0.0909]])
model.fc2.weight Parameter containing:
tensor([[ 0.4230, -0.1221, -0.1666, -0.1624],
        [ 0.0144,  0.2241,  0.1178,  0.2899],
        [ 0.3056,  0.1448,  0.1660, -0.1174],
        [-0.2497, -0.1847, -0.4646,  0.2338],
        [ 0.0927, -0.4808,  0.4453,  0.0812],
        [ 0.4900, -0.0619,  0.1709,  0.0317],
        [-0.3879,  0.4873,  0.1820,  0.4883],
        [ 0.2517, -0.4870,  0.3567,  0.2088],
        [-0.3880, -0.1763,  0.2137,  0.4873],
        [-0.3683,  0.0044, -0.2208, -0.0047]], requires_grad=True)
model.fc1.weight.requires_grad: False
model.fc2.weight.requires_grad: True
