diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 3ec69c33b4bb9..90feec9ed8d6b 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", Optional:$variadic2); } -// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) +// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> AttrSizedOperandsOp: // CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -108,8 +108,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", Variadic:$variadic2); } -// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, AttrSizedResultsOp]: +// CHECK: op = AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results +// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -159,7 +160,7 @@ def AttributedOp : TestOp<"attributed_op"> { UnitAttr:$unitAttr, I32Attr:$in); } -// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) +// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> AttributedOp: // CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -196,7 +197,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr:$is); } -// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) +// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> AttributedOpWithOperands // CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -221,7 +222,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { let results = (outs); } -// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) +// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> DefaultValuedAttrsOp: // CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip) // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" @@ -239,7 +240,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu let results = (outs AnyType:$res, AnyType); } -// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None) +// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None) -> _ods_ir.OpResultList: // CHECK: return DeriveResultTypesOp(type_=type_, results=results, loc=loc, ip=ip).results // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" @@ -249,8 +250,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir let results = (outs AnyType:$res, Variadic); } -// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)) +// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, DeriveResultTypesVariadicOp]: +// CHECK: op = DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip); results = op.results +// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class EmptyOp(_ods_ir.OpView): @@ -267,7 +269,7 @@ def EmptyOp : TestOp<"empty">; // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip) -// CHECK: def empty(*, loc=None, ip=None) +// CHECK: def empty(*, loc=None, ip=None) -> EmptyOp: // CHECK: return EmptyOp(loc=loc, ip=ip) // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" @@ -281,7 +283,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { let results = (outs I32:$i32, F32:$f32); } -// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None) +// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None) -> _ods_ir.OpResultList: // CHECK: return InferResultTypesImpliedOp(results=results, loc=loc, ip=ip).results // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" @@ -295,7 +297,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> let results = (outs AnyType, AnyType, AnyType); } -// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None) +// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None) -> _ods_ir.OpResultList: // CHECK: return InferResultTypesOp(results=results, loc=loc, ip=ip).results // CHECK: @_ods_cext.register_operation(_Dialect) @@ -334,7 +336,7 @@ def MissingNamesOp : TestOp<"missing_names"> { let results = (outs I32:$i32, AnyFloat, I64:$i64); } -// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) +// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _ods_ir.OpResultList: // CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results // CHECK: @_ods_cext.register_operation(_Dialect) @@ -366,7 +368,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> { // CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1] } -// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) +// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> OneOptionalOperandOp: // CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -399,7 +401,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { let arguments = (ins AnyType:$non_variadic, Variadic:$variadic); } -// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) +// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> OneVariadicOperandOp: // CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -433,8 +435,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> { let results = (outs Variadic:$variadic, AnyType:$non_variadic); } -// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)) +// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, OneVariadicResultOp]: +// CHECK: op = OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip); results = op.results +// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class PythonKeywordOp(_ods_ir.OpView): @@ -458,7 +461,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> { let arguments = (ins AnyType:$in); } -// CHECK: def python_keyword(in_, *, loc=None, ip=None) +// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> PythonKeywordOp: // CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip) // CHECK-LABEL: OPERATION_NAME = "test.same_results" @@ -471,8 +474,8 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { let results = (outs AnyType:$res); } -// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None) -// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip) +// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult: +// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip).result // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { @@ -481,8 +484,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu let results = (outs Variadic:$res); } -// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)) +// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameResultsVariadicOp]: +// CHECK: op = SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip); results = op.results +// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -508,7 +512,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand", Variadic:$variadic2); } -// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) +// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> SameVariadicOperandSizeOp: // CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -534,8 +538,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result", Variadic:$variadic2); } -// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameVariadicResultSizeOp]: +// CHECK: op = SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results +// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SimpleOp(_ods_ir.OpView): @@ -575,7 +580,7 @@ def SimpleOp : TestOp<"simple"> { let results = (outs I64:$i64, AnyFloat:$f64); } -// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) +// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList: // CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): @@ -603,7 +608,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { // CHECK: return self.regions[2:] } -// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) +// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> VariadicAndNormalRegionOp: // CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip) // CHECK: class VariadicRegionOp(_ods_ir.OpView): @@ -627,7 +632,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> { // CHECK: return self.regions[0:] } -// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) +// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> VariadicRegionOp: // CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -636,7 +641,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> { def WithSpecialCharactersOp : TestOp<"123with--special.characters"> { } -// CHECK: def _123with__special_characters(*, loc=None, ip=None) +// CHECK: def _123with__special_characters(*, loc=None, ip=None) -> WithSpecialCharactersOp: // CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -651,11 +656,11 @@ def WithSuccessorsOp : TestOp<"with_successors"> { VariadicSuccessor:$successors); } -// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) +// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> WithSuccessorsOp: // CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip) // CHECK: class snake_case(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.snake_case" def already_snake_case : TestOp<"snake_case"> {} -// CHECK: def snake_case_(*, loc=None, ip=None) +// CHECK: def snake_case_(*, loc=None, ip=None) -> snake_case: // CHECK: return snake_case(loc=loc, ip=ip) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 68262822ca6b5..17aaef7e1b9f4 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,7 +1,9 @@ # RUN: %PYTHON %s pybind11 | FileCheck %s # RUN: %PYTHON %s nanobind | FileCheck %s - +import inspect import sys +from typing import Union + from mlir.ir import * import mlir.dialects.func as func import mlir.dialects.python_test as test @@ -323,6 +325,7 @@ def resultTypesDefinedByTraits(): # CHECK: f32 index print(no_infer.single.type, no_infer.doubled.type) + # CHECK-LABEL: TEST: testOptionalOperandOp @run def testOptionalOperandOp(): @@ -594,6 +597,17 @@ def testInferTypeOpInterface(): # CHECK: f32 print(two_operands.result.type) + assert ( + inspect.signature( + test.infer_results_variadic_inputs_op + ).return_annotation + is OpResult + ) + assert isinstance( + test.infer_results_variadic_inputs_op(single=zero, doubled=zero), + OpResult, + ) + # CHECK-LABEL: TEST: testVariadicOperandAccess @run @@ -621,6 +635,15 @@ def values(lst): # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)'] print(values(variadic_operands.variadic2)) + assert ( + inspect.signature(test.same_variadic_operand).return_annotation + is test.SameVariadicOperandSizeOp + ) + assert isinstance( + test.same_variadic_operand([zero, one], two, [three, four]), + test.SameVariadicOperandSizeOp, + ) + # CHECK-LABEL: TEST: testVariadicResultAccess @run @@ -642,6 +665,15 @@ def types(lst): # CHECK: [IntegerType(i3), IntegerType(i4)] print(types(op.variadic2)) + assert ( + inspect.signature(test.same_variadic_result_vfv).return_annotation + is Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV] + ) + assert isinstance( + test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]), + OpResultList, + ) + # Test Variadic-Variadic-Variadic op = test.SameVariadicResultSizeOpVVV( [i[0], i[1]], [i[2], i[3]], [i[4], i[5]] @@ -713,3 +745,12 @@ def types(lst): print(types(op.variadic2)) # CHECK: i4 print(op.non_variadic3.type) + + assert ( + inspect.signature(test.results_variadic).return_annotation + is Union[OpResult, OpResultList, test.ResultsVariadicOp] + ) + assert isinstance( + test.results_variadic([i[0]]), + OpResult, + ) diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py index 01b5542119b4e..a063aa972cc48 100644 --- a/mlir/test/python/ir/auto_location.py +++ b/mlir/test/python/ir/auto_location.py @@ -51,7 +51,7 @@ def testInferLocations(): _cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__) three = arith.constant(IndexType.get(), 3) # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))) # fmt: on print(three.location) @@ -60,14 +60,14 @@ def foo(): print(four.location) # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations..foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations..foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) # fmt: on foo() _cext.globals.register_traceback_file_exclusion(__file__) # fmt: off - # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235)) + # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235)) # fmt: on foo() diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td index 026e64a3cfc19..1e94b94dc714b 100644 --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -265,4 +265,8 @@ def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf", AnyType:$non_variadic3); } +def ResultsVariadicOp : TestOp<"results_variadic"> { + let results = (outs Variadic:$res); +} + #endif // PYTHON_TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 6a7aa9e3432d5..21f712e85e6c0 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -36,7 +36,6 @@ from ._ods_common import _cext as _ods_cext from ._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, - get_op_result_or_op_results as _get_op_result_or_op_results, get_op_results_or_values as _get_op_results_or_values, segmented_accessor as _ods_segmented_accessor, ) @@ -276,8 +275,9 @@ def {0}({2}) -> {4}: )Py"; constexpr const char *valueBuilderVariadicTemplate = R"Py( -def {0}({2}) -> {4}: - return _get_op_result_or_op_results({1}({3})) +def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]: + op = {1}({3}); results = op.results + return results if len(results) > 1 else (results[0] if len(results) == 1 else op) )Py"; static llvm::cl::OptionCategory @@ -1013,21 +1013,18 @@ static void emitValueBuilder(const Operator &op, nameWithoutDialect += "_"; std::string params = llvm::join(valueBuilderParams, ", "); std::string args = llvm::join(opBuilderArgs, ", "); - const char *type = - (op.getNumResults() > 1 - ? "_Sequence[_ods_ir.Value]" - : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")); - if (op.getNumVariableLengthResults() > 0) { + if (op.getNumVariableLengthResults()) { os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect, - op.getCppClassName(), params, args, type); + op.getCppClassName(), params, args); } else { - const char *results; - if (op.getNumResults() == 0) { - results = ""; + std::string type = op.getCppClassName().str(); + const char *results = ""; + if (op.getNumResults() > 1) { + type = "_ods_ir.OpResultList"; + results = ".results"; } else if (op.getNumResults() == 1) { + type = "_ods_ir.OpResult"; results = ".result"; - } else { - results = ".results"; } os << formatv(valueBuilderTemplate, nameWithoutDialect, op.getCppClassName(), params, args, type, results);