-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] add unchecked gettors #160954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesSome of the current gettors required passing locations (i.e., there be an active location) because they're using the "checked" API. This PR renames those gettors (explicitly advertising the checked aspect) and adds "unchecked" gettors which only require an active context. Full diff: https://github.com/llvm/llvm-project/pull/160954.diff 3 Files Affected:
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 55b9331270cdc..b044965f6ac1a 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
- llvmStructType.def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none());
+ llvmStructType
+ .def_classmethod(
+ "get_literal_checked",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirLocation loc) {
+ CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "loc"_a = nb::none())
+ .def_classmethod(
+ "get_literal",
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirContext context) {
+ CollectDiagnosticsToStringScope scope(context);
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return cls(type);
+ },
+ "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
llvmStructType.def_classmethod(
"get_identified",
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c77653f97e6dd..24e92ffffe8ae 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -565,7 +565,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](PyType &type, double value, DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
},
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get",
+ [](PyType &type, double value, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_f32",
[](double value, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 07dc00521833f..0238a24708962 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -643,7 +643,12 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
nb::arg("element_type"), nb::kw_only(),
nb::arg("scalable") = nb::none(),
nb::arg("scalable_dims") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a vector type")
+ nb::arg("context") = nb::none(), "Create a vector type")
+ .def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("loc") = nb::none(), "Create a vector type")
.def_prop_ro(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
}
private:
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
if (scalable && scalableDims) {
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), type);
}
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
};
/// Ranked Tensor Type subclass - RankedTensorType.
@@ -710,7 +752,7 @@ class PyRankedTensorType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](std::vector<int64_t> shape, PyType &elementType,
std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -724,6 +766,22 @@ class PyRankedTensorType
nb::arg("shape"), nb::arg("element_type"),
nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
"Create a ranked tensor type");
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+ "Create a ranked tensor type");
c.def_prop_ro(
"encoding",
[](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
},
nb::arg("element_type"), nb::arg("loc") = nb::none(),
"Create a unranked tensor type");
+ c.def_static(
+ "get_checked",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("context") = nb::none(),
+ "Create a unranked tensor type");
}
};
@@ -772,7 +841,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](std::vector<int64_t> shape, PyType &elementType,
PyAttribute *layout, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
nb::arg("shape"), nb::arg("element_type"),
nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
nb::arg("loc") = nb::none(), "Create a memref type")
+ .def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(),
+ nb::arg("memory_space") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a memref type")
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -842,7 +932,7 @@ class PyUnrankedMemRefType
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
+ "get_checked",
[](PyType &elementType, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -858,6 +948,32 @@ class PyUnrankedMemRefType
},
nb::arg("element_type"), nb::arg("memory_space").none(),
nb::arg("loc") = nb::none(), "Create a unranked memref type")
+ .def_prop_ro(
+ "memory_space",
+ [](PyUnrankedMemRefType &self)
+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given Unranked MemRef type.")
+ .def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("context") = nb::none(), "Create a unranked memref type")
.def_prop_ro(
"memory_space",
[](PyUnrankedMemRefType &self)
|
2a2b067
to
3fb3522
Compare
3fb3522
to
cd137e0
Compare
cd137e0
to
bec4ede
Compare
✅ With the latest revision this PR passed the Python code formatter. |
bec4ede
to
bfa9fef
Compare
0490bb6
to
bcc8bca
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so the loc is used to emit errors when the check fails? The change looks good : )
yea that's correct |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/15384 Here is the relevant piece of the build log for the reference
|
Some of the current gettors require passing locations (i.e., there be an active location) because they're using the "checked" APIs. This PR adds "unchecked" gettors which only require an active context.
Some of the current gettors require passing locations (i.e., there be an active location) because they're using the "checked" APIs. This PR adds "unchecked" gettors which only require an active context.