Skip to content

Commit

Permalink
[mlir][tablegen] Generate default attr values in Python bindings
Browse files Browse the repository at this point in the history
When specifying an op attribute with a default value (via DefaultValuedAttr), the default value is a string of C++ code. In the general case, the default value of such an attribute cannot be translated to Python when generating the bindings. However, we can hard-code default Python values for frequently-used C++ default values.

This change adds a Python default value for empty ArrayAttrs.

Differential Revision: https://reviews.llvm.org/D127750
  • Loading branch information
matthias-springer committed Jun 15, 2022
1 parent 9fc0aa4 commit 989d2b5
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
22 changes: 21 additions & 1 deletion mlir/test/mlir-tblgen/op-python-bindings.td
Expand Up @@ -179,6 +179,27 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: attributes["arr"] = arr if arr is not None else _ods_ir.ArrayAttr.get([])
// CHECK: unsupported is not None, "attribute unsupported must be specified"
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))

let arguments = (ins DefaultValuedAttr<I64ArrayAttr, "{}">:$arr,
DefaultValuedAttr<I64ArrayAttr, "dummy_func()">:$unsupported);
let results = (outs);
}

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, type, *, loc=None, ip=None):
Expand Down Expand Up @@ -544,4 +565,3 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
let successors = (successor AnySuccessor:$successor,
VariadicSuccessor<AnySuccessor>:$successors);
}

50 changes: 50 additions & 0 deletions mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Expand Up @@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringSet.h"
Expand Down Expand Up @@ -542,6 +543,21 @@ constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
constexpr const char *initOptionalAttributeTemplate =
R"Py(if {1} is not None: attributes["{0}"] = {1})Py";

/// Template for setting an attribute with a default value in the operation
/// builder.
/// {0} is the attribute name;
/// {1} is the builder argument name;
/// {2} is the default value.
constexpr const char *initDefaultValuedAttributeTemplate =
R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py";

/// Template for asserting that an attribute value was provided when calling a
/// builder.
/// {0} is the attribute name;
/// {1} is the builder argument name.
constexpr const char *assertAttributeValueSpecified =
R"Py(assert {1} is not None, "attribute {0} must be specified")Py";

constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
_ods_get_default_loc_context(loc)))Py";
Expand Down Expand Up @@ -647,6 +663,21 @@ static void populateBuilderArgsSuccessors(
}
}

/// Generates Python code for the default value of the given attribute.
static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) {
assert(attr.hasDefaultValue() && "expected attribute with default value");
StringRef storageType = attr.getStorageType().trim();
StringRef defaultValCpp = attr.getDefaultValue().trim();

// A list of commonly used attribute types and default values for which
// we can generate Python code. Extend as needed.
if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}"))
return std::string("_ods_ir.ArrayAttr.get([])");

// No match: Cannot generate Python code.
return failure();
}

/// Populates `builderLines` with additional lines that are required in the
/// builder to set up operation attributes. `argNames` is expected to contain
/// the names of builder arguments that correspond to op arguments, i.e. to the
Expand All @@ -669,6 +700,25 @@ populateBuilderLinesAttr(const Operator &op,
continue;
}

// Attributes with default value are handled specially.
if (attribute->attr.hasDefaultValue()) {
// In case we cannot generate Python code for the default value, the
// attribute must be specified by the user.
FailureOr<std::string> defaultValPy =
getAttributeDefaultValue(attribute->attr);
if (succeeded(defaultValPy)) {
builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate,
attribute->name, argNames[i],
*defaultValPy));
} else {
builderLines.push_back(llvm::formatv(assertAttributeValueSpecified,
attribute->name, argNames[i]));
builderLines.push_back(
llvm::formatv(initAttributeTemplate, attribute->name, argNames[i]));
}
continue;
}

builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
? initOptionalAttributeTemplate
: initAttributeTemplate,
Expand Down

0 comments on commit 989d2b5

Please sign in to comment.