<a href="https://colab.research.google.com/github/onejbsmith/ExportedData/blob/main/Save_and_Load_Your_PyTorch_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Save and Load Your PyTorch Models

By [Adrian Tam](https://machinelearningmastery.com/author/adriantam/ "Posts by Adrian Tam") on February 13, 2023 in [Deep Learning with PyTorch](https://machinelearningmastery.com/category/deep-learning-with-pytorch/ "View all items in Deep Learning with PyTorch")

_Tweet_ _Tweet_ Share Share

Last Updated on April 8, 2023

A deep learning model is a mathematical abstraction of data, in which a lot of parameters are involved. Training these parameters can take hours, days, and even weeks but afterward, you can make use of the result to apply on new data. This is called inference in machine learning. It is important to know how we can preserve the trained model in disk and later, load it for use in inference. In this post, you will discover how to save your PyTorch models to files and load them up again to make predictions. After reading this chapter, you will know:

- What are states and parameters in a PyTorch model
- How to save model states
- How to load model states

**Kick-start your project** with my book [Deep Learning with PyTorch](https://machinelearningmastery.com/deep-learning-with-pytorch/). It provides **self-study tutorials** with **working code**.

## Overview

This post is in three parts; they are

- Build an Example Model
- What’s Inside a PyTorch Model
- Accessing `state_dict` of a Model

## Build an Example Model

Let’s start with a very simple model in PyTorch. It is a model based on the iris dataset. You will load the dataset using scikit-learn (which the targets are integer labels 0, 1, and 2) and train a neural network for this multiclass classification problem. In this model, you used log softmax as the output activation so you can combine with the negative log likelihood loss function. It is equivalent to no output activation combined with cross entropy loss function.

In [1]:
#@title Example Model
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# convert NumPy array into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# split
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)

# PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

model = Multiclass()
    
# loss metric and optimizer
loss_fn = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# prepare model and training parameters
n_epochs = 100
batch_size = 5
batch_start = torch.arange(0, len(X), batch_size)

# training loop
for epoch in range(n_epochs):
    for start in batch_start:
        # take a batch
        X_batch = X_train[start:start+batch_size]
        y_batch = y_train[start:start+batch_size]
        # forward pass
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        # update weights
        optimizer.step()

In [2]:
#@title #####With such a simple model and small dataset, it shouldn’t take a long time to finish training. Afterwards, we can confirm that this model works, by evaluating it with the test set:
...
y_pred = model(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Accuracy: 1.00


## What’s Inside a PyTorch Model

PyTorch model is an object in Python. It holds some deep learning building blocks such as various kinds of layers and activation functions. It also knows how to connect them so it can produce you an output from your input tensors. The algorithm of a model is fixed at the time you created it, however, it has trainable parameters that is supposed to be modified during training loop so the model can be more accurate.

You saw how to get the model parameters when you set up the optimizer for your training loop, namely,

In [None]:

optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer

In [3]:
#@title #####The function `model.parameters()` give you a generator that references each layers’ trainable parameters in turn in the form of PyTorch tensors. Therefore, it is possible for you to make a copy of them or overwrite them, for example:

# Create a new model
newmodel = Multiclass()

# Ask PyTorch to ignore autograd on update and overwrite parameters
# Copy old tensor from model.parameters() to new tensor of newmodel.parameters()
with torch.no_grad():
    for newtensor, oldtensor in zip(newmodel.parameters(), model.parameters()):
        newtensor.copy_(oldtensor)

# newmodel has been initialized by copying its 
# parameter tenso from model

# test with new model using copied tensor
y_pred = newmodel(X_test)

# show accuracy
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Accuracy: 1.00


Which the result should be exactly the same as before since you essentially made the two models identical by copying the parameters.

However, this is not always the case. Some models has **non-trainable parameters**. One example is the batch normalization layer that is common in many convolution neural networks. What it does is to apply normalization on tensors that produced by its previous layer and pass on the normalized tensor to its next layer. It has two parameters: The mean and standard deviation, which are learned from your input data during training loop but not trainable by the optimizer. Therefore these are not part of `model.parameters()` but equally important.

## Accessing `state_dict` of a Model

To access all parameters of a model, trainable or not, you can get it from `state_dict()` function. From the model above, this is what you can get:

In [None]:

import pprint
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(model.state_dict())

It is called `state_dict` because all state variables of a model are here. It is an `OrderedDict` object from Python’s built-in `collections` module. All components from a PyTorch model has a name and so as the parameters therein. The `OrderedDict` object allows you to map the weights back to the parameters correctly by matching their names.

This is how you should save and load the model: Fetch the model states into an `OrderedDict`, serialize and save it to disk. For inference, you create a model first (without training), and load the states. In Python, the native format for serialization is pickle:

In [6]:
#@title Save model
import pickle

# Save model
with open("iris-model.pickle", "wb") as fp:
    pickle.dump(model.state_dict(), fp)
    
# Create new model and load states
newmodel = Multiclass()
with open("iris-model.pickle", "rb") as fp:
    newmodel.load_state_dict(pickle.load(fp))

# test with new model using copied tensor
y_pred = newmodel(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Accuracy: 1.00


You know it works because the model you didn’t train produced the same result as the one you trained.

Indeed, the recommended way is to use the PyTorch API to save and load the states, instead of using pickle manually:

In [7]:
#@title Save model
torch.save(model.state_dict(), "iris-model.pth")

# Create new model and load states
newmodel = Multiclass()
newmodel.load_state_dict(torch.load("iris-model.pth"))

# test with new model using copied tensor
y_pred = newmodel(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Accuracy: 1.00


The `*.pth` file is indeed a zip file of some pickle files created by PyTorch. It is recommended because PyTorch can store additional information in it. Note that you stored only the states but not the model. You still need to create the model using Python code and load the states into it. If you wish to store the model as well, you can pass in the entire model instead of the states:

In [None]:

#@title Save model
torch.save(model, "iris-model-full.pth")
 
# Load model
newmodel = torch.load("iris-model-full.pth")
 
# test with new model using copied tensor
y_pred = newmodel(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

But remember, due to the nature of Python language, doing so does not relieve you from keeping the code of the model. The `newmodel` object above is an instance of `Multiclass` class that you defined before. When you load the model from disk, Python need to know in detail how this class is defined. If you run a script with just the line `torch.load()`, you will see the following error message:

Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../torch/serialization.py", line 789, in load return \_load(opened\_zipfile, map\_location, pickle\_module, \*\*pickle\_load\_args) File "/.../torch/serialization.py", line 1131, in \_load result = unpickler.load() File "/.../torch/serialization.py", line 1124, in find\_class return super().find\_class(mod\_name, name) AttributeError: Can't get attribute 'Multiclass' on <module '\_\_main\_\_' (built-in)>

<table class="crayon-table" style=""><tbody><tr class="urvanov-syntax-highlighter-row"><td class="crayon-nums " data-settings="show"><div class="urvanov-syntax-highlighter-nums-content" style="font-size: 12px !important; line-height: 15px !important;"><div class="crayon-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-1">1</div><div class="crayon-num crayon-striped-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-2">2</div><div class="crayon-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-3">3</div><div class="crayon-num crayon-striped-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-4">4</div><div class="crayon-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-5">5</div><div class="crayon-num crayon-striped-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-6">6</div><div class="crayon-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-7">7</div><div class="crayon-num crayon-striped-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-8">8</div><div class="crayon-num" data-line="urvanov-syntax-highlighter-643a8358291ae326442672-9">9</div></div></td><td class="urvanov-syntax-highlighter-code"><div class="crayon-pre" style="font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;"><div class="crayon-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-1"><span class="crayon-e">Traceback</span><span class="crayon-h"> </span><span class="crayon-sy">(</span><span class="crayon-e">most </span><span class="crayon-e">recent </span><span class="crayon-e">call </span><span class="crayon-v">last</span><span class="crayon-sy">)</span><span class="crayon-o">:</span></div><div class="crayon-line crayon-striped-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-2"><span class="crayon-i">File</span><span class="crayon-h"> </span><span class="crayon-s">"&lt;stdin&gt;"</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-i">line</span><span class="crayon-h"> </span><span class="crayon-cn">1</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-st">in</span><span class="crayon-h"> </span><span class="crayon-o">&lt;</span><span class="crayon-v">module</span><span class="crayon-o">&gt;</span></div><div class="crayon-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-3"><span class="crayon-i">File</span><span class="crayon-h"> </span><span class="crayon-s">"/.../torch/serialization.py"</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-i">line</span><span class="crayon-h"> </span><span class="crayon-cn">789</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-st">in</span><span class="crayon-h"> </span><span class="crayon-e">load</span></div><div class="crayon-line crayon-striped-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-4"><span class="crayon-st">return</span><span class="crayon-h"> </span><span class="crayon-e">_load</span><span class="crayon-sy">(</span><span class="crayon-v">opened_zipfile</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-v">map_location</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-v">pickle_module</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-o">*</span><span class="crayon-o">*</span><span class="crayon-v">pickle_load_args</span><span class="crayon-sy">)</span></div><div class="crayon-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-5"><span class="crayon-i">File</span><span class="crayon-h"> </span><span class="crayon-s">"/.../torch/serialization.py"</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-i">line</span><span class="crayon-h"> </span><span class="crayon-cn">1131</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-st">in</span><span class="crayon-h"> </span><span class="crayon-e">_load</span></div><div class="crayon-line crayon-striped-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-6"><span class="crayon-v">result</span><span class="crayon-h"> </span><span class="crayon-o">=</span><span class="crayon-h"> </span><span class="crayon-v">unpickler</span><span class="crayon-sy">.</span><span class="crayon-e">load</span><span class="crayon-sy">(</span><span class="crayon-sy">)</span></div><div class="crayon-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-7"><span class="crayon-i">File</span><span class="crayon-h"> </span><span class="crayon-s">"/.../torch/serialization.py"</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-i">line</span><span class="crayon-h"> </span><span class="crayon-cn">1124</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-st">in</span><span class="crayon-h"> </span><span class="crayon-e">find_class</span></div><div class="crayon-line crayon-striped-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-8"><span class="crayon-st">return</span><span class="crayon-h"> </span><span class="crayon-r">super</span><span class="crayon-sy">(</span><span class="crayon-sy">)</span><span class="crayon-sy">.</span><span class="crayon-e">find_class</span><span class="crayon-sy">(</span><span class="crayon-v">mod_name</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-v">name</span><span class="crayon-sy">)</span></div><div class="crayon-line" id="urvanov-syntax-highlighter-643a8358291ae326442672-9"><span class="crayon-v">AttributeError</span><span class="crayon-o">:</span><span class="crayon-h"> </span><span class="crayon-i">Can</span><span class="crayon-s">'t get attribute '</span><span class="crayon-i">Multiclass</span><span class="crayon-s">' on &lt;module '</span><span class="crayon-v">__main_</span><span class="crayon-sy">_</span>'<span class="crayon-h"> </span><span class="crayon-sy">(</span><span class="crayon-v">built</span><span class="crayon-o">-</span><span class="crayon-st">in</span><span class="crayon-sy">)</span><span class="crayon-o">&gt;</span></div></div></td></tr></tbody></table>

That’s why it is recommended to save only the state dict rather than the entire model.

Putting everything together, the following is the complete code to demonstrate how to create a model, train it, and save to disk:

In [9]:
#@title How to create a model, train it, and save to disk
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# convert NumPy array into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# split
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)

# PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

model = Multiclass()
    
# loss metric and optimizer
loss_fn = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# prepare model and training parameters
n_epochs = 100
batch_size = 5
batch_start = torch.arange(0, len(X), batch_size)

# training loop
for epoch in range(n_epochs):
    for start in batch_start:
        # take a batch
        X_batch = X_train[start:start+batch_size]
        y_batch = y_train[start:start+batch_size]
        # forward pass
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        # update weights
        optimizer.step()

# Save model
torch.save(model.state_dict(), "iris-model.pth")

In [8]:
#@title And the following is how to load the model from disk and run it for inference:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# convert NumPy array into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

# Create new model and load states
model = Multiclass()
with open("iris-model.pickle", "rb") as fp:
    model.load_state_dict(pickle.load(fp))

# Run model for inference
y_pred = model(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Accuracy: 1.00


## Further Readings

This section provides more resources on the topic if you are looking to go deeper.

- [Saving and loading models](https://pytorch.org/tutorials/beginner/saving_loading_models.html) from PyTorch tutorial

## Summary

In this post, you learned how to keep a copy of your trained PyTorch model in disk and how to reuse it. In particular, you learned

- What are parameters and states in a PyTorch model
- How to save all necessary states from a model to disk
- How to rebuild a working model from the saved states

## Get Started on Deep Learning with PyTorch!

[![Deep Learning with PyTorch](https://machinelearningmastery.com/wp-content/uploads/2023/03/DLWPT-220.jpg)](https://machinelearningmastery.com/deep-learning-with-pytorch/)

#### Learn how to build deep learning models

...using the newly released PyTorch 2.0 library

Discover how in my new Ebook:  
[Deep Learning with PyTorch](https://machinelearningmastery.com/deep-learning-with-pytorch/)

It provides **self-study tutorials** with **hundreds of working code** to turn you from a novice to expert. It equips you with  
_tensor operation_, _training_, _evaluation_, _hyperparameter optimization_, and much more...

#### Kick-start your deep learning journey with hands-on exercises

  

[See What's Inside](https://machinelearningmastery.com/deep-learning-with-pytorch/)