In [4]:
import torch
import numpy as np
import torch.nn as nn
# from common_tools import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(2020)


net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./02 save_model/model.pkl"
path_state_dict = "./02 save_model/model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

训练前:  tensor([[[-0.0960,  0.0843, -0.0197,  0.0121, -0.0995],
         [-0.0318, -0.0466,  0.0976,  0.0868, -0.0302],
         [ 0.0419,  0.0014, -0.0378, -0.1023,  0.1013],
         [ 0.0789, -0.0198, -0.0418,  0.1115,  0.0468],
         [-0.0942, -0.0556, -0.1029,  0.0030, -0.0642]],

        [[-0.0340,  0.1095, -0.0899,  0.0892, -0.0790],
         [-0.0199,  0.0397,  0.0863, -0.0758, -0.1098],
         [ 0.0129,  0.0106, -0.0341, -0.0178,  0.0380],
         [ 0.0924, -0.0408,  0.0029, -0.0861, -0.0089],
         [-0.0878,  0.0733,  0.0192, -0.0620,  0.0889]],

        [[ 0.0931,  0.1117, -0.1129, -0.0428, -0.0400],
         [ 0.0137, -0.0524,  0.0354, -0.1121,  0.0883],
         [ 0.0411, -0.0715,  0.0748,  0.0257, -0.0648],
         [-0.0107, -0.0592,  0.0772,  0.0384,  0.0831],
         [-0.0697,  0.0782, -0.0069,  0.1138, -0.1026]]],
       grad_fn=<SelectBackward>)
训练后:  tensor([[[2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 

方式一、加载整个模型

In [5]:
path_model = "./02 save_model/model.pkl"
net_load = torch.load(path_model)

print(net_load)

LeNet2(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=2019, bias=True)
  )
)


方式二、只加载模型参数：先创建一个模型对象，再把参数加载到模型中

In [8]:
path_state_dict = "./02 save_model/model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)  # 获取保存的参数
net_new = LeNet2(classes=2019)  # 创建一个模型对象

print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)  # 将参数填入模型
print("加载后: ", net_new.features[0].weight[0, ...])

加载前:  tensor([[[-0.0453,  0.0198,  0.0941, -0.0883,  0.0652],
         [-0.0441,  0.0865, -0.0095, -0.0467, -0.0405],
         [ 0.0191,  0.1050,  0.0631, -0.1102, -0.0081],
         [ 0.0625,  0.0106,  0.0098,  0.1109, -0.0674],
         [-0.1075,  0.0584, -0.0211,  0.0718, -0.0728]],

        [[-0.0498,  0.1145,  0.0481,  0.0665, -0.1124],
         [-0.0278, -0.0609, -0.0188,  0.0451, -0.0878],
         [ 0.0433, -0.0745, -0.0804,  0.0561,  0.0483],
         [ 0.0016,  0.0219, -0.0268, -0.0758, -0.0878],
         [ 0.0042, -0.0165,  0.0632,  0.0865, -0.1131]],

        [[ 0.0097, -0.0020,  0.0067,  0.1069, -0.0533],
         [-0.1113, -0.0351,  0.0720,  0.0741, -0.0334],
         [-0.1125, -0.0111,  0.0646, -0.0109, -0.0078],
         [ 0.0534,  0.1092,  0.0464, -0.1090,  0.0230],
         [ 0.0734, -0.0677,  0.0312, -0.0720, -0.0457]]],
       grad_fn=<SelectBackward>)
加载后:  tensor([[[2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 

#### 模型的断点续练  
在训练过程中，可能由于某种意外原因导致训练终止，这时需要重新开始训练。断点续练在训练过程中每隔一定的 epoch 就保存模型和优化器的参数。如果训练意外终止，下次可重新加载最新的模型参数和优化器的参数，继续训练。