In [None]:
# default_exp jit

# Jit
> Jit support for `fastai` models

Currently only Vision and Tabular models are supported

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import torch
from fastcore.all import *
from fastai.learner import *
from fastai.torch_core import TensorBase

In [None]:
#exporti
@patch
def requires_grad_(self:TensorBase, requires_grad=True):
    # Workaround https://github.com/pytorch/pytorch/issues/50219
    self.requires_grad = requires_grad
    return self

In [None]:
#slow
from fastai.vision.all import *

In [None]:
#slow
set_seed(99, True)
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))

We'll train a quick model to export:

In [None]:
#slow
learn = cnn_learner(dls, resnet34, metrics=error_rate).to_fp16()
learn.fine_tune(1)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




epoch,train_loss,valid_loss,error_rate,time
0,0.161436,0.026999,0.008119,00:46


epoch,train_loss,valid_loss,error_rate,time
0,0.059667,0.012131,0.005413,00:49


There are two possible scenarios with `jit`: `trace` and `script`. Ideally you should use `torch.jit.script`, however if there is dynamic behavior, `torch.jit.trace` should be utilized instead. As a result `trace` is tried by default

`Learner.to_jit()` will perform this decision unless a specific version is specified:

Ideally `torch.jit.trace` should be used, as it is built for dynamic behavior (such as CNN's). If your model is not convolutional in nature you should use `trace`

In [None]:
#export
mk_class('JitMode', **{o:o.lower() for o in ['Trace','Script']},
         doc="All possible export modes as attributes to get tab-completion and typo-proofing")

In [None]:
#export
@patch
def to_jit(self:Learner, fname='export.pt', mode=JitMode.Trace):
    "Exports `learn.model` using `jit` with `mode` to `fname`"
    inp = self.dls.one_batch()[:self.dls.n_inp]
    if not isinstance(inp, tuple): inp = (inp,)
    self.model.eval()
    self.model.to(inp[0].device)
    traced_model = getattr(torch.jit, mode)(self.model, inp)
    torch.jit.save(traced_model, fname)

In [None]:
show_doc(Learner.to_jit, title_level=3)

<h4 id="Learner.to_jit" class="doc_header"><code>Learner.to_jit</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.to_jit</code>(**`fname`**=*`'export.pt'`*, **`mode`**=*`'trace'`*)

Exports `learn.model` using `jit` with `mode` to `fname`

Below you will find a number of examples using `Learner.to_jit` and loading them back in

### Tabular (Multi-Input)

In [None]:
from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)

dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation',
                 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])

learn = tabular_learner(dls, metrics=accuracy)

In [None]:
#hide
#slow
with tempfile.TemporaryDirectory() as tmpdir:
    cat,cont,_ = dls.one_batch()
    with torch.no_grad():
        learn.model.eval()
        learn.model.to(cat.device)
        probs = learn.model(cat,cont)
    learn.to_jit(f'{tmpdir}/trace.pt', 'trace')
    trace = torch.jit.load(f'{tmpdir}/trace.pt', map_location=cat.device)
    trace.eval()
    probs_jit = trace(cat,cont)
    test_close(probs_jit, probs)

Tabular models can only be exported with `torch.jit.trace`, so we'll use that:

In [None]:
learn.to_jit('trace.pt', mode=JitMode.Trace)

Now we can load it back in using raw torch and pass in a batch of data:

In [None]:
loaded_model = torch.jit.load("trace.pt")
cat,cont,_ = dls.one_batch()

And perform inference:

In [None]:
probs = loaded_model(cat,cont); probs[:3]

tensor([[-0.0414,  0.0794],
        [-0.0426,  0.1249],
        [-0.0102,  0.1299]], device='cuda:0', grad_fn=<SliceBackward>)

> Note: As these are just the models, raw probabilities are returned. You still need to perform a soft or argmax

### Vision

Below is an example using `ResNet`:

In [None]:
from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
learn = cnn_learner(dls, resnet18)

In [None]:
#hide
#slow
with tempfile.TemporaryDirectory() as tmpdir:
    x,_ = dls.one_batch()
    with torch.no_grad():
        learn.model.eval()
        learn.model.to(x.device)
        probs = learn.model(x)
    learn.to_jit(f'{tmpdir}/trace.pt', 'trace')
    trace = torch.jit.load(f'{tmpdir}/trace.pt', map_location=x.device)
    trace.eval()
    probs_trace = trace(x)
    test_close(probs, probs_trace)

Since `ResNet` is a vision model, `trace` should be used:

In [None]:
learn.to_jit('trace.pt', mode=JitMode.Trace)

Just as before we can now load it in and perform inference:

In [None]:
loaded_model = torch.jit.load("trace.pt")
loaded_model.eval()
x,_ = dls.one_batch()

probs = loaded_model(x); probs[:3]

tensor([[-1.1999, -2.8738],
        [ 5.3266,  1.8526],
        [ 0.1073, -0.3077]], device='cuda:0', grad_fn=<SliceBackward>)