Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ namespace {

class PyRegionIterator {
public:
PyRegionIterator(PyOperationRef operation)
: operation(std::move(operation)) {}
PyRegionIterator(PyOperationRef operation, int nextIndex)
: operation(std::move(operation)), nextIndex(nextIndex) {}

PyRegionIterator &dunderIter() { return *this; }

Expand All @@ -228,7 +228,7 @@ class PyRegionIterator {

private:
PyOperationRef operation;
int nextIndex = 0;
intptr_t nextIndex = 0;
};

/// Regions of an op are fixed length and indexed numerically so are represented
Expand All @@ -247,7 +247,7 @@ class PyRegionList : public Sliceable<PyRegionList, PyRegion> {

PyRegionIterator dunderIter() {
operation->checkValid();
return PyRegionIterator(operation);
return PyRegionIterator(operation, startIndex);
}

static void bindDerived(ClassTy &c) {
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Bindings/Python/NanobindUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ class Sliceable {
/// Hook for derived classes willing to bind more methods.
static void bindDerived(ClassTy &) {}

private:
intptr_t startIndex;
intptr_t length;
intptr_t step;
Expand Down
28 changes: 26 additions & 2 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import gc
import io
import itertools
from tempfile import NamedTemporaryFile
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
from mlir.dialects import arith, func, scf
from mlir.dialects._ods_common import _cext
from mlir.extras import types as T


def run(f):
Expand Down Expand Up @@ -1199,3 +1199,27 @@ def testGetOwnerConcreteOpview():
r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw)
for u in a.result.uses:
assert isinstance(u.owner, arith.AddIOp)


# CHECK-LABEL: TEST: testIndexSwitch
@run
def testIndexSwitch():
with Context() as ctx, Location.unknown():
i32 = T.i32()
module = Module.create()
with InsertionPoint(module.body):

@func.FuncOp.from_py_func(T.index())
def index_switch(index):
c1 = arith.constant(i32, 1)
switch_op = scf.IndexSwitchOp(
results_=[i32], arg=index, cases=range(3), num_caseRegions=3
)

assert len(switch_op.regions) == 4
assert len(switch_op.regions[2:]) == 2
assert len([i for i in switch_op.regions[2:]]) == 2
assert len(switch_op.caseRegions) == 3
assert len([i for i in switch_op.caseRegions]) == 3
assert len(switch_op.caseRegions[1:]) == 2
assert len([i for i in switch_op.caseRegions[1:]]) == 2
Loading