In [6]:
import torch
import torch.nn as nn

import warnings
warnings.filterwarnings("ignore")

THREE DIFFERENT METHODS TO REMEMBER:
 - `torch.save(arg, PATH)`. This can be model, tensor, or dictionary
 - `torch.load(PATH)`
 - `torch.load_state_dict(arg)`


TWO DIFFERENT WAYS OF SAVING

1) lazy way: save whole model

`torch.save(model, PATH)`

model class must be defined somewhere

`model = torch.load(PATH)
model.eval()`

2) recommended way: save only the state_dict

`torch.save(model.state_dict(), PATH)`

model must be created again with parameters

`model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()`

In [2]:
# Create a model class
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)
        
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

In [3]:
model = Model(n_input_features=6)
# train your model...

### Save all

In [9]:
for param in model.parameters():
    print(param)


Parameter containing:
tensor([[ 0.1238, -0.3248, -0.3635,  0.1627, -0.4055, -0.1760]],
       requires_grad=True)
Parameter containing:
tensor([0.0947], requires_grad=True)


In [7]:
# save and load entire model

FILE = "model.pth"
torch.save(model, FILE)

In [8]:
# Load and print again
loaded_model = torch.load(FILE)
loaded_model.eval()

for param in loaded_model.parameters():
    print(param)

Parameter containing:
tensor([[ 0.1238, -0.3248, -0.3635,  0.1627, -0.4055, -0.1760]],
       requires_grad=True)
Parameter containing:
tensor([0.0947], requires_grad=True)


### Save only state dict

In [10]:
# save only state dict
FILE = "model.pth"
torch.save(model.state_dict(), FILE)

print(model.state_dict())

OrderedDict([('linear.weight', tensor([[ 0.1238, -0.3248, -0.3635,  0.1627, -0.4055, -0.1760]])), ('linear.bias', tensor([0.0947]))])


In [11]:
# Create the model class before using its load methods
loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE)) # it takes the loaded dictionary, not the path file itself
loaded_model.eval()

print(loaded_model.state_dict())


OrderedDict([('linear.weight', tensor([[ 0.1238, -0.3248, -0.3635,  0.1627, -0.4055, -0.1760]])), ('linear.bias', tensor([0.0947]))])


### Load from a checkpoint

In [12]:
# Define our learning rate and optimizer
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [13]:
# Define a checkpoint dictionary state
checkpoint = {
"epoch": 90,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict()
}
print(checkpoint)

{'epoch': 90, 'model_state': OrderedDict([('linear.weight', tensor([[ 0.1238, -0.3248, -0.3635,  0.1627, -0.4055, -0.1760]])), ('linear.bias', tensor([0.0947]))]), 'optim_state': {'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [6520825168, 6520825408]}]}}


In [14]:
# Save checkpoint

FILE = "checkpoint.pth"
torch.save(checkpoint, FILE)

In [16]:
# define model and optimizer

model = Model(n_input_features=6)
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=0)

In [19]:
# Load from the checkpoint

checkpoint = torch.load(FILE)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
epoch = checkpoint['epoch']

model.eval()
# model.train()
print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [6539766992, 6539771984]}]}
