# 4d. Inference

After having saved the model, we want to use it for inference.

In [71]:
import random
import math
import collections
import math
from typing import Tuple, List

import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
import seaborn as sns
sns.set(style="darkgrid")

from workshop import data
import helper

import time

def get_model():
    return torch.nn.Sequential(collections.OrderedDict([
        ("reshape", torch.nn.Flatten()),
        ("hidden", torch.nn.Linear(28*28,256)),
        ("sigmoid", torch.nn.Sigmoid()),
        ("output", torch.nn.Linear(256,10)),
      ]))


def time_it(f, n, *args):
    # warmup
    f(*args)
    
    # measure
    start = time.time()
    for _ in range(n):
        f(*args)
    return (1000 * (time.time() - start)) / n

# Loading

Load the whole model (architecture, parameters):

In [12]:
model = torch.load("../data/model.pt")
model.eval()  # important because operations like dropout behave differently on inference

Sequential(
  (reshape): Flatten()
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (sigmoid): Sigmoid()
  (output): Linear(in_features=256, out_features=10, bias=True)
)

Only load the model's parameters:

In [15]:
model = get_model()
model.load_state_dict(torch.load("../data/model_params.pt"))
model.eval()

Sequential(
  (reshape): Flatten()
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (sigmoid): Sigmoid()
  (output): Linear(in_features=256, out_features=10, bias=True)
)

# Optimizing for inference

PyTorch builds a dynamic graph. But for most models an optimized static graph can be saved.

PyTorch models can be converted to TorchScript which makes the model more portable (e.g. you can also use it in a C++ program). TorchScript models are run by a faster interpreter.

Tracing will invoke the model with example data and record the operations to build an optimized graph.

For more information, see [TorchScript](https://pytorch.org/docs/stable/jit.html).

In [72]:
example_data = torch.rand(1, 784)

traced_model = torch.jit.trace(model, (example_data,))

time_normal = time_it(model, 200, example_data)
print(f"Avg. inference time for normal model: {time_normal:.4f}")

time_traced = time_it(traced_model, 200, example_data)
print(f"Avg. inference time for traced model: {time_traced:.4f}")

Avg. inference time for normal model: 0.0820
Avg. inference time for traced model: 0.0676


# Exercise 1 (Optional):

If you have a cuda enabled graphics cards, send data and the model to the gpu by calling *.cuda()* and compare the inference time. 

In [70]:
# TODO: Exercise 1