In [None]:
# default_exp onnx

# 03_onnx
> Exporting models to `ONNX` format

In [None]:
# export
from fastai2.tabular.all import *
import onnxruntime as ort

In [None]:
@patch
def to_onnx(x:Learner, fname='export.onnx'):
    "Export model to `ONNX` format"
    orig_bs = x.dls[0].bs
    x.dls[0].bs=1
    dummy_inp = next(iter(x.dls[0]))
    names = inspect.getfullargspec(x.model.forward).args[1:]
    torch.onnx.export(x.model, dummy_inp[:-1], fname,
                     input_names=names, output_names=['output'])

Currently supports single-output models. See an example usage below:

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')

In [None]:
splits = RandomSplitter()(range_of(df))
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
y_names = 'salary'

In [None]:
dls = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
                   y_names=y_names, y_block=y_block, splits=splits).dataloaders()
learn = tabular_learner(dls, layers=[200,100])

In [None]:
learn.to_onnx('tabular.onnx')

In [None]:
class fastONNX():
    "Onnx wrapper for `predict`"
    def __init__(self, fn):
        self.ort_session = ort.InferenceSession(fn)
        try:
            self.ort_session.set_providers(['CUDAExecutionProvider'])
        except:
            self.ort_session.set_providers(['CPUExecutionProvider'])
            
    def predict(self, *inp):
        names = [i.name for i in self.ort_session.get_inputs()]
        inps = [x.cpu().numpy() for x in inp]
        xs = {name:x for name,x in zip(names,inps)}
        outs = self.ort_session.run(None, xs)
        return outs

In [None]:
tab_inf = fastONNX('tabular.onnx')

In [None]:
batch = next(iter(learn.dls[0]))

In [None]:
%%timeit
_ = tab_inf.predict(*batch[:2])

187 µs ± 287 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [None]:
%%timeit
with torch.no_grad():
    learn.model.eval()
    out = learn.model(*batch[:2])

494 µs ± 291 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
