# Saving and Loading Models

Training neural networks takes a long time, so it would be a good idea to regularly save your progress. In this notebooks we will go over how we can save and load our models and training progress. Lets do our usual imports and define a model for us to save.

In [1]:
# standard libraries
import math, os, time, glob
import numpy as np

# plotting
import matplotlib.pyplot as plt

# progress bars
from tqdm.notebook import trange, tqdm

# PyTorch
import torch
import torch.nn as nn # lets not write out torch.nn every time

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
hidden = 32 # how many hidden units we want

network = torch.nn.Sequential(
    torch.nn.Linear(1, hidden), # = x A^T + b
    torch.nn.ReLU(), # = ReLU(x A^T + b)
    torch.nn.Linear(hidden, 1, bias=False), # = ReLU(x A^T + b) C^T
)

network = network.to(device)

## Saving and Loading a whole Model

In order to save an entire model we write it to a file as follows.

In [3]:
torch.save(network, 'model.pth')

And then restore it again.

In [4]:
network2 = torch.load('model.pth')

network2

Sequential(
  (0): Linear(in_features=1, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=1, bias=False)
)

The disadvantage of the above approach is that it depends on the current environment having the exact same definition of all the modules used when saving the model. As long as you stick to this environment with its specific version of PyTorch and do not define your own modules saving the entire model is fine. In more advanced scenarios it is recommended to only save your model's parameters, more details are available at <https://pytorch.org/tutorials/beginner/saving_loading_models.html>.

## Checkpointing

During training we might not want to just save the model or its parameters but also anciliar information like what epoch we are in, the state of the optimizer, etc. In the example below we save the model and optimizer parameters in file tagged with a timestamp.



In [5]:
def timestamp():
    t = time.localtime()
    return time.strftime('%Y_%b_%d_%H_%M_%S', t)

In [6]:
# Lets set up our training process in the usual manner

X = torch.arange(0., 1., step=0.01, device=device)
Y = 0.6 * torch.sin(6*X) * torch.sin(3*X+1) + 0.25
X = X.unsqueeze(-1)
Y = Y.unsqueeze(-1)


dataset = torch.utils.data.TensorDataset(X, Y)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])

EPOCHS = 1000
BATCH_SIZE = 8
LEARNING_RATE = 0.005

loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

optimizer = torch.optim.SGD(network.parameters(), lr=LEARNING_RATE)

In [7]:
# Now we can start training
for epoch in trange(1, EPOCHS+1):
    
    network.train() 
    
    for x, y in iter(loader):
        optimizer.zero_grad() 
        prediction = network(x) 
        loss = (y - prediction).pow(2).sum() 
        loss.backward() 
        optimizer.step() 
    
    # Save every few epochs
    if epoch % 200 == 0 or epoch == EPOCHS:
        torch.save({
            'epoch': epoch,
            'model_state_dict': network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f"check_{timestamp()}.pth")

  0%|          | 0/1000 [00:00<?, ?it/s]

We can list the checkpoint files we have written to disk using `glob`. Calling `glob` gives us a list of files that match the desired filter, we then sort that list by name using `sorted`.

In [8]:
sorted(glob.glob('check_*.pth')) # glob yields a list of the

['check_2021_Sep_19_17_24_39.pth',
 'check_2021_Sep_19_17_24_40.pth',
 'check_2021_Sep_19_17_24_41.pth',
 'check_2021_Sep_19_17_24_42.pth',
 'check_2021_Sep_19_17_24_44.pth']

We do not want to store too many checkpoint files so we will delete all except the 3 newest.

In [9]:
checkpoints = sorted(glob.glob('check_*.pth'))
for file in checkpoints[:-3]:
    os.remove(file)

Now let us load the newest checkpoint and restore our model and optimizer to their previous state.

In [10]:
checkpoint = torch.load(checkpoints[-1])

network.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
last_epoch = checkpoint['epoch']
print(f"Restored from epoch {last_epoch}")

Restored from epoch 1000


And we can continue training from where we left off.

In [11]:
for epoch in trange(last_epoch, last_epoch+EPOCHS):
    
    network.train() 
    
    for x, y in iter(loader):
        optimizer.zero_grad() 
        prediction = network(x) 
        loss = (y - prediction).pow(2).sum() 
        loss.backward() 
        optimizer.step() 
    
    # Save every few epochs
    if epoch % 200 == 0 or epoch == EPOCHS:
        torch.save({
            'epoch': epoch,
            'model_state_dict': network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f"check_{timestamp()}.pth")
        

# clean up checkpoint files
checkpoints = sorted(glob.glob('check_*.pth'))
for file in checkpoints[:-3]:
    os.remove(file)

  0%|          | 0/1000 [00:00<?, ?it/s]