# SAVING AND LOADING MODELS

This document provides solutions to a variety of use cases regarding the saving and loading of PyTorch models. Feel free to read the whole document, or just skip to the code you need for a desired use case.

When it comes to saving and loading models, there are three core functions to be familiar with:

- `torch.save`: Saves a serialized object to disk. This function uses Python’s `pickle` utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
- `torch.load`: Uses `pickle`’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
- `torch.nn.Module.load_state_dict`: Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?.

## Contents:

- What is a state_dict?
- Saving & Loading Model for Inference
- Saving & Loading a General Checkpoint
- Saving Multiple Models in One File
- Warmstarting Model Using Parameters from a Different Model
- Saving & Loading Model Across Devices

## What is a state_dict?

In PyTorch, the learnable parameters (i.e. weights and biases) of an `torch.nn.Module` model are contained in the model’s parameters (accessed with `model.parameters()`). A `state_dict` is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict. Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used.

Because `state_dict` objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

## Example:

Let’s take a look at the `state_dict` from the simple model used in the Training a classifier tutorial.

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim

In [13]:
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    
# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140158265378928, 140158265379720, 140158265378568, 140158265378712, 140158265378496, 140158265378280, 140158265378064, 140158265377848, 140158265378856, 140158265379216]}]


In [44]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-4.2429e-02,  5.9782e-02,  5.0412e-02,  1.8451e-02,  2.6615e-02],
                        [ 9.4491e-02, -7.4326e-02, -3.4779e-03,  9.1112e-02, -5.9225e-02],
                        [ 9.0978e-02, -6.1702e-03,  5.4920e-02, -5.3032e-03,  6.6603e-02],
                        [-4.8112e-03, -3.5489e-02, -1.1200e-01,  3.2576e-02, -9.5213e-02],
                        [ 1.0924e-01,  6.0533e-02,  1.0165e-01, -8.4772e-03, -4.2743e-02]],
              
                       [[-3.8417e-02, -4.2031e-02, -6.0666e-03, -8.6070e-02, -9.4199e-02],
                        [ 1.1180e-01,  1.1926e-02, -7.3397e-02, -9.8471e-02,  3.5933e-02],
                        [ 5.8302e-02, -7.2425e-02, -9.3167e-02, -7.1973e-02, -5.0978e-02],
                        [ 1.1384e-01,  7.0082e-02, -6.8690e-02,  6.5220e-02,  8.0614e-02],
                        [ 7.5960e-02,  5.6229e-02, -9.1138e-03,  1.1053e-01,  2.2933e-02]],
              
                       [[-2.

## Saving & Loading Model for Inference

Save/Load state_dict (Recommended)

### Save:
`torch.save(model.state_dict(), PATH)`

### Load:
``model = TheModelClass(*args, **kwargs)``

``model.load_state_dict(torch.load(PATH))``

`model.eval()`

In [31]:
PATH = "../../../../MEGA/DatabaseLocal/myNet.pt"
torch.save(model.state_dict(), PATH)

**When saving a model for inference**, it is only necessary to save the trained model’s learned parameters. Saving the model’s state_dict with the torch.save() function will give you the most flexibility for restoring the model later, which is why it is the recommended method for saving models.

A common PyTorch convention is to save models using either a .pt or .pth file extension.

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

#### NOTE

Notice that the load_state_dict() function takes a dictionary object, NOT a path to a saved object. This means that you must deserialize the saved state_dict before you pass it to the load_state_dict() function. For example, you CANNOT load using model.load_state_dict(PATH).

In [40]:
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

## Save/Load Entire Model
#### Save:

torch.save(model, PATH)

In [47]:
torch.save(model, PATH)

#### Load:

#### Model class must be defined somewhere

model = torch.load(PATH)

model.eval()

In [48]:
model = torch.load(PATH)
model.eval()

TheModelClass(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

This save/load process uses the most intuitive syntax and involves the least amount of code. Saving a model in this way will save the entire module using Python’s pickle module. **The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved**. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors.

A common PyTorch convention is to save models using either a .pt or .pth file extension.

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

## Saving & Loading a General Checkpoint for Inference and/or Resuming Training

#### Save:

#### Load: