# PyTorch saving/loading demo

Say that you build and train a model using PyTorch for your project.  The training may take a long time.  The TA will want to run your model *without* retraining it.  This is possible if the model is properly saved.  We will now demonstrate how to save and load a PyTorch model.

In [1]:
import numpy as np
import torch
import torch.nn as nn

First, we will create a very simple PyTorch model.

In [2]:
d_in = 30
d_out = 20

class DumbNet(nn.Module):
    def __init__(self):
        super(DumbNet, self).__init__()
        self.Dense = nn.Linear(d_in,d_out)
    def forward(self,x):
        out = self.Dense(x)
        return out

model = DumbNet()

Usually, we would train the model at this point.  But, since this is only a saving demo, we will just use the random initial coefficients that the model was initialized with. 

To save a PyTorch model, we take an input, pass it through the model, and save the "trace".  For this purpose, we can use any input.  We will create a random input with the proper dimension.

In [3]:
x = torch.randn(d_in) # random input
x = x[None,:] # add singleton batch index

with torch.no_grad():
    traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "./saved_model.pth")

Let's check to see whether the save worked.  First we input the signal x to the original model, 

In [4]:
# original model:
with torch.no_grad():
    out = model(x)
print(out)

tensor([[ 1.1029,  0.2097, -0.9790,  0.7794,  0.0414, -0.0873, -0.3453,  0.4293,
          0.1349, -0.7586, -0.2868,  0.5888,  0.3675, -0.0620, -0.1760, -0.7405,
          0.3168,  0.1510, -0.5165,  0.2528]])


Now we also input the signal to the reloaded model.

In [5]:
# reloaded model:
model2 = torch.jit.load("./saved_model.pth")
with torch.no_grad():
    out2 = model2(x)
print(out2)

tensor([[ 1.1029,  0.2097, -0.9790,  0.7794,  0.0414, -0.0873, -0.3453,  0.4293,
          0.1349, -0.7586, -0.2868,  0.5888,  0.3675, -0.0620, -0.1760, -0.7405,
          0.3168,  0.1510, -0.5165,  0.2528]])


Since the outputs are identical, the save worked.