In [None]:
# default_exp onnx

# 03_onnx
> Exporting models to `ONNX` format

In [None]:
# export
from fastai2.tabular.all import *
import onnxruntime as ort

In [None]:
#export
@patch
def to_onnx(x:Learner, fname='export', path=Path('.')):
    "Export model to `ONNX` format"
    orig_bs = x.dls[0].bs
    x.dls[0].bs=1
    dummy_inp = next(iter(x.dls[0]))
    x.dls[0].bs = orig_bs
    names = inspect.getfullargspec(x.model.forward).args[1:]
    dynamic_axes = {n:{0:'batch_size'} for n in names}
    dynamic_axes['output'] = {0:'batch_size'}
    torch.onnx.export(x.model, dummy_inp[:-1], path/f'{fname}.onnx',
                     input_names=names, output_names=['output'],
                     dynamic_axes=dynamic_axes)
    x.export(path/f'{fname}.pkl')

Currently supports single-output models. See an example usage below:

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')

In [None]:
splits = RandomSplitter()(range_of(df))
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
y_names = 'salary'

In [None]:
dls = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
                   y_names=y_names, splits=splits).dataloaders()
learn = tabular_learner(dls, layers=[200,100])

In [None]:
learn.to_onnx('tabular')

In [None]:
#export
from fastinference.inference import _fully_decode, _decode_loss

/media/mldata/fastai2/zach/fastinference
/media/mldata/fastai2/zach/fastinference/nbs


In [None]:
#export
class fastONNX():
    "ONNX wrapper for `Learner`"
    def __init__(self, fn):
        self.ort_session = ort.InferenceSession(fn+'.onnx')
        try:
            self.ort_session.set_providers(['CUDAExecutionProvider'])
            cpu = False
        except:
            self.ort_session.set_providers(['CPUExecutionProvider'])
            cpu = True
        self.learn = load_learner(fn+'.pkl')
        self.learn.model = None
    
    def to_numpy(self, t:tensor): return t.detach.cpu().numpy() if t.requires_grad else t.cpu().numpy()
    
    def predict(self, inps):
        "Predict a single numpy item"
        if isinstance(inps[0], Tensor): inps = [self.to_numpy(x) for x in inps]
        names = [i.name for i in self.ort_session.get_inputs()]
        xs = {name:x for name,x in zip(names,inps)}
        outs = self.ort_session.run(None, xs)
        return outs
    
    def get_preds(self, dl=None, raw_outs=False, decoded_loss=True, fully_decoded=False):
        "Get predictions with possible decoding"
        inps, outs, dec_out, raw = [], [], [], []
        loss_func = self.learn.loss_func
        is_multi, n_inp = False, self.learn.dls.n_inp
        if n_inp > 1:
            is_multi = true
            [inps.append([]) for _ in range(n_inp)]
        for batch in dl:
            batch_np = []
            if is_multi:
                for i in range(n_inp):
                    item = self.to_numpy(batch[i])
                    inps[i].append(item)
                    batch_np.append(item)
            else:
                inps.append(self.to_numpy(batch[:n_inp]))
            if decoded_loss or fully_decoded:
                out = self.predict(batch_np)
                raw.append(out)
                dec_out.append(loss_func.decodes(tensor(out)))
            else:
                raw.append(self.predict(batch_np))
        axis = 1 if len(dl) > 1 else 0
        raw = np.concatenate(raw, axis=axis)
        if decoded_loss or fully_decoded:
            dec_out = np.concatenate(dec_out, axis=axis)
        if not raw_outs:
            try: outs.insert(0, loss_func.activation(tensor(raw)).numpy())
            except: outs.insert(0, dec_out)
        else:
            outs.insert(0, raw)
        if fully_decoded: outs = _fully_decode(self.learn.dls, inps, outs, dec_out, is_multi)
        if decoded_loss: outs = _decode_loss(self.learn.dls.vocab, dec_out, outs)
        return outs
    
    def test_dl(self, test_items, **kwargs): return self.learn.dls.test_dl(test_items, **kwargs)

In [None]:
dl = learn.dls.test_dl(df.iloc[:100], bs=64)

In [None]:
tab_inf = fastONNX('tabular')

In [None]:
%%time
_ = tab_inf.get_preds(dl=dl, raw_outs=False, decoded_loss=True)

CPU times: user 9.11 ms, sys: 207 µs, total: 9.32 ms
Wall time: 8.59 ms


In [None]:
batch = next(iter(dl))

In [None]:
%%time
_ = tab_inf.predict(batch[:2])

CPU times: user 3.85 ms, sys: 738 µs, total: 4.59 ms
Wall time: 2.22 ms
