# 官方文档关于这部分内容很棒

https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [1]:
%matplotlib inline

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
from IPython.core.debugger import set_trace

# `state_dict()`

## `nn.Module.state_dict()`
`nn.Module` 包含state dict，即每一层网络的参数

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
model = Net()

In [5]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-7.1123e-02, -4.4851e-02,  1.6729e-02, -3.4377e-02, -9.8778e-02],
                        [ 5.5455e-02,  7.6009e-03,  7.0884e-03, -1.0246e-01,  9.6086e-03],
                        [ 3.4145e-02,  8.4261e-02,  5.2151e-02,  1.0591e-01, -8.7947e-02],
                        [ 1.8519e-02, -1.0095e-01, -1.3092e-02,  1.3405e-02, -4.7419e-02],
                        [ 1.0840e-01,  2.5495e-02,  1.0014e-01, -8.9385e-03,  8.7777e-03]],
              
                       [[ 1.0578e-01,  7.4442e-02, -6.6563e-02,  1.0604e-01, -6.9550e-02],
                        [ 6.0601e-02,  1.6669e-02,  5.6252e-02,  8.8460e-02, -8.1655e-02],
                        [ 2.6876e-02,  3.4845e-02,  3.8986e-02, -8.5444e-02,  2.1472e-02],
                        [ 6.1246e-02, -9.9257e-02,  7.1392e-03,  7.8273e-02,  1.0704e-01],
                        [-1.0204e-01,  1.0887e-01, -3.5698e-02,  2.5086e-02,  2.8041e-02]],
              
                       [[ 1.

In [14]:
def state_dict_info(obj):
    """
    帮助函数
    """
    print(f"{'layer':25} shape")
    print("===================================================")
    for k,v in obj.state_dict().items():
        try:
            print(f"{k:25} {v.shape}")
        except AttributeError:
            print(f"{k:25} {v}")

In [9]:
state_dict_info(model)

layer                     shape
conv1.weight              torch.Size([6, 3, 5, 5])
conv1.bias                torch.Size([6])
conv2.weight              torch.Size([16, 6, 5, 5])
conv2.bias                torch.Size([16])
fc1.weight                torch.Size([120, 400])
fc1.bias                  torch.Size([120])
fc2.weight                torch.Size([84, 120])
fc2.bias                  torch.Size([84])
fc3.weight                torch.Size([10, 84])
fc3.bias                  torch.Size([10])


## `nn.Optimizer`

另一方面，Optimizers也包含`state_dict`.

In [16]:
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [17]:
state_dict_info(optimizer)

layer                     shape
state                     {}
param_groups              [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [5031464032, 5031464680, 5014780376, 5030253048, 4581119248, 5014737112, 5031534456, 5031531576, 5031530640, 5031502112]}]


In [12]:
optimizer.state_dict()["state"]

{}

In [13]:
optimizer.state_dict()["param_groups"]

[{'lr': 1.0,
  'rho': 0.9,
  'eps': 1e-06,
  'weight_decay': 0,
  'params': [5031464032,
   5031464680,
   5014780376,
   5030253048,
   4581119248,
   5014737112,
   5031534456,
   5031531576,
   5031530640,
   5031502112]}]

## 存储和加载`state_dict`

In [13]:
model_file = "model_state_dict.pt"
torch.save(model.state_dict(), model_file)

In [14]:
model = Net()
model.load_state_dict(torch.load(model_file))

<All keys matched successfully>

# 一个常用的存储方式：Checkpoint

例子:
```python
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
    },
    PATH,
)
```

# Exercise
- `optimizer.state_dict()["state"]`为什么是`{}`？
- 自己写一个checkpoint函数，可以保存、加载checkpoint