# PyTorch model (de)serialization

## Saving the model

In this example we will explore the serialization and deserialization of PyTorch model. We'll use the [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) from previous examples, augmented with `torch.save()` call at the end.

We save the trained model like this:

```python
torch.save({
    'epoch': args.epochs, # == 10
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, './mnist-model.pt')
```

## Loading PyTorch model

Here we'll assume that we already have the file with the saved MNIST model with all default hyperparameters and trained for 10 epochs. Let's load it.

In [1]:
import torch

model_state = torch.load('./mnist-model.pt')

print(type(model_state))
print(model_state.keys())
print('epoch =', model_state['epoch'])

<class 'dict'>
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])
epoch = 10


So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [pickle](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle).