In [1]:
import torch
import torchvision.models as models

  from .autonotebook import tqdm as notebook_tqdm


# 保存和加载模型权重

In [2]:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(),'model_weights.pth')



In [3]:
## 创建相同的模型结构，但不加载预训练权重
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))
## 加载权重做推理前需要调用此函数将dropout和batch normalization层设为评估模式，否则将会导致不一致的推理结果
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

# 保存和加载整个模型结构和权重

In [4]:
torch.save(model,"model.pth")
model = torch.load("model.pth")

# 保存和加载训练时的断点
## 1.导入必要的库

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## 2.定义和初始化神经网络

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    
    def foward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        ## view相当于reshape
        x = view(-1,16*5*5)
        X = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


## 3.初始化优化器

In [7]:
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

## 4.保存一般断点

In [8]:
## 附加信息
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
    'epoch':EPOCH,
    'model_state_dict':net.state_dict(),
    'optimizer_state_dict':optimizer.state_dict(),
    'loss':LOSS,
    },PATH)

## 5.加载一般断点

In [9]:
## 加载前首先初始化模型和优化器
model = Net()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

## training mode or evaluation mode
model.eval()
# - or -
model.train()

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

## 6.迁移学习下的热启动模式
### 只需设置strict=False来忽略非匹配的模型层参数

In [10]:
torch.save(model.state_dict(),PATH)

model = Net()
model.load_state_dict(torch.load(PATH),strict=False)

<All keys matched successfully>