Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,14 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {

PyArrayAttributeIterator &dunderIter() { return *this; }

nb::object dunderNext() {
nb::typed<nb::object, PyAttribute> dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
return PyAttribute(this->attr.getContext(),
mlirArrayAttrGetElement(attr.get(), nextIndex++))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(this->attr.getContext(),
mlirArrayAttrGetElement(attr.get(), nextIndex++))
.maybeDownCast());
}

static void bind(nb::module_ &m) {
Expand Down Expand Up @@ -524,13 +525,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
},
nb::arg("attributes"), nb::arg("context") = nb::none(),
"Gets a uniqued Array attribute");
c.def(
"__getitem__",
[](PyArrayAttribute &arr, intptr_t i) {
if (i >= mlirArrayAttrGetNumElements(arr))
throw nb::index_error("ArrayAttribute index out of range");
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
})
c.def("__getitem__",
[](PyArrayAttribute &arr, intptr_t i) {
if (i >= mlirArrayAttrGetNumElements(arr))
throw nb::index_error("ArrayAttribute index out of range");
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast());
})
.def("__len__",
[](const PyArrayAttribute &arr) {
return mlirArrayAttrGetNumElements(arr);
Expand Down Expand Up @@ -1014,9 +1015,10 @@ class PyDenseElementsAttribute
if (!mlirDenseElementsAttrIsSplat(self))
throw nb::value_error(
"get_splat_value called on a non-splat attribute");
return PyAttribute(self.getContext(),
mlirDenseElementsAttrGetSplatValue(self))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(self.getContext(),
mlirDenseElementsAttrGetSplatValue(self))
.maybeDownCast());
});
}

Expand Down Expand Up @@ -1527,7 +1529,8 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
if (mlirAttributeIsNull(attr))
throw nb::key_error("attempt to access a non-existent attribute");
return PyAttribute(self.getContext(), attr).maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(self.getContext(), attr).maybeDownCast());
});
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
Expand Down Expand Up @@ -1595,8 +1598,9 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued Type attribute");
c.def_prop_ro("value", [](PyTypeAttribute &self) {
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyType>>(
PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
.maybeDownCast());
});
}
};
Expand Down
82 changes: 52 additions & 30 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ class PyOperationIterator {
static void bind(nb::module_ &m) {
nb::class_<PyOperationIterator>(m, "OperationIterator")
.def("__iter__", &PyOperationIterator::dunderIter)
.def("__next__", &PyOperationIterator::dunderNext);
.def("__next__", &PyOperationIterator::dunderNext,
nb::sig("def __next__(self) -> OpView"));
}

private:
Expand Down Expand Up @@ -1604,8 +1605,9 @@ class PyConcreteValue : public PyValue {
return DerivedTy::isaFunction(otherValue);
},
nb::arg("other_value"));
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](DerivedTy &self) { return self.maybeDownCast(); });
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) {
return nb::cast<nb::typed<nb::object, DerivedTy>>(self.maybeDownCast());
});
DerivedTy::bindDerived(cls);
}

Expand Down Expand Up @@ -1638,14 +1640,15 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {

/// Returns the list of types of the values held by container.
template <typename Container>
static std::vector<nb::object> getValueTypes(Container &container,
PyMlirContextRef &context) {
std::vector<nb::object> result;
static std::vector<nb::typed<nb::object, PyType>>
getValueTypes(Container &container, PyMlirContextRef &context) {
std::vector<nb::typed<nb::object, PyType>> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(PyType(context->getRef(),
mlirValueGetType(container.getElement(i).get()))
.maybeDownCast());
result.push_back(nb::cast<nb::typed<nb::object, PyType>>(
PyType(context->getRef(),
mlirValueGetType(container.getElement(i).get()))
.maybeDownCast()));
}
return result;
}
Expand Down Expand Up @@ -2677,13 +2680,15 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}

nb::object dunderGetItemNamed(const std::string &name) {
nb::typed<nb::object, PyAttribute>
dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
throw nb::key_error("attempt to access a non-existent attribute");
}
return PyAttribute(operation->getContext(), attr).maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(operation->getContext(), attr).maybeDownCast());
}

PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
Expand Down Expand Up @@ -2961,14 +2966,17 @@ void mlir::python::populateIRCore(nb::module_ &m) {
new (&self) PyMlirContext(context);
})
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
.def("_get_context_again",
[](PyMlirContext &self) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def(
"_get_context_again",
[](PyMlirContext &self) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
},
nb::sig("def _get_context_again(self) -> Context"))
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule,
nb::sig("def _CAPICreate(self) -> Context"))
.def("__enter__", &PyMlirContext::contextEnter)
.def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
nb::arg("exc_value").none(), nb::arg("traceback").none())
Expand Down Expand Up @@ -3463,8 +3471,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"result",
[](PyOperationBase &self) {
auto &operation = self.getOperation();
return PyOpResult(operation.getRef(), getUniqueResult(operation))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyOpResult>>(
PyOpResult(operation.getRef(), getUniqueResult(operation))
.maybeDownCast());
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
Expand Down Expand Up @@ -3988,7 +3997,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
context->get(), toMlirStringRef(attrSpec));
if (mlirAttributeIsNull(attr))
throw MLIRError("Unable to parse attribute", errors.take());
return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(context.get()->getRef(), attr).maybeDownCast());
},
nb::arg("asm"), nb::arg("context") = nb::none(),
"Parses an attribute from an assembly form. Raises an MLIRError on "
Expand All @@ -3999,9 +4009,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Context that owns the Attribute")
.def_prop_ro("type",
[](PyAttribute &self) {
return PyType(self.getContext(),
mlirAttributeGetType(self))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyType>>(
PyType(self.getContext(), mlirAttributeGetType(self))
.maybeDownCast());
})
.def(
"get_named",
Expand Down Expand Up @@ -4049,7 +4059,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"mlirTypeID was expected to be non-null.");
return PyTypeID(mlirTypeID);
})
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
nb::cast<nb::typed<nb::object, PyAttribute>>(self.maybeDownCast());
});

//----------------------------------------------------------------------------
// Mapping of PyNamedAttribute
Expand Down Expand Up @@ -4100,7 +4112,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
if (mlirTypeIsNull(type))
throw MLIRError("Unable to parse type", errors.take());
return PyType(context.get()->getRef(), type).maybeDownCast();
return nb::cast<nb::typed<nb::object, PyType>>(
PyType(context.get()->getRef(), type).maybeDownCast());
},
nb::arg("asm"), nb::arg("context") = nb::none(),
kContextParseTypeDocstring)
Expand Down Expand Up @@ -4139,7 +4152,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.parts.append(")");
return printAccum.join();
})
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyType &self) {
return nb::cast<nb::typed<nb::object, PyType>>(
self.maybeDownCast());
})
.def_prop_ro("typeid", [](PyType &self) {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
Expand Down Expand Up @@ -4267,9 +4284,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("state"), kGetNameAsOperand)
.def_prop_ro("type",
[](PyValue &self) {
return PyType(self.getParentOperation()->getContext(),
mlirValueGetType(self.get()))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyType>>(
PyType(self.getParentOperation()->getContext(),
mlirValueGetType(self.get()))
.maybeDownCast());
})
.def(
"set_type",
Expand Down Expand Up @@ -4305,7 +4323,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("with_"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) {
return nb::cast<nb::typed<nb::object, PyValue>>(
self.maybeDownCast());
})
.def_prop_ro(
"location",
[](MlirValue self) {
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -1102,8 +1102,9 @@ class PyConcreteAttribute : public BaseTy {
},
nanobind::arg("other"));
cls.def_prop_ro("type", [](PyAttribute &attr) {
return PyType(attr.getContext(), mlirAttributeGetType(attr))
.maybeDownCast();
return nanobind::cast<nanobind::typed<nanobind::object, PyType>>(
PyType(attr.getContext(), mlirAttributeGetType(attr))
.maybeDownCast());
});
cls.def_prop_ro_static(
"static_typeid",
Expand Down
44 changes: 25 additions & 19 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,9 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
c.def_prop_ro(
"element_type",
[](PyComplexType &self) {
return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyType>>(
PyType(self.getContext(), mlirComplexTypeGetElementType(self))
.maybeDownCast());
},
"Returns element type.");
}
Expand All @@ -516,8 +517,9 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
[](PyShapedType &self) {
return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyType>>(
PyType(self.getContext(), mlirShapedTypeGetElementType(self))
.maybeDownCast());
},
"Returns the element type of the shaped type.");
c.def_prop_ro(
Expand Down Expand Up @@ -898,11 +900,21 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
},
nb::arg("elements"), nb::arg("context") = nb::none(),
"Create a tuple type");
c.def_static(
"get_tuple",
[](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
elements.data());
return PyTupleType(context->getRef(), t);
},
nb::arg("elements"), nb::arg("context") = nb::none(),
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) {
return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyType>>(
PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
.maybeDownCast());
},
nb::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_prop_ro(
Expand All @@ -926,23 +938,17 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<PyType> inputs, std::vector<PyType> results,
[](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
std::vector<MlirType> mlirInputs;
mlirInputs.reserve(inputs.size());
for (const auto &input : inputs)
mlirInputs.push_back(input.get());
std::vector<MlirType> mlirResults;
mlirResults.reserve(results.size());
for (const auto &result : results)
mlirResults.push_back(result.get());

MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
mlirInputs.data(), results.size(),
mlirResults.data());
MlirType t =
mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
// clang-format off
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
// clang-format on
"Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
"inputs",
Expand Down
Loading