# Save and Load Model

Let's learn how to save and load models in PyTorch.

In [1]:
import torch
import torchvision.models as models

## Saving and Loading Model Weights

PyTorch models store their learned parameters in a model's `state_dict`.

### Save

Use `torch.save` to save a model's `state_dict`.

In [2]:
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\johnj/.cache\torch\hub\checkpoints\vgg16-397923af.pth
100%|██████████| 528M/528M [00:10<00:00, 53.4MB/s] 


### Load

Use `torch.load_state_dict` to load a model's `state_dict`.

We should call `model.eval()` before inference to set the model to evaluation mode.

In [3]:
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

  model.load_state_dict(torch.load('model_weights.pth'))


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

## Saving and Loading Models with Shapes

When loading the model's weights, we also need to instantiate the model first. Since the model class includes the architecture, we need to instantiate the model first before loading the weights.

### Save

In [4]:
torch.save(model, 'model.pth')

### Load

In [5]:
model = torch.load('model.pth')

  model = torch.load('model.pth')
