diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h index 330318683c15e..bd0c715813bd9 100644 --- a/mlir/include/mlir/Bindings/Python/IRCore.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -1374,7 +1374,7 @@ class MLIR_PYTHON_API_EXPORTED PyRegionIterator { PyRegionIterator &dunderIter() { return *this; } - PyRegion dunderNext(); + nanobind::typed dunderNext(); static void bind(nanobind::module_ &m); @@ -1417,7 +1417,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockIterator { PyBlockIterator &dunderIter() { return *this; } - PyBlock dunderNext(); + nanobind::typed dunderNext(); static void bind(nanobind::module_ &m); @@ -1508,7 +1508,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator { PyOpOperandIterator &dunderIter() { return *this; } - PyOpOperand dunderNext(); + nanobind::typed dunderNext(); static void bind(nanobind::module_ &m); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 3685ff0d602e2..b2e9d9887e098 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -226,8 +226,11 @@ PyArrayAttribute::PyArrayAttributeIterator::dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (PyArrayAttribute::PyArrayAttributeIterator::nextIndex >= mlirArrayAttrGetNumElements( - PyArrayAttribute::PyArrayAttributeIterator::attr.get())) - throw nb::stop_iteration(); + PyArrayAttribute::PyArrayAttributeIterator::attr.get())) { + PyErr_SetNone(PyExc_StopIteration); + // python functions should return NULL after setting any exception + return nb::object(); + } return PyAttribute( this->PyArrayAttribute::PyArrayAttributeIterator::attr .getContext(), diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 19db41fae4fe2..1fe35d0f3fae5 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,13 +194,15 @@ nb::object PyBlock::getCapsule() { // Collections. //------------------------------------------------------------------------------ -PyRegion PyRegionIterator::dunderNext() { +nb::typed PyRegionIterator::dunderNext() { operation->checkValid(); if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { - throw nb::stop_iteration(); + PyErr_SetNone(PyExc_StopIteration); + // python functions should return NULL after setting any exception + return nb::object(); } MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); - return PyRegion(operation, region); + return nb::cast(PyRegion(operation, region)); } void PyRegionIterator::bind(nb::module_ &m) { @@ -244,15 +246,17 @@ PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length, return PyRegionList(operation, startIndex, length, step); } -PyBlock PyBlockIterator::dunderNext() { +nb::typed PyBlockIterator::dunderNext() { operation->checkValid(); if (mlirBlockIsNull(next)) { - throw nb::stop_iteration(); + PyErr_SetNone(PyExc_StopIteration); + // python functions should return NULL after setting any exception + return nb::object(); } PyBlock returnBlock(operation, next); next = mlirBlockGetNextInRegion(next); - return returnBlock; + return nb::cast(returnBlock); } void PyBlockIterator::bind(nb::module_ &m) { @@ -327,7 +331,9 @@ void PyBlockList::bind(nb::module_ &m) { nb::typed PyOperationIterator::dunderNext() { parentOperation->checkValid(); if (mlirOperationIsNull(next)) { - throw nb::stop_iteration(); + PyErr_SetNone(PyExc_StopIteration); + // python functions should return NULL after setting any exception + return nb::object(); } PyOperationRef returnOperation = @@ -410,13 +416,16 @@ void PyOpOperand::bind(nb::module_ &m) { "Returns the operand number in the owning operation."); } -PyOpOperand PyOpOperandIterator::dunderNext() { - if (mlirOpOperandIsNull(opOperand)) - throw nb::stop_iteration(); +nb::typed PyOpOperandIterator::dunderNext() { + if (mlirOpOperandIsNull(opOperand)) { + PyErr_SetNone(PyExc_StopIteration); + // python functions should return NULL after setting any exception + return nb::object(); + } PyOpOperand returnOpOperand(opOperand); opOperand = mlirOpOperandGetNextUse(opOperand); - return returnOpOperand; + return nb::cast(returnOpOperand); } void PyOpOperandIterator::bind(nb::module_ &m) {