# Saving model

Saving and loading PyTorch models is crucial because any model you build needs to be transferred and deployed in some way. Check the [official tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html). Here, we'll experiment with the options from the tutorial.

In [4]:
import torch
from pathlib import Path

## State dict

The classical method to save and load a model's state dictionary follows these steps:

- Retrieve the model's state dictionary with `torch.nn.Module.state_dict()`.
- Save the state dictionary with `torch.save`.
- Load the state dictionary with `torch.load`.
- Load the weights into the model using `torch.nn.Module.load_state_dict()`.

---

In the following cell, we created a simple model and initialized it with a constant value.

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(3, 3),
    torch.nn.Linear(3, 3)
)
for p in model.parameters():
    torch.nn.init.constant_(p, 3)

model.state_dict()

OrderedDict([('0.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('0.bias', tensor([3., 3., 3.])),
             ('1.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('1.bias', tensor([3., 3., 3.]))])

Now, using `torch.save`, we save the state dictionary and discard the original model.

In [None]:
torch.save(obj=model.state_dict(), f=Path("/tmp")/"my_model")
del model

Now, with `torch.load`, we load the state dictionary—since all values were constant during saving, they remain as 3.

In [None]:
state_dict = torch.load(Path("/tmp")/"my_model", weights_only=False)
state_dict

OrderedDict([('0.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('0.bias', tensor([3., 3., 3.])),
             ('1.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('1.bias', tensor([3., 3., 3.]))])

By recreating the exact same model and loading the previously saved state dictionary into it, you can fully recreate the model.

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(3, 3),
    torch.nn.Linear(3, 3)
)

model.load_state_dict(state_dict)

<All keys matched successfully>

Successfully executed `torch.nn.load_state_dict` returns a special string.

## Save entire model

By passing the entire model to the `torch.save` function, you'll save a serialized Torch model. Then, with just one line of code, you can restore the model using the `torch.load` function.

---

The following cell demonstrates creating a model, initializing its weights with a constant, and saving this model to disk by passing it as `obj` to the `torch.save` function.

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(3, 3),
    torch.nn.Linear(3, 3)
)
for p in model.parameters():
    torch.nn.init.constant_(p, 3)

torch.save(model, Path("/tmp")/"model.pht")

With `torch.load`, the model can be restored. The following cell shows that the loaded model has weights identical to those initialized before saving.

In [None]:
torch.load(Path("/tmp")/"model.pht", weights_only=False).state_dict()

OrderedDict([('0.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('0.bias', tensor([3., 3., 3.])),
             ('1.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('1.bias', tensor([3., 3., 3.]))])