学习如何把内存中训练好的模型参数存储在硬盘上供后续使用

读写tensor

In [1]:
#save函数：存储tensor，可以保存各种对象，先把对象序列化再存到disk
#load函数：读取tesnor，将保存的文件反序列化为内存
import torch
from torch import nn
x = torch.ones(3)
torch.save(x,'x.pt')


In [2]:
#将数据从存储的 文件读回内存
x2 = torch.load('x.pt')
x2

tensor([1., 1., 1.])

In [3]:
#还可以存储tensor的list
y = torch.zeros(4)
torch.save([x,y],'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

In [4]:
#tensor的字典
torch.save({'x':x,'y':y},'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

读写模型

In [7]:
#state_dict从参数名称隐射到参数tensor的字典对象
class MLP(nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.hidden = nn.Linear(3,2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2,1)
    def forward(self,x):
        a = self.act(self.hidden(x))
        return self.output(a)

In [8]:
net = MLP()
print(net)

MLP(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (act): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)


In [12]:
net.state_dict()
#只有可学习的层才有state_dict中的条目，优化器（optim）中也有一个state_dict

OrderedDict([('hidden.weight',
              tensor([[ 0.1915, -0.2214,  0.3726],
                      [ 0.0424,  0.3245,  0.4636]])),
             ('hidden.bias', tensor([-0.2195, -0.1572])),
             ('output.weight', tensor([[-0.2206,  0.6449]])),
             ('output.bias', tensor([-0.3518]))])

In [14]:
optmizer = torch.optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)
optmizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.001,
   'momentum': 0.9,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'maximize': False,
   'foreach': None,
   'differentiable': False,
   'fused': None,
   'params': [0, 1, 2, 3]}]}

保存和加载模型

In [16]:
#只保存和加载模型参数：state-dict
#torch.save(model.state_dict(),PATH)
#推荐的文件后缀名是pt或者pth

In [17]:
#保存和加载整个模型
# torch.save(model,PATH)
# model = torch.load(PATH)

In [19]:
# 实现一下方法一
X = torch.randn(2,3)
Y= net(X)
PATH = "./net.pt"
torch.save(net.state_dict(),PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
print(Y2==Y)

tensor([[True],
        [True]])
