Skip to content

Not able to handle custom operators that return list on tensors #922

@borisfom

Description

@borisfom

Here, I am getting this error trying to define extension for my custom op that returns list of tensors :

E           onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from test.onnx failed:This is an invalid model. Type Error: Type 'tensor(flo\
at)' of input parameter (tensor_product_uniform_1d_jit) of operator (SequenceAt) in node (n0_2) is invalid.

This is how my op looks like - there seem to be no way to explicitly specify output as sequence:

    @onnx_op(
        op_type="cuequivariance_ops::tensor_product_uniform_1d",
        inputs=[
            PyOp.dt_float,  # in0                                                                                                                                                                       
            PyOp.dt_float,  # in1                                                                                                                                                                       
            PyOp.dt_float,  # in2 (optional)                                                                                                                                                            
            PyOp.dt_float,  # in3 (optional)                                                                                                                                                            
            PyOp.dt_float,  # in4 (optional)                                                                                                                                                            
            PyOp.dt_float,  # in5 (optional)                                                                                                                                                            
            PyOp.dt_float,  # in6 (optional)                                                                                                                                                            
            PyOp.dt_int64,  # indexes (optional)                                                                                                                                                        
        ],
        attrs={
            "name": PyCustomOpDef.dt_string,
             ....
            "batch_size": PyCustomOpDef.dt_int64,
        }
    )
    def _(in0, in1, in2=None, in3=None, in4=None, in5=None, in6=None, indexes=None,
          **kwargs) -> Sequence[Tensor]:
        args = [in2, in3, in4, in5, in6, indexes]
        args = maybe_from_numpy(args)
        result = torch.ops.cuequivariance_ops.tensor_product_uniform_1d_jit(
            *args,
            *(kwargs.values()),
        )
        return result

torch.ops.cuequivariance_ops.tensor_product_uniform_1d_jit returns List[Tensor].
Any workarounds ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions