# Saving and Loading Models
PyTorch provides several methods for saving and loading models.

This Demo will cover several methods using an example Model. 

#### Functions for Saving and Loading 
`torch.save()`: Save PyTorch objects (models, tensors, dictionaries, etc...) using Pythons pickle module.

`torch.load()`: Loads PyTorch objects into memory.

`load_state_dict()`: Loads saved parameters from objects. 

In [76]:
# Example Fake Model
import torch.nn as nn
import torch.nn.functional as F

class FakeNet(nn.Module):
    def __init__(self):
        super(FakeNet, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.batch_norm = nn.BatchNorm1d(50) 
        self.fc2 = nn.Linear(50, 1)        

    def forward(self, x):
        x = F.relu(self.fc1(x))              
        x = self.batch_norm(x)               
        x = self.fc2(x)                      
        return x


In [None]:
# Create our model
model = FakeNet()
print(model)

In [78]:
# Create a fake dataset
import torch
from torch.utils.data import Dataset, DataLoader

class FakeDataset(Dataset):
    def __init__(self, num_samples=1000):
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random input data with 10 features
        x = torch.randn(10)
        # Generate a random target value
        y = torch.randn(1)
        return x, y



# Create a dataset and data loader
dataset = FakeDataset(num_samples=1000)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [79]:
# create loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [80]:
# Train a fake model
N_EPOCHS = 5

for epoch in range(N_EPOCHS):
    running_loss = 0.0
    for i, (inputs, targets) in enumerate(data_loader):
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Accumulate loss
        running_loss += loss.item()


# Saving and Loading using `state_dict`
`state_dict` is a dictionary that stores all the learnable parameters of a model, like weights and biases as well as hyperparameters of an Optimizer. This makes it easy to save, load, and transfer the model’s parameters, allowing flexible model saving and reloading across different environments.

In [None]:
# Print the state_dict of the model
print(model.state_dict())

In [None]:
# Print the paramters of each layer
for k, v in model.state_dict().items():
    print(f"Layer Name: {k} Parameters:{v.size()}")

In [None]:
# Print the hyperparameters of the Optimizer
print(optimizer.state_dict())

In [84]:
# Save the state_dict for each (recommended approach)
import torch

torch.save(model.state_dict(), "model_state_dict.pt") # .pt or pth extension for models

In [85]:
# Save the state_dict for the Optimizer
torch.save(optimizer.state_dict(), "optimizer")

In [86]:
# NOTE: state_dict is ONLY saving the parameters!!!

### Model Inference
REVIEW: Inference is the process of using a trained model to make predictions.

Let's load a model using using its state_dict and prepare it for inference. 

In [None]:
# Create a new model
new_model = FakeNet()
print(new_model)

In [None]:
# Show the current state_dict
for k, v in new_model.state_dict().items():
    print(f"Layer Name: {k} Parameters:{v}")

In [None]:
# Load the parameters into our model
new_model.load_state_dict(torch.load("model_state_dict.pt", weights_only=True)) # ONLY the parameters!

In [None]:
# Print it again to show the difference
for k, v in new_model.state_dict().items():
    print(f"Layer Name: {k} Parameters:{v}")

In [91]:
# The parameters have been updated after loaded!

In [None]:
# Create example input
import torch
# Random batch size of 1-10 features
sample_input = torch.randn(1, 10)
print(sample_input)

In [None]:
# Lets do an example infernce on our model
new_model.eval()

# Call the model with input to get a prediction
output = new_model(sample_input)
print(output)

# Saving and Loading entire Model
PyTorch provides the option to save a full model to the filesystem as well.

full model = full python pickle version of model

This can potentially cause issues because it relies on the exact class definitions and file structure from when the model was saved, so loading may fail if used in a different project or after code changes.

In [94]:
# To save a full model
import torch

torch.save(model, "model_full.pt")

In [None]:
# Import the model class from file
from fake_net import FakeNet

# Initialize and use the model
model = FakeNet()
print(model)

In [96]:
# Try again
torch.save(model, "model_full.pt")

In [None]:
# Look at the size difference
!ls -lh model*

In [None]:
# Load a full model
from fake_net import FakeNet

# Initialize and use the model
new_model = FakeNet()
print(new_model)

In [None]:
# Load it from the full model
new_model = torch.load("model_full.pt", weights_only=False) # More than just the parameters
print(new_model)

In [None]:
# Check inference
new_model.eval()

# Call the model with input to get a prediction
output = new_model(sample_input)
print(output)

# Saving and Loading a Checkpoint
A model checkpoint is a way to save parameters as a snapshot in a point in time. 

This is helpful to continue a long training job that may have failed at some point or to give multiple models as options to use from a training run.

In [101]:
# Save a checkpoint
import torch

# dummy epoch and loss
epoch = 5
loss = 0.05

In [102]:
# Save a checkpoint
torch.save({'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss}, 
            f'{epoch}_checkpoint.tar') # .tar file 

In [None]:
# Load the checkpoint
# Initialize the Model as we have before. NOTE: also optimizer in our case
from fake_net import FakeNet

# Initialize and use the model
model = FakeNet()
print(model)

In [104]:
# Load the model as a checkpoint
import torch

# Load the tar file
checkpoint = torch.load(f"{epoch}_checkpoint.tar", weights_only='true')

In [None]:
# Show the checkpoint info
print(checkpoint)

In [None]:
# Load the parameters to our model
model.load_state_dict(checkpoint['model_state_dict']) 

In [107]:
# Load the optimizer
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Load the loss and the epoch. NOTE that we could have save other information here as well
loss = checkpoint['loss']
epoch = checkpoint['epoch']
print(loss, epoch)

In [None]:
# Test inference
model.eval()
output = model(sample_input)
print(output)

# Adding Checkpoints to Training
Its good practice to include checkpoints as part of your training loop.

How you save checkpoints is up to you. ie: every so often, every epoch, every epoch which improves on loss.

In [110]:
# Lets include a checkpoint in our training loop that saves a checkpoint every 2 epochs
# Train a fake model
N_EPOCHS = 10

for epoch in range(N_EPOCHS):
    running_loss = 0.0
    for i, (inputs, targets) in enumerate(data_loader):
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Accumulate loss
        running_loss += loss.item()

    ######### Save a checkpoint every 2 epochs
    if epoch % 2 == 0:
        torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss}, 
                f'training_checkpoint_{epoch}.tar')

# Save the final checkpoint after the last epoch
torch.save({
    'epoch': N_EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
}, 'training_checkpoint_final.tar')

In [None]:
# List all the checkpoints
!ls -l training_checkpoint*

NOTE: We can now load any of these checkpoints to either continue training from that point in time or run inference. 

# Warmstarting
Warmstarting is where we initialize a new model to train from trained parameters of a previously trained model.

This is helpful in Transfer Learning which is covered in more detail later.

With warmstarting we can also initialize only certain layers of a previously trained model.

In [112]:
# Example Fake Model
import torch.nn as nn
import torch.nn.functional as F

class FakeNet(nn.Module):
    def __init__(self):
        super(FakeNet, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.batch_norm = nn.BatchNorm1d(50) 
        self.fc2 = nn.Linear(50, 1)        

    def forward(self, x):
        x = F.relu(self.fc1(x))              
        x = self.batch_norm(x)               
        x = self.fc2(x)                      
        return x

In [None]:
# create new model
new_model = FakeNet()
new_model

In [None]:
# Show the parameters
new_model.state_dict()

In [None]:
# Load our very first trained model parameters into the new one
new_model.load_state_dict(torch.load('model_state_dict.pt'), strict=False)

In [None]:
# Show the new parameters
print(new_model.state_dict())

We would now take the parameters we just added into this model and train it!

# Saving and Loading Across Devices
PyTorch supports multiple different devices such as CPU and GPUs.

Its common practice to train on a GPU for speed but do inference on a CPU for cost for example.

In [117]:
# Load a model on CPU that was saved on GPU
import torch

model = torch.load('model_state_dict.pt', map_location='cpu', weights_only=True) # Using map_location

In [None]:
# CPU to GPU and GPU to GPU
model = torch.load('model_state_dict.pt', map_location='cuda:0', weights_only=True) # Using map_location to the GPU device

In addition to above we must also put the model on the GPU:

```py
model.to('cuda')
```

As well as the inputs for inference.
```py
model.eval()
outputs = model(sample_input.to('cuda'))
```

In [None]:
# Remember our is_available() function to find the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)