# Learning Objectives

Based on the previous step, we'll learn how to save/load the models for resuming training or for applying on data.

_Note: we have packages some of the training functions in the `helper.py` module for simplifying the notebook._

### Learning Objectives

- save and load a model to resume interrupted training
- save and load a model for using in deployment

### Requirements

To benefit from this content, it is preferable to know:
- how to train a simple model (see step 03)

In [1]:
import torch
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from helpers import BasicMultiClassTargetDataset
from helpers import BasicNeuralNet
from helpers import BasicModelTrainer

# 1. Training a model (iris again)

This whole section is already known. We'll do the whole IRIS basic neural net thing again. This time we've packaged all steps into a `helper.py` module to get rid of the usual lines of codes.

We will just:
- load the iris data from scikit-learn
- package it in a torch `DatasSet`
- create a `Module` class for our model
- execute a training loop using autograd.

In [2]:
data = load_iris()

np.random.seed(481516)  # just for this notebook to be consistent between runs
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.33)

In [3]:
# see class in helper.py
iris_training_dataset = BasicMultiClassTargetDataset(X_train, y_train, 3) # class_count=3
iris_testing_dataset = BasicMultiClassTargetDataset(X_test, y_test, 3)    # class_count=3

In [4]:
model = BasicNeuralNet(
    4,  # input has size 4 (attributes)
    3,  # output has size 3 (one-hot, 3 classes)
    6   # hidden layer (param)
)

We'll just apply SGD with a specific criterion (MSELoss). SGD is initialized on the `parameters` of the model instance.

In [5]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
criterion = torch.nn.MSELoss()

# this is a helper class just executing the usual loop (see step 03)
trainer = BasicModelTrainer(
    model,
    optimizer,
    criterion,
    verbose=True
)

epochs=500

# executing the training
model, loss = trainer.fit(
    iris_training_dataset,
    epochs=epochs,  # just for trying
    batch_size=10
)

batch_loss=0.303242	 avg_loss=0.026988	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=0]	 epoch_loss=0.026988
batch_loss=0.299518	 avg_loss=0.025907	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=1]	 epoch_loss=0.025907
batch_loss=0.236603	 avg_loss=0.025199	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=2]	 epoch_loss=0.025199
batch_loss=0.230935	 avg_loss=0.024626	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=3]	 epoch_loss=0.024626
batch_loss=0.213823	 avg_loss=0.024131	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=4]	 epoch_loss=0.024131
batch_loss=0.206273	 avg_loss=0.023643	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=5]	 epoch_loss=0.023643
batch_loss=0.244583	 avg_loss=0.023215	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=6]	 epoch_loss=0.023215
batch_loss=0.226429	 avg_loss=0.022835	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=7]	 epoch_loss=0.022835
batch_loss=0.224443	 avg_loss=0.022557	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=8]	 epoch_loss=0.022557
b

batch_loss=0.099115	 avg_loss=0.013372	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=73]	 epoch_loss=0.013372
batch_loss=0.105574	 avg_loss=0.013278	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=74]	 epoch_loss=0.013278
batch_loss=0.108426	 avg_loss=0.013132	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=75]	 epoch_loss=0.013132
batch_loss=0.125160	 avg_loss=0.013041	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=76]	 epoch_loss=0.013041
batch_loss=0.145483	 avg_loss=0.012885	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=77]	 epoch_loss=0.012885
batch_loss=0.121510	 avg_loss=0.012813	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=78]	 epoch_loss=0.012813
batch_loss=0.152269	 avg_loss=0.012700	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=79]	 epoch_loss=0.012700
batch_loss=0.137863	 avg_loss=0.012609	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=80]	 epoch_loss=0.012609
batch_loss=0.130769	 avg_loss=0.012532	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=81]	 epoch_loss=0

batch_loss=0.095843	 avg_loss=0.009692	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=146]	 epoch_loss=0.009692
batch_loss=0.094493	 avg_loss=0.009637	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=147]	 epoch_loss=0.009637
batch_loss=0.110032	 avg_loss=0.009578	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=148]	 epoch_loss=0.009578
batch_loss=0.087718	 avg_loss=0.009589	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=149]	 epoch_loss=0.009589
batch_loss=0.132563	 avg_loss=0.009566	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=150]	 epoch_loss=0.009566
batch_loss=0.081678	 avg_loss=0.009495	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=151]	 epoch_loss=0.009495
batch_loss=0.104082	 avg_loss=0.009502	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=152]	 epoch_loss=0.009502
batch_loss=0.092372	 avg_loss=0.009464	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=153]	 epoch_loss=0.009464
batch_loss=0.083066	 avg_loss=0.009388	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=154]	 epo

batch_loss=0.094702	 avg_loss=0.007251	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=218]	 epoch_loss=0.007251
batch_loss=0.054635	 avg_loss=0.007273	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=219]	 epoch_loss=0.007273
batch_loss=0.071178	 avg_loss=0.007276	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=220]	 epoch_loss=0.007276
batch_loss=0.070639	 avg_loss=0.007180	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=221]	 epoch_loss=0.007180
batch_loss=0.080030	 avg_loss=0.007103	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=222]	 epoch_loss=0.007103
batch_loss=0.093507	 avg_loss=0.007161	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=223]	 epoch_loss=0.007161
batch_loss=0.088336	 avg_loss=0.007076	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=224]	 epoch_loss=0.007076
batch_loss=0.091850	 avg_loss=0.007145	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=225]	 epoch_loss=0.007145
batch_loss=0.094682	 avg_loss=0.007022	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=226]	 epo

batch_loss=0.039875	 avg_loss=0.005433	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=290]	 epoch_loss=0.005433
batch_loss=0.059270	 avg_loss=0.005229	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=291]	 epoch_loss=0.005229
batch_loss=0.034210	 avg_loss=0.005220	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=292]	 epoch_loss=0.005220
batch_loss=0.039790	 avg_loss=0.005050	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=293]	 epoch_loss=0.005050
batch_loss=0.048305	 avg_loss=0.005065	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=294]	 epoch_loss=0.005065
batch_loss=0.032187	 avg_loss=0.005185	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=295]	 epoch_loss=0.005185
batch_loss=0.073851	 avg_loss=0.005080	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=296]	 epoch_loss=0.005080
batch_loss=0.061647	 avg_loss=0.005046	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=297]	 epoch_loss=0.005046
batch_loss=0.030993	 avg_loss=0.005045	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=298]	 epo

batch_loss=0.016470	 avg_loss=0.003754	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=362]	 epoch_loss=0.003754
batch_loss=0.034155	 avg_loss=0.003843	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=363]	 epoch_loss=0.003843
batch_loss=0.058577	 avg_loss=0.003941	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=364]	 epoch_loss=0.003941
batch_loss=0.035034	 avg_loss=0.003953	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=365]	 epoch_loss=0.003953
batch_loss=0.028039	 avg_loss=0.003788	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=366]	 epoch_loss=0.003788
batch_loss=0.028894	 avg_loss=0.003850	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=367]	 epoch_loss=0.003850
batch_loss=0.047196	 avg_loss=0.003823	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=368]	 epoch_loss=0.003823
batch_loss=0.043864	 avg_loss=0.003895	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=369]	 epoch_loss=0.003895
batch_loss=0.046011	 avg_loss=0.003753	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=370]	 epo

batch_loss=0.052184	 avg_loss=0.003124	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=434]	 epoch_loss=0.003124
batch_loss=0.023898	 avg_loss=0.003215	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=435]	 epoch_loss=0.003215
batch_loss=0.024430	 avg_loss=0.003191	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=436]	 epoch_loss=0.003191
batch_loss=0.024844	 avg_loss=0.003101	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=437]	 epoch_loss=0.003101
batch_loss=0.035915	 avg_loss=0.003109	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=438]	 epoch_loss=0.003109
batch_loss=0.044338	 avg_loss=0.003119	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=439]	 epoch_loss=0.003119
batch_loss=0.039690	 avg_loss=0.003080	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=440]	 epoch_loss=0.003080
batch_loss=0.023983	 avg_loss=0.003093	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=441]	 epoch_loss=0.003093
batch_loss=0.032014	 avg_loss=0.003043	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=442]	 epo

In [6]:
print("Accuracy: {}%".format(
    trainer.test_accuracy(iris_testing_dataset)
))

Accuracy: 100.0%


# 2. Saving the model

See [pytorch tutorial on saving and loading models](https://pytorch.org/tutorials/beginner/saving_loading_models.html).
> When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s `state_dict`. It is important to also save the optimizer’s `state_dict`, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external `torch.nn.Embedding` layers, etc.

In [7]:
model_file_path = "models/step-04-model-state-epoch{}-loss{:2f}.tar".format(epochs, loss)

torch.save(
    {
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    },
    model_file_path
)

print("saved as {}".format(model_file_path))

saved as models/step-04-model-state-epoch500-loss0.002800.tar


# 3. Loading the model for resuming training

> To load the items, first initialize the model and optimizer, then load the dictionary locally using `torch.load()`. From here, you can easily access the saved items by simply querying the dictionary as you would expect.

In [8]:
model_2 = BasicNeuralNet(
    4,  # input has size 4 (attributes)
    3,  # output has size 3 (one-hot, 3 classes)
    6   # hidden layer (param)
)

optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.01)
criterion_2 = torch.nn.MSELoss()

# comment/uncomment below to use your own saved model
#checkpoint = torch.load(model_file_path)

# or simply use the demo
checkpoint_2 = torch.load("models/step-04-demo-model-state-epoch500-loss0.238621.tar")

# this loads the state dict into the model and optimizer
model_2.load_state_dict(checkpoint_2['model_state_dict'])
optimizer_2.load_state_dict(checkpoint_2['optimizer_state_dict'])

restart_epoch = checkpoint_2['epoch']
loss_2 = checkpoint_2['loss']

> Remember that you must call `model.eval()` to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call `model.train()` to ensure these layers are in training mode.

In [9]:
# use .eval() when loading the model for inference (production)
#model_loaded.eval()

# use .train() when loading the model for training more (interrupted training?)
model_2.train()

BasicNeuralNet(
  (x_to_z): Linear(in_features=4, out_features=6, bias=True)
  (z_to_h): Sigmoid()
  (h_to_s): Linear(in_features=6, out_features=3, bias=True)
  (s_to_y): Softmax(dim=1)
)

We can now resume the training...

In [10]:
from helpers import BasicModelTrainer

trainer_2 = BasicModelTrainer(
    model_2,
    optimizer_2,
    criterion_2,
    verbose=True
)

In [11]:
model_3, loss_3 = trainer_2.fit(
    iris_training_dataset,
    epochs=10,
    batch_size=10
)

batch_loss=0.014966	 avg_loss=0.002447	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=0]	 epoch_loss=0.002447
batch_loss=0.039009	 avg_loss=0.002326	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=1]	 epoch_loss=0.002326
batch_loss=0.008579	 avg_loss=0.002361	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=2]	 epoch_loss=0.002361
batch_loss=0.031452	 avg_loss=0.002334	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=3]	 epoch_loss=0.002334
batch_loss=0.023984	 avg_loss=0.002346	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=4]	 epoch_loss=0.002346
batch_loss=0.014880	 avg_loss=0.002414	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=5]	 epoch_loss=0.002414
batch_loss=0.023900	 avg_loss=0.002530	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=6]	 epoch_loss=0.002530
batch_loss=0.015412	 avg_loss=0.002345	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=7]	 epoch_loss=0.002345
batch_loss=0.008640	 avg_loss=0.002381	 epoch_ETA:  0: 0: 0 secs (data=100/100)
[epoch=8]	 epoch_loss=0.002381
b

# 4. Loading a model for inference (production)

When saving/loading a model for using it in production. You only need to save the state_dict of the model. A call to `model.eval()` will make sure the model is initialized properly.

In [12]:
model_4 = BasicNeuralNet(
    4,  # input has size 4 (attributes)
    3,  # output has size 3 (one-hot, 3 classes)
    6   # hidden layer (param)
)

# comment/uncomment below to use your own saved model
#checkpoint_4 = torch.load(model_file_path)

# or simply use the demo
checkpoint_4 = torch.load("models/step-04-demo-model-state-epoch500-loss0.238621.tar")

# this loads the state dict into the model only
model_4.load_state_dict(checkpoint_4['model_state_dict'])

# use .eval() when loading the model for inference (production)
model_4.eval()

BasicNeuralNet(
  (x_to_z): Linear(in_features=4, out_features=6, bias=True)
  (z_to_h): Sigmoid()
  (h_to_s): Linear(in_features=6, out_features=3, bias=True)
  (s_to_y): Softmax(dim=1)
)

We can now use this loaded model for evaluating its accuracy on the testing set...

In [13]:
# batch the testing data as well
iris_testing_loader = torch.utils.data.DataLoader(
    dataset=iris_testing_dataset,
    batch_size=10,
    shuffle=True
)

correct = 0
total = 0

with torch.no_grad():  # deactivate autograd during testing
    for data in iris_testing_loader:  # iterate on batches
        # get testing data batch
        inputs, targets = data
        
        # apply the NN
        outputs = model_4(inputs)                 # compute output class tensor
        predicted = torch.argmax(outputs, dim=1)  # get argmax of P(y_hat|x)
        actual = torch.argmax(targets, dim=1)     # get y

        # compute score
        total += targets.size(0)
        correct += (predicted == actual).sum().item()

print("Accuracy: {:2f}".format(100 * correct / total))

Accuracy: 98.000000
