diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index ba85cb8406b31..632046389e12c 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", } // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): @@ -157,7 +157,7 @@ def AttributedOp : TestOp<"attributed_op"> { } // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)) +// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOpWithOperands(_ods_ir.OpView): @@ -193,7 +193,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { } // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)) +// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView): @@ -217,7 +217,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { } // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)) +// CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip) // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { @@ -235,7 +235,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu } // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)) +// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { @@ -262,7 +262,7 @@ def EmptyOp : TestOp<"empty">; // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: def empty(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)) +// CHECK: return EmptyOp(loc=loc, ip=ip) // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { @@ -275,7 +275,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { } // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)) +// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { @@ -288,7 +288,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> } // CHECK: def infer_result_types_op(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)) +// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): @@ -326,7 +326,7 @@ def MissingNamesOp : TestOp<"missing_names"> { } // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)) +// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneOptionalOperandOp(_ods_ir.OpView): @@ -357,7 +357,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> { } // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)) +// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): @@ -389,7 +389,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { } // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)) +// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicResultOp(_ods_ir.OpView): @@ -446,7 +446,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> { } // CHECK: def python_keyword(in_, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)) +// CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip) // CHECK-LABEL: OPERATION_NAME = "test.same_results" def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { @@ -460,7 +460,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { } // CHECK: def same_results(in1, in2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)) +// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip) // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { @@ -497,7 +497,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand", } // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView): @@ -563,7 +563,7 @@ def SimpleOp : TestOp<"simple"> { } // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)) +// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region" @@ -590,7 +590,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { } // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) +// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip) // CHECK: class VariadicRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_region" @@ -613,7 +613,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> { } // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) +// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView): @@ -622,7 +622,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> { } // CHECK: def _123with__special_characters(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)) +// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSuccessorsOp(_ods_ir.OpView): @@ -637,4 +637,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> { } // CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)) +// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 0c5c936f5adde..5019b69d91127 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -271,6 +271,11 @@ constexpr const char *regionAccessorTemplate = R"Py( )Py"; constexpr const char *valueBuilderTemplate = R"Py( +def {0}({2}) -> {4}: + return {1}({3}){5} +)Py"; + +constexpr const char *valueBuilderVariadicTemplate = R"Py( def {0}({2}) -> {4}: return _get_op_result_or_op_results({1}({3})) )Py"; @@ -992,15 +997,29 @@ static void emitValueBuilder(const Operator &op, auto lhs = *llvm::split(arg, "=").begin(); return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str(); }); - std::string nameWithoutDialect = - op.getOperationName().substr(op.getOperationName().find('.') + 1); - os << formatv( - valueBuilderTemplate, sanitizeName(nameWithoutDialect), - op.getCppClassName(), llvm::join(valueBuilderParams, ", "), - llvm::join(opBuilderArgs, ", "), + std::string nameWithoutDialect = sanitizeName( + op.getOperationName().substr(op.getOperationName().find('.') + 1)); + std::string params = llvm::join(valueBuilderParams, ", "); + std::string args = llvm::join(opBuilderArgs, ", "); + const char *type = (op.getNumResults() > 1 ? "_Sequence[_ods_ir.Value]" - : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"))); + : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")); + if (op.getNumVariableLengthResults() > 0) { + os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect, + op.getCppClassName(), params, args, type); + } else { + const char *results; + if (op.getNumResults() == 0) { + results = ""; + } else if (op.getNumResults() == 1) { + results = ".result"; + } else { + results = ".results"; + } + os << formatv(valueBuilderTemplate, nameWithoutDialect, + op.getCppClassName(), params, args, type, results); + } } /// Emits bindings for a specific Op to the given output stream.