In [1]:
import onnx
import torch
import onnxruntime as ort
import numpy as np


class Model(torch.nn.Module):
    def forward(self, data, indices):
        output = [data]
        if indices.size(dim = -1) > 1:
            output = torch.tensor_split(data, indices[:-1])
        return output
    
dynamic_axes = {'data': {0: 'n'}, 'indices': {0: 'm'}}
input_names = ['data', 'indices']
output_names = ['test_output']
onnx_path = 'test_op.onnx'

In [2]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [3]:
model = Model()

In [4]:
model = torch.jit.script(model)

In [5]:
model

RecursiveScriptModule(original_name=Model)

In [6]:
model.eval()

RecursiveScriptModule(original_name=Model)

In [7]:
data = torch.rand((6, ), dtype = torch.float).to('cpu')
indices = torch.tensor([1, 3, 5, 6]).to('cpu')

In [8]:
temp_inputs = (data, indices)
temp_inputs

(tensor([0.1599, 0.6062, 0.8264, 0.8728, 0.4459, 0.1708]),
 tensor([1, 3, 5, 6]))

In [9]:
def create_ort_inputs(inputs, input_names):
    return {name: to_numpy(inputs[index]) for (index, name) in enumerate(input_names)}

In [10]:
def run_ort_inference(onnx_path, inputs, input_names):
    ort_session = ort.InferenceSession(onnx_path)
    runSessionParams = ort.RunOptions()
    runSessionParams.log_severity_level = 0
    ort_out = ort_session.run(None, create_ort_inputs(inputs, input_names), run_options = runSessionParams)
    return ort_out

In [11]:
def run_pt_inference(model, inputs):
    return model(*inputs)

In [12]:
def export_onnx(model, inputs, dynamic_axes, input_names, output_names, onnx_path, opset_version = 15):
    torch.onnx.export(model, inputs, onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset_version)
    model_output = run_pt_inference(model, inputs)
    ort_output = run_ort_inference(onnx_path, inputs, input_names)
    print ("Model output - \n{}".format(model_output))
    print ("ORT output - \n{}".format(ort_output))

In [13]:
export_onnx(model, temp_inputs, dynamic_axes, input_names, output_names, onnx_path)

Model output - 
[tensor([0.1599]), tensor([0.6062, 0.8264]), tensor([0.8728, 0.4459]), tensor([0.1708])]
ORT output - 
[[array([0.15993482], dtype=float32), array([0.6061704, 0.8263963], dtype=float32), array([0.8728442 , 0.44591272], dtype=float32), array([0.17079455], dtype=float32)]]


2022-05-27 00:25:05.731471 [I:onnxruntime:, sequential_executor.cc:176 Execute] Begin execution
2022-05-27 00:25:05.731533 [I:onnxruntime:, sequential_executor.cc:176 Execute] Begin execution


In [14]:
data = torch.rand((9, ), dtype = torch.float).to('cpu')
indices = torch.tensor([1, 3, 5, 6, 8, 9]).to('cpu')
temp_inputs = (data, indices)
temp_inputs

(tensor([0.6227, 0.2334, 0.1477, 0.3524, 0.3059, 0.0208, 0.4220, 0.8973, 0.2717]),
 tensor([1, 3, 5, 6, 8, 9]))

In [15]:
run_ort_inference(onnx_path, temp_inputs, input_names)

2022-05-27 00:25:05.748837 [I:onnxruntime:, sequential_executor.cc:176 Execute] Begin execution
2022-05-27 00:25:05.748870 [I:onnxruntime:, sequential_executor.cc:176 Execute] Begin execution


[[array([0.6227192], dtype=float32),
  array([0.23342085, 0.14773607], dtype=float32),
  array([0.3523618, 0.3059066], dtype=float32),
  array([0.02083826], dtype=float32),
  array([0.42197102, 0.8973381 ], dtype=float32),
  array([0.2716781], dtype=float32)]]

In [16]:
data = torch.rand((10, ), dtype = torch.float).to('cpu')
indices = torch.tensor([10]).to('cpu')
temp_inputs = (data, indices)
temp_inputs

(tensor([0.1340, 0.3370, 0.9966, 0.5550, 0.2751, 0.7691, 0.3250, 0.4477, 0.3297,
         0.7813]),
 tensor([10]))

In [17]:
run_ort_inference(onnx_path, temp_inputs, input_names)

2022-05-27 00:25:05.771019 [I:onnxruntime:, sequential_executor.cc:176 Execute] Begin execution
2022-05-27 00:25:05.771055 [I:onnxruntime:, sequential_executor.cc:176 Execute] Begin execution


[[array([0.13395113, 0.3369828 , 0.9966254 , 0.55499226, 0.27509874,
         0.7691482 , 0.3250174 , 0.44767803, 0.32971817, 0.78128254],
        dtype=float32)]]