In [None]:
# default_exp jit

# Torchscript
> `torchscript.jit` support for `fastai` models

Currently only Vision and Tabular models are supported. More details on this technology can be found on the official [pytorch documentation](https://pytorch.org/docs/stable/jit.html)

Motivations: 
- What is torchscript? What is a serialized model?
- Why one would like to export a model to another format?
- Is it faster? 

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import torch
from fastcore.all import *
from fastai.basics import *
from fastai.learner import *

In [None]:
#exporti
@patch
def requires_grad_(self:TensorBase, requires_grad=True):
    # Workaround https://github.com/pytorch/pytorch/issues/50219
    self.requires_grad = requires_grad
    return self

There are two possible scenarios with `jit`: `trace` and `script`. 
- `torch.jit.trace` when the module has no flow control/ dynamic behaviour. (e.g. modules build with `nn.Sequential`). 
- `torch.jit.script` should be utilized when there are dynamics. (e.g. RNNs, Language models). 
As a result `trace` is tried by default. You can also have a [combination of both](https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting).

>from Pytorch doc: In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. Tracing and scripting can be composed to suit the particular requirements of a part of a model
Scripted functions can call traced functions. This is particularly useful when you need to use control-flow around a simple feed-forward model. For instance the beam search of a sequence to sequence model will typically be written in script but can call an encoder module generated using tracing.

Most of vision encoders are `traceable's. 

In [None]:
#export
mk_class('JitMode', **{o:o.lower() for o in ['Trace','Script']},
         doc="All possible export modes as attributes to get tab-completion and typo-proofing")

In [None]:
#export
_all_ = ['JitMode']

In [None]:
show_doc(JitMode, title_level=3)

<h3 id="JitMode" class="doc_header"><code>class</code> <code>JitMode</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>JitMode</code>(**\*`args`**, **\*\*`kwargs`**)

All possible export modes as attributes to get tab-completion and typo-proofing

Another important thing to consider, is the serving device. If you plan on doing inference on CPU, yoou should first convert you model to CPU, and then trace-it ([ref](https://pytorch.org/docs/stable/jit.html#frequently-asked-questions))

In [None]:
#export
@patch
def to_jit(self:Learner, fname='export.ts', mode=JitMode.Trace, device='cpu'):
    "Exports `learn.model` using `jit` with `mode` to `fname`"
    inp = self.dls.one_batch()[:self.dls.n_inp]
    if not isinstance(inp, tuple): inp = (inp,)
    self.model.eval()
    self.model.to(device)
    inp = to_device(inp, device)
    traced_model = getattr(torch.jit, mode)(self.model, inp)
    torch.jit.save(traced_model, learn.path/fname)

In [None]:
show_doc(Learner.to_jit, title_level=3)

<h3 id="Learner.to_jit" class="doc_header"><code>Learner.to_jit</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h3>

> <code>Learner.to_jit</code>(**`fname`**=*`'export.ts'`*, **`mode`**=*`'trace'`*, **`device`**=*`'cpu'`*)

Exports `learn.model` using [`jit`](/fastexport/jit.html) with `mode` to `fname`

Below you will find a number of examples using `Learner.to_jit` and loading them back in

### Tabular (Multi-Input)

In [None]:
from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)

dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation',
                 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])

learn = tabular_learner(dls, metrics=accuracy)

In [None]:
#hide
#slow
with tempfile.TemporaryDirectory() as tmpdir:
    cat,cont,_ = dls.one_batch()
    with torch.no_grad():
        learn.model.eval()
        learn.model.to(cat.device)
        probs = learn.model(cat,cont)
    learn.to_jit(f'{tmpdir}/trace.pt', 'trace')
    trace = torch.jit.load(f'{tmpdir}/trace.pt', map_location=cat.device)
    trace.eval()
    probs_jit = trace(cat,cont)
    test_close(probs_jit, probs)

Tabular models can only be exported with `torch.jit.trace`, so we'll use that:

In [None]:
learn.to_jit('models/tab.pt', mode=JitMode.Trace, device='cpu')

Now we can load it back in using raw torch and pass in a batch of data:

In [None]:
loaded_model = torch.jit.load("models/tab.pt")
cat,cont,_ = dls.cpu().one_batch()

And perform inference:

In [None]:
probs = loaded_model(cat,cont); probs[:3]

tensor([[0.0411, 0.0396],
        [0.0045, 0.0858],
        [0.0303, 0.0698]], grad_fn=<SliceBackward>)

> As these are just the models, raw probabilities are returned. You still need to perform a soft or argmax

### Vision

Below is an example using `ResNet`:

In [None]:
from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
def label_func(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
    '.', get_image_files(path), valid_pct=0.2,
    label_func=label_func, item_tfms=Resize(224))
learn = cnn_learner(dls, resnet18)

In [None]:
#hide
#slow
with tempfile.TemporaryDirectory() as tmpdir:
    x,_ = dls.one_batch()
    with torch.no_grad():
        learn.model.eval()
        learn.model.to(x.device)
        probs = learn.model(x)
    learn.to_jit(f'{tmpdir}/trace.pt', 'trace')
    trace = torch.jit.load(f'{tmpdir}/trace.pt', map_location=x.device)
    trace.eval()
    probs_trace = trace(x)
    test_close(probs, probs_trace)

Since `ResNet` is a vision model, `trace` should be used:

In [None]:
learn.to_jit('models/resnet.pt', mode=JitMode.Trace)

Just as before we can now load it in and perform inference:

In [None]:
loaded_model = torch.jit.load("models/resnet.pt")
loaded_model.eval()
x,_ = dls.cpu().one_batch()

probs = loaded_model(x); probs[:3]

tensor([[-0.2017, -3.3517],
        [ 2.8087, -2.2481],
        [-0.2001, -2.5412]], grad_fn=<SliceBackward>)