# PyTorch Model Persistence

One paramount important requirement in DL model training and learning is the ability to store and **save** the internal state of a model for the future!

We might want to **load and save** a model for:

1. **inference**; 
2. **re-start** the training where we left (i.e. _checkpoint_ )
3. **save** the best hyper-parameter configuration in a randomised _grid search_ optimisation
4. $\ldots$

There are **two** approaches for saving and loading models for inference in PyTorch. 

The **first** is saving and loading the `state_dict`, and the second is saving and loading the **entire model**.

##### Let's define our (usual) model and optimiser first

In [1]:
import torch

In [2]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = torch.relu(self.linear1(x))
        y_pred = self.linear2(h_relu)
        return y_pred

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

#### Saving the entire model using `pickle`

We could use the Python `pickle` module to save and load an entire model.

Using this approach yields the most intuitive syntax and involves the least amount of code.

In [4]:
import pickle

with open('model_serialisation.pkl', 'wb') as pkf: 
    pickle.dump(model, pkf)



In [5]:
with open('model_serialisation.pkl', 'rb') as pkf:
    model_pkl = pickle.load(pkf)
    for name_str, param in model_pkl.named_parameters():
        print("{:21} {:19} {}".format(name_str, str(param.shape), param.numel()))

linear1.weight        torch.Size([100, 1000]) 100000
linear1.bias          torch.Size([100])   100
linear2.weight        torch.Size([10, 100]) 1000
linear2.bias          torch.Size([10])    10


**However**, this method is far from being flexible: the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. 

The reason for this is because pickle does not save the model class itself. 
Rather, it saves a path to the file containing the class, which is used during load time. 

**For this reason**, your code can break in various ways when used in other projects or after refactors. 

## Introducing `model|optim.state_dict`

In PyTorch, the learnable parameters (i.e. weights and biases) of a `torch.nn.Module` model are contained in the model’s parameters (accessed with `model.parameters()`). 

A `state_dict` is simply a Python dictionary object that maps each layer to its parameter tensor.

In [6]:
# model (named) parameters
for name_str, param in model.named_parameters():
    print("{:21} {:19} {}".format(name_str, str(param.shape), param.numel()))

linear1.weight        torch.Size([100, 1000]) 100000
linear1.bias          torch.Size([100])   100
linear2.weight        torch.Size([10, 100]) 1000
linear2.bias          torch.Size([10])    10


In [7]:
optimizer.param_groups

[{'params': [Parameter containing:
   tensor([[-0.0132, -0.0153, -0.0047,  ...,  0.0122, -0.0004,  0.0304],
           [-0.0260, -0.0314, -0.0120,  ...,  0.0271,  0.0107, -0.0053],
           [-0.0219, -0.0167,  0.0007,  ..., -0.0293,  0.0283,  0.0269],
           ...,
           [-0.0116, -0.0082,  0.0290,  ..., -0.0251, -0.0084, -0.0225],
           [-0.0150,  0.0082, -0.0266,  ...,  0.0112,  0.0007, -0.0288],
           [ 0.0042,  0.0121,  0.0178,  ..., -0.0171, -0.0172,  0.0250]],
          requires_grad=True),
   Parameter containing:
   tensor([-2.8804e-02,  9.2716e-03,  3.0880e-02, -1.9075e-03, -2.9148e-02,
            1.5161e-03, -1.6314e-02, -1.2834e-02,  2.6948e-02, -2.5903e-02,
           -3.7491e-03, -2.8486e-03,  8.6375e-03, -4.5993e-03, -2.6027e-02,
           -2.3151e-02,  1.9023e-02,  2.8956e-02,  2.5606e-02, -7.8710e-03,
            2.1047e-02,  1.8598e-02, -2.4842e-02, -2.1742e-02,  2.3683e-02,
           -1.6587e-02, -1.8220e-02, -7.4297e-03, -7.9439e-03, -2.3919e-02

In [8]:
p = optimizer.param_groups[0]
type(p)

dict

In [9]:
list(p.keys())

['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov']

In [10]:
p['params']

[Parameter containing:
 tensor([[-0.0132, -0.0153, -0.0047,  ...,  0.0122, -0.0004,  0.0304],
         [-0.0260, -0.0314, -0.0120,  ...,  0.0271,  0.0107, -0.0053],
         [-0.0219, -0.0167,  0.0007,  ..., -0.0293,  0.0283,  0.0269],
         ...,
         [-0.0116, -0.0082,  0.0290,  ..., -0.0251, -0.0084, -0.0225],
         [-0.0150,  0.0082, -0.0266,  ...,  0.0112,  0.0007, -0.0288],
         [ 0.0042,  0.0121,  0.0178,  ..., -0.0171, -0.0172,  0.0250]],
        requires_grad=True),
 Parameter containing:
 tensor([-2.8804e-02,  9.2716e-03,  3.0880e-02, -1.9075e-03, -2.9148e-02,
          1.5161e-03, -1.6314e-02, -1.2834e-02,  2.6948e-02, -2.5903e-02,
         -3.7491e-03, -2.8486e-03,  8.6375e-03, -4.5993e-03, -2.6027e-02,
         -2.3151e-02,  1.9023e-02,  2.8956e-02,  2.5606e-02, -7.8710e-03,
          2.1047e-02,  1.8598e-02, -2.4842e-02, -2.1742e-02,  2.3683e-02,
         -1.6587e-02, -1.8220e-02, -7.4297e-03, -7.9439e-03, -2.3919e-02,
         -2.3253e-02,  1.3756e-03, -2.75

In [11]:
type(p['params'][0])

torch.nn.parameter.Parameter

In [12]:
for optim_param in ('lr', 'momentum', 'nesterov', 'weight_decay'):
    print(f'{optim_param}: {p[optim_param]}')

lr: 0.0001
momentum: 0
nesterov: False
weight_decay: 0


When we have to save a DL model, we definitely **need** to save model parameters (e.g. _inference_ ), but for other cases (i.e. _model checkpoint_ ) we **also need** to save **optimiser** `parameters` and `hyper-parameters`

##### `state_dict`

A `state_dict` is an integral entity if you are interested in saving or loading models from PyTorch. 

Because `state_dict` objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers. 

Note that **only** layers with learnable parameters and registered buffers (e.g. batchnorm’s running_mean) have entries in the model’s `state_dict`. Optimizer objects (`torch.optim`) also have a `state_dict`, which contains information about the optimizer’s state, as well as the **hyperparameters** used. 

In [13]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print()

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
linear1.weight 	 torch.Size([100, 1000])
linear1.bias 	 torch.Size([100])
linear2.weight 	 torch.Size([10, 100])
linear2.bias 	 torch.Size([10])

Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140227236321408, 140227235900480, 140227235899072, 140227235901184]}]


### Saving and Loading models for Inference in PyTorch

Each instance of a `torch.nn.Module` can be saved using the `torch.save()` function.

Saving the model’s `state_dict` with the `torch.save()` function will give you the most flexibility for restoring the model later. 

This is the **recommended method** for saving models, because it is only really necessary to save the trained model’s learned parameters. 



In [14]:
# Save
torch.save(model.state_dict(), "model_state_dict.pt")

In [15]:
# Load
model = TwoLayerNet(D_in, H, D_out)
model.load_state_dict(torch.load("model_state_dict.pt"))

<All keys matched successfully>

A common PyTorch **convention** is to save models using either a `.pt` or `.pth` file extension.

Notice that the `load_state_dict()` function takes a dictionary object, NOT a `path` to a saved object. 

This means that you **must** deserialize the saved `state_dict` before you pass it to the `load_state_dict()` function. 

For example, you **CANNOT** just load using `model.load_state_dict("path_to_file.pt")`.

###### Saving and Loading Entire Model

Let’s try the same thing with the entire model.

In [16]:
# Save
torch.save(model, "model.pth")

# Load
model = torch.load("model.pth")



In [17]:
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

linear1.weight 	 torch.Size([100, 1000])
linear1.bias 	 torch.Size([100])
linear2.weight 	 torch.Size([10, 100])
linear2.bias 	 torch.Size([10])


---

### Saving and loading model checkpoint

Saving and loading a general `checkpoint model` for inference or resuming training can be helpful for picking up where you last left off. 

When saving a general checkpoint, you must save more than just the model’s `state_dict`. 

It is **also important** to save the **optimizer**’s `state_dict`, as this contains buffers and parameters that are updated as the model trains. 

**Moreover**, you might also want to save the `epoch` you left off on, the latest recorded `training loss`, external layers, and more, based on your own algorithm.

In [18]:
# Additional information
EPOCH = 5
LOSS = 0.4
CHKPOINT = "model_checpoint.pth"

torch.save({'epoch': EPOCH,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, CHKPOINT)

In [19]:
# Load
model = TwoLayerNet(D_in, H, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

checkpoint = torch.load(CHKPOINT)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [20]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print()

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
linear1.weight 	 torch.Size([100, 1000])
linear1.bias 	 torch.Size([100])
linear2.weight 	 torch.Size([10, 100])
linear2.bias 	 torch.Size([10])

Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140227209983552, 140227209983936, 140227209984128, 140227209984192]}]


In [21]:
print('Loss from Checkpoint: ', loss)
print('Epoch: ', epoch)

Loss from Checkpoint:  0.4
Epoch:  5
