diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 212228fbac91e..404c4d842e02c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -485,13 +485,14 @@ class PyArrayAttribute : public PyConcreteAttribute { PyArrayAttributeIterator &dunderIter() { return *this; } - nb::object dunderNext() { + nb::typed dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw nb::stop_iteration(); - return PyAttribute(this->attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)) - .maybeDownCast(); + return nb::cast>( + PyAttribute(this->attr.getContext(), + mlirArrayAttrGetElement(attr.get(), nextIndex++)) + .maybeDownCast()); } static void bind(nb::module_ &m) { @@ -524,13 +525,13 @@ class PyArrayAttribute : public PyConcreteAttribute { }, nb::arg("attributes"), nb::arg("context") = nb::none(), "Gets a uniqued Array attribute"); - c.def( - "__getitem__", - [](PyArrayAttribute &arr, intptr_t i) { - if (i >= mlirArrayAttrGetNumElements(arr)) - throw nb::index_error("ArrayAttribute index out of range"); - return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast(); - }) + c.def("__getitem__", + [](PyArrayAttribute &arr, intptr_t i) { + if (i >= mlirArrayAttrGetNumElements(arr)) + throw nb::index_error("ArrayAttribute index out of range"); + return nb::cast>( + PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast()); + }) .def("__len__", [](const PyArrayAttribute &arr) { return mlirArrayAttrGetNumElements(arr); @@ -1014,9 +1015,10 @@ class PyDenseElementsAttribute if (!mlirDenseElementsAttrIsSplat(self)) throw nb::value_error( "get_splat_value called on a non-splat attribute"); - return PyAttribute(self.getContext(), - mlirDenseElementsAttrGetSplatValue(self)) - .maybeDownCast(); + return nb::cast>( + PyAttribute(self.getContext(), + mlirDenseElementsAttrGetSplatValue(self)) + .maybeDownCast()); }); } @@ -1527,7 +1529,8 @@ class PyDictAttribute : public PyConcreteAttribute { mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) throw nb::key_error("attempt to access a non-existent attribute"); - return PyAttribute(self.getContext(), attr).maybeDownCast(); + return nb::cast>( + PyAttribute(self.getContext(), attr).maybeDownCast()); }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { @@ -1595,8 +1598,9 @@ class PyTypeAttribute : public PyConcreteAttribute { nb::arg("value"), nb::arg("context") = nb::none(), "Gets a uniqued Type attribute"); c.def_prop_ro("value", [](PyTypeAttribute &self) { - return PyType(self.getContext(), mlirTypeAttrGetValue(self.get())) - .maybeDownCast(); + return nb::cast>( + PyType(self.getContext(), mlirTypeAttrGetValue(self.get())) + .maybeDownCast()); }); } }; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 81386f2227a7f..5a6edfa737fd7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -528,7 +528,8 @@ class PyOperationIterator { static void bind(nb::module_ &m) { nb::class_(m, "OperationIterator") .def("__iter__", &PyOperationIterator::dunderIter) - .def("__next__", &PyOperationIterator::dunderNext); + .def("__next__", &PyOperationIterator::dunderNext, + nb::sig("def __next__(self) -> OpView")); } private: @@ -1604,8 +1605,9 @@ class PyConcreteValue : public PyValue { return DerivedTy::isaFunction(otherValue); }, nb::arg("other_value")); - cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](DerivedTy &self) { return self.maybeDownCast(); }); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) { + return nb::cast>(self.maybeDownCast()); + }); DerivedTy::bindDerived(cls); } @@ -1638,14 +1640,15 @@ class PyOpResult : public PyConcreteValue { /// Returns the list of types of the values held by container. template -static std::vector getValueTypes(Container &container, - PyMlirContextRef &context) { - std::vector result; +static std::vector> +getValueTypes(Container &container, PyMlirContextRef &context) { + std::vector> result; result.reserve(container.size()); for (int i = 0, e = container.size(); i < e; ++i) { - result.push_back(PyType(context->getRef(), - mlirValueGetType(container.getElement(i).get())) - .maybeDownCast()); + result.push_back(nb::cast>( + PyType(context->getRef(), + mlirValueGetType(container.getElement(i).get())) + .maybeDownCast())); } return result; } @@ -2677,13 +2680,15 @@ class PyOpAttributeMap { PyOpAttributeMap(PyOperationRef operation) : operation(std::move(operation)) {} - nb::object dunderGetItemNamed(const std::string &name) { + nb::typed + dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { throw nb::key_error("attempt to access a non-existent attribute"); } - return PyAttribute(operation->getContext(), attr).maybeDownCast(); + return nb::cast>( + PyAttribute(operation->getContext(), attr).maybeDownCast()); } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { @@ -2961,14 +2966,17 @@ void mlir::python::populateIRCore(nb::module_ &m) { new (&self) PyMlirContext(context); }) .def_static("_get_live_count", &PyMlirContext::getLiveCount) - .def("_get_context_again", - [](PyMlirContext &self) { - PyMlirContextRef ref = PyMlirContext::forContext(self.get()); - return ref.releaseObject(); - }) + .def( + "_get_context_again", + [](PyMlirContext &self) { + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); + }, + nb::sig("def _get_context_again(self) -> Context")) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule, + nb::sig("def _CAPICreate(self) -> Context")) .def("__enter__", &PyMlirContext::contextEnter) .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), nb::arg("traceback").none()) @@ -3463,8 +3471,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { "result", [](PyOperationBase &self) { auto &operation = self.getOperation(); - return PyOpResult(operation.getRef(), getUniqueResult(operation)) - .maybeDownCast(); + return nb::cast>( + PyOpResult(operation.getRef(), getUniqueResult(operation)) + .maybeDownCast()); }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") @@ -3988,7 +3997,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { context->get(), toMlirStringRef(attrSpec)); if (mlirAttributeIsNull(attr)) throw MLIRError("Unable to parse attribute", errors.take()); - return PyAttribute(context.get()->getRef(), attr).maybeDownCast(); + return nb::cast>( + PyAttribute(context.get()->getRef(), attr).maybeDownCast()); }, nb::arg("asm"), nb::arg("context") = nb::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " @@ -3999,9 +4009,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Context that owns the Attribute") .def_prop_ro("type", [](PyAttribute &self) { - return PyType(self.getContext(), - mlirAttributeGetType(self)) - .maybeDownCast(); + return nb::cast>( + PyType(self.getContext(), mlirAttributeGetType(self)) + .maybeDownCast()); }) .def( "get_named", @@ -4049,7 +4059,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { "mlirTypeID was expected to be non-null."); return PyTypeID(mlirTypeID); }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast); + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { + nb::cast>(self.maybeDownCast()); + }); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute @@ -4100,7 +4112,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); if (mlirTypeIsNull(type)) throw MLIRError("Unable to parse type", errors.take()); - return PyType(context.get()->getRef(), type).maybeDownCast(); + return nb::cast>( + PyType(context.get()->getRef(), type).maybeDownCast()); }, nb::arg("asm"), nb::arg("context") = nb::none(), kContextParseTypeDocstring) @@ -4139,7 +4152,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyType &self) { + return nb::cast>( + self.maybeDownCast()); + }) .def_prop_ro("typeid", [](PyType &self) { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) @@ -4267,9 +4284,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("state"), kGetNameAsOperand) .def_prop_ro("type", [](PyValue &self) { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())) - .maybeDownCast(); + return nb::cast>( + PyType(self.getParentOperation()->getContext(), + mlirValueGetType(self.get())) + .maybeDownCast()); }) .def( "set_type", @@ -4305,7 +4323,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("with_"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) { + return nb::cast>( + self.maybeDownCast()); + }) .def_prop_ro( "location", [](MlirValue self) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 6e97c00d478f1..dc9913bc5ebb2 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -1102,8 +1102,9 @@ class PyConcreteAttribute : public BaseTy { }, nanobind::arg("other")); cls.def_prop_ro("type", [](PyAttribute &attr) { - return PyType(attr.getContext(), mlirAttributeGetType(attr)) - .maybeDownCast(); + return nanobind::cast>( + PyType(attr.getContext(), mlirAttributeGetType(attr)) + .maybeDownCast()); }); cls.def_prop_ro_static( "static_typeid", diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index a7aa1c65c6c43..a228ca4418c4a 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -502,8 +502,9 @@ class PyComplexType : public PyConcreteType { c.def_prop_ro( "element_type", [](PyComplexType &self) { - return PyType(self.getContext(), mlirComplexTypeGetElementType(self)) - .maybeDownCast(); + return nb::cast>( + PyType(self.getContext(), mlirComplexTypeGetElementType(self)) + .maybeDownCast()); }, "Returns element type."); } @@ -516,8 +517,9 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { c.def_prop_ro( "element_type", [](PyShapedType &self) { - return PyType(self.getContext(), mlirShapedTypeGetElementType(self)) - .maybeDownCast(); + return nb::cast>( + PyType(self.getContext(), mlirShapedTypeGetElementType(self)) + .maybeDownCast()); }, "Returns the element type of the shaped type."); c.def_prop_ro( @@ -898,11 +900,21 @@ class PyTupleType : public PyConcreteType { }, nb::arg("elements"), nb::arg("context") = nb::none(), "Create a tuple type"); + c.def_static( + "get_tuple", + [](std::vector elements, DefaultingPyMlirContext context) { + MlirType t = mlirTupleTypeGet(context->get(), elements.size(), + elements.data()); + return PyTupleType(context->getRef(), t); + }, + nb::arg("elements"), nb::arg("context") = nb::none(), + "Create a tuple type"); c.def( "get_type", [](PyTupleType &self, intptr_t pos) { - return PyType(self.getContext(), mlirTupleTypeGetType(self, pos)) - .maybeDownCast(); + return nb::cast>( + PyType(self.getContext(), mlirTupleTypeGetType(self, pos)) + .maybeDownCast()); }, nb::arg("pos"), "Returns the pos-th type in the tuple type."); c.def_prop_ro( @@ -926,23 +938,17 @@ class PyFunctionType : public PyConcreteType { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::vector inputs, std::vector results, + [](std::vector inputs, std::vector results, DefaultingPyMlirContext context) { - std::vector mlirInputs; - mlirInputs.reserve(inputs.size()); - for (const auto &input : inputs) - mlirInputs.push_back(input.get()); - std::vector mlirResults; - mlirResults.reserve(results.size()); - for (const auto &result : results) - mlirResults.push_back(result.get()); - - MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(), - mlirInputs.data(), results.size(), - mlirResults.data()); + MlirType t = + mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), + results.size(), results.data()); return PyFunctionType(context->getRef(), t); }, nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"), + // clang-format on "Gets a FunctionType from a list of input and result types"); c.def_prop_ro( "inputs",