diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h index 53c55b7086ce8..4ff5061b945aa 100644 --- a/mlir/include/mlir/Bindings/Python/IRCore.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -931,6 +931,7 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy { using ClassTy = nanobind::class_; using IsAFunctionTy = bool (*)(MlirType); using GetTypeIDFunctionTy = MlirTypeID (*)(); + using Base = PyConcreteType; static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteType() = default; @@ -1063,6 +1064,7 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy { using IsAFunctionTy = bool (*)(MlirAttribute); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; + using Base = PyConcreteAttribute; PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) @@ -1521,6 +1523,7 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue { using IsAFunctionTy = bool (*)(MlirValue); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; + using Base = PyConcreteValue; PyConcreteValue() = default; PyConcreteValue(PyOperationRef operationRef, MlirValue value) diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp index 26ffc0e427e41..0a0cb1d158abc 100644 --- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp @@ -8,58 +8,97 @@ #include "mlir-c/Dialect/AMDGPU.h" #include "mlir-c/IR.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "nanobind/nanobind.h" namespace nb = nanobind; using namespace llvm; -using namespace mlir; -using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -static void populateDialectAMDGPUSubmodule(const nb::module_ &m) { - auto amdgpuTDMBaseType = - mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType, - mlirAMDGPUTDMBaseTypeGetTypeID); +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace amdgpu { +struct TDMBaseType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAMDGPUTDMBaseTypeGetTypeID; + static constexpr const char *pyClassName = "TDMBaseType"; + using Base::Base; - amdgpuTDMBaseType.def_classmethod( - "get", - [](const nb::object &cls, MlirType elementType, MlirContext ctx) { - return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType)); - }, - "Gets an instance of TDMBaseType in the same context", nb::arg("cls"), - nb::arg("element_type"), nb::arg("ctx") = nb::none()); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyType &elementType, DefaultingPyMlirContext context) { + return TDMBaseType( + context->getRef(), + mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType)); + }, + "Gets an instance of TDMBaseType in the same context", + nb::arg("element_type"), nb::arg("context").none() = nb::none()); + } +}; - auto amdgpuTDMDescriptorType = mlir_type_subclass( - m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType, - mlirAMDGPUTDMDescriptorTypeGetTypeID); +struct TDMDescriptorType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsAAMDGPUTDMDescriptorType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAMDGPUTDMDescriptorTypeGetTypeID; + static constexpr const char *pyClassName = "TDMDescriptorType"; + using Base::Base; - amdgpuTDMDescriptorType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx)); - }, - "Gets an instance of TDMDescriptorType in the same context", - nb::arg("cls"), nb::arg("ctx") = nb::none()); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return TDMDescriptorType( + context->getRef(), + mlirAMDGPUTDMDescriptorTypeGet(context.get()->get())); + }, + "Gets an instance of TDMDescriptorType in the same context", + nb::arg("context").none() = nb::none()); + } +}; - auto amdgpuTDMGatherBaseType = mlir_type_subclass( - m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType, - mlirAMDGPUTDMGatherBaseTypeGetTypeID); +struct TDMGatherBaseType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsAAMDGPUTDMGatherBaseType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAMDGPUTDMGatherBaseTypeGetTypeID; + static constexpr const char *pyClassName = "TDMGatherBaseType"; + using Base::Base; - amdgpuTDMGatherBaseType.def_classmethod( - "get", - [](const nb::object &cls, MlirType elementType, MlirType indexType, - MlirContext ctx) { - return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType)); - }, - "Gets an instance of TDMGatherBaseType in the same context", - nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"), - nb::arg("ctx") = nb::none()); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyType &elementType, const PyType &indexType, + DefaultingPyMlirContext context) { + return TDMGatherBaseType( + context->getRef(), + mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType, + indexType)); + }, + "Gets an instance of TDMGatherBaseType in the same context", + nb::arg("element_type"), nb::arg("index_type"), + nb::arg("context").none() = nb::none()); + } }; +static void populateDialectAMDGPUSubmodule(nb::module_ &m) { + TDMBaseType::bind(m); + TDMDescriptorType::bind(m); + TDMGatherBaseType::bind(m); +} +} // namespace amdgpu +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir + NB_MODULE(_mlirDialectsAMDGPU, m) { m.doc() = "MLIR AMDGPU dialect."; - populateDialectAMDGPUSubmodule(m); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::amdgpu:: + populateDialectAMDGPUSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp index 2568d535edb5a..21475e9c8189d 100644 --- a/mlir/lib/Bindings/Python/DialectGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -9,83 +9,110 @@ #include "mlir-c/Dialect/GPU.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/IRAttributes.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace nanobind::literals; - -using namespace mlir; -using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace gpu { // ----------------------------------------------------------------------------- -// Module initialization. +// AsyncTokenType // ----------------------------------------------------------------------------- -NB_MODULE(_mlirDialectsGPU, m) { - m.doc() = "MLIR GPU Dialect"; - //===-------------------------------------------------------------------===// - // AsyncTokenType - //===-------------------------------------------------------------------===// +struct AsyncTokenType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType; + static constexpr const char *pyClassName = "AsyncTokenType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return AsyncTokenType(context->getRef(), + mlirGPUAsyncTokenTypeGet(context.get()->get())); + }, + "Gets an instance of AsyncTokenType in the same context", + nb::arg("context").none() = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// ObjectAttr +//===-------------------------------------------------------------------===// + +struct ObjectAttr : PyConcreteAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr; + static constexpr const char *pyClassName = "ObjectAttr"; + using Base::Base; - auto mlirGPUAsyncTokenType = - mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyAttribute &target, uint32_t format, const nb::bytes &object, + std::optional mlirObjectProps, + std::optional mlirKernelsAttr, + DefaultingPyMlirContext context) { + MlirStringRef objectStrRef = mlirStringRefCreate( + static_cast(const_cast(object.data())), + object.size()); + return ObjectAttr( + context->getRef(), + mlirGPUObjectAttrGetWithKernels( + mlirAttributeGetContext(target), target, format, objectStrRef, + mlirObjectProps.has_value() ? *mlirObjectProps + : MlirAttribute{nullptr}, + mlirKernelsAttr.has_value() ? *mlirKernelsAttr + : MlirAttribute{nullptr})); + }, + "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(), + "kernels"_a = nb::none(), "context"_a = nb::none(), + "Gets a gpu.object from parameters."); - mlirGPUAsyncTokenType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirGPUAsyncTokenTypeGet(ctx)); - }, - "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"), - nb::arg("ctx") = nb::none()); + c.def_prop_ro("target", [](ObjectAttr &self) { + return PyAttribute(self.getContext(), mlirGPUObjectAttrGetTarget(self)); + }); + c.def_prop_ro("format", [](const ObjectAttr &self) { + return mlirGPUObjectAttrGetFormat(self); + }); + c.def_prop_ro("object", [](const ObjectAttr &self) { + MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); + return nb::bytes(stringRef.data, stringRef.length); + }); + c.def_prop_ro( + "properties", [](ObjectAttr &self) -> std::optional { + if (mlirGPUObjectAttrHasProperties(self)) + return PyDictAttribute(self.getContext(), + mlirGPUObjectAttrGetProperties(self)); + return std::nullopt; + }); + c.def_prop_ro("kernels", + [](ObjectAttr &self) -> std::optional { + if (mlirGPUObjectAttrHasKernels(self)) + return PyAttribute(self.getContext(), + mlirGPUObjectAttrGetKernels(self)); + return std::nullopt; + }); + } +}; +} // namespace gpu +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir - //===-------------------------------------------------------------------===// - // ObjectAttr - //===-------------------------------------------------------------------===// +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +NB_MODULE(_mlirDialectsGPU, m) { + m.doc() = "MLIR GPU Dialect"; - mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) - .def_classmethod( - "get", - [](const nb::object &cls, MlirAttribute target, uint32_t format, - const nb::bytes &object, - std::optional mlirObjectProps, - std::optional mlirKernelsAttr) { - MlirStringRef objectStrRef = mlirStringRefCreate( - static_cast(const_cast(object.data())), - object.size()); - return cls(mlirGPUObjectAttrGetWithKernels( - mlirAttributeGetContext(target), target, format, objectStrRef, - mlirObjectProps.has_value() ? *mlirObjectProps - : MlirAttribute{nullptr}, - mlirKernelsAttr.has_value() ? *mlirKernelsAttr - : MlirAttribute{nullptr})); - }, - "cls"_a, "target"_a, "format"_a, "object"_a, - "properties"_a = nb::none(), "kernels"_a = nb::none(), - "Gets a gpu.object from parameters.") - .def_property_readonly( - "target", - [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); }) - .def_property_readonly( - "format", - [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); }) - .def_property_readonly( - "object", - [](MlirAttribute self) { - MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); - return nb::bytes(stringRef.data, stringRef.length); - }) - .def_property_readonly("properties", - [](MlirAttribute self) -> nb::object { - if (mlirGPUObjectAttrHasProperties(self)) - return nb::cast( - mlirGPUObjectAttrGetProperties(self)); - return nb::none(); - }) - .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object { - if (mlirGPUObjectAttrHasKernels(self)) - return nb::cast(mlirGPUObjectAttrGetKernels(self)); - return nb::none(); - }); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::gpu::AsyncTokenType::bind(m); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::gpu::ObjectAttr::bind(m); } diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 05681cecf82b3..6056bb3265911 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -13,163 +13,204 @@ #include "mlir-c/Support.h" #include "mlir-c/Target/LLVMIR.h" #include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace nanobind::literals; - using namespace llvm; using namespace mlir; -using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -static void populateDialectLLVMSubmodule(nanobind::module_ &m) { - - //===--------------------------------------------------------------------===// - // StructType - //===--------------------------------------------------------------------===// - - auto llvmStructType = mlir_type_subclass( - m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID); - - llvmStructType - .def_classmethod( - "get_literal", - [](const nb::object &cls, const std::vector &elements, - bool packed, MlirLocation loc) { - CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); - - MlirType type = mlirLLVMStructTypeLiteralGetChecked( - loc, elements.size(), elements.data(), packed); - if (mlirTypeIsNull(type)) { - throw nb::value_error(scope.takeMessage().c_str()); - } - return cls(type); - }, - "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, - "loc"_a = nb::none()) - .def_classmethod( - "get_literal_unchecked", - [](const nb::object &cls, const std::vector &elements, - bool packed, MlirContext context) { - CollectDiagnosticsToStringScope scope(context); - - MlirType type = mlirLLVMStructTypeLiteralGet( - context, elements.size(), elements.data(), packed); - if (mlirTypeIsNull(type)) { - throw nb::value_error(scope.takeMessage().c_str()); - } - return cls(type); - }, - "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, - "context"_a = nb::none()); - - llvmStructType.def_classmethod( - "get_identified", - [](const nb::object &cls, const std::string &name, MlirContext context) { - return cls(mlirLLVMStructTypeIdentifiedGet( - context, mlirStringRefCreate(name.data(), name.size()))); - }, - "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none()); +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace llvm { +//===--------------------------------------------------------------------===// +// StructType +//===--------------------------------------------------------------------===// + +struct StructType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLLVMStructTypeGetTypeID; + static constexpr const char *pyClassName = "StructType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_literal", + [](const std::vector &elements, bool packed, MlirLocation loc, + DefaultingPyMlirContext context) { + python::CollectDiagnosticsToStringScope scope( + mlirLocationGetContext(loc)); + std::vector elements_(elements.size()); + std::copy(elements.begin(), elements.end(), elements_.begin()); + + MlirType type = mlirLLVMStructTypeLiteralGetChecked( + loc, elements.size(), elements_.data(), packed); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return StructType(context->getRef(), type); + }, + "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(), + "context"_a = nb::none()); + + c.def_static( + "get_literal_unchecked", + [](const std::vector &elements, bool packed, + DefaultingPyMlirContext context) { + python::CollectDiagnosticsToStringScope scope(context.get()->get()); + + std::vector elements_(elements.size()); + std::copy(elements.begin(), elements.end(), elements_.begin()); + + MlirType type = mlirLLVMStructTypeLiteralGet( + context.get()->get(), elements.size(), elements_.data(), packed); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return StructType(context->getRef(), type); + }, + "elements"_a, nb::kw_only(), "packed"_a = false, + "context"_a = nb::none()); + + c.def_static( + "get_identified", + [](const std::string &name, DefaultingPyMlirContext context) { + return StructType(context->getRef(), + mlirLLVMStructTypeIdentifiedGet( + context.get()->get(), + mlirStringRefCreate(name.data(), name.size()))); + }, + "name"_a, nb::kw_only(), "context"_a = nb::none()); + + c.def_static( + "get_opaque", + [](const std::string &name, DefaultingPyMlirContext context) { + return StructType(context->getRef(), + mlirLLVMStructTypeOpaqueGet( + context.get()->get(), + mlirStringRefCreate(name.data(), name.size()))); + }, + "name"_a, "context"_a = nb::none()); + + c.def( + "set_body", + [](const StructType &self, const std::vector &elements, + bool packed) { + std::vector elements_(elements.size()); + std::copy(elements.begin(), elements.end(), elements_.begin()); + MlirLogicalResult result = mlirLLVMStructTypeSetBody( + self, elements.size(), elements_.data(), packed); + if (!mlirLogicalResultIsSuccess(result)) { + throw nb::value_error( + "Struct body already set to different content."); + } + }, + "elements"_a, nb::kw_only(), "packed"_a = false); + + c.def_static( + "new_identified", + [](const std::string &name, const std::vector &elements, + bool packed, DefaultingPyMlirContext context) { + std::vector elements_(elements.size()); + std::copy(elements.begin(), elements.end(), elements_.begin()); + return StructType(context->getRef(), + mlirLLVMStructTypeIdentifiedNewGet( + context.get()->get(), + mlirStringRefCreate(name.data(), name.length()), + elements.size(), elements_.data(), packed)); + }, + "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "context"_a = nb::none()); + + c.def_prop_ro( + "name", [](const StructType &type) -> std::optional { + if (mlirLLVMStructTypeIsLiteral(type)) + return std::nullopt; + + MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type); + return StringRef(stringRef.data, stringRef.length).str(); + }); + + c.def_prop_ro("body", [](const StructType &type) -> nb::object { + // Don't crash in absence of a body. + if (mlirLLVMStructTypeIsOpaque(type)) + return nb::none(); + + nb::list body; + for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); + i < e; ++i) { + body.append(mlirLLVMStructTypeGetElementType(type, i)); + } + return body; + }); + + c.def_prop_ro("packed", [](const StructType &type) { + return mlirLLVMStructTypeIsPacked(type); + }); + + c.def_prop_ro("opaque", [](const StructType &type) { + return mlirLLVMStructTypeIsOpaque(type); + }); + } +}; + +//===--------------------------------------------------------------------===// +// PointerType +//===--------------------------------------------------------------------===// + +struct PointerType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMPointerType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLLVMPointerTypeGetTypeID; + static constexpr const char *pyClassName = "PointerType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::optional addressSpace, + DefaultingPyMlirContext context) { + python::CollectDiagnosticsToStringScope scope(context.get()->get()); + MlirType type = mlirLLVMPointerTypeGet( + context.get()->get(), + addressSpace.has_value() ? *addressSpace : 0); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return PointerType(context->getRef(), type); + }, + "address_space"_a = nb::none(), nb::kw_only(), + "context"_a = nb::none()); + c.def_prop_ro("address_space", [](const PointerType &type) { + return mlirLLVMPointerTypeGetAddressSpace(type); + }); + } +}; - llvmStructType.def_classmethod( - "get_opaque", - [](const nb::object &cls, const std::string &name, MlirContext context) { - return cls(mlirLLVMStructTypeOpaqueGet( - context, mlirStringRefCreate(name.data(), name.size()))); - }, - "cls"_a, "name"_a, "context"_a = nb::none()); - - llvmStructType.def( - "set_body", - [](MlirType self, const std::vector &elements, bool packed) { - MlirLogicalResult result = mlirLLVMStructTypeSetBody( - self, elements.size(), elements.data(), packed); - if (!mlirLogicalResultIsSuccess(result)) { - throw nb::value_error( - "Struct body already set to different content."); - } - }, - "elements"_a, nb::kw_only(), "packed"_a = false); - - llvmStructType.def_classmethod( - "new_identified", - [](const nb::object &cls, const std::string &name, - const std::vector &elements, bool packed, MlirContext ctx) { - return cls(mlirLLVMStructTypeIdentifiedNewGet( - ctx, mlirStringRefCreate(name.data(), name.length()), - elements.size(), elements.data(), packed)); - }, - "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false, - "context"_a = nb::none()); - - llvmStructType.def_property_readonly( - "name", [](MlirType type) -> std::optional { - if (mlirLLVMStructTypeIsLiteral(type)) - return std::nullopt; - - MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type); - return StringRef(stringRef.data, stringRef.length).str(); - }); - - llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object { - // Don't crash in absence of a body. - if (mlirLLVMStructTypeIsOpaque(type)) - return nb::none(); - - nb::list body; - for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e; - ++i) { - body.append(mlirLLVMStructTypeGetElementType(type, i)); - } - return body; - }); - - llvmStructType.def_property_readonly( - "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); }); - - llvmStructType.def_property_readonly( - "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); }); - - //===--------------------------------------------------------------------===// - // PointerType - //===--------------------------------------------------------------------===// - - mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType, - mlirLLVMPointerTypeGetTypeID) - .def_classmethod( - "get", - [](const nb::object &cls, std::optional addressSpace, - MlirContext context) { - CollectDiagnosticsToStringScope scope(context); - MlirType type = mlirLLVMPointerTypeGet( - context, addressSpace.has_value() ? *addressSpace : 0); - if (mlirTypeIsNull(type)) { - throw nb::value_error(scope.takeMessage().c_str()); - } - return cls(type); - }, - "cls"_a, "address_space"_a = nb::none(), nb::kw_only(), - "context"_a = nb::none()) - .def_property_readonly("address_space", [](MlirType type) { - return mlirLLVMPointerTypeGetAddressSpace(type); - }); +static void populateDialectLLVMSubmodule(nanobind::module_ &m) { + StructType::bind(m); + PointerType::bind(m); m.def( "translate_module_to_llvmir", - [](MlirOperation module) { + [](const PyOperation &module) { return mlirTranslateModuleToLLVMIRToString(module); }, - // clang-format off - nb::sig("def translate_module_to_llvmir(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> str"), - // clang-format on "module"_a, nb::rv_policy::take_ownership); } +} // namespace llvm +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir NB_MODULE(_mlirDialectsLLVM, m) { m.doc() = "MLIR LLVM Dialect"; - populateDialectLLVMSubmodule(m); + python::mlir::llvm::populateDialectLLVMSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp index 18917416412c1..325db6d41c460 100644 --- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp @@ -8,34 +8,48 @@ #include "mlir-c/Dialect/NVGPU.h" #include "mlir-c/IR.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; -using namespace mlir; -using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -static void populateDialectNVGPUSubmodule(const nb::module_ &m) { - auto nvgpuTensorMapDescriptorType = mlir_type_subclass( - m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType); +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace nvgpu { +struct TensorMapDescriptorType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsANVGPUTensorMapDescriptorType; + static constexpr const char *pyClassName = "TensorMapDescriptorType"; + using Base::Base; - nvgpuTensorMapDescriptorType.def_classmethod( - "get", - [](const nb::object &cls, MlirType tensorMemrefType, int swizzle, - int l2promo, int oobFill, int interleave, MlirContext ctx) { - return cls(mlirNVGPUTensorMapDescriptorTypeGet( - ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave)); - }, - "Gets an instance of TensorMapDescriptorType in the same context", - nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"), - nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"), - nb::arg("ctx") = nb::none()); -} + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyType &tensorMemrefType, int swizzle, int l2promo, + int oobFill, int interleave, DefaultingPyMlirContext context) { + return TensorMapDescriptorType( + context->getRef(), mlirNVGPUTensorMapDescriptorTypeGet( + context.get()->get(), tensorMemrefType, + swizzle, l2promo, oobFill, interleave)); + }, + "Gets an instance of TensorMapDescriptorType in the same context", + nb::arg("tensor_type"), nb::arg("swizzle"), nb::arg("l2promo"), + nb::arg("oob_fill"), nb::arg("interleave"), + nb::arg("context").none() = nb::none()); + } +}; +} // namespace nvgpu +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir NB_MODULE(_mlirDialectsNVGPU, m) { m.doc() = "MLIR NVGPU dialect."; - populateDialectNVGPUSubmodule(m); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::nvgpu::TensorMapDescriptorType:: + bind(m); } diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp index 1acb41080f711..ac72734ea5c21 100644 --- a/mlir/lib/Bindings/Python/DialectPDL.cpp +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -8,98 +8,158 @@ #include "mlir-c/Dialect/PDL.h" #include "mlir-c/IR.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; -using namespace mlir; -using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -static void populateDialectPDLSubmodule(const nanobind::module_ &m) { - //===-------------------------------------------------------------------===// - // PDLType - //===-------------------------------------------------------------------===// - - auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType); - - //===-------------------------------------------------------------------===// - // AttributeType - //===-------------------------------------------------------------------===// - - auto attributeType = - mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType); - attributeType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirPDLAttributeTypeGet(ctx)); - }, - "Get an instance of AttributeType in given context.", nb::arg("cls"), - nb::arg("context") = nb::none()); - - //===-------------------------------------------------------------------===// - // OperationType - //===-------------------------------------------------------------------===// - - auto operationType = - mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType); - operationType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirPDLOperationTypeGet(ctx)); - }, - "Get an instance of OperationType in given context.", nb::arg("cls"), - nb::arg("context") = nb::none()); - - //===-------------------------------------------------------------------===// - // RangeType - //===-------------------------------------------------------------------===// - - auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType); - rangeType.def_classmethod( - "get", - [](const nb::object &cls, MlirType elementType) { - return cls(mlirPDLRangeTypeGet(elementType)); - }, - "Gets an instance of RangeType in the same context as the provided " - "element type.", - nb::arg("cls"), nb::arg("element_type")); - rangeType.def_property_readonly( - "element_type", - [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); }, - nb::sig( - "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")), - "Get the element type."); - - //===-------------------------------------------------------------------===// - // TypeType - //===-------------------------------------------------------------------===// - - auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType); - typeType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirPDLTypeTypeGet(ctx)); - }, - "Get an instance of TypeType in given context.", nb::arg("cls"), - nb::arg("context") = nb::none()); - - //===-------------------------------------------------------------------===// - // ValueType - //===-------------------------------------------------------------------===// - - auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType); - valueType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirPDLValueTypeGet(ctx)); - }, - "Get an instance of TypeType in given context.", nb::arg("cls"), - nb::arg("context") = nb::none()); +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace pdl { + +//===-------------------------------------------------------------------===// +// PDLType +//===-------------------------------------------------------------------===// + +struct PDLType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLType; + static constexpr const char *pyClassName = "PDLType"; + using Base::Base; + + static void bindDerived(ClassTy &c) {} +}; + +//===-------------------------------------------------------------------===// +// AttributeType +//===-------------------------------------------------------------------===// + +struct AttributeType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType; + static constexpr const char *pyClassName = "AttributeType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return AttributeType(context->getRef(), + mlirPDLAttributeTypeGet(context.get()->get())); + }, + "Get an instance of AttributeType in given context.", + nb::arg("context").none() = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// OperationType +//===-------------------------------------------------------------------===// + +struct OperationType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType; + static constexpr const char *pyClassName = "OperationType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return OperationType(context->getRef(), + mlirPDLOperationTypeGet(context.get()->get())); + }, + "Get an instance of OperationType in given context.", + nb::arg("context").none() = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// RangeType +//===-------------------------------------------------------------------===// + +struct RangeType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType; + static constexpr const char *pyClassName = "RangeType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyType &elementType, DefaultingPyMlirContext context) { + return RangeType(context->getRef(), mlirPDLRangeTypeGet(elementType)); + }, + "Gets an instance of RangeType in the same context as the provided " + "element type.", + nb::arg("element_type"), nb::arg("context").none() = nb::none()); + c.def_prop_ro( + "element_type", + [](RangeType &type) { + return PyType(type.getContext(), + mlirPDLRangeTypeGetElementType(type)); + }, + "Get the element type."); + } +}; + +//===-------------------------------------------------------------------===// +// TypeType +//===-------------------------------------------------------------------===// + +struct TypeType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType; + static constexpr const char *pyClassName = "TypeType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return TypeType(context->getRef(), + mlirPDLTypeTypeGet(context.get()->get())); + }, + "Get an instance of TypeType in given context.", + nb::arg("context").none() = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// ValueType +//===-------------------------------------------------------------------===// + +struct ValueType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType; + static constexpr const char *pyClassName = "ValueType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return ValueType(context->getRef(), + mlirPDLValueTypeGet(context.get()->get())); + }, + "Get an instance of TypeType in given context.", + nb::arg("context").none() = nb::none()); + } +}; + +static void populateDialectPDLSubmodule(nanobind::module_ &m) { + PDLType::bind(m); + AttributeType::bind(m); + OperationType::bind(m); + RangeType::bind(m); + TypeType::bind(m); + ValueType::bind(m); } +} // namespace pdl +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir NB_MODULE(_mlirDialectsPDL, m) { m.doc() = "MLIR PDL dialect."; - populateDialectPDLSubmodule(m); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::pdl::populateDialectPDLSubmodule( + m); } diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index a5220fcc00604..9c6a15c97134d 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -6,385 +6,495 @@ // //===----------------------------------------------------------------------===// -#include #include -#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Dialect/Quant.h" #include "mlir-c/IR.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include + namespace nb = nanobind; using namespace llvm; -using namespace mlir; using namespace mlir::python::nanobind_adaptors; -static void populateDialectQuantSubmodule(const nb::module_ &m) { - //===-------------------------------------------------------------------===// - // QuantizedType - //===-------------------------------------------------------------------===// +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace quant { +//===-------------------------------------------------------------------===// +// QuantizedType +//===-------------------------------------------------------------------===// + +struct QuantizedType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAQuantizedType; + static constexpr const char *pyClassName = "QuantizedType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "default_minimum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, + integralWidth); + }, + "Default minimum value for the integer with the specified signedness " + "and " + "bit width.", + nb::arg("is_signed"), nb::arg("integral_width")); + c.def_static( + "default_maximum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, + integralWidth); + }, + "Default maximum value for the integer with the specified signedness " + "and " + "bit width.", + nb::arg("is_signed"), nb::arg("integral_width")); + c.def_prop_ro( + "expressed_type", + [](QuantizedType &type) { + return PyType(type.getContext(), + mlirQuantizedTypeGetExpressedType(type)); + }, + "Type expressed by this quantized type."); + c.def_prop_ro( + "flags", + [](const QuantizedType &type) { + return mlirQuantizedTypeGetFlags(type); + }, + "Flags of this quantized type (named accessors should be preferred to " + "this)"); + c.def_prop_ro( + "is_signed", + [](const QuantizedType &type) { + return mlirQuantizedTypeIsSigned(type); + }, + "Signedness of this quantized type."); + c.def_prop_ro( + "storage_type", + [](QuantizedType &type) { + return PyType(type.getContext(), + mlirQuantizedTypeGetStorageType(type)); + }, + "Storage type backing this quantized type."); + c.def_prop_ro( + "storage_type_min", + [](const QuantizedType &type) { + return mlirQuantizedTypeGetStorageTypeMin(type); + }, + "The minimum value held by the storage type of this quantized type."); + c.def_prop_ro( + "storage_type_max", + [](const QuantizedType &type) { + return mlirQuantizedTypeGetStorageTypeMax(type); + }, + "The maximum value held by the storage type of this quantized type."); + c.def_prop_ro( + "storage_type_integral_width", + [](const QuantizedType &type) { + return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); + }, + "The bitwidth of the storage type of this quantized type."); + c.def( + "is_compatible_expressed_type", + [](const QuantizedType &type, const PyType &candidate) { + return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); + }, + "Checks whether the candidate type can be expressed by this quantized " + "type.", + nb::arg("candidate")); + c.def_prop_ro( + "quantized_element_type", + [](QuantizedType &type) { + return PyType(type.getContext(), + mlirQuantizedTypeGetQuantizedElementType(type)); + }, + "Element type of this quantized type expressed as quantized type."); + c.def( + "cast_from_storage_type", + [](QuantizedType &type, const PyType &candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return QuantizedType(type.getContext(), castResult); + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on the storage type of this quantized type to " + "a " + "corresponding type based on the quantized type. Raises TypeError if " + "the " + "cast is not valid.", + nb::arg("candidate")); + c.def_static( + "cast_to_storage_type", + [](QuantizedType &type) { + MlirType castResult = mlirQuantizedTypeCastToStorageType(type); + if (!mlirTypeIsNull(castResult)) + return PyType(type.getContext(), castResult); + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the storage type of this quantized type. Raises TypeError if " + "the cast is not valid.", + nb::arg("type")); + c.def( + "cast_from_expressed_type", + [](QuantizedType &type, const PyType &candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromExpressedType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return PyType(type.getContext(), castResult); + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type " + "to " + "a corresponding type based on the quantized type. Raises TypeError if " + "the cast is not valid.", + nb::arg("candidate")); + c.def_static( + "cast_to_expressed_type", + [](QuantizedType &type) { + MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); + if (!mlirTypeIsNull(castResult)) + return PyType(type.getContext(), castResult); + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the expressed type of this quantized type. Raises TypeError " + "if " + "the cast is not valid.", + nb::arg("type")); + c.def( + "cast_expressed_to_storage_type", + [](QuantizedType &type, const PyType &candidate) { + MlirType castResult = + mlirQuantizedTypeCastExpressedToStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return PyType(type.getContext(), castResult); + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type " + "to " + "a corresponding type based on the storage type. Raises TypeError if " + "the " + "cast is not valid.", + nb::arg("candidate")); + } +}; + +//===-------------------------------------------------------------------===// +// AnyQuantizedType +//===-------------------------------------------------------------------===// + +struct AnyQuantizedType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType; + static constexpr const char *pyClassName = "AnyQuantizedType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](unsigned flags, const PyType &storageType, + const PyType &expressedType, int64_t storageTypeMin, + int64_t storageTypeMax, DefaultingPyMlirContext context) { + return AnyQuantizedType( + context->getRef(), + mlirAnyQuantizedTypeGet(flags, storageType, expressedType, + storageTypeMin, storageTypeMax)); + }, + "Gets an instance of AnyQuantizedType in the same context as the " + "provided storage type.", + nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"), + nb::arg("storage_type_min"), nb::arg("storage_type_max"), + nb::arg("context") = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// UniformQuantizedType +//===-------------------------------------------------------------------===// + +struct UniformQuantizedType + : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType; + static constexpr const char *pyClassName = "UniformQuantizedType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](unsigned flags, const PyType &storageType, + const PyType &expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax, + DefaultingPyMlirContext context) { + return UniformQuantizedType( + context->getRef(), + mlirUniformQuantizedTypeGet(flags, storageType, expressedType, + scale, zeroPoint, storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedType in the same context as the " + "provided storage type.", + nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"), + nb::arg("scale"), nb::arg("zero_point"), nb::arg("storage_type_min"), + nb::arg("storage_type_max"), nb::arg("context") = nb::none()); + c.def_prop_ro( + "scale", + [](const UniformQuantizedType &type) { + return mlirUniformQuantizedTypeGetScale(type); + }, + "The scale designates the difference between the real values " + "corresponding to consecutive quantized values differing by 1."); + c.def_prop_ro( + "zero_point", + [](const UniformQuantizedType &type) { + return mlirUniformQuantizedTypeGetZeroPoint(type); + }, + "The storage value corresponding to the real value 0 in the affine " + "equation."); + c.def_prop_ro( + "is_fixed_point", + [](const UniformQuantizedType &type) { + return mlirUniformQuantizedTypeIsFixedPoint(type); + }, + "Fixed point values are real numbers divided by a scale."); + } +}; + +//===-------------------------------------------------------------------===// +// UniformQuantizedPerAxisType +//===-------------------------------------------------------------------===// + +struct UniformQuantizedPerAxisType + : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsAUniformQuantizedPerAxisType; + static constexpr const char *pyClassName = "UniformQuantizedPerAxisType"; + using Base::Base; - auto quantizedType = - mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType); - quantizedType.def_staticmethod( - "default_minimum_for_integer", - [](bool isSigned, unsigned integralWidth) { - return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, - integralWidth); - }, - "Default minimum value for the integer with the specified signedness and " - "bit width.", - nb::arg("is_signed"), nb::arg("integral_width")); - quantizedType.def_staticmethod( - "default_maximum_for_integer", - [](bool isSigned, unsigned integralWidth) { - return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, - integralWidth); - }, - "Default maximum value for the integer with the specified signedness and " - "bit width.", - nb::arg("is_signed"), nb::arg("integral_width")); - quantizedType.def_property_readonly( - "expressed_type", - [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, - "Type expressed by this quantized type."); - quantizedType.def_property_readonly( - "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, - "Flags of this quantized type (named accessors should be preferred to " - "this)"); - quantizedType.def_property_readonly( - "is_signed", - [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, - "Signedness of this quantized type."); - quantizedType.def_property_readonly( - "storage_type", - [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, - "Storage type backing this quantized type."); - quantizedType.def_property_readonly( - "storage_type_min", - [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, - "The minimum value held by the storage type of this quantized type."); - quantizedType.def_property_readonly( - "storage_type_max", - [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, - "The maximum value held by the storage type of this quantized type."); - quantizedType.def_property_readonly( - "storage_type_integral_width", - [](MlirType type) { - return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); - }, - "The bitwidth of the storage type of this quantized type."); - quantizedType.def( - "is_compatible_expressed_type", - [](MlirType type, MlirType candidate) { - return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); - }, - "Checks whether the candidate type can be expressed by this quantized " - "type.", - nb::arg("candidate")); - quantizedType.def_property_readonly( - "quantized_element_type", - [](MlirType type) { - return mlirQuantizedTypeGetQuantizedElementType(type); - }, - "Element type of this quantized type expressed as quantized type."); - quantizedType.def( - "cast_from_storage_type", - [](MlirType type, MlirType candidate) { - MlirType castResult = - mlirQuantizedTypeCastFromStorageType(type, candidate); - if (!mlirTypeIsNull(castResult)) - return castResult; - throw nb::type_error("Invalid cast."); - }, - "Casts from a type based on the storage type of this quantized type to a " - "corresponding type based on the quantized type. Raises TypeError if the " - "cast is not valid.", - nb::arg("candidate")); - quantizedType.def_staticmethod( - "cast_to_storage_type", - [](MlirType type) { - MlirType castResult = mlirQuantizedTypeCastToStorageType(type); - if (!mlirTypeIsNull(castResult)) - return castResult; - throw nb::type_error("Invalid cast."); - }, - "Casts from a type based on a quantized type to a corresponding type " - "based on the storage type of this quantized type. Raises TypeError if " - "the cast is not valid.", - nb::arg("type")); - quantizedType.def( - "cast_from_expressed_type", - [](MlirType type, MlirType candidate) { - MlirType castResult = - mlirQuantizedTypeCastFromExpressedType(type, candidate); - if (!mlirTypeIsNull(castResult)) - return castResult; - throw nb::type_error("Invalid cast."); - }, - "Casts from a type based on the expressed type of this quantized type to " - "a corresponding type based on the quantized type. Raises TypeError if " - "the cast is not valid.", - nb::arg("candidate")); - quantizedType.def_staticmethod( - "cast_to_expressed_type", - [](MlirType type) { - MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); - if (!mlirTypeIsNull(castResult)) - return castResult; - throw nb::type_error("Invalid cast."); - }, - "Casts from a type based on a quantized type to a corresponding type " - "based on the expressed type of this quantized type. Raises TypeError if " - "the cast is not valid.", - nb::arg("type")); - quantizedType.def( - "cast_expressed_to_storage_type", - [](MlirType type, MlirType candidate) { - MlirType castResult = - mlirQuantizedTypeCastExpressedToStorageType(type, candidate); - if (!mlirTypeIsNull(castResult)) - return castResult; - throw nb::type_error("Invalid cast."); - }, - "Casts from a type based on the expressed type of this quantized type to " - "a corresponding type based on the storage type. Raises TypeError if the " - "cast is not valid.", - nb::arg("candidate")); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](unsigned flags, const PyType &storageType, + const PyType &expressedType, std::vector scales, + std::vector zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax, + DefaultingPyMlirContext context) { + if (scales.size() != zeroPoints.size()) + throw nb::value_error( + "Mismatching number of scales and zero points."); + auto nDims = static_cast(scales.size()); + return UniformQuantizedPerAxisType( + context->getRef(), + mlirUniformQuantizedPerAxisTypeGet( + flags, storageType, expressedType, nDims, scales.data(), + zeroPoints.data(), quantizedDimension, storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedPerAxisType in the same context " + "as " + "the provided storage type.", + nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"), + nb::arg("scales"), nb::arg("zero_points"), + nb::arg("quantized_dimension"), nb::arg("storage_type_min"), + nb::arg("storage_type_max"), nb::arg("context") = nb::none()); + c.def_prop_ro( + "scales", + [](const UniformQuantizedPerAxisType &type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector scales; + scales.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); + scales.push_back(scale); + } + return scales; + }, + "The scales designate the difference between the real values " + "corresponding to consecutive quantized values differing by 1. The ith " + "scale corresponds to the ith slice in the quantized_dimension."); + c.def_prop_ro( + "zero_points", + [](const UniformQuantizedPerAxisType &type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector zeroPoints; + zeroPoints.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + int64_t zeroPoint = + mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); + zeroPoints.push_back(zeroPoint); + } + return zeroPoints; + }, + "the storage values corresponding to the real value 0 in the affine " + "equation. The ith zero point corresponds to the ith slice in the " + "quantized_dimension."); + c.def_prop_ro( + "quantized_dimension", + [](const UniformQuantizedPerAxisType &type) { + return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); + }, + "Specifies the dimension of the shape that the scales and zero points " + "correspond to."); + c.def_prop_ro( + "is_fixed_point", + [](const UniformQuantizedPerAxisType &type) { + return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); + }, + "Fixed point values are real numbers divided by a scale."); + } +}; - quantizedType.get_class().attr("FLAG_SIGNED") = - mlirQuantizedTypeGetSignedFlag(); +//===-------------------------------------------------------------------===// +// UniformQuantizedSubChannelType +//===-------------------------------------------------------------------===// - //===-------------------------------------------------------------------===// - // AnyQuantizedType - //===-------------------------------------------------------------------===// +struct UniformQuantizedSubChannelType + : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsAUniformQuantizedSubChannelType; + static constexpr const char *pyClassName = "UniformQuantizedSubChannelType"; + using Base::Base; - auto anyQuantizedType = - mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, - quantizedType.get_class()); - anyQuantizedType.def_classmethod( - "get", - [](const nb::object &cls, unsigned flags, MlirType storageType, - MlirType expressedType, int64_t storageTypeMin, - int64_t storageTypeMax) { - return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, - storageTypeMin, storageTypeMax)); - }, - "Gets an instance of AnyQuantizedType in the same context as the " - "provided storage type.", - nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), - nb::arg("expressed_type"), nb::arg("storage_type_min"), - nb::arg("storage_type_max")); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](unsigned flags, const PyType &storageType, + const PyType &expressedType, MlirAttribute scales, + MlirAttribute zeroPoints, std::vector quantizedDimensions, + std::vector blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax, DefaultingPyMlirContext context) { + return UniformQuantizedSubChannelType( + context->getRef(), + mlirUniformQuantizedSubChannelTypeGet( + flags, storageType, expressedType, scales, zeroPoints, + static_cast(blockSizes.size()), + quantizedDimensions.data(), blockSizes.data(), storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedSubChannel in the same context as " + "the provided storage type.", + nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"), + nb::arg("scales"), nb::arg("zero_points"), + nb::arg("quantized_dimensions"), nb::arg("block_sizes"), + nb::arg("storage_type_min"), nb::arg("storage_type_max"), + nb::arg("context") = nb::none()); + c.def_prop_ro( + "quantized_dimensions", + [](const UniformQuantizedSubChannelType &type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector quantizedDimensions; + quantizedDimensions.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + quantizedDimensions.push_back( + mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, + i)); + } + return quantizedDimensions; + }, + "Gets the quantized dimensions. Each element in the returned list " + "represents an axis of the quantized data tensor that has a specified " + "block size. The order of elements corresponds to the order of block " + "sizes returned by 'block_sizes' method. It means that the data tensor " + "is quantized along the i-th dimension in the returned list using the " + "i-th block size from block_sizes method."); + c.def_prop_ro( + "block_sizes", + [](const UniformQuantizedSubChannelType &type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector blockSizes; + blockSizes.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + blockSizes.push_back( + mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); + } + return blockSizes; + }, + "Gets the block sizes for the quantized dimensions. The i-th element " + "in " + "the returned list corresponds to the block size for the i-th " + "dimension " + "in the list returned by quantized_dimensions method."); + c.def_prop_ro( + "scales", + [](UniformQuantizedSubChannelType &type) { + return PyDenseElementsAttribute( + type.getContext(), + mlirUniformQuantizedSubChannelTypeGetScales(type)); + }, + "The scales of the quantized type."); + c.def_prop_ro( + "zero_points", + [](UniformQuantizedSubChannelType &type) { + return PyDenseElementsAttribute( + type.getContext(), + mlirUniformQuantizedSubChannelTypeGetZeroPoints(type)); + }, + "The zero points of the quantized type."); + } +}; - //===-------------------------------------------------------------------===// - // UniformQuantizedType - //===-------------------------------------------------------------------===// +//===-------------------------------------------------------------------===// +// CalibratedQuantizedType +//===-------------------------------------------------------------------===// - auto uniformQuantizedType = mlir_type_subclass( - m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, - quantizedType.get_class()); - uniformQuantizedType.def_classmethod( - "get", - [](const nb::object &cls, unsigned flags, MlirType storageType, - MlirType expressedType, double scale, int64_t zeroPoint, - int64_t storageTypeMin, int64_t storageTypeMax) { - return cls(mlirUniformQuantizedTypeGet(flags, storageType, - expressedType, scale, zeroPoint, - storageTypeMin, storageTypeMax)); - }, - "Gets an instance of UniformQuantizedType in the same context as the " - "provided storage type.", - nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), - nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"), - nb::arg("storage_type_min"), nb::arg("storage_type_max")); - uniformQuantizedType.def_property_readonly( - "scale", - [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, - "The scale designates the difference between the real values " - "corresponding to consecutive quantized values differing by 1."); - uniformQuantizedType.def_property_readonly( - "zero_point", - [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, - "The storage value corresponding to the real value 0 in the affine " - "equation."); - uniformQuantizedType.def_property_readonly( - "is_fixed_point", - [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, - "Fixed point values are real numbers divided by a scale."); +struct CalibratedQuantizedType + : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsACalibratedQuantizedType; + static constexpr const char *pyClassName = "CalibratedQuantizedType"; + using Base::Base; - //===-------------------------------------------------------------------===// - // UniformQuantizedPerAxisType - //===-------------------------------------------------------------------===// - auto uniformQuantizedPerAxisType = mlir_type_subclass( - m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, - quantizedType.get_class()); - uniformQuantizedPerAxisType.def_classmethod( - "get", - [](const nb::object &cls, unsigned flags, MlirType storageType, - MlirType expressedType, std::vector scales, - std::vector zeroPoints, int32_t quantizedDimension, - int64_t storageTypeMin, int64_t storageTypeMax) { - if (scales.size() != zeroPoints.size()) - throw nb::value_error( - "Mismatching number of scales and zero points."); - auto nDims = static_cast(scales.size()); - return cls(mlirUniformQuantizedPerAxisTypeGet( - flags, storageType, expressedType, nDims, scales.data(), - zeroPoints.data(), quantizedDimension, storageTypeMin, - storageTypeMax)); - }, - "Gets an instance of UniformQuantizedPerAxisType in the same context as " - "the provided storage type.", - nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), - nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), - nb::arg("quantized_dimension"), nb::arg("storage_type_min"), - nb::arg("storage_type_max")); - uniformQuantizedPerAxisType.def_property_readonly( - "scales", - [](MlirType type) { - intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); - std::vector scales; - scales.reserve(nDim); - for (intptr_t i = 0; i < nDim; ++i) { - double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); - scales.push_back(scale); - } - return scales; - }, - "The scales designate the difference between the real values " - "corresponding to consecutive quantized values differing by 1. The ith " - "scale corresponds to the ith slice in the quantized_dimension."); - uniformQuantizedPerAxisType.def_property_readonly( - "zero_points", - [](MlirType type) { - intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); - std::vector zeroPoints; - zeroPoints.reserve(nDim); - for (intptr_t i = 0; i < nDim; ++i) { - int64_t zeroPoint = - mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); - zeroPoints.push_back(zeroPoint); - } - return zeroPoints; - }, - "the storage values corresponding to the real value 0 in the affine " - "equation. The ith zero point corresponds to the ith slice in the " - "quantized_dimension."); - uniformQuantizedPerAxisType.def_property_readonly( - "quantized_dimension", - [](MlirType type) { - return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); - }, - "Specifies the dimension of the shape that the scales and zero points " - "correspond to."); - uniformQuantizedPerAxisType.def_property_readonly( - "is_fixed_point", - [](MlirType type) { - return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); - }, - "Fixed point values are real numbers divided by a scale."); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyType &expressedType, double min, double max, + DefaultingPyMlirContext context) { + return CalibratedQuantizedType( + context->getRef(), + mlirCalibratedQuantizedTypeGet(expressedType, min, max)); + }, + "Gets an instance of CalibratedQuantizedType in the same context as " + "the " + "provided expressed type.", + nb::arg("expressed_type"), nb::arg("min"), nb::arg("max"), + nb::arg("context") = nb::none()); + c.def_prop_ro("min", [](const PyType &type) { + return mlirCalibratedQuantizedTypeGetMin(type); + }); + c.def_prop_ro("max", [](const PyType &type) { + return mlirCalibratedQuantizedTypeGetMax(type); + }); + } +}; - //===-------------------------------------------------------------------===// - // UniformQuantizedSubChannelType - //===-------------------------------------------------------------------===// - auto uniformQuantizedSubChannelType = mlir_type_subclass( - m, "UniformQuantizedSubChannelType", - mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); - uniformQuantizedSubChannelType.def_classmethod( - "get", - [](const nb::object &cls, unsigned flags, MlirType storageType, - MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, - std::vector quantizedDimensions, - std::vector blockSizes, int64_t storageTypeMin, - int64_t storageTypeMax) { - return cls(mlirUniformQuantizedSubChannelTypeGet( - flags, storageType, expressedType, scales, zeroPoints, - static_cast(blockSizes.size()), - quantizedDimensions.data(), blockSizes.data(), storageTypeMin, - storageTypeMax)); - }, - "Gets an instance of UniformQuantizedSubChannel in the same context as " - "the provided storage type.", - nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), - nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), - nb::arg("quantized_dimensions"), nb::arg("block_sizes"), - nb::arg("storage_type_min"), nb::arg("storage_type_max")); - uniformQuantizedSubChannelType.def_property_readonly( - "quantized_dimensions", - [](MlirType type) { - intptr_t nDim = - mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); - std::vector quantizedDimensions; - quantizedDimensions.reserve(nDim); - for (intptr_t i = 0; i < nDim; ++i) { - quantizedDimensions.push_back( - mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i)); - } - return quantizedDimensions; - }, - "Gets the quantized dimensions. Each element in the returned list " - "represents an axis of the quantized data tensor that has a specified " - "block size. The order of elements corresponds to the order of block " - "sizes returned by 'block_sizes' method. It means that the data tensor " - "is quantized along the i-th dimension in the returned list using the " - "i-th block size from block_sizes method."); - uniformQuantizedSubChannelType.def_property_readonly( - "block_sizes", - [](MlirType type) { - intptr_t nDim = - mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); - std::vector blockSizes; - blockSizes.reserve(nDim); - for (intptr_t i = 0; i < nDim; ++i) { - blockSizes.push_back( - mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); - } - return blockSizes; - }, - "Gets the block sizes for the quantized dimensions. The i-th element in " - "the returned list corresponds to the block size for the i-th dimension " - "in the list returned by quantized_dimensions method."); - uniformQuantizedSubChannelType.def_property_readonly( - "scales", - [](MlirType type) -> MlirAttribute { - return mlirUniformQuantizedSubChannelTypeGetScales(type); - }, - "The scales of the quantized type."); - uniformQuantizedSubChannelType.def_property_readonly( - "zero_points", - [](MlirType type) -> MlirAttribute { - return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type); - }, - "The zero points of the quantized type."); +static void populateDialectQuantSubmodule(nb::module_ &m) { + QuantizedType::bind(m); - //===-------------------------------------------------------------------===// - // CalibratedQuantizedType - //===-------------------------------------------------------------------===// + // Set the FLAG_SIGNED class attribute after binding QuantizedType + auto quantizedTypeClass = m.attr("QuantizedType"); + quantizedTypeClass.attr("FLAG_SIGNED") = mlirQuantizedTypeGetSignedFlag(); - auto calibratedQuantizedType = mlir_type_subclass( - m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, - quantizedType.get_class()); - calibratedQuantizedType.def_classmethod( - "get", - [](const nb::object &cls, MlirType expressedType, double min, - double max) { - return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); - }, - "Gets an instance of CalibratedQuantizedType in the same context as the " - "provided expressed type.", - nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"), - nb::arg("max")); - calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { - return mlirCalibratedQuantizedTypeGetMin(type); - }); - calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { - return mlirCalibratedQuantizedTypeGetMax(type); - }); + AnyQuantizedType::bind(m); + UniformQuantizedType::bind(m); + UniformQuantizedPerAxisType::bind(m); + UniformQuantizedSubChannelType::bind(m); + CalibratedQuantizedType::bind(m); } +} // namespace quant +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir NB_MODULE(_mlirDialectsQuant, m) { m.doc() = "MLIR Quantization dialect"; - populateDialectQuantSubmodule(m); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::quant:: + populateDialectQuantSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp index a87918a05b126..2edf9bb48c18d 100644 --- a/mlir/lib/Bindings/Python/DialectSMT.cpp +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -13,44 +13,77 @@ #include "mlir-c/Support.h" #include "mlir-c/Target/ExportSMTLIB.h" #include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace nanobind::literals; - using namespace mlir; -using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -static void populateDialectSMTSubmodule(nanobind::module_ &m) { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace smt { +struct BoolType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABool; + static constexpr const char *pyClassName = "BoolType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return BoolType(context->getRef(), + mlirSMTTypeGetBool(context.get()->get())); + }, + nb::arg("context").none() = nb::none()); + } +}; + +struct BitVectorType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABitVector; + static constexpr const char *pyClassName = "BitVectorType"; + using Base::Base; - auto smtBoolType = - mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) - .def_staticmethod( - "get", - [](MlirContext context) { return mlirSMTTypeGetBool(context); }, - "context"_a = nb::none()); - auto smtBitVectorType = - mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector) - .def_staticmethod( - "get", - [](int32_t width, MlirContext context) { - return mlirSMTTypeGetBitVector(context, width); - }, - "width"_a, "context"_a = nb::none()); - auto smtIntType = - mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt) - .def_staticmethod( - "get", - [](MlirContext context) { return mlirSMTTypeGetInt(context); }, - "context"_a = nb::none()); + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int32_t width, DefaultingPyMlirContext context) { + return BitVectorType( + context->getRef(), + mlirSMTTypeGetBitVector(context.get()->get(), width)); + }, + nb::arg("width"), nb::arg("context").none() = nb::none()); + } +}; + +struct IntType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsAInt; + static constexpr const char *pyClassName = "IntType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return IntType(context->getRef(), + mlirSMTTypeGetInt(context.get()->get())); + }, + nb::arg("context").none() = nb::none()); + } +}; + +static void populateDialectSMTSubmodule(nanobind::module_ &m) { + BoolType::bind(m); + BitVectorType::bind(m); + IntType::bind(m); auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues, bool indentLetBody) { - mlir::python::CollectDiagnosticsToStringScope scope( - mlirOperationGetContext(module)); + CollectDiagnosticsToStringScope scope(mlirOperationGetContext(module)); PyPrintAccumulator printAccum; MlirLogicalResult result = mlirTranslateOperationToSMTLIB( module, printAccum.getCallback(), printAccum.getUserData(), @@ -64,7 +97,7 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) { m.def( "export_smtlib", - [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues, + [&exportSMTLIB](const PyOperation &module, bool inlineSingleUseValues, bool indentLetBody) { return exportSMTLIB(module, inlineSingleUseValues, indentLetBody); }, @@ -72,17 +105,21 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) { "indent_let_body"_a = false); m.def( "export_smtlib", - [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues, + [&exportSMTLIB](PyModule &module, bool inlineSingleUseValues, bool indentLetBody) { - return exportSMTLIB(mlirModuleGetOperation(module), + return exportSMTLIB(mlirModuleGetOperation(module.get()), inlineSingleUseValues, indentLetBody); }, "module"_a, "inline_single_use_values"_a = false, "indent_let_body"_a = false); } +} // namespace smt +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir NB_MODULE(_mlirDialectsSMT, m) { m.doc() = "MLIR SMT Dialect"; - populateDialectSMTSubmodule(m); + python::mlir::smt::populateDialectSMTSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 00b65ee9745dc..dbb067a3cff81 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -12,137 +12,179 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; -using namespace mlir; using namespace mlir::python::nanobind_adaptors; -static void populateDialectSparseTensorSubmodule(const nb::module_ &m) { - nb::enum_(m, "LevelFormat", nb::is_arithmetic(), - nb::is_flag()) +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace sparse_tensor { + +enum PySparseTensorLevelFormat : std::underlying_type_t< + MlirSparseTensorLevelFormat> { + MLIR_SPARSE_TENSOR_LEVEL_DENSE = MLIR_SPARSE_TENSOR_LEVEL_DENSE, + MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M, + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED, + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = MLIR_SPARSE_TENSOR_LEVEL_SINGLETON, + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED +}; + +enum PySparseTensorLevelPropertyNondefault : std::underlying_type_t< + MlirSparseTensorLevelPropertyNondefault> { + MLIR_SPARSE_PROPERTY_NON_ORDERED = MLIR_SPARSE_PROPERTY_NON_ORDERED, + MLIR_SPARSE_PROPERTY_NON_UNIQUE = MLIR_SPARSE_PROPERTY_NON_UNIQUE, + MLIR_SPARSE_PROPERTY_SOA = MLIR_SPARSE_PROPERTY_SOA, +}; + +struct EncodingAttr : PyConcreteAttribute { + static constexpr IsAFunctionTy isaFunction = + mlirAttributeIsASparseTensorEncodingAttr; + static constexpr const char *pyClassName = "EncodingAttr"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector lvlTypes, + std::optional dimToLvl, + std::optional lvlToDim, int posWidth, int crdWidth, + std::optional explicitVal, + std::optional implicitVal, + DefaultingPyMlirContext context) { + return EncodingAttr( + context->getRef(), + mlirSparseTensorEncodingAttrGet( + context.get()->get(), lvlTypes.size(), lvlTypes.data(), + dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, + lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth, + crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr}, + implicitVal ? *implicitVal : MlirAttribute{nullptr})); + }, + nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(), + nb::arg("lvl_to_dim").none(), nb::arg("pos_width"), + nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(), + nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(), + "Gets a sparse_tensor.encoding from parameters."); + + c.def_static( + "build_level_type", + [](PySparseTensorLevelFormat lvlFmt, + const std::vector &properties, + unsigned n, unsigned m) { + std::vector props; + props.reserve(properties.size()); + for (auto prop : properties) { + props.push_back( + static_cast(prop)); + } + return mlirSparseTensorEncodingAttrBuildLvlType( + static_cast(lvlFmt), props.data(), + props.size(), n, m); + }, + nb::arg("lvl_fmt"), + nb::arg("properties") = + std::vector(), + nb::arg("n") = 0, nb::arg("m") = 0, + "Builds a sparse_tensor.encoding.level_type from parameters."); + + c.def_prop_ro("lvl_types", [](const EncodingAttr &self) { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + std::vector ret; + ret.reserve(lvlRank); + for (int l = 0; l < lvlRank; ++l) + ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l)); + return ret; + }); + + c.def_prop_ro( + "dim_to_lvl", [](EncodingAttr &self) -> std::optional { + MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return PyAffineMap(self.getContext(), ret); + }); + + c.def_prop_ro( + "lvl_to_dim", [](EncodingAttr &self) -> std::optional { + MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return PyAffineMap(self.getContext(), ret); + }); + + c.def_prop_ro("pos_width", mlirSparseTensorEncodingAttrGetPosWidth); + c.def_prop_ro("crd_width", mlirSparseTensorEncodingAttrGetCrdWidth); + + c.def_prop_ro( + "explicit_val", [](EncodingAttr &self) -> std::optional { + MlirAttribute ret = mlirSparseTensorEncodingAttrGetExplicitVal(self); + if (mlirAttributeIsNull(ret)) + return {}; + return PyAttribute(self.getContext(), ret); + }); + + c.def_prop_ro( + "implicit_val", [](EncodingAttr &self) -> std::optional { + MlirAttribute ret = mlirSparseTensorEncodingAttrGetImplicitVal(self); + if (mlirAttributeIsNull(ret)) + return {}; + return PyAttribute(self.getContext(), ret); + }); + + c.def_prop_ro("structured_n", [](const EncodingAttr &self) -> unsigned { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + return mlirSparseTensorEncodingAttrGetStructuredN( + mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); + }); + + c.def_prop_ro("structured_m", [](const EncodingAttr &self) -> unsigned { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + return mlirSparseTensorEncodingAttrGetStructuredM( + mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); + }); + + c.def_prop_ro("lvl_formats_enum", [](const EncodingAttr &self) { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + std::vector ret; + ret.reserve(lvlRank); + + for (int l = 0; l < lvlRank; l++) + ret.push_back(static_cast( + mlirSparseTensorEncodingAttrGetLvlFmt(self, l))); + return ret; + }); + } +}; + +static void populateDialectSparseTensorSubmodule(nb::module_ &m) { + nb::enum_(m, "LevelFormat", nb::is_arithmetic(), + nb::is_flag()) .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED); - nb::enum_(m, "LevelProperty") + nb::enum_(m, "LevelProperty") .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED) .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE) .value("soa", MLIR_SPARSE_PROPERTY_SOA); - mlir_attribute_subclass(m, "EncodingAttr", - mlirAttributeIsASparseTensorEncodingAttr) - .def_classmethod( - "get", - [](const nb::object &cls, - std::vector lvlTypes, - std::optional dimToLvl, - std::optional lvlToDim, int posWidth, int crdWidth, - std::optional explicitVal, - std::optional implicitVal, MlirContext context) { - return cls(mlirSparseTensorEncodingAttrGet( - context, lvlTypes.size(), lvlTypes.data(), - dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, - lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth, - crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr}, - implicitVal ? *implicitVal : MlirAttribute{nullptr})); - }, - nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(), - nb::arg("lvl_to_dim").none(), nb::arg("pos_width"), - nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(), - nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(), - "Gets a sparse_tensor.encoding from parameters.") - .def_classmethod( - "build_level_type", - [](const nb::object &cls, MlirSparseTensorLevelFormat lvlFmt, - const std::vector - &properties, - unsigned n, unsigned m) { - return mlirSparseTensorEncodingAttrBuildLvlType( - lvlFmt, properties.data(), properties.size(), n, m); - }, - nb::arg("cls"), nb::arg("lvl_fmt"), - nb::arg("properties") = - std::vector(), - nb::arg("n") = 0, nb::arg("m") = 0, - "Builds a sparse_tensor.encoding.level_type from parameters.") - .def_property_readonly( - "lvl_types", - [](MlirAttribute self) { - const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); - std::vector ret; - ret.reserve(lvlRank); - for (int l = 0; l < lvlRank; ++l) - ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l)); - return ret; - }) - .def_property_readonly( - "dim_to_lvl", - [](MlirAttribute self) -> std::optional { - MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self); - if (mlirAffineMapIsNull(ret)) - return {}; - return ret; - }) - .def_property_readonly( - "lvl_to_dim", - [](MlirAttribute self) -> std::optional { - MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self); - if (mlirAffineMapIsNull(ret)) - return {}; - return ret; - }) - .def_property_readonly("pos_width", - mlirSparseTensorEncodingAttrGetPosWidth) - .def_property_readonly("crd_width", - mlirSparseTensorEncodingAttrGetCrdWidth) - .def_property_readonly( - "explicit_val", - [](MlirAttribute self) -> std::optional { - MlirAttribute ret = - mlirSparseTensorEncodingAttrGetExplicitVal(self); - if (mlirAttributeIsNull(ret)) - return {}; - return ret; - }) - .def_property_readonly( - "implicit_val", - [](MlirAttribute self) -> std::optional { - MlirAttribute ret = - mlirSparseTensorEncodingAttrGetImplicitVal(self); - if (mlirAttributeIsNull(ret)) - return {}; - return ret; - }) - .def_property_readonly( - "structured_n", - [](MlirAttribute self) -> unsigned { - const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); - return mlirSparseTensorEncodingAttrGetStructuredN( - mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); - }) - .def_property_readonly( - "structured_m", - [](MlirAttribute self) -> unsigned { - const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); - return mlirSparseTensorEncodingAttrGetStructuredM( - mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); - }) - .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) { - const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); - std::vector ret; - ret.reserve(lvlRank); - for (int l = 0; l < lvlRank; l++) - ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l)); - return ret; - }); + EncodingAttr::bind(m); } +} // namespace sparse_tensor +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir NB_MODULE(_mlirDialectsSparseTensor, m) { m.doc() = "MLIR SparseTensor dialect."; - populateDialectSparseTensorSubmodule(m); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::sparse_tensor:: + populateDialectSparseTensorSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index 150c69953d960..1d0b5527882fd 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -11,112 +11,165 @@ #include "mlir-c/Dialect/Transform.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; -using namespace mlir; -using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -static void populateDialectTransformSubmodule(const nb::module_ &m) { - //===-------------------------------------------------------------------===// - // AnyOpType - //===-------------------------------------------------------------------===// - - auto anyOpType = - mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType, - mlirTransformAnyOpTypeGetTypeID); - anyOpType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirTransformAnyOpTypeGet(ctx)); - }, - "Get an instance of AnyOpType in the given context.", nb::arg("cls"), - nb::arg("context") = nb::none()); - - //===-------------------------------------------------------------------===// - // AnyParamType - //===-------------------------------------------------------------------===// - - auto anyParamType = - mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType, - mlirTransformAnyParamTypeGetTypeID); - anyParamType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirTransformAnyParamTypeGet(ctx)); - }, - "Get an instance of AnyParamType in the given context.", nb::arg("cls"), - nb::arg("context") = nb::none()); - - //===-------------------------------------------------------------------===// - // AnyValueType - //===-------------------------------------------------------------------===// - - auto anyValueType = - mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType, - mlirTransformAnyValueTypeGetTypeID); - anyValueType.def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirTransformAnyValueTypeGet(ctx)); - }, - "Get an instance of AnyValueType in the given context.", nb::arg("cls"), - nb::arg("context") = nb::none()); - - //===-------------------------------------------------------------------===// - // OperationType - //===-------------------------------------------------------------------===// - - auto operationType = - mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType, - mlirTransformOperationTypeGetTypeID); - operationType.def_classmethod( - "get", - [](const nb::object &cls, const std::string &operationName, - MlirContext ctx) { - MlirStringRef cOperationName = - mlirStringRefCreate(operationName.data(), operationName.size()); - return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); - }, - "Get an instance of OperationType for the given kind in the given " - "context", - nb::arg("cls"), nb::arg("operation_name"), - nb::arg("context") = nb::none()); - operationType.def_property_readonly( - "operation_name", - [](MlirType type) { - MlirStringRef operationName = - mlirTransformOperationTypeGetOperationName(type); - return nb::str(operationName.data, operationName.length); - }, - "Get the name of the payload operation accepted by the handle."); - - //===-------------------------------------------------------------------===// - // ParamType - //===-------------------------------------------------------------------===// - - auto paramType = - mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType, - mlirTransformParamTypeGetTypeID); - paramType.def_classmethod( - "get", - [](const nb::object &cls, MlirType type, MlirContext ctx) { - return cls(mlirTransformParamTypeGet(ctx, type)); - }, - "Get an instance of ParamType for the given type in the given context.", - nb::arg("cls"), nb::arg("type"), nb::arg("context") = nb::none()); - paramType.def_property_readonly( - "type", - [](MlirType type) { - MlirType paramType = mlirTransformParamTypeGetType(type); - return paramType; - }, - "Get the type this ParamType is associated with."); +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +namespace transform { +//===-------------------------------------------------------------------===// +// AnyOpType +//===-------------------------------------------------------------------===// + +struct AnyOpType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyOpType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTransformAnyOpTypeGetTypeID; + static constexpr const char *pyClassName = "AnyOpType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return AnyOpType(context->getRef(), + mlirTransformAnyOpTypeGet(context.get()->get())); + }, + "Get an instance of AnyOpType in the given context.", + nb::arg("context").none() = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// AnyParamType +//===-------------------------------------------------------------------===// + +struct AnyParamType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyParamType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTransformAnyParamTypeGetTypeID; + static constexpr const char *pyClassName = "AnyParamType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet( + context.get()->get())); + }, + "Get an instance of AnyParamType in the given context.", + nb::arg("context").none() = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// AnyValueType +//===-------------------------------------------------------------------===// + +struct AnyValueType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyValueType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTransformAnyValueTypeGetTypeID; + static constexpr const char *pyClassName = "AnyValueType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet( + context.get()->get())); + }, + "Get an instance of AnyValueType in the given context.", + nb::arg("context").none() = nb::none()); + } +}; + +//===-------------------------------------------------------------------===// +// OperationType +//===-------------------------------------------------------------------===// + +struct OperationType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsATransformOperationType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTransformOperationTypeGetTypeID; + static constexpr const char *pyClassName = "OperationType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &operationName, DefaultingPyMlirContext context) { + MlirStringRef cOperationName = + mlirStringRefCreate(operationName.data(), operationName.size()); + return OperationType(context->getRef(), + mlirTransformOperationTypeGet( + context.get()->get(), cOperationName)); + }, + "Get an instance of OperationType for the given kind in the given " + "context", + nb::arg("operation_name"), nb::arg("context").none() = nb::none()); + c.def_prop_ro( + "operation_name", + [](const OperationType &type) { + MlirStringRef operationName = + mlirTransformOperationTypeGetOperationName(type); + return nb::str(operationName.data, operationName.length); + }, + "Get the name of the payload operation accepted by the handle."); + } +}; + +//===-------------------------------------------------------------------===// +// ParamType +//===-------------------------------------------------------------------===// + +struct ParamType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformParamType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTransformParamTypeGetTypeID; + static constexpr const char *pyClassName = "ParamType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyType &type, DefaultingPyMlirContext context) { + return ParamType(context->getRef(), mlirTransformParamTypeGet( + context.get()->get(), type)); + }, + "Get an instance of ParamType for the given type in the given context.", + nb::arg("type"), nb::arg("context").none() = nb::none()); + c.def_prop_ro( + "type", + [](ParamType type) { + return PyType(type.getContext(), mlirTransformParamTypeGetType(type)); + }, + "Get the type this ParamType is associated with."); + } +}; + +static void populateDialectTransformSubmodule(nb::module_ &m) { + AnyOpType::bind(m); + AnyParamType::bind(m); + AnyValueType::bind(m); + OperationType::bind(m); + ParamType::bind(m); } +} // namespace transform +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir NB_MODULE(_mlirDialectsTransform, m) { m.doc() = "MLIR Transform dialect."; - populateDialectTransformSubmodule(m); + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::transform:: + populateDialectTransformSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index c282f4b6996e5..9830c277ac147 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -55,7 +55,7 @@ class PyPatternRewriter { mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); } - void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); } + void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); } private: MlirRewriterBase base; @@ -342,38 +342,29 @@ void populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the PatternRewriter //---------------------------------------------------------------------------- - nb:: - class_(m, "PatternRewriter") - .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter.") - .def( - "replace_op", - [](PyPatternRewriter &self, MlirOperation op, - MlirOperation newOp) { self.replaceOp(op, newOp); }, - "Replace an operation with a new operation.", nb::arg("op"), - nb::arg("new_op"), - // clang-format off - nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") - // clang-format on - ) - .def( - "replace_op", - [](PyPatternRewriter &self, MlirOperation op, - const std::vector &values) { - self.replaceOp(op, values); - }, - "Replace an operation with a list of values.", nb::arg("op"), - nb::arg("values"), - // clang-format off - nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") - // clang-format on - ) - .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", - nb::arg("op"), - // clang-format off - nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") - // clang-format on - ); + nb::class_(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter.") + .def( + "replace_op", + [](PyPatternRewriter &self, PyOperationBase &op, + PyOperationBase &newOp) { + self.replaceOp(op.getOperation(), newOp.getOperation()); + }, + "Replace an operation with a new operation.", nb::arg("op"), + nb::arg("new_op")) + .def( + "replace_op", + [](PyPatternRewriter &self, PyOperationBase &op, + const std::vector &values) { + std::vector values_(values.size()); + std::copy(values.begin(), values.end(), values_.begin()); + self.replaceOp(op.getOperation(), values_); + }, + "Replace an operation with a list of values.", nb::arg("op"), + nb::arg("values")) + .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + nb::arg("op")); //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet @@ -428,42 +419,21 @@ void populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "PDLResultList") - .def( - "append", - [](PyMlirPDLResultList results, const PyValue &value) { - mlirPDLResultListPushBackValue(results, value); - }, - // clang-format off - nb::sig("def append(self, value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")") - // clang-format on - ) - .def( - "append", - [](PyMlirPDLResultList results, const PyOperation &op) { - mlirPDLResultListPushBackOperation(results, op); - }, - // clang-format off - nb::sig("def append(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")") - // clang-format on - ) - .def( - "append", - [](PyMlirPDLResultList results, const PyType &type) { - mlirPDLResultListPushBackType(results, type); - }, - // clang-format off - nb::sig("def append(self, type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")") - // clang-format on - ) - .def( - "append", - [](PyMlirPDLResultList results, const PyAttribute &attr) { - mlirPDLResultListPushBackAttribute(results, attr); - }, - // clang-format off - nb::sig("def append(self, attr: " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")") - // clang-format on - ); + .def("append", + [](PyMlirPDLResultList results, const PyValue &value) { + mlirPDLResultListPushBackValue(results, value); + }) + .def("append", + [](PyMlirPDLResultList results, const PyOperation &op) { + mlirPDLResultListPushBackOperation(results, op); + }) + .def("append", + [](PyMlirPDLResultList results, const PyType &type) { + mlirPDLResultListPushBackType(results, type); + }) + .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) { + mlirPDLResultListPushBackAttribute(results, attr); + }); nb::class_(m, "PDLModule") .def( "__init__", @@ -471,9 +441,6 @@ void populateRewriteSubmodule(nb::module_ &m) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, - // clang-format off - nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"), - // clang-format on "module"_a, "Create a PDL module from the given module.") .def( "__init__", @@ -481,9 +448,6 @@ void populateRewriteSubmodule(nb::module_ &m) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, - // clang-format off - nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"), - // clang-format on "module"_a, "Create a PDL module from the given module.") .def( "freeze", @@ -552,9 +516,6 @@ void populateRewriteSubmodule(nb::module_ &m) { throw std::runtime_error("pattern application failed to converge"); }, "module"_a, "set"_a, - // clang-format off - nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"), - // clang-format on "Applys the given patterns to the given module greedily while folding " "results.") .def( @@ -568,9 +529,6 @@ void populateRewriteSubmodule(nb::module_ &m) { "pattern application failed to converge"); }, "op"_a, "set"_a, - // clang-format off - nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"), - // clang-format on "Applys the given patterns to the given op greedily while folding " "results.") .def( @@ -579,9 +537,6 @@ void populateRewriteSubmodule(nb::module_ &m) { mlirWalkAndApplyPatterns(op.getOperation(), set.get()); }, "op"_a, "set"_a, - // clang-format off - nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"), - // clang-format on "Applies the given patterns to the given op by a fast walk-based " "driver."); } diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py index 8d045cad7a4a3..b4d19878056db 100644 --- a/mlir/python/mlir/dialects/transform/extras/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -43,8 +43,9 @@ def __init__( self.parent = parent self.children = children if children is not None else [] -@ir.register_value_caster(AnyOpType.get_static_typeid()) -@ir.register_value_caster(OperationType.get_static_typeid()) + +@ir.register_value_caster(AnyOpType.static_typeid) +@ir.register_value_caster(OperationType.static_typeid) class OpHandle(Handle): """ Wrapper around a transform operation handle with methods to chain further @@ -132,8 +133,8 @@ def print(self, name: Optional[str] = None) -> "OpHandle": return self -@ir.register_value_caster(AnyParamType.get_static_typeid()) -@ir.register_value_caster(ParamType.get_static_typeid()) +@ir.register_value_caster(AnyParamType.static_typeid) +@ir.register_value_caster(ParamType.static_typeid) class ParamHandle(Handle): """Wrapper around a transform param handle.""" @@ -147,7 +148,7 @@ def __init__( super().__init__(v, parent=parent, children=children) -@ir.register_value_caster(AnyValueType.get_static_typeid()) +@ir.register_value_caster(AnyValueType.static_typeid) class ValueHandle(Handle): """ Wrapper around a transform value handle with methods to chain further diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py index dfba2a36b8980..f75428d295c9c 100644 --- a/mlir/test/python/dialects/pdl_types.py +++ b/mlir/test/python/dialects/pdl_types.py @@ -5,149 +5,149 @@ def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: test_attribute_type @run def test_attribute_type(): - with Context(): - parsedType = Type.parse("!pdl.attribute") - constructedType = pdl.AttributeType.get() + with Context(): + parsedType = Type.parse("!pdl.attribute") + constructedType = pdl.AttributeType.get() - assert pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) + assert pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) - assert pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) + assert pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) - assert parsedType == constructedType + assert parsedType == constructedType - # CHECK: !pdl.attribute - print(parsedType) - # CHECK: !pdl.attribute - print(constructedType) + # CHECK: !pdl.attribute + print(parsedType) + # CHECK: !pdl.attribute + print(constructedType) # CHECK-LABEL: TEST: test_operation_type @run def test_operation_type(): - with Context(): - parsedType = Type.parse("!pdl.operation") - constructedType = pdl.OperationType.get() + with Context(): + parsedType = Type.parse("!pdl.operation") + constructedType = pdl.OperationType.get() - assert not pdl.AttributeType.isinstance(parsedType) - assert pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) + assert not pdl.AttributeType.isinstance(parsedType) + assert pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) - assert not pdl.AttributeType.isinstance(constructedType) - assert pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) + assert not pdl.AttributeType.isinstance(constructedType) + assert pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) - assert parsedType == constructedType + assert parsedType == constructedType - # CHECK: !pdl.operation - print(parsedType) - # CHECK: !pdl.operation - print(constructedType) + # CHECK: !pdl.operation + print(parsedType) + # CHECK: !pdl.operation + print(constructedType) # CHECK-LABEL: TEST: test_range_type @run def test_range_type(): - with Context(): - typeType = Type.parse("!pdl.type") - parsedType = Type.parse("!pdl.range") - constructedType = pdl.RangeType.get(typeType) - elementType = constructedType.element_type - - assert not pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) - - assert not pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) - - assert parsedType == constructedType - assert elementType == typeType - - # CHECK: !pdl.range - print(parsedType) - # CHECK: !pdl.range - print(constructedType) - # CHECK: !pdl.type - print(elementType) + with Context(): + typeType = Type.parse("!pdl.type") + parsedType = Type.parse("!pdl.range") + constructedType = pdl.RangeType.get(typeType) + elementType = constructedType.element_type + + assert not pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) + + assert not pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) + + assert parsedType == constructedType + assert elementType == typeType + + # CHECK: !pdl.range + print(parsedType) + # CHECK: !pdl.range + print(constructedType) + # CHECK: !pdl.type + print(elementType) # CHECK-LABEL: TEST: test_type_type @run def test_type_type(): - with Context(): - parsedType = Type.parse("!pdl.type") - constructedType = pdl.TypeType.get() + with Context(): + parsedType = Type.parse("!pdl.type") + constructedType = pdl.TypeType.get() - assert not pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) + assert not pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) - assert not pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) + assert not pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) - assert parsedType == constructedType + assert parsedType == constructedType - # CHECK: !pdl.type - print(parsedType) - # CHECK: !pdl.type - print(constructedType) + # CHECK: !pdl.type + print(parsedType) + # CHECK: !pdl.type + print(constructedType) # CHECK-LABEL: TEST: test_value_type @run def test_value_type(): - with Context(): - parsedType = Type.parse("!pdl.value") - constructedType = pdl.ValueType.get() + with Context(): + parsedType = Type.parse("!pdl.value") + constructedType = pdl.ValueType.get() - assert not pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert pdl.ValueType.isinstance(parsedType) + assert not pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert pdl.ValueType.isinstance(parsedType) - assert not pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert pdl.ValueType.isinstance(constructedType) + assert not pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert pdl.ValueType.isinstance(constructedType) - assert parsedType == constructedType + assert parsedType == constructedType - # CHECK: !pdl.value - print(parsedType) - # CHECK: !pdl.value - print(constructedType) + # CHECK: !pdl.value + print(parsedType) + # CHECK: !pdl.value + print(constructedType) # CHECK-LABEL: TEST: test_type_without_context @@ -157,7 +157,10 @@ def test_type_without_context(): # should raise an exception but not crash. try: constructedType = pdl.ValueType.get() - except TypeError: - pass + except RuntimeError as e: + assert ( + "An MLIR function requires a Context but none was provided in the call or from the surrounding environment" + in e.args[0] + ) else: assert False, "Expected TypeError to be raised."