# Saving and Loading models

It is important to understand how to save a model so we can reuse it without running the entire training pipeline over and over, as well as the various formats, and how PyTorch handles this process. We could also for example, save the best state of a model during the training pipeline to ensure best results.

In [1]:
# Imports
import torch
import torch.nn as nn

We have 3 methods we can use to save/load a model

- 1 <br>
    `torch.save(arg, PATH)`
- 2 <br>
    `torch.load(PATH)`
- 3 <br>
    `model.load_state_dict(arg)`



`torch.save` can use tensors, models, or dictionaries as parameters for saving. It makes use of Python's pickle model to serialize the model. <br>

The lazy method to use this would be like:
```py
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()  # Evaluation mode
```
The disadvantage of this approach is that the serialized data is bound to the specific classes and exact directory structure used when the model is saved.
<br>
The recommended way of saving our model is like so:
```py
torch.save(model.state_dict(), PATH)  # Saves the parameters
# Model must be created again with parameters
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
```
Now let's see some examples in practice.

In [3]:
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)
    
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

# This is a super simple example model
model = Model(n_input_features=6)
# Training loop would go here
# ...

FILE = "model.pth"
torch.save(model, FILE)  # LAZY

In [4]:
# Load the model
model = torch.load(FILE)
model.eval()  # Evaluation mode

# Use the model anyhow
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.3934, -0.3836,  0.0619, -0.1382, -0.0944, -0.3254]],
       requires_grad=True)
Parameter containing:
tensor([-0.2454], requires_grad=True)


In [5]:
# Actual recommended method
torch.save(model.state_dict(), FILE)

In [6]:
# Load from state dict
loaded_model = Model(n_input_features=6)  # Reinstantiate
loaded_model.load_state_dict(torch.load(FILE))
loaded_model.eval()  # Evaluation mode

# Use the model anyhow
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.3934, -0.3836,  0.0619, -0.1382, -0.0944, -0.3254]],
       requires_grad=True)
Parameter containing:
tensor([-0.2454], requires_grad=True)


In [7]:
# State dict stuff
print(model.state_dict())

OrderedDict([('linear.weight', tensor([[-0.3934, -0.3836,  0.0619, -0.1382, -0.0944, -0.3254]])), ('linear.bias', tensor([-0.2454]))])


In [9]:
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# We can create checkpoints to stop during training
# We make a checkpoint dict
checkpoint = {
    "epoch": 90,  # Example
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict(),
}

torch.save(checkpoint, "checkpoint.pth")

# Load checkpoint
loaded_checkpoint = torch.load("checkpoint.pth")
epoch = loaded_checkpoint["epoch"]

# Make new model to continue training
model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(), lr=0)  # lr can be whatever for now

model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optim_state"])

# Can continue from current epoch stored in state_dict
# Note device interoperability matters (CPU/GPU). Check the docs.