In [1]:
# 2025/7/25
# zhangzhong
# https://docs.pytorch.org/tutorials/beginner/saving_loading_models.html

In [5]:
from torch import nn
import torch.nn.functional as F
import torch
from torch import optim

In [None]:
# When it comes to saving and loading models, there are three core functions to be familiar with:

# torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.

# torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).

# torch.nn.Module.load_state_dict: Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?.

In [4]:
# What is a state dict
# torch.nn.Module, A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor.
# Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used.

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
model = TheModelClass()
print(model)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print(optimizer)

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    momentum: 0.9
    nesterov: False
    weight_decay: 0
)


In [10]:
# print model's state dict
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]


In [None]:
# Save & Loading Model for Inference
# When saving a model for inference, it is only necessary to save the trained model’s learned parameters. 
# Save
# A common PyTorch convention is to save models using either a .pt or .pth file extension.
# model_state = model.state_dict() returns a **reference** to the state and not its copy
torch.save(model.state_dict(), 'model_weights.pth')

In [None]:
# Load
# To load a model, you first need to instantiate the model class, then load the state_dict into it.
# Note that the model class must match the one used to save the state_dict.
# Load本质上是加载参数，所以模型的架构必须是一致的。
model = TheModelClass()
model.load_state_dict(torch.load('model_weights.pth'))

<All keys matched successfully>

In [None]:
# Save and Load Entire Model
# this way is strongly not recommended
# This save/load process uses the most intuitive syntax and involves the least amount of code.
# The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. 
# Because of this, your code can break in various ways when used in other projects or after refactors.

In [None]:
# Saving & Loading a General Checkpoint for Inference and Resuming Training
# To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary
# A common PyTorch convention is to save these checkpoints using the .tar file extension.
torch.save({
            # 'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            # 'loss': loss,
            # ...
            }, 'checkpoint.tar')

In [None]:
# To load the items, first initialize the model and optimizer, then load the dictionary locally using torch.load()
# From here, you can easily access the saved items by simply querying the dictionary as you would expect.
model = TheModelClass()
optimizer= optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('checkpoint.tar')
# 这里只是更新了模型的参数
model.load_state_dict(checkpoint['model_state_dict'])
# 这样写optimizer仍然正确的绑定到了model的参数上
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# 这是一个非常关键的问题，正确顺序很重要。

# ⸻

# 🎯 简短回答：

# 如果你在加载 optimizer state 后，再调用 model.to(cuda)，**优化器内部绑定的参数 tensor（如 momentum buffer 等）**仍然在 CPU 上，而模型参数此时在 CUDA 上，这将导致 optimizer 的行为不正确（比如梯度更新出错或训练崩溃）。

# ⸻

# ✅ 正确顺序应当是：

# model = TheModelClass()
# checkpoint = torch.load("checkpoint.pt", map_location="cuda")  # 保证加载到对的设备上
# model.load_state_dict(checkpoint["model"])
# model.to("cuda")  # 把模型放到 CUDA 上

# optimizer = optim.SGD(model.parameters(), lr=0.001)
# optimizer.load_state_dict(checkpoint["optimizer"])  # 加载的是 CUDA 上的参数状态

# 	•	在 optimizer 加载 state_dict 时，内部 state 会按当前绑定的 model.parameters() 的设备进行迁移（只要是同一个参数对象）。

# ⸻

# ❌ 错误顺序会导致什么？

# model = TheModelClass()
# optimizer = optim.SGD(model.parameters(), lr=0.001)
# optimizer.load_state_dict(checkpoint["optimizer"])  # ⚠️ 此时 state 在 CPU 上

# model.load_state_dict(checkpoint["model"])
# model.to("cuda")  # 模型到 CUDA 上了，但 optimizer 还在管理 CPU 上的状态

# 	•	此时 optimizer 的状态（如动量）和模型参数不匹配设备，会导致训练出错。
# 	•	常见报错有：RuntimeError: expected scalar type Float but found CUDAFloat 等。

# ⸻

# 🔧 补救方法（如果你必须要 .to() 在后面）：

# 你可以手动将 optimizer 的状态迁移到 GPU：

# # 在 model.to("cuda") 之后手动处理
# for state in optimizer.state.values():
#     for k, v in state.items():
#         if isinstance(v, torch.Tensor):
#             state[k] = v.cuda()

# 不过这更繁琐，强烈建议还是在 .to(cuda) 后再 load optimizer，这样最自然、最安全。

# ⸻

# ✅ 推荐通用加载顺序：

# # 1. Init model
# model = TheModelClass()

# # 2. Load checkpoint with correct device mapping
# checkpoint = torch.load("checkpoint.pt", map_location="cuda")

# # 3. Load model state
# model.load_state_dict(checkpoint["model"])
# model.to("cuda")  # 必须在 optimizer 初始化前

# # 4. Init optimizer
# optimizer = optim.SGD(model.parameters(), lr=0.001)
# optimizer.load_state_dict(checkpoint["optimizer"])

# 如果你用 AMP 或 GradScaler，也一样要在模型转 cuda 后再恢复 scaler 状态。

# ⸻

# 如你希望我提供完整的加载代码模板，也可以告诉我你使用的 AMP / DDP 等组件。

In [None]:
# Saving & Loading Model Across Devices
# 需要在load_state_dict时指定map_location参数

torch.save(model.state_dict(), 'model_weights.pth')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 此时参数全部在cpu上·
model = TheModelClass()

state_dict = torch.load('model_weights.pth', map_location=device)
# 此时参数全部在cuda上
print(next(iter(state_dict.values())).device)  # 打印参数所在设备

# # 这一步只是把加载的参数内容“复制”到 model 的参数中，但 model 的 param 本身没变 device
# 所以实际上是把参数从cuda复制到cpu上
model.load_state_dict(state_dict)
# 还真是在cpu上！
print(next(iter(model.parameters())).device)

# 然后在吧参数从cpu复制到gpu上
model.to(device)  # 确保模型在正确的设备上
# 这里就是在cuda里面了
print(next(iter(model.parameters())).device)

cuda:0
cpu
cuda:0
