diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d90f27bd037e6..40a466beee159 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -204,8 +204,8 @@ namespace { class PyRegionIterator { public: - PyRegionIterator(PyOperationRef operation) - : operation(std::move(operation)) {} + PyRegionIterator(PyOperationRef operation, int nextIndex) + : operation(std::move(operation)), nextIndex(nextIndex) {} PyRegionIterator &dunderIter() { return *this; } @@ -228,7 +228,7 @@ class PyRegionIterator { private: PyOperationRef operation; - int nextIndex = 0; + intptr_t nextIndex = 0; }; /// Regions of an op are fixed length and indexed numerically so are represented @@ -247,7 +247,7 @@ class PyRegionList : public Sliceable { PyRegionIterator dunderIter() { operation->checkValid(); - return PyRegionIterator(operation); + return PyRegionIterator(operation, startIndex); } static void bindDerived(ClassTy &c) { diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index 64ea4329f65f1..658e8ad5330ef 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -395,7 +395,6 @@ class Sliceable { /// Hook for derived classes willing to bind more methods. static void bindDerived(ClassTy &) {} -private: intptr_t startIndex; intptr_t length; intptr_t step; diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index f5fa4dad856f8..1bdd345d98c05 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -2,12 +2,12 @@ import gc import io -import itertools from tempfile import NamedTemporaryFile from mlir.ir import * from mlir.dialects.builtin import ModuleOp -from mlir.dialects import arith +from mlir.dialects import arith, func, scf from mlir.dialects._ods_common import _cext +from mlir.extras import types as T def run(f): @@ -1199,3 +1199,27 @@ def testGetOwnerConcreteOpview(): r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw) for u in a.result.uses: assert isinstance(u.owner, arith.AddIOp) + + +# CHECK-LABEL: TEST: testIndexSwitch +@run +def testIndexSwitch(): + with Context() as ctx, Location.unknown(): + i32 = T.i32() + module = Module.create() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func(T.index()) + def index_switch(index): + c1 = arith.constant(i32, 1) + switch_op = scf.IndexSwitchOp( + results_=[i32], arg=index, cases=range(3), num_caseRegions=3 + ) + + assert len(switch_op.regions) == 4 + assert len(switch_op.regions[2:]) == 2 + assert len([i for i in switch_op.regions[2:]]) == 2 + assert len(switch_op.caseRegions) == 3 + assert len([i for i in switch_op.caseRegions]) == 3 + assert len(switch_op.caseRegions[1:]) == 2 + assert len([i for i in switch_op.caseRegions[1:]]) == 2