<a href="https://colab.research.google.com/github/dongminkim0220/pytorch_tutorial/blob/master/Save_and_Load_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Save and Load Models

PyTorch Tutorial 17 - Saving and Loading Models

https://www.youtube.com/watch?v=9L9jEOwRrCg&t=45s

In [1]:
import torch
import torch.nn as nn

In [2]:
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features = 6)
FILE = "simple_model.pth"

## Complete Model

In [3]:
torch.save(model, FILE)

In [4]:
model = torch.load(FILE)
model.eval()

Model(
  (linear): Linear(in_features=6, out_features=1, bias=True)
)

In [5]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[ 0.2178, -0.0961, -0.2846,  0.2673, -0.1091, -0.2597]],
       requires_grad=True)
Parameter containing:
tensor([0.1104], requires_grad=True)


## State Dict

In [6]:
torch.save(model.state_dict(), FILE)

In [7]:
loaded_model = Model(n_input_features = 6)
loaded_model.load_state_dict(torch.load(FILE))
model.eval()

Model(
  (linear): Linear(in_features=6, out_features=1, bias=True)
)

In [8]:
print(model.state_dict())

OrderedDict([('linear.weight', tensor([[ 0.2178, -0.0961, -0.2846,  0.2673, -0.1091, -0.2597]])), ('linear.bias', tensor([0.1104]))])


## checkpointing

In [9]:
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1]}]}


In [10]:
checkpoint = {
    "epoch": 90,
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict(),
}

In [11]:
torch.save(checkpoint, "checkpoint.pth")

In [12]:
loaded_checkpoint = torch.load("checkpoint.pth")

In [13]:
epoch = loaded_checkpoint["epoch"]
model = Model(n_input_features = 6)
optimizer = torch.optim.SGD(model.parameters(), lr = 0)
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optim_state"])

In [14]:
optimizer.state_dict()

{'param_groups': [{'dampening': 0,
   'lr': 0.01,
   'momentum': 0,
   'nesterov': False,
   'params': [0, 1],
   'weight_decay': 0}],
 'state': {}}