-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Add Python bindings for DenseResourceElementsAttr. #66319
[mlir] Add Python bindings for DenseResourceElementsAttr. #66319
Conversation
Only construction and type casting are implemented. The method to create is explicitly named "unsafe" and the documentation calls out what the caller is responsible for. There really isn't a better way to do this and retain the power-user feature this represents.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir ChangesOnly construction and type casting are implemented. The method to create is explicitly named "unsafe" and the documentation calls out what the caller is responsible for. There really isn't a better way to do this and retain the power-user feature this represents. -- Full diff: https://github.com//pull/66319.diff4 Files Affected:
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 93c4ed5692ef26d..47ba9b68f2700dd 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -558,6 +558,9 @@ mlirDenseElementsAttrGetRawData(MlirAttribute attr); // Resource blob attributes. //===----------------------------------------------------------------------===// +MLIR_CAPI_EXPORTED bool +mlirAttributeIsADenseResourceElements(MlirAttribute attr); + MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int *elements); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 105d2cecf20a193..31092bda4893cbc 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -72,6 +72,31 @@ or 255), then a splat will be created. type or if the buffer does not meet expectations. )"; +static const char kDenseResourceElementsAttrGetUnsafeDocstring[] = + R"(Gets a DenseResourceElementsAttr from a Python buffer or array. + +This function is extremely unsafe and must be used when strict invariants +are met: + +* The contents of the buffer matches the contiguous data layout implied + by the given ShapedType. +* The caller arranges to keep the backing memory alive and valid for the + duration of any use of the attribute. + +Args: + buffer: The array or buffer to convert. + name: Name to provide to the resource (may be changed upon collision). + type: The explicit ShapedType to construct the attribute with. + context: Explicit context, if not from context manager. + +Returns: + DenseResourceElementsAttr on success. + +Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. +)"; + namespace { static MlirStringRef toMlirStringRef(const std::string &s) { @@ -997,6 +1022,51 @@ class PyDenseIntElementsAttribute } }; +class PyDenseResourceElementsAttribute + : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { +public: + static constexpr IsAFunctionTy isaFunction = + mlirAttributeIsADenseResourceElements; + static constexpr const char *pyClassName = "DenseResourceElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseResourceElementsAttribute + getFromBuffer(py::buffer buffer, std::string name, PyType type, + DefaultingPyMlirContext contextWrapper) { + // Request a contiguous view. In exotic cases, this will cause a copy. + int flags = PyBUF_ND; + Py_buffer view; + if (PyObject_GetBuffer(buffer.ptr(), &view, flags) != 0) { + throw py::error_already_set(); + } + auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); + + if (!mlirTypeIsAShaped(type)) { + throw std::invalid_argument( + "Constructing a DenseResourceElementsAttr requires a ShapedType"); + } + size_t rawBufferSize = view.len; + MlirAttribute attr = mlirUnmanagedDenseBlobResourceElementsAttrGet( + type, toMlirStringRef(name), view.buf, rawBufferSize); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseResourceElementsAttr could not be constructed from the given " + "buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get_unsafe_from_buffer", + PyDenseResourceElementsAttribute::getFromBuffer, + py::arg("array"), py::arg("name"), py::arg("type"), + py::arg("context") = py::none(), + kDenseResourceElementsAttrGetUnsafeDocstring); + } +}; + class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; @@ -1273,6 +1343,7 @@ void mlir::python::populateIRAttributes(py::module &m) { PyGlobals::get().registerTypeCaster( mlirDenseIntOrFPElementsAttrGetTypeID(), pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + PyDenseResourceElementsAttribute::bind(m); PyDictAttribute::bind(m); PySymbolRefAttribute::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 84a958d01d2eb14..c3a5057acafef06 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -770,6 +770,10 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { // Resource blob attributes. //===----------------------------------------------------------------------===// +bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) { + return llvm::isa<DenseResourceElementsAttr>(unwrap(attr)); +} + template <typename U, typename T> static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, intptr_t numElements, const T *elements) { diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py index 452d860861d783a..61a1935a0bae3c4 100644 --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -417,3 +417,24 @@ def testGetDenseElementsIndex(): print(arr) # CHECK: True print(arr.dtype == np.int64) + + +# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr +@run +def testGetDenseResourceElementsAttr(): + with Context(), Location.unknown(): + element_type = IntegerType.get_signless(32) + tensor_type = RankedTensorType.get((2, 3), element_type) + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + resource = DenseResourceElementsAttr.get_unsafe_from_buffer( + array, "from_py", tensor_type + ) + module = Module.parse("module {}") + module.operation.attributes["test.resource"] = resource + # CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32> + # CHECK: from_py: "0x01000000010000000200000003000000040000000500000006000000" + print(module) + + # Verifies type casting. + # CHECK: dense_resource<from_py> : tensor<2x3xi32> + print(DenseResourceElementsAttr(module.operation.attributes["test.resource"])) |
PTAL - one question back to you. |
PTAL - Updated with explicit alignment and mutability. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I had same question about unmanaged in name given allows for specifying deletion here (and first time I've seen weakref used in python :-))
Only construction and type casting are implemented. The method to create is explicitly named "unsafe" and the documentation calls out what the caller is responsible for. There really isn't a better way to do this and retain the power-user feature this represents.
I'm seeing a failure on s390x that I think is caused by this change:
|
We can refactor the test to use native byte order. But this landed six months ago. Does anyone even use this stuff on BE? None of the devs have any practical way to test. |
Not sure, but we build and package mlir for s390x on Fedora. There is an s390x buildbot for mlir, but I'm not sure why it didn't catch this failure. |
There is a bot for s390x: I'm surprised we didn't catch this sooner. |
I'm traveling this weekend but will see what I can do when I'm at a keyboard. |
@stellaraccident It's not urgent we can work around the test failure easily. |
Only construction and type casting are implemented. The method to create is explicitly named "unsafe" and the documentation calls out what the caller is responsible for. There really isn't a better way to do this and retain the power-user feature this represents.