https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [1]:
import torch
import torch.nn as nn

In [2]:
class Model(nn.Module):
    def __init__(self, n_input):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input, 1)

    def forward(self, x):
        Y_pred = torch.sigmoid(self.linear(x))
        return Y_pred

In [16]:
model = Model(n_input = 6)

##   ***Load Lazy***

In [4]:
filename = 'model.pth'
torch.save(model, filename)

In [15]:
my_model = torch.load(filename)
my_model.eval()

Model(
  (linear): Linear(in_features=6, out_features=1, bias=True)
)

In [13]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[ 0.0718,  0.2085,  0.2047, -0.0131, -0.0875,  0.3444]],
       requires_grad=True)
Parameter containing:
tensor([-0.1778], requires_grad=True)


##  ***Prefer way to Save & Load model***

In [18]:
#   Save model
filename = 'prefer_model.pth'
torch.save(model.state_dict(), filename)

In [24]:
model = Model(n_input = 6)
model.load_state_dict(torch.load(filename))
model.eval()

Model(
  (linear): Linear(in_features=6, out_features=1, bias=True)
)

#### ***State_dict and Model checkpoint***

In [22]:
#   Load model
model = Model(n_input = 6)
lr = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr = lr)
optimizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.01,
   'momentum': 0,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [0, 1]}]}

In [26]:
#   Checkpoint
checkpoint = {'epoch': 90,
              'model_state': model.state_dict(),
              'optim_state': optimizer.state_dict()}
torch.save(checkpoint, 'checkpoint_model.pth')

In [33]:
loaded_checkpoint = torch.load('checkpoint_model.pth') 
model = Model(n_input=6)
epoch = loaded_checkpoint['epoch']
model.load_state_dict(loaded_checkpoint['model_state'])
optimizer.load_state_dict(loaded_checkpoint['optim_state'])

In [47]:
n = optimizer.state_dict()['param_groups'][0]

n['lr']

0.01