<a href="https://colab.research.google.com/github/mostafa-ja/intro-to-pytorch/blob/master/Part%206%20-%20Saving%20and%20Loading%20Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Saving and Loading Models

In this notebook, I'll show you how to save and load models with PyTorch. This is important because you'll often want to load previously trained models to use in making predictions or to continue training on new data.

In [2]:
!wget 'https://raw.githubusercontent.com/udacity/deep-learning-v2-pytorch/master/intro-to-pytorch/fc_model.py'

--2023-02-21 12:22:43--  https://raw.githubusercontent.com/udacity/deep-learning-v2-pytorch/master/intro-to-pytorch/fc_model.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3543 (3.5K) [text/plain]
Saving to: ‘fc_model.py’


2023-02-21 12:22:43 (58.6 MB/s) - ‘fc_model.py’ saved [3543/3543]



In [3]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms

import helper
import fc_model

In [4]:
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting /root/.pytorch/F_MNIST_data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting /root/.pytorch/F_MNIST_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting /root/.pytorch/F_MNIST_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting /root/.pytorch/F_MNIST_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /root/.pytorch/F_MNIST_data/FashionMNIST/raw



# Train a network

To make things more concise here, I moved the model architecture and training code from the last part to a file called `fc_model`. Importing this, we can easily create a fully-connected network with `fc_model.Network`, and train the network using `fc_model.train`. I'll use this model (once it's trained) to demonstrate how we can save and load models.

In [6]:
# Create the network, define the criterion and optimizer

model = fc_model.Network(784, 10, [512, 256, 128])
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [7]:
fc_model.train(model, trainloader, testloader, criterion, optimizer, epochs=2)

Epoch: 1/2..  Training Loss: 1.679..  Test Loss: 0.997..  Test Accuracy: 0.630
Epoch: 1/2..  Training Loss: 1.018..  Test Loss: 0.758..  Test Accuracy: 0.724
Epoch: 1/2..  Training Loss: 0.828..  Test Loss: 0.676..  Test Accuracy: 0.751
Epoch: 1/2..  Training Loss: 0.796..  Test Loss: 0.639..  Test Accuracy: 0.742
Epoch: 1/2..  Training Loss: 0.756..  Test Loss: 0.619..  Test Accuracy: 0.769
Epoch: 1/2..  Training Loss: 0.706..  Test Loss: 0.592..  Test Accuracy: 0.775
Epoch: 1/2..  Training Loss: 0.667..  Test Loss: 0.556..  Test Accuracy: 0.792
Epoch: 1/2..  Training Loss: 0.671..  Test Loss: 0.581..  Test Accuracy: 0.783
Epoch: 1/2..  Training Loss: 0.670..  Test Loss: 0.561..  Test Accuracy: 0.790
Epoch: 1/2..  Training Loss: 0.644..  Test Loss: 0.543..  Test Accuracy: 0.796
Epoch: 1/2..  Training Loss: 0.621..  Test Loss: 0.534..  Test Accuracy: 0.803
Epoch: 1/2..  Training Loss: 0.618..  Test Loss: 0.537..  Test Accuracy: 0.801
Epoch: 1/2..  Training Loss: 0.638..  Test Loss: 0.5

## Saving and loading networks

As you can imagine, it's impractical to train a network every time you need to use it. Instead, we can save trained networks then load them later to train more or use them for predictions.

The parameters for PyTorch networks are stored in a model's `state_dict`. We can see the state dict contains the weight and bias matrices for each of our layers.

In [8]:
print("Our model: \n\n", model, '\n')
print("The state dict keys: \n\n", model.state_dict().keys())

Our model: 

 Network(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
  )
  (output): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
) 

The state dict keys: 

 odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'hidden_layers.2.weight', 'hidden_layers.2.bias', 'output.weight', 'output.bias'])


In [28]:
model.state_dict()['hidden_layers.0.weight']

tensor([[ 0.0649,  0.0205,  0.0485,  ...,  0.0290,  0.0374,  0.0365],
        [ 0.0165, -0.0147,  0.0378,  ...,  0.0214,  0.0185,  0.0183],
        [ 0.0107,  0.0557,  0.0459,  ...,  0.0284,  0.0203,  0.0028],
        ...,
        [ 0.0684,  0.0567,  0.0904,  ...,  0.0921,  0.0859,  0.0542],
        [ 0.0279,  0.0456, -0.0193,  ..., -0.0090,  0.0289,  0.0370],
        [ 0.0067,  0.0195, -0.0213,  ...,  0.0256, -0.0235,  0.0198]])

In [20]:
model.hidden_layers[0].weight.data

tensor([[ 0.0649,  0.0205,  0.0485,  ...,  0.0290,  0.0374,  0.0365],
        [ 0.0165, -0.0147,  0.0378,  ...,  0.0214,  0.0185,  0.0183],
        [ 0.0107,  0.0557,  0.0459,  ...,  0.0284,  0.0203,  0.0028],
        ...,
        [ 0.0684,  0.0567,  0.0904,  ...,  0.0921,  0.0859,  0.0542],
        [ 0.0279,  0.0456, -0.0193,  ..., -0.0090,  0.0289,  0.0370],
        [ 0.0067,  0.0195, -0.0213,  ...,  0.0256, -0.0235,  0.0198]])

The simplest thing to do is simply save the state dict with `torch.save`. For example, we can save it to a file `'checkpoint.pth'`.

In [29]:
torch.save(model.state_dict(), 'checkpoint.pth')

Then we can load the state dict with `torch.load`.

In [30]:
state_dict = torch.load('checkpoint.pth')
print(state_dict.keys())

odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'hidden_layers.2.weight', 'hidden_layers.2.bias', 'output.weight', 'output.bias'])


And to load the state dict in to the network, you do `model.load_state_dict(state_dict)`.

In [31]:
model.load_state_dict(state_dict)

<All keys matched successfully>

Seems pretty straightforward, but as usual it's a bit more complicated. Loading the state dict works only if the model architecture is exactly the same as the checkpoint architecture. If I create a model with a different architecture, this fails.

In [32]:
# Try this
model = fc_model.Network(784, 10, [400, 200, 100])
# This will throw an error because the tensor sizes are wrong!
model.load_state_dict(state_dict)

RuntimeError: ignored

This means we need to rebuild the model exactly as it was when trained. Information about the model architecture needs to be saved in the checkpoint, along with the state dict. To do this, you build a dictionary with all the information you need to compeletely rebuild the model.

In [35]:
model.hidden_layers

ModuleList(
  (0): Linear(in_features=784, out_features=400, bias=True)
  (1): Linear(in_features=400, out_features=200, bias=True)
  (2): Linear(in_features=200, out_features=100, bias=True)
)

In [34]:
model.hidden_layers[0]

Linear(in_features=784, out_features=400, bias=True)

In [36]:
model.hidden_layers[0].out_features

400

In [39]:
checkpoint = {'input_size': 784,
              'output_size': 10,
              'hidden_layers': [each.out_features for each in model.hidden_layers],
              'state_dict': model.state_dict()}

torch.save(checkpoint, 'checkpoint.pth')

Now the checkpoint has all the necessary information to rebuild the trained model. You can easily make that a function if you want. Similarly, we can write a function to load checkpoints. 

In [40]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = fc_model.Network(checkpoint['input_size'],
                             checkpoint['output_size'],
                             checkpoint['hidden_layers'])
    model.load_state_dict(checkpoint['state_dict'])
    
    return model

In [41]:
model = load_checkpoint('checkpoint.pth')
print(model)

Network(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=784, out_features=400, bias=True)
    (1): Linear(in_features=400, out_features=200, bias=True)
    (2): Linear(in_features=200, out_features=100, bias=True)
  )
  (output): Linear(in_features=100, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
