# Save and Load a Neural Network Model
In this notebook, we will demonstrate how to persist a model state with `saving`, `loading`, and `running` model predictions.

In [1]:
import torch
import torchvision.models as models

## Saving Model Weights
`PyTorch` stores the learned parameters in an internal state dictionary, called `state_dict`. These can be persisted via the `torch.save` method.

In [8]:
# Load a pre-trained VGG16 model with ImageNet weights
model = models.vgg16(weights='IMAGENET1K_V1')

# Show the model architecture
print(model)

# Print the models' state dictionary
print("\nState Dictionary:")
for key, value in model.state_dict().items():
    print(f"{key}: {value.size()}")

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [5]:
# Save the model's state dictionary to a file
torch.save(model.state_dict(), 'out/vgg16_weights.pth')

## Loading Model Weights
To load model weights, we need to create a model instance first, and then load the state dictionary into the model using the `load_state_dict` method.

In [None]:
model = models.vgg16()  # Initialize a new model instance

model.load_state_dict(torch.load('out/vgg16_weights.pth',
                                 weights_only=True  # Explicitly load only the weights
                    ))  # Load the saved weights

# Be sure to set the model to evaluation mode before inference to set the dropout and batch normalization layers to evaluation mode
model.eval()  # Set the model to evaluation mode

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

## Saving and Loading the Entire Model
When loading model weights, we needed to instantiate the class first, since the class defines the structure of a Network.

In case we want to save the entire model, i.e., the structure of the network together with the learned parameters, we can pass the `model` object instead of the `state_dict` to the `torch.save` method. This will save the entire model, including its architecture and parameters.

In [10]:
torch.save(model, 'out/vgg16_model.pth')  # Save the entire model

In [12]:
# Load the entire model
model = torch.load('out/vgg16_model.pth', weights_only=False)  # Load the entire model

In [13]:
# Print the loaded model's architecture
print("\nLoaded Model Architecture:")
print(model)
# Print the loaded model's state dictionary
print("\nLoaded Model State Dictionary:")
for key, value in model.state_dict().items():
    print(f"{key}: {value.size()}")



Loaded Model Architecture:
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, str