diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 6fe3ff5302e261..79577463d9199b 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -58,7 +58,7 @@ def __init__(self, type = TypeAttr.get(type) sym_visibility = StringAttr.get( str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) if body_builder: entry_block = self.add_entry_block() with InsertionPoint(entry_block): diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py index a3df7ff0336072..8db82cf81c6788 100644 --- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py +++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py @@ -48,7 +48,7 @@ def __init__(self, type = TypeAttr.get(type) sym_visibility = StringAttr.get( str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) if body_builder: entry_block = self.add_entry_block() with InsertionPoint(entry_block): diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index fb5b519c7c0229..bb63fe64dd035e 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -93,7 +93,7 @@ def __init__(self, ip=None): type = type if type is None else _get_value(type) result = pdl.AttributeType.get() - super().__init__(result, type, value, loc=loc, ip=ip) + super().__init__(result, type=type, value=value, loc=loc, ip=ip) class EraseOp: @@ -118,7 +118,7 @@ def __init__(self, ip=None): type = type if type is None else _get_value(type) result = pdl.ValueType.get() - super().__init__(result, type, loc=loc, ip=ip) + super().__init__(result, type=type, loc=loc, ip=ip) class OperandsOp: @@ -131,7 +131,7 @@ def __init__(self, ip=None): types = types if types is None else _get_value(types) result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, types, loc=loc, ip=ip) + super().__init__(result, type=types, loc=loc, ip=ip) class OperationOp: @@ -155,7 +155,7 @@ def __init__(self, attributeNames = ArrayAttr.get(attributeNames) types = _get_values(types) result = pdl.OperationType.get() - super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip) + super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip) class PatternOp: @@ -170,7 +170,7 @@ def __init__(self, """Creates an PDL `pattern` operation.""" name_attr = None if name is None else _get_str_attr(name) benefit_attr = _get_int_attr(16, benefit) - super().__init__(benefit_attr, name_attr, loc=loc, ip=ip) + super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip) self.regions[0].blocks.append() @property @@ -192,7 +192,7 @@ def __init__(self, op = _get_value(op) with_op = with_op if with_op is None else _get_value(with_op) with_values = _get_values(with_values) - super().__init__(op, with_op, with_values, loc=loc, ip=ip) + super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) class ResultOp: @@ -222,7 +222,7 @@ def __init__(self, ip=None): parent = _get_value(parent) index = index if index is None else _get_int_attr(32, index) - super().__init__(result, parent, index, loc=loc, ip=ip) + super().__init__(result, parent, index=index, loc=loc, ip=ip) class RewriteOp: @@ -238,7 +238,7 @@ def __init__(self, root = root if root is None else _get_value(root) name = name if name is None else _get_str_attr(name) args = _get_values(args) - super().__init__(root, name, args, loc=loc, ip=ip) + super().__init__(args, root=root,name=name, loc=loc, ip=ip) def add_body(self): """Add body (block) to the rewrite.""" @@ -261,7 +261,7 @@ def __init__(self, ip=None): type = type if type is None else _get_type_attr(type) result = pdl.TypeType.get() - super().__init__(result, type, loc=loc, ip=ip) + super().__init__(result, type=type, loc=loc, ip=ip) class TypesOp: @@ -275,4 +275,4 @@ def __init__(self, types = _get_array_attr([_get_type_attr(ty) for ty in types]) types = None if not types else types result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, types, loc=loc, ip=ip) + super().__init__(result, types=types, loc=loc, ip=ip) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 59b5dec83c0308..f744ce501b1064 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -21,7 +21,7 @@ class TestOp traits = []> : // CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,] def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { - // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): + // CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -110,7 +110,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOp : TestOp<"attributed_op"> { - // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, *, loc=None, ip=None): + // CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -152,7 +152,7 @@ def AttributedOp : TestOp<"attributed_op"> { // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { - // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, *, loc=None, ip=None): + // CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -286,7 +286,7 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK-NOT: _ODS_RESULT_SEGMENTS def OneOptionalOperandOp : TestOp<"one_optional_operand"> { let arguments = (ins AnyType:$non_optional, Optional:$optional); - // CHECK: def __init__(self, non_optional, optional, *, loc=None, ip=None): + // CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index e7b1f44a3ad8e9..c73fce23d3c49c 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -28,13 +28,13 @@ def testAttributes(): # CHECK-DAG: optional_i32 = 2 : i32 # CHECK-DAG: unit # CHECK: } - op = test.AttributedOp(one, two, unit) + op = test.AttributedOp(one, optional_i32=two, unit=unit) print(f"{op}") # CHECK: "python_test.attributed_op"() { # CHECK: mandatory_i32 = 2 : i32 # CHECK: } - op2 = test.AttributedOp(two, None, None) + op2 = test.AttributedOp(two) print(f"{op2}") # @@ -218,11 +218,11 @@ def testOptionalOperandOp(): module = Module.create() with InsertionPoint(module.body): - op1 = test.OptionalOperandOp(None) + op1 = test.OptionalOperandOp() # CHECK: op1.input is None: True print(f"op1.input is None: {op1.input is None}") - op2 = test.OptionalOperandOp(op1) + op2 = test.OptionalOperandOp(input=op1) # CHECK: op2.input is None: False print(f"op2.input is None: {op2.input is None}") diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py index c31579545e6e7f..8f8d7f19191cfa 100644 --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -46,9 +46,9 @@ def testTransferReadOp(): with InsertionPoint(f.add_entry_block()): A, zero, padding, mask = f.arguments vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding, mask, None) + padding, mask=mask) vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding, None, None) + padding) func.ReturnOp([]) # CHECK: @transfer_read(%[[MEM:.*]]: memref, %[[IDX:.*]]: index, diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 16fccff973ca76..83d2acce3ba2c9 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -620,6 +620,13 @@ populateBuilderArgs(const Operator &op, if (!op.getArg(i).is()) operandNames.push_back(name); } +} + +/// Populates `builderArgs` with the Python-compatible names of builder function +/// successor arguments. Additionally, `successorArgNames` is also populated. +static void populateBuilderArgsSuccessors( + const Operator &op, llvm::SmallVectorImpl &builderArgs, + llvm::SmallVectorImpl &successorArgNames) { for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { NamedSuccessor successor = op.getSuccessor(i); @@ -857,6 +864,8 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { populateBuilderArgsResults(op, builderArgs); size_t numResultArgs = builderArgs.size(); populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); + size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; + populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); populateBuilderLinesOperand(op, operandArgNames, builderLines); populateBuilderLinesAttr( @@ -868,10 +877,53 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderRegions(op, builderArgs, builderLines); - builderArgs.push_back("*"); - builderArgs.push_back("loc=None"); - builderArgs.push_back("ip=None"); - os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), + // Layout of builderArgs vector elements: + // [ result_args operand_attr_args successor_args regions ] + + // Determine whether the argument corresponding to a given index into the + // builderArgs vector is a python keyword argument or not. + auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { + // All result, successor, and region arguments are positional arguments. + if ((builderArgIndex < numResultArgs) || + (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) + return false; + // Keyword arguments: + // - optional named attributes (including unit attributes) + // - default-valued named attributes + // - optional operands + Argument a = op.getArg(builderArgIndex - numResultArgs); + if (auto *nattr = a.dyn_cast()) + return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); + else if (auto *ntype = a.dyn_cast()) + return ntype->isOptional(); + else + return false; + }; + + // StringRefs in functionArgs refer to strings allocated by builderArgs. + llvm::SmallVector functionArgs; + + // Add positional arguments. + for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { + if (!isKeywordArgFn(i)) + functionArgs.push_back(builderArgs[i]); + } + + // Add a bare '*' to indicate that all following arguments must be keyword + // arguments. + functionArgs.push_back("*"); + + // Add a default 'None' value to each keyword arg string, and then add to the + // function args list. + for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { + if (isKeywordArgFn(i)) { + builderArgs[i].append("=None"); + functionArgs.push_back(builderArgs[i]); + } + } + functionArgs.push_back("loc=None"); + functionArgs.push_back("ip=None"); + os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), llvm::join(builderLines, "\n ")); }