# PyTorch model (de)serialization

## Saving the model

In this example we will explore the serialization and deserialization of PyTorch model. We'll use the [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) from previous examples, augmented with `torch.save()` call at the end.

We save the trained model like this:

```python
torch.save({
    'epoch': args.epochs, # == 10
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, './mnist-model.pt')
```

## Loading PyTorch model

Here we'll assume that we already have the file with the saved MNIST model with all default hyperparameters and trained for 10 epochs. Loading it is simple:

In [1]:
import torch

model_state = torch.load('./mnist-model.pt')

print(type(model_state))
print(model_state.keys())
print('epoch =', model_state['epoch'])

<class 'dict'>
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])
epoch = 10


So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [`pickle`](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle). The most interesting part here are the model's and optimizer's parameters, as returned from [`torch.nn.Module.state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.state_dict) method. Let's take a closer look.

In [2]:
model_params = model_state['model_state_dict']
print("model_params:", type(model_params), "\n")

for (key, val) in model_params.items():
    print("%12s: %s %s" % (key, type(val), val.size()))
    
print("\n%12s: %s" % ("conv1.bias", model_params['conv1.bias']))

model_params: <class 'collections.OrderedDict'> 

conv1.weight: <class 'torch.Tensor'> torch.Size([10, 1, 5, 5])
  conv1.bias: <class 'torch.Tensor'> torch.Size([10])
conv2.weight: <class 'torch.Tensor'> torch.Size([20, 10, 5, 5])
  conv2.bias: <class 'torch.Tensor'> torch.Size([20])
  fc1.weight: <class 'torch.Tensor'> torch.Size([50, 320])
    fc1.bias: <class 'torch.Tensor'> torch.Size([50])
  fc2.weight: <class 'torch.Tensor'> torch.Size([10, 50])
    fc2.bias: <class 'torch.Tensor'> torch.Size([10])

  conv1.bias: tensor([ 0.0272, -0.0762, -0.0617,  0.0235,  0.1745,  0.0320,  0.0871,  0.0674,
        -0.0222, -0.0541])


That is, `.state_dict()` produces an `OrderedDict` of tensors, and uses for keys names of the variables and their parameters.

Now we need to populate the actual model's parameters (on CUDA or CPU) with that data. For that, we have to use the method [`torch.nn.Module.load_state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.load_state_dict). Unfortunately, it won't recreate the model's topology for us. We have to use the code from [MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py) to build it explicitly:

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net()

Remember, that initially all parameters are initialized with random values, e.g.

In [4]:
model.conv1.bias

Parameter containing:
tensor([ 0.1578,  0.1650, -0.1272, -0.1976,  0.0318, -0.1246, -0.0474, -0.0620,
         0.1829, -0.1198], requires_grad=True)

Now we can populate them with data from the file:

In [5]:
model.load_state_dict(model_params)

model.conv1.bias

Parameter containing:
tensor([ 0.0272, -0.0762, -0.0617,  0.0235,  0.1745,  0.0320,  0.0871,  0.0674,
        -0.0222, -0.0541], requires_grad=True)