### TorchScript and Torch.JIT': Optimizing Inference Time on Trained Models

**TorchScript takes in an eager-mode PyTorch module and compiles its code into TorchScript. This format can run on non-Python environments by itself (including C++). Torch.jit(just-in-time) compiler used for generating this code optimizes the model through layer-fusion, quantization and sparsification etc. It is not meant for training, rather, use it for converting your trained model into an optimized equivalent that performs faster for inference.** 

&nbsp;
Here is a good TowardsDataScience articles on this:

&nbsp; 
https://towardsdatascience.com/pytorch-jit-and-torchscript-c2a77bac0fff

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchsummary import summary
from torchvision import datasets, transforms
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
import time

class LeNet5(nn.Module):
    def __init__(self, num_classes, **kwargs):
        super(LeNet5, self).__init__(**kwargs)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=(4,4))
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=480, out_features=84)
        self.fc2 = nn.Linear(in_features=84, out_features=10)


    def forward(self, X):
        h = F.max_pool2d(torch.tanh(self.conv1(X)), kernel_size =(2,2))
        h = F.max_pool2d(torch.tanh(self.conv2(h)), kernel_size =(2,2))
        h = self.flatten(torch.tanh(self.conv3(h)))
        h = torch.tanh(self.fc1(h))
        h = self.fc2(h)
        return h



In [2]:
#Data pipeline.

train_dataset = datasets.MNIST(
    root = 'data/MNIST',
    train=True,
    download=True,
    transform=transforms.ToTensor()

)

test_dataset = datasets.MNIST(
    root='mnist_data',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle = True,
    drop_last=True,
    num_workers=8
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    drop_last=True,
    num_workers=8
)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw



In [37]:
#Training loop.
def train(model, epochs):  
    for epoch in range(epochs):
        start = time.time()
        train_loss, tr_correct_preds = 0, 0
        val_loss, tst_correct_preds = 0, 0

        for (train_X, train_y) in train_dataloader:
            train_X, train_y = train_X.to(device), train_y.to(device)
            train_preds = lenet5(train_X)
            loss = loss_fn(train_preds, train_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            tr_correct_preds += (train_preds.argmax(dim=1) == train_y).sum().item()

        for (val_X, val_y) in test_dataloader:
            with torch.no_grad():
                val_X, val_y = val_X.to(device), val_y.to(device)
                val_preds = lenet5(val_X)
                val_loss += loss_fn(val_preds, val_y).item()
                tst_correct_preds += (val_preds.argmax(dim=1) == val_y).sum().item()
        
        end = time.time()
        
        num_train_steps = len(train_dataloader.dataset) // batch_size
        num_val_steps = len(test_dataloader.dataset) // batch_size
        
        print('Epoch {}: \nTrain Loss:{}, Train acc: {}, Val loss: {}, Val acc:{},\n'
              'Correct Training Samples: {}, Correct Validation Samples: {} \nEpoch took: {} seconds.\n'
              .format(epoch + 1,
                      train_loss/num_train_steps,
                      tr_correct_preds/ (num_train_steps * batch_size),
                      val_loss/num_val_steps,
                      tst_correct_preds/ (num_val_steps * batch_size),
                      tr_correct_preds, tst_correct_preds,
                      (end - start)))

In [38]:
#Hyperparameters and other configuration.
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
batch_size = 64
epochs = 5
optimizer = torch.optim.SGD(lenet5.parameters(), lr=learning_rate,
                            momentum=0.9, nesterov=True)

In [53]:
#Initialize the dynamic model. 
lenet5_dynamic = LeNet5(10).to(device)
train(lenet5_dynamic, 30)

Epoch 1: 
Train Loss:0.011836165772801168, Train acc: 0.9974319637139808, Val loss: 0.030782481208053724, Val acc:0.9896834935897436,
Correct Training Samples: 59814, Correct Validation Samples: 9881 
Epoch took: 4.035372495651245 seconds.

Epoch 2: 
Train Loss:0.011740395764460223, Train acc: 0.9975653681963714, Val loss: 0.030068295101675455, Val acc:0.9895833333333334,
Correct Training Samples: 59822, Correct Validation Samples: 9880 
Epoch took: 3.941103935241699 seconds.

Epoch 3: 
Train Loss:0.01139823569482085, Train acc: 0.9976654215581644, Val loss: 0.030282602170000946, Val acc:0.9893830128205128,
Correct Training Samples: 59828, Correct Validation Samples: 9878 
Epoch took: 4.0197389125823975 seconds.

Epoch 4: 
Train Loss:0.011241211394320681, Train acc: 0.9976987726787621, Val loss: 0.03065123837908546, Val acc:0.9894831730769231,
Correct Training Samples: 59830, Correct Validation Samples: 9879 
Epoch took: 3.904510498046875 seconds.

Epoch 5: 
Train Loss:0.01104152475728

### Once the model is trained, one can compile it with JIT and save it like this:

In [57]:
lenet5_optimized = torch.jit.script(lenet5_dynamic) # this is way 1, does not require dummy input.


'''
Another alternative is to run torch.jit.trace(model, dummy_input_tensor)
Trace does the same thing, but .script seems to be more robust to control flow and type handling. 
Also .script() does not require dummy input. 
If you use .trace() make sure to send the dummy_input tensor to the same device as the model.
'''

print(type(lenet5_optimized))
#save the optimized script into .pth file. 
lenet5_optimized.save('lenet5_compute_graph.pth')

<class 'torch.jit._script.RecursiveScriptModule'>


### Then one can load the model back on from .pth file like this:

In [70]:
lenet5_loaded = torch.jit.load('lenet5_compute_graph.pth')
lenet5_loaded = lenet5_loaded.to(device) # benchmark on GPU.
print(type(lenet5_loaded))

<class 'torch.jit._script.RecursiveScriptModule'>


### Benchmarking inference speed of the eager mode and static mode.

In [79]:
def inference_time_benchmark(model, runs, batch_size):
    input_batch = torch.rand(size=[batch_size, 1, 28, 28]).cuda() #dummy MNIST batch for benchmarking.
    total_time = 0
    start = time.time()
    for i in range(runs):
        _ = model(input_batch)
        total_time += time.time() - start
    
    return total_time / runs
    


In [83]:
dynamic_time = inference_time_benchmark(lenet5_dynamic, 10000, 64)
static_time = inference_time_benchmark(lenet5_loaded, 10000, 64)

print(static_time / dynamic_time)

0.9165674172903597


### Voila! 8.4% speedup over eager mode. Remember this when deploying a trained model.