In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from tensorboardX import SummaryWriter
os.environ['CUDA_VISIBLE_DEVICES'] = ''
torch.cuda.is_available()

False

In [7]:
class Net(nn.Module):
    def __init__(self, C=10):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, C)
        self.fc4 = nn.Linear(C, 2)
        self.C = 10

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)

    def forward(self, x):
        pass

In [8]:
import torch.optim as optim
def print_state_dict(model, size=True):
    """
    Args:
    - model(nn.Module):
    """
    for key, param in model.state_dict().items():
        print(key, '\t', param.size() if size else param)
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [9]:
#打印模型的状态字典
print("Net's state_dict:")
print_state_dict(net)

print('')
print("optimizer's state_dict:")
print_state_dict(optimizer, False)


Net's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
fc4.weight 	 torch.Size([2, 10])
fc4.bias 	 torch.Size([2])

optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4776700976, 4776698576, 4423520160, 4423519200, 4776703232, 4447407152, 4423517840, 4423517920, 4783396576, 4783394976, 4783396096, 4783397296]}]


In [10]:
#模型和优化器，都有状态字典
#没有maxpooling，因为这是个操作，没有参数

In [11]:
"""方式一：状态字典
可以热加载
"""
SAVE_PATH = './sd_net.pth'

#保存模型pickle
#torch.save(net.state_dict(), SAVE_PATH)

#加载模型
sd_net = Net()
sd_net.load_state_dict(torch.load(SAVE_PATH), strict=False)
#strict,允许加载部分参数（热加载）
sd_net.eval() #把模型的dropout，BN设置为评估状态

#查看模型状态
print("Loaded Net's state_dict:")
print_state_dict(sd_net)

Loaded Net's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
fc4.weight 	 torch.Size([2, 10])
fc4.bias 	 torch.Size([2])


In [12]:
"""方式二：保存整个模型
依赖源码
"""
SAVE_PATH = './sd_net_whole.pth'

#保存模型
#torch.save(net, SAVE_PATH)
#save函数，没有保存类的代码（类的定义），只保存了类的数据

#加载模型
loaded_net = torch.load(SAVE_PATH)#要求类必须存在于上下文环境中，否则报错

#查看模型状态
print("Loaded Net's state_dict:")
print_state_dict(loaded_net)

Loaded Net's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


In [13]:
#保存/加载多个模型

In [14]:
net = Net()
net2 = Net2()

SAVE_PATH = './sd_net_double.pth'

#保存模型pickle
torch.save({
    'net':net.state_dict(),
    'net2':net2.state_dict(),
}, SAVE_PATH)

#加载模型
sd_net = Net()
sd_net2 = Net2()

state_dict = torch.load(SAVE_PATH)
sd1, sd2 = state_dict['net'], state_dict['net2']

sd_net.load_state_dict(sd1, strict=False)
sd_net2.load_state_dict(sd2, strict=False)

#strict,允许加载部分参数（热加载）
sd_net.eval() #把模型的dropout，BN设置为评估状态
sd_net2.eval()

#查看模型状态
print("Loaded Net's state_dict:")
print_state_dict(sd_net)

print('')
print("Loaded Net2's state_dict:")
print_state_dict(sd_net2)


Loaded Net's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
fc4.weight 	 torch.Size([2, 10])
fc4.bias 	 torch.Size([2])

Loaded Net2's state_dict:
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])


In [15]:
"""checkpoint
训练了10 epoch (checkpoint) -> +5 epoch
1.保存模型用于预测，是不需要保存优化器的状态字典
2.保存模型用于训练，此时需要保存优化器的状态字典
"""
SAVE_PATH = './sd_net_ckt.tar'
#保存checkpoint，后缀一般用tar
#保存模型，一般用pth
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optim_sd = optimizer.state_dict()

#保存checkpoint
torch.save({
    'net':net.state_dict(),
    'optim':optim_sd,
    #'opt_state':optim_sd['state']
    #'opt_param':optim_sd['param_groups'],
}, SAVE_PATH)

#加载checkpoint
state_dict = torch.load(SAVE_PATH)

#恢复模型
sd_net = Net()
sd_net.load_state_dict(state_dict['net'], strict=False)
sd_net.eval() #继续训练，不需要这一行

#恢复优化器
sd_optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
sd_optimizer.load_state_dict(state_dict['optim'])

print("Loaded Net's state_dict:")
print_state_dict(sd_net)

print('')
print("Loaded Optimizer's state_dict:")
print_state_dict(sd_optimizer, False)

Loaded Net's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
fc4.weight 	 torch.Size([2, 10])
fc4.bias 	 torch.Size([2])

Loaded Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4784160512, 4784160192, 4776702272, 4785081472, 4784159952, 4784160432, 4784159872, 4784177552, 4784179312, 4784180592, 4784179232, 4784179152]}]


In [16]:
"""device
1.模型在CPU上保存，在CPU上加载
2.模型在GPU上保存：
    2.1 在CPU上加载 ：设置torch.save的map_location='cpu'
    2.2 在GPU上加载 ：状态字典在GPU上，模型取决于实例化的模型所在的设备
"""

SAVE_PATH = './sd_net.pth'
device =torch.device('cuda:0')

#保存模型
net = Net()
net = net.to(device)
sd = torch.load(SAVE_PATH)
print('Save weight device= %s' % (sd['fc1.weight'].device))
print('Save net device= %s' % (net.fc1.weight.data.device))

torch.save(net.state_dict(), SAVE_PATH)

#加载模型
sd_net = Net()
sd = torch.load(SAVE_PATH, map_location='cpu')
print('Loaded weight device = %s' % (sd['fc1.weight'].device))
sd_net.load_state_dict(sd, strict=False)
print('Loaded net device = %s' % (sd_net.fc1.weight.data.device))



AssertionError: Torch not compiled with CUDA enabled