From bfa9fefcafeb8abfa8d21b48ce74c7af4a20464f Mon Sep 17 00:00:00 2001 From: makslevental Date: Fri, 26 Sep 2025 14:42:04 -0700 Subject: [PATCH 1/2] [MLIR][Python] rename checked gettors and add unchecked gettors --- mlir/lib/Bindings/Python/DialectLLVM.cpp | 46 ++++++--- mlir/lib/Bindings/Python/IRAttributes.cpp | 12 +++ mlir/lib/Bindings/Python/IRTypes.cpp | 116 +++++++++++++++++++++- 3 files changed, 154 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 55b9331270cdc..38de4a0e329a0 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) { auto llvmStructType = mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); - llvmStructType.def_classmethod( - "get_literal", - [](const nb::object &cls, const std::vector &elements, - bool packed, MlirLocation loc) { - CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); - - MlirType type = mlirLLVMStructTypeLiteralGetChecked( - loc, elements.size(), elements.data(), packed); - if (mlirTypeIsNull(type)) { - throw nb::value_error(scope.takeMessage().c_str()); - } - return cls(type); - }, - "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, - "loc"_a = nb::none()); + llvmStructType + .def_classmethod( + "get_literal", + [](const nb::object &cls, const std::vector &elements, + bool packed, MlirLocation loc) { + CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); + + MlirType type = mlirLLVMStructTypeLiteralGetChecked( + loc, elements.size(), elements.data(), packed); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return cls(type); + }, + "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "loc"_a = nb::none()) + .def_classmethod( + "get_literal_unchecked", + [](const nb::object &cls, const std::vector &elements, + bool packed, MlirContext context) { + CollectDiagnosticsToStringScope scope(context); + + MlirType type = mlirLLVMStructTypeLiteralGet( + context, elements.size(), elements.data(), packed); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return cls(type); + }, + "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "context"_a = nb::none()); llvmStructType.def_classmethod( "get_identified", diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index c77653f97e6dd..045c0fbf4630f 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute { }, nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(), "Gets an uniqued float point attribute associated to a type"); + c.def_static( + "get_unchecked", + [](PyType &type, double value, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute attr = + mlirFloatAttrDoubleGet(context.get()->get(), type, value); + if (mlirAttributeIsNull(attr)) + throw MLIRError("Invalid attribute", errors.take()); + return PyFloatAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(), + "Gets an uniqued float point attribute associated to a type"); c.def_static( "get_f32", [](double value, DefaultingPyMlirContext context) { diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 07dc00521833f..3488d92250b45 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::get, nb::arg("shape"), + c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"), nb::arg("element_type"), nb::kw_only(), nb::arg("scalable") = nb::none(), nb::arg("scalable_dims") = nb::none(), nb::arg("loc") = nb::none(), "Create a vector type") + .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"), + nb::arg("element_type"), nb::kw_only(), + nb::arg("scalable") = nb::none(), + nb::arg("scalable_dims") = nb::none(), + nb::arg("context") = nb::none(), "Create a vector type") .def_prop_ro( "scalable", [](MlirType self) { return mlirVectorTypeIsScalable(self); }) @@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType { } private: - static PyVectorType get(std::vector shape, PyType &elementType, - std::optional scalable, - std::optional> scalableDims, - DefaultingPyLocation loc) { + static PyVectorType + getChecked(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyLocation loc) { if (scalable && scalableDims) { throw nb::value_error("'scalable' and 'scalable_dims' kwargs " "are mutually exclusive."); @@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType { throw MLIRError("Invalid type", errors.take()); return PyVectorType(elementType.getContext(), type); } + + static PyVectorType get(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyMlirContext context) { + if (scalable && scalableDims) { + throw nb::value_error("'scalable' and 'scalable_dims' kwargs " + "are mutually exclusive."); + } + + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType type; + if (scalable) { + if (scalable->size() != shape.size()) + throw nb::value_error("Expected len(scalable) == len(shape)."); + + SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( + *scalable, [](const nb::handle &h) { return nb::cast(h); })); + type = mlirVectorTypeGetScalable(shape.size(), shape.data(), + scalableDimFlags.data(), elementType); + } else if (scalableDims) { + SmallVector scalableDimFlags(shape.size(), false); + for (int64_t dim : *scalableDims) { + if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) + throw nb::value_error("Scalable dimension index out of bounds."); + scalableDimFlags[dim] = true; + } + type = mlirVectorTypeGetScalable(shape.size(), shape.data(), + scalableDimFlags.data(), elementType); + } else { + type = mlirVectorTypeGet(shape.size(), shape.data(), elementType); + } + if (mlirTypeIsNull(type)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), type); + } }; /// Ranked Tensor Type subclass - RankedTensorType. @@ -724,6 +766,22 @@ class PyRankedTensorType nb::arg("shape"), nb::arg("element_type"), nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(), "Create a ranked tensor type"); + c.def_static( + "get_unchecked", + [](std::vector shape, PyType &elementType, + std::optional &encodingAttr, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType t = mlirRankedTensorTypeGet( + shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(), + "Create a ranked tensor type"); c.def_prop_ro( "encoding", [](PyRankedTensorType &self) @@ -758,6 +816,17 @@ class PyUnrankedTensorType }, nb::arg("element_type"), nb::arg("loc") = nb::none(), "Create a unranked tensor type"); + c.def_static( + "get_unchecked", + [](PyType &elementType, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType t = mlirUnrankedTensorTypeGet(elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("context") = nb::none(), + "Create a unranked tensor type"); } }; @@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType { nb::arg("shape"), nb::arg("element_type"), nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(), nb::arg("loc") = nb::none(), "Create a memref type") + .def_static( + "get_unchecked", + [](std::vector shape, PyType &elementType, + PyAttribute *layout, PyAttribute *memorySpace, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute layoutAttr = + layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGet(elementType, shape.size(), shape.data(), + layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("layout") = nb::none(), + nb::arg("memory_space") = nb::none(), + nb::arg("context") = nb::none(), "Create a memref type") .def_prop_ro( "layout", [](PyMemRefType &self) -> nb::typed { @@ -858,6 +948,22 @@ class PyUnrankedMemRefType }, nb::arg("element_type"), nb::arg("memory_space").none(), nb::arg("loc") = nb::none(), "Create a unranked memref type") + .def_static( + "get_unchecked", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("memory_space").none(), + nb::arg("context") = nb::none(), "Create a unranked memref type") .def_prop_ro( "memory_space", [](PyUnrankedMemRefType &self) From bcc8bcad3b7a6dc65b40892fa52d33a2ff813bb8 Mon Sep 17 00:00:00 2001 From: makslevental Date: Fri, 26 Sep 2025 16:15:13 -0700 Subject: [PATCH 2/2] run ci --- mlir/test/python/ir/builtin_types.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index b42bfd9bc6587..54863253fc770 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -371,11 +371,16 @@ def testAbstractShapedType(): # CHECK-LABEL: TEST: testVectorType @run def testVectorType(): + shape = [2, 3] + with Context(): + f32 = F32Type.get() + # CHECK: unchecked vector type: vector<2x3xf32> + print("unchecked vector type:", VectorType.get_unchecked(shape, f32)) + with Context(), Location.unknown(): f32 = F32Type.get() - shape = [2, 3] - # CHECK: vector type: vector<2x3xf32> - print("vector type:", VectorType.get(shape, f32)) + # CHECK: checked vector type: vector<2x3xf32> + print("checked vector type:", VectorType.get(shape, f32)) none = NoneType.get() try: