Skip to content

Commit

Permalink
Revert "Revert "[mlir][py] Enable building ops with raw inputs""
Browse files Browse the repository at this point in the history
Fix Python 3.6.9 issue encountered due to type checking here. Will
add back in follow up.

This reverts commit 1f47fee.
  • Loading branch information
jpienaar committed Dec 22, 2022
1 parent 02f4cfa commit b57acb9
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 36 deletions.
28 changes: 28 additions & 0 deletions mlir/docs/Bindings/Python.md
Expand Up @@ -743,6 +743,34 @@ with Context():
dictionary = DictAttr.get({"array": array, "unit": UnitAttr.get()})
```

Custom builders for Attributes to be used during Operation creation can be
registered by way of the `register_attribute_builder`. In particular the
following is how a custom builder is registered for `I32Attr`:

```python
@register_attribute_builder("I32Attr")
def _i32Attr(x: int, context: Context):
return IntegerAttr.get(
IntegerType.get_signless(32, context=context), x)
```

This allows to invoke op creation of an op with a `I32Attr` with

```python
foo.Op(30)
```

The registration is based on the ODS name but registry is via pure python
method. Only single custom builder is allowed to be registered per ODS attribute
type (e.g., I32Attr can have only one, which can correspond to multiple of the
underlying IntegerAttr type).

instead of

```python
foo.Op(IntegerAttr.get(IndexType.get_signless(32, context=context), 30))
```

## Style

In general, for the core parts of MLIR, the Python bindings should be largely
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Bindings/Python/Globals.h
Expand Up @@ -58,6 +58,12 @@ class PyGlobals {
/// have a DIALECT_NAMESPACE attribute.
pybind11::object registerDialectDecorator(pybind11::object pyClass);

/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc);

/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
Expand All @@ -71,6 +77,10 @@ class PyGlobals {
pybind11::object pyClass,
pybind11::object rawOpViewClass);

/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
lookupAttributeBuilder(const std::string &attributeKind);

/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
llvm::Optional<pybind11::object>
Expand All @@ -92,6 +102,8 @@ class PyGlobals {
/// Map of operation name to custom subclass that directly initializes
/// the OpView base class (bypassing the user class constructor).
llvm::StringMap<pybind11::object> rawOpViewClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::function> attributeBuilderMap;

/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Bindings/Python/IRCore.cpp
Expand Up @@ -194,6 +194,29 @@ struct PyGlobalDebugFlag {
}
};

struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
}
static py::function dundeGetItemNamed(const std::string &attributeKind) {
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
if (!builder)
throw py::key_error();
return *builder;
}
static void dundeSetItemNamed(const std::string &attributeKind,
py::function func) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
}

static void bind(py::module &m) {
py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
.def_static("contains", &PyAttrBuilderMap::dunderContains)
.def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
.def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
}
};

//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) {

// Debug bindings.
PyGlobalDebugFlag::bind(m);

// Attribute builder getter.
PyAttrBuilderMap::bind(m);
}
27 changes: 27 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.cpp
Expand Up @@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
loadedDialectModulesCache.insert(dialectNamespace);
}

void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc) {
py::function &found = attributeBuilderMap[attributeKind];
if (found) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind + "' is already registered")
.str());
}
found = std::move(pyFunc);
}

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
Expand All @@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}

std::optional<py::function>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
// Fast match against the class map first (common case).
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}

// Not found and loading did not yield a registration. Negative cache.
attributeBuilderMap[attributeKind] = py::none();
return std::nullopt;
}

llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
Expand Down
41 changes: 41 additions & 0 deletions mlir/python/mlir/ir.py
Expand Up @@ -4,3 +4,44 @@

from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug


# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind):
def decorator_builder(func):
AttrBuilder.insert(kind, func)
return func
return decorator_builder


@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)

@register_attribute_builder("IndexAttr")
def _indexAttr(x, context):
return IntegerAttr.get(IndexType.get(context=context), x)

@register_attribute_builder("I32Attr")
def _i32Attr(x, context):
return IntegerAttr.get(
IntegerType.get_signless(32, context=context), x)

@register_attribute_builder("I64Attr")
def _i64Attr(x, context):
return IntegerAttr.get(
IntegerType.get_signless(64, context=context), x)

@register_attribute_builder("SymbolNameAttr")
def _symbolNameAttr(x, context):
return StringAttr.get(x, context=context)

try:
import numpy as np
@register_attribute_builder("IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64), type=IndexType.get(context=context),
context=context)
except ImportError:
pass
19 changes: 11 additions & 8 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Expand Up @@ -115,11 +115,14 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: attributes["i32attr"] = i32attr
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
// CHECK: attributes["i32attr"] = (i32attr if (
// CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
// CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr')
// CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = in_
// CHECK: attributes["in"] = (in_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
Expand Down Expand Up @@ -161,7 +164,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
// CHECK: if is_ is not None: attributes["is"] = (is_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
Expand All @@ -188,8 +191,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: if arr is not None: attributes["arr"] = arr
// CHECK: if unsupported is not None: attributes["unsupported"] = unsupported
// CHECK: if arr is not None: attributes["arr"] = (arr
// CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
Expand All @@ -202,7 +205,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, type, *, loc=None, ip=None):
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: _ods_result_type_source_attr = attributes["type"]
Expand All @@ -217,7 +220,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
// CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, Variadic<AnyType>);
}
Expand Down
13 changes: 11 additions & 2 deletions mlir/test/python/dialects/shape.py
Expand Up @@ -22,9 +22,18 @@ def testConstShape():
@func.FuncOp.from_py_func(
RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
def const_shape_tensor(arg):
shape.ConstWitnessOp(False)
shape.ConstSizeOp(30)
shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
shape.ConstShapeOp([1, 2])
return shape.ConstShapeOp(
DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
DenseElementsAttr.get(
np.array([3, 4], dtype=np.int64), type=IndexType.get()))

# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
# CHECK: shape.const_shape [10, 20] : tensor<2xindex>
# CHECK-DAG: shape.const_witness false
# CHECK-DAG: shape.const_size 30
# CHECK-DAG: shape.const_size 40
# CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
# CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
print(module)
67 changes: 41 additions & 26 deletions mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Expand Up @@ -280,15 +280,16 @@ static llvm::cl::opt<std::string> clDialectExtensionName(

using AttributeClasses = DenseMap<StringRef, StringRef>;

/// Checks whether `str` is a Python keyword.
static bool isPythonKeyword(StringRef str) {
static llvm::StringSet<> keywords(
{"and", "as", "assert", "break", "class", "continue",
"def", "del", "elif", "else", "except", "finally",
"for", "from", "global", "if", "import", "in",
"is", "lambda", "nonlocal", "not", "or", "pass",
"raise", "return", "try", "while", "with", "yield"});
return keywords.contains(str);
/// Checks whether `str` is a Python keyword or would shadow builtin function.
static bool isPythonReserved(StringRef str) {
static llvm::StringSet<> reserved(
{"and", "as", "assert", "break", "callable", "class",
"continue", "def", "del", "elif", "else", "except",
"finally", "for", "from", "global", "if", "import",
"in", "is", "lambda", "nonlocal", "not", "or",
"pass", "raise", "return", "issubclass", "try", "type",
"while", "with", "yield"});
return reserved.contains(str);
}

/// Checks whether `str` would shadow a generated variable or attribute
Expand All @@ -306,7 +307,7 @@ static bool isODSReserved(StringRef str) {
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
if (isPythonKeyword(name) || isODSReserved(name))
if (isPythonReserved(name) || isODSReserved(name))
return (name + "_").str();
return name.str();
}
Expand Down Expand Up @@ -531,16 +532,30 @@ constexpr const char *multiOperandAppendPackTemplate =
"operands.append(_get_op_results_or_values({0}))";
constexpr const char *multiResultAppendTemplate = "results.extend({0})";

/// Template for setting an attribute in the operation builder.
/// {0} is the attribute name;
/// {1} is the builder argument name.
constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";

/// Template for setting an optional attribute in the operation builder.
/// {0} is the attribute name;
/// {1} is the builder argument name.
constexpr const char *initOptionalAttributeTemplate =
R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";

/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";

constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
Expand Down Expand Up @@ -656,6 +671,7 @@ static void
populateBuilderLinesAttr(const Operator &op,
llvm::ArrayRef<std::string> argNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
auto *attribute = arg.dyn_cast<NamedAttribute *>();
Expand All @@ -670,10 +686,10 @@ populateBuilderLinesAttr(const Operator &op,
}

builderLines.push_back(llvm::formatv(
(attribute->attr.isOptional() || attribute->attr.hasDefaultValue())
? initOptionalAttributeTemplate
: initAttributeTemplate,
attribute->name, argNames[i]));
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
}
}

Expand Down Expand Up @@ -753,8 +769,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// corresponding interface:
/// - {0} is the name of the class for which the types are inferred.
constexpr const char *inferTypeInterfaceTemplate =
R"PY(_ods_context = _ods_get_default_loc_context(loc)
results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
operands=operands,
attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
context=_ods_context,
Expand Down

0 comments on commit b57acb9

Please sign in to comment.