Skip to content

Commit

Permalink
[mlir] Make attributes mutable in Python bindings
Browse files Browse the repository at this point in the history
Attributes represent additional data about an operation and are intended to be
modifiable during the lifetime of the operation. In the dialect-specific Python
bindings, attributes are exposed as properties on the operation class. Allow
for assigning values to these properties. Also support creating new and
deleting existing attributes through the generic "attributes" property of an
operation. Any validity checking must be performed by the op verifier after the
mutation, similarly to C++. Operations are not invalidated in the process: no
dangling pointers can be created as all attributes are owned by the context and
will remain live even if they are not used in any operation.

Introduce a Python Test dialect by analogy with the Test dialect and to avoid
polluting the latter with Python-specific constructs. Use this dialect to
implement a test for the attribute access and mutation API.

Reviewed By: stellaraccident, mehdi_amini

Differential Revision: https://reviews.llvm.org/D91652
  • Loading branch information
ftynse committed Nov 24, 2020
1 parent 0b2d84f commit 029e199
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 10 deletions.
16 changes: 15 additions & 1 deletion mlir/lib/Bindings/Python/IRModules.cpp
Expand Up @@ -1306,6 +1306,18 @@ class PyOpAttributeMap {
return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
}

void dunderSetItem(const std::string &name, PyAttribute attr) {
mlirOperationSetAttributeByName(operation->get(), name.c_str(), attr.attr);
}

void dunderDelItem(const std::string &name) {
int removed =
mlirOperationRemoveAttributeByName(operation->get(), name.c_str());
if (!removed)
throw SetPyError(PyExc_KeyError,
"attempt to delete a non-existent attribute");
}

intptr_t dunderLen() {
return mlirOperationGetNumAttributes(operation->get());
}
Expand All @@ -1320,7 +1332,9 @@ class PyOpAttributeMap {
.def("__contains__", &PyOpAttributeMap::dunderContains)
.def("__len__", &PyOpAttributeMap::dunderLen)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed);
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
.def("__setitem__", &PyOpAttributeMap::dunderSetItem)
.def("__delitem__", &PyOpAttributeMap::dunderDelItem);
}

private:
Expand Down
3 changes: 3 additions & 0 deletions mlir/test/Bindings/CMakeLists.txt
@@ -0,0 +1,3 @@
if(MLIR_BINDINGS_PYTHON_ENABLED)
add_subdirectory(Python)
endif()
4 changes: 4 additions & 0 deletions mlir/test/Bindings/Python/CMakeLists.txt
@@ -0,0 +1,4 @@
include(AddMLIRPythonExtension)
add_mlir_dialect_python_bindings(MLIRBindingsPythonTestOps
python_test_ops.td
python_test)
128 changes: 128 additions & 0 deletions mlir/test/Bindings/Python/dialects/python_test.py
@@ -0,0 +1,128 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
import mlir.dialects.python_test as test

def run(f):
print("\nTEST:", f.__name__)
f()

# CHECK-LABEL: TEST: testAttributes
def testAttributes():
with Context(), Location.unknown():
#
# Check op construction with attributes.
#

i32 = IntegerType.get_signless(32)
one = IntegerAttr.get(i32, 1)
two = IntegerAttr.get(i32, 2)
unit = UnitAttr.get()

# CHECK: "python_test.attributed_op"() {
# CHECK-DAG: mandatory_i32 = 1 : i32
# CHECK-DAG: optional_i32 = 2 : i32
# CHECK-DAG: unit
# CHECK: }
op = test.AttributedOp(one, two, unit)
print(f"{op}")

# CHECK: "python_test.attributed_op"() {
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
op2 = test.AttributedOp(two, None, None)
print(f"{op2}")

#
# Check generic "attributes" access and mutation.
#

assert "additional" not in op.attributes

# CHECK: "python_test.attributed_op"() {
# CHECK-DAG: additional = 1 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = one
print(f"{op2}")

# CHECK: "python_test.attributed_op"() {
# CHECK-DAG: additional = 2 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = two
print(f"{op2}")

# CHECK: "python_test.attributed_op"() {
# CHECK-NOT: additional = 2 : i32
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
del op2.attributes["additional"]
print(f"{op2}")

try:
print(op.attributes["additional"])
except KeyError:
pass
else:
assert False, "expected KeyError on unknown attribute key"

#
# Check accessors to defined attributes.
#

# CHECK: Mandatory: 1
# CHECK: Optional: 2
# CHECK: Unit: True
print(f"Mandatory: {op.mandatory_i32.value}")
print(f"Optional: {op.optional_i32.value}")
print(f"Unit: {op.unit}")

# CHECK: Mandatory: 2
# CHECK: Optional: None
# CHECK: Unit: False
print(f"Mandatory: {op2.mandatory_i32.value}")
print(f"Optional: {op2.optional_i32}")
print(f"Unit: {op2.unit}")

# CHECK: Mandatory: 2
# CHECK: Optional: None
# CHECK: Unit: False
op.mandatory_i32 = two
op.optional_i32 = None
op.unit = False
print(f"Mandatory: {op.mandatory_i32.value}")
print(f"Optional: {op.optional_i32}")
print(f"Unit: {op.unit}")
assert "optional_i32" not in op.attributes
assert "unit" not in op.attributes

try:
op.mandatory_i32 = None
except ValueError:
pass
else:
assert False, "expected ValueError on setting a mandatory attribute to None"

# CHECK: Optional: 2
op.optional_i32 = two
print(f"Optional: {op.optional_i32.value}")

# CHECK: Optional: None
del op.optional_i32
print(f"Optional: {op.optional_i32}")

# CHECK: Unit: False
op.unit = None
print(f"Unit: {op.unit}")
assert "unit" not in op.attributes

# CHECK: Unit: True
op.unit = True
print(f"Unit: {op.unit}")

# CHECK: Unit: False
del op.unit
print(f"Unit: {op.unit}")

run(testAttributes)
1 change: 1 addition & 0 deletions mlir/test/Bindings/Python/lit.local.cfg
@@ -1,2 +1,3 @@
if not config.enable_bindings_python:
config.unsupported = True
config.excludes.add('python_test_ops.td')
28 changes: 28 additions & 0 deletions mlir/test/Bindings/Python/python_test_ops.td
@@ -0,0 +1,28 @@
//===-- python_test_ops.td - Python test Op definitions ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_TEST_OPS
#define PYTHON_TEST_OPS

include "mlir/IR/OpBase.td"
include "../../../lib/Bindings/Python/Attributes.td"

def Python_Test_Dialect : Dialect {
let name = "python_test";
let cppNamespace = "PythonTest";
}
class TestOp<string mnemonic, list<OpTrait> traits = []>
: Op<Python_Test_Dialect, mnemonic, traits>;

def AttributedOp : TestOp<"attributed_op"> {
let arguments = (ins I32Attr:$mandatory_i32,
OptionalAttr<I32Attr>:$optional_i32,
UnitAttr:$unit);
}

#endif // PYTHON_TEST_OPS
2 changes: 2 additions & 0 deletions mlir/test/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Bindings)
add_subdirectory(CAPI)
add_subdirectory(EDSC)
add_subdirectory(mlir-cpu-runner)
Expand Down Expand Up @@ -100,6 +101,7 @@ endif()
if(MLIR_BINDINGS_PYTHON_ENABLED)
list(APPEND MLIR_TEST_DEPENDS
MLIRBindingsPythonExtension
MLIRBindingsPythonTestOps
MLIRTransformsBindingsPythonExtension
)
endif()
Expand Down
81 changes: 72 additions & 9 deletions mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Expand Up @@ -167,7 +167,7 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
return {1}(self.operation.attributes["{2}"])
)Py";

/// Template for a accessing a unit operation attribute, returns True of the
/// Template for a getter of a unit operation attribute, returns True of the
/// unit attribute is present, False otherwise (unit attributes have meaning
/// by mere presence):
/// {0} is the name of the attribute sanitized for Python,
Expand All @@ -178,6 +178,53 @@ constexpr const char *unitAttributeGetterTemplate = R"Py(
return "{1}" in self.operation.attributes
)Py";

/// Template for an operation attribute setter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["{1}"] = value
)Py";

/// Template for a setter of an optional operation attribute, setting to None
/// removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if value is not None:
self.operation.attributes["{1}"] = value
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
)Py";

/// Template for a setter of a unit operation attribute, setting to None or
/// False removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if bool(value):
self.operation.attributes["{1}"] = _ir.UnitAttr.get()
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
)Py";

/// Template for a deleter of an optional or a unit operation attribute, removes
/// the attribute from the operation:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeDeleterTemplate = R"Py(
@{0}.deleter
def {0}(self):
del self.operation.attributes["{1}"]
)Py";

static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");

Expand Down Expand Up @@ -351,23 +398,39 @@ static void emitAttributeAccessors(const Operator &op,
if (namedAttr.name.empty())
continue;

std::string sanitizedName = sanitizeName(namedAttr.name);

// Unit attributes are handled specially.
if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
os << llvm::formatv(unitAttributeGetterTemplate,
sanitizeName(namedAttr.name), namedAttr.name);
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name);
continue;
}

// Other kinds of attributes need a mapping to a Python type.
if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
continue;

os << llvm::formatv(
namedAttr.attr.isOptional() ? optionalAttributeGetterTemplate
: attributeGetterTemplate,
sanitizeName(namedAttr.name),
attributeClasses.lookup(namedAttr.attr.getStorageType()),
namedAttr.name);
StringRef pythonType =
attributeClasses.lookup(namedAttr.attr.getStorageType());
if (namedAttr.attr.isOptional()) {
os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
pythonType, namedAttr.name);
os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name);
} else {
os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
namedAttr.name);
os << llvm::formatv(attributeSetterTemplate, sanitizedName,
namedAttr.name);
// Non-optional attributes cannot be deleted.
}
}
}

Expand Down

0 comments on commit 029e199

Please sign in to comment.