-
Notifications
You must be signed in to change notification settings - Fork 117
Closed as not planned
Labels
Description
Here, I am getting this error trying to define extension for my custom op that returns list of tensors :
E onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from test.onnx failed:This is an invalid model. Type Error: Type 'tensor(flo\
at)' of input parameter (tensor_product_uniform_1d_jit) of operator (SequenceAt) in node (n0_2) is invalid.
This is how my op looks like - there seem to be no way to explicitly specify output as sequence:
@onnx_op(
op_type="cuequivariance_ops::tensor_product_uniform_1d",
inputs=[
PyOp.dt_float, # in0
PyOp.dt_float, # in1
PyOp.dt_float, # in2 (optional)
PyOp.dt_float, # in3 (optional)
PyOp.dt_float, # in4 (optional)
PyOp.dt_float, # in5 (optional)
PyOp.dt_float, # in6 (optional)
PyOp.dt_int64, # indexes (optional)
],
attrs={
"name": PyCustomOpDef.dt_string,
....
"batch_size": PyCustomOpDef.dt_int64,
}
)
def _(in0, in1, in2=None, in3=None, in4=None, in5=None, in6=None, indexes=None,
**kwargs) -> Sequence[Tensor]:
args = [in2, in3, in4, in5, in6, indexes]
args = maybe_from_numpy(args)
result = torch.ops.cuequivariance_ops.tensor_product_uniform_1d_jit(
*args,
*(kwargs.values()),
)
return result
torch.ops.cuequivariance_ops.tensor_product_uniform_1d_jit returns List[Tensor].
Any workarounds ?