diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 12793f7dd15be..dc41aaea3261c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -361,37 +361,45 @@ class PyRegionIterator { /// Regions of an op are fixed length and indexed numerically so are represented /// with a sequence-like container. -class PyRegionList { +class PyRegionList : public Sliceable { public: - PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} + static constexpr const char *pyClassName = "RegionSequence"; + + PyRegionList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumRegions(operation->get()) + : length, + step), + operation(std::move(operation)) {} PyRegionIterator dunderIter() { operation->checkValid(); return PyRegionIterator(operation); } - intptr_t dunderLen() { + static void bindDerived(ClassTy &c) { + c.def("__iter__", &PyRegionList::dunderIter); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { operation->checkValid(); return mlirOperationGetNumRegions(operation->get()); } - PyRegion dunderGetItem(intptr_t index) { - // dunderLen checks validity. - if (index < 0 || index >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds region"); - } - MlirRegion region = mlirOperationGetRegion(operation->get(), index); - return PyRegion(operation, region); + PyRegion getRawElement(intptr_t pos) { + operation->checkValid(); + return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos)); } - static void bind(nb::module_ &m) { - nb::class_(m, "RegionSequence") - .def("__len__", &PyRegionList::dunderLen) - .def("__iter__", &PyRegionList::dunderIter) - .def("__getitem__", &PyRegionList::dunderGetItem); + PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyRegionList(operation, startIndex, length, step); } -private: PyOperationRef operation; }; @@ -450,6 +458,9 @@ class PyBlockList { PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); + if (index < 0) { + index += dunderLen(); + } if (index < 0) { throw nb::index_error("attempt to access out of bounds block"); } @@ -546,6 +557,9 @@ class PyOperationList { nb::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); + if (index < 0) { + index += dunderLen(); + } if (index < 0) { throw nb::index_error("attempt to access out of bounds operation"); } @@ -2629,6 +2643,9 @@ class PyOpAttributeMap { } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { + if (index < 0) { + index += dunderLen(); + } if (index < 0 || index >= dunderLen()) { throw nb::index_error("attempt to access out of bounds attribute"); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index c93de2fe3154e..c60ff72ff9fd4 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -2466,7 +2466,10 @@ class RegionIterator: def __next__(self) -> Region: ... class RegionSequence: + @overload def __getitem__(self, arg0: int) -> Region: ... + @overload + def __getitem__(self, arg0: slice) -> Sequence[Region]: ... def __iter__(self) -> RegionIterator: ... def __len__(self) -> int: ... diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index dd2731ba2e1f1..b08fe98397fbc 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -44,7 +44,7 @@ def testTraverseOpRegionBlockIterators(): op = module.operation assert op.context is ctx # Get the block using iterators off of the named collections. - regions = list(op.regions) + regions = list(op.regions[:]) blocks = list(regions[0].blocks) # CHECK: MODULE REGIONS=1 BLOCKS=1 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") @@ -86,8 +86,24 @@ def walk_operations(indent, op): # CHECK: Block iter: