# Torch jit trace usage

Author Diluka H.

Torch jit trace can be used to save a trained model to a file "model.pt" and then load it and retrain it without original code, as long as there are no switch or if statements in the model class.

1. Training the model.
2. Save the model.
3. Load the trained model. This section is standalone and can be run by itself even if you restart the jupyter kernel indepened of the other sections.



# Training a model

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
from torch.utils.data import DataLoader, TensorDataset

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5) 
        self.conv2 = nn.Conv2d(6, 16, 5) 
        # an affine operation: y = Wx + b
        #How calculate inputs to linear layer https://datascience.stackexchange.com/questions/40906/determining-size-of-fc-layer-after-conv-layer-in-pytorch
        self.fc1 = nn.Linear(13456, 120)  # 5*5 from image dimension # orginal was 16 * 5 * 5 for a 32x32 image
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 8)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))#128 -5 = 123 6*123*123 # after maxpool
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        #print(x.shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

epochs = 50
batchSize = 10
#Load dataset
tensor_x = torch.tensor(np.load('data.npy')).unsqueeze(1)
tensor_y = torch.tensor(np.load('currents.npy'))
#print(tensor_x.shape)
#print(tensor_y.shape)
dataset = TensorDataset(tensor_x,tensor_y) # create your datset
dataLoader = DataLoader(dataset,batch_size=batchSize, shuffle=True)

#for x,y in dataLoader:
#    print(x)
#    print(y)
#    break

model = Net()
print(model)
model.float()

# create your optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.SmoothL1Loss()

model.train()
lossHistory = []
for epoch in range(epochs):
    lossTotal = 0
    for x,y in dataLoader:
        model.zero_grad()                
        yhat= model(x.float())
        loss = criterion(yhat.view(-1),y.float().view(-1))
        loss.backward()
        optimizer.step()
        
        lossTotal +=loss
    lossHistory.append(lossTotal.detach().numpy())
    #https://stackoverflow.com/questions/63582590/why-do-we-call-detach-before-calling-numpy-on-a-pytorch-tensor
    print("Epoch: ",epoch, "Loss: ",lossTotal.item())
        
#plt.plot(lossHistory)
#plt.title('Loss')
#plt.xlabel('Epoch')
#plt.ylabel('Loss')

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=13456, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=8, bias=True)
)
Epoch:  0 Loss:  2.3040506839752197
Epoch:  1 Loss:  1.9797697067260742
Epoch:  2 Loss:  1.6775538921356201
Epoch:  3 Loss:  1.5693578720092773
Epoch:  4 Loss:  1.4593698978424072
Epoch:  5 Loss:  1.3346567153930664
Epoch:  6 Loss:  1.3202340602874756
Epoch:  7 Loss:  1.321162223815918
Epoch:  8 Loss:  1.2872483730316162
Epoch:  9 Loss:  1.2617595195770264
Epoch:  10 Loss:  1.190783143043518
Epoch:  11 Loss:  1.2801188230514526
Epoch:  12 Loss:  1.2691946029663086
Epoch:  13 Loss:  1.2173444032669067
Epoch:  14 Loss:  1.204860806465149
Epoch:  15 Loss:  1.284468412399292
Epoch:  16 Loss:  1.2324779033660889
Epoch:  17 Loss:  1.1393040418624878
Epoch:  18 Loss:  1.1703532934188843
E

# Saving a model

In [2]:
#Save Model
saveModelPath = "NN_Model_Checkpoint/model.pth"
with torch.no_grad():
    print(model(x.float()))
    traced_cell = torch.jit.trace(model, (x.float()))
torch.jit.save(traced_cell, saveModelPath)

tensor([[-0.0372,  0.3695,  0.9205,  0.3139,  1.1918,  0.4809,  0.2466, -0.0688],
        [ 0.2344,  1.1551,  0.2446,  0.0195,  0.2172,  0.3406,  1.3210,  0.3527]])


# Load and train a model

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
from torch.utils.data import DataLoader, TensorDataset

#Loading the model
savedModelPath = "NN_Model_Checkpoint/model.pth"
jitModel = torch.jit.load(savedModelPath);

#Training the loaded model
epochs = 50
batchSize = 10
#Load dataset
tensor_x = torch.tensor(np.load('data.npy')).unsqueeze(1)
tensor_y = torch.tensor(np.load('currents.npy'))

dataset = TensorDataset(tensor_x,tensor_y) # create your datset
dataLoader = DataLoader(dataset,batch_size=batchSize, shuffle=True)

model = jitModel
print(model)
model.float()

# create your optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.SmoothL1Loss()

model.train()
lossHistory = []
for epoch in range(epochs):
    lossTotal = 0
    for x,y in dataLoader:
        model.zero_grad()                
        yhat= model(x.float())
        loss = criterion(yhat.view(-1),y.float().view(-1))
        loss.backward()
        optimizer.step()
        
        lossTotal +=loss
    lossHistory.append(lossTotal.detach().numpy())
    #https://stackoverflow.com/questions/63582590/why-do-we-call-detach-before-calling-numpy-on-a-pytorch-tensor
    print("Epoch: ",epoch, "Loss: ",lossTotal.item())

RecursiveScriptModule(
  original_name=Net
  (conv1): RecursiveScriptModule(original_name=Conv2d)
  (conv2): RecursiveScriptModule(original_name=Conv2d)
  (fc1): RecursiveScriptModule(original_name=Linear)
  (fc2): RecursiveScriptModule(original_name=Linear)
  (fc3): RecursiveScriptModule(original_name=Linear)
)
Epoch:  0 Loss:  0.6976651549339294
Epoch:  1 Loss:  0.6960465312004089
Epoch:  2 Loss:  0.6697688102722168
Epoch:  3 Loss:  0.6618806719779968
Epoch:  4 Loss:  0.5896336436271667
Epoch:  5 Loss:  0.5941463112831116
Epoch:  6 Loss:  0.5925495624542236
Epoch:  7 Loss:  0.5897407531738281
Epoch:  8 Loss:  0.545324444770813
Epoch:  9 Loss:  0.6062442064285278
Epoch:  10 Loss:  0.5815252661705017
Epoch:  11 Loss:  0.5069475173950195
Epoch:  12 Loss:  0.5409581065177917
Epoch:  13 Loss:  0.57889723777771
Epoch:  14 Loss:  0.4760935306549072
Epoch:  15 Loss:  0.44365933537483215
Epoch:  16 Loss:  0.38971441984176636
Epoch:  17 Loss:  0.39124596118927
Epoch:  18 Loss:  0.4424299895763