### torch.save (general checkpoint)

* easier (not that useful) : https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html
* full tutorial (with all about checkpoints and etc) : https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

* what I made a this tutorial out of : 
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-multiple-models-in-one-file


두번째 것 보고 하자!

# 1. loading and saving models
(prereqs : Ordered Dict 보고오기 (이미 정리해놓은 ipynb가 같은 디렉토리에 있다)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import torchsummary

## 1.1. basic saving and loading
**PyTorch : parameter 값들을, internal state dictionary called `state_dict`에 저장을 해둠!**

#### `state_dict`이란?
* model, optimizer의 state, parameter value등등을 **ordered dictionary** 형태로 저장한 것! 

<br>

#### 따라서, 이 state_dict을 저장/불러오기 하면된다! 하는 방법은 다음과 같음
* 저장 : `torch.save(model.state_dict())`
* 불러오기 : `model = Model()` **후** `model.load_state_dict(torch.load(PATH))` 
    * 즉, 먼저 model instance를 **만든후** weight값들을 얹는다!


    * `torch.save`로 이 `state_dict`를 저장 가능!

(model자체를 pickle(?)로 저장할 수 있다는데, 일단은 안함!(`state_dict()`를 저장하는 것의 거의 대부분이라고 해서)


In [2]:
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

### 1.1 `state_dict`이 뭔지 보기 

In [3]:
print(type(model.state_dict())) #즉, state_dict는 ordred dict type이다. 

print("Model's state_dict : ")
for key in model.state_dict():
    print(key, '\t', model.state_dict()[key].shape)

#즉, model.state_dict()는 ordred dict with the individual 
#parameter tensors as the keys and its vazlues as the values
    
print("\nOptimizer's state_dict : ")
for key in optimizer.state_dict():
    print(key, '\t', optimizer.state_dict()[key])

<class 'collections.OrderedDict'>
Model's state_dict : 
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Optimizer's state_dict : 
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]


### 1.2 save/load only the model (not a "general" checkpoint) 
* `state_dict()`를 저장하도록 하자 
* only useful for inference btw, since optimizer등등을 저장 안했으니

In [4]:
##saving
torch.save(model.state_dict(),"./model_ckpt.pth")

##loading
model = TheModelClass() #먼저 initialize해야함!!!
imported_state_dict = torch.load('./model_ckpt.pth') #import(load) state dict!
model.load_state_dict(imported_state_dict)           #load imported state dict!
model.eval() #if want to the eval

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

###  1.3. General Checkpoint
만약 training을 다시 할 목적이라면, optimizer의 state dict, 현재 epoch 등등도 저장해야한다! 이것을 어떻게 할지 보자

* `torch.save(<the dictionary>)` 로 하되, `<the dictionary>`안에 model, optimizer state_dict, epoch값 등등을 넣자
*  `ckpt = torch.load(XX)`로 dictionary를 load 한 후에, `ckpt['XX']`식으로, model_state_dict등을 하나하나 부르면 된다 

<br>

#### **즉, `torch.save`는 `state_dict`뿐만이 어떤 dictinoary더라도 저장할 수 있는 용도다!** 

In [5]:
###SAVING THE GENERAL CHECKPOINT
from collections import OrderedDict
save_path = './checkpoint.pth'
dict_to_save = OrderedDict({'epoch': 10, 
                            'model_state_dict':model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss' : 0.3 })
#dictionary 값이 dictionary이기도 하다!
type(dict_to_save['model_state_dict'])

torch.save(dict_to_save, save_path)

In [6]:
###LOADING THE GENERAL CHECKPOINT
model = TheModelClass()
optimizer = optim.SGD(params = model.parameters(), lr = 0.01)

checkpoint = torch.load(save_path)    #load the checkpiont that has all the state_dict and so on
print(checkpoint.keys()) #위에서 했떤 것처럼, 


##now let's load the model and opitmizer's state_dict and others
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

print(epoch, loss )

odict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
10 0.3


## 1.4. DDP checkopint등도 있기는 한데, 스킵한다 
(https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-multiple-models-in-one-file)
