From 9886e3edd6c161b3348901f61487d0d6b131cd45 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Mon, 10 Nov 2025 20:13:10 -0800 Subject: [PATCH 1/6] [mlir][python] Wrappers for scf.index_switch The C++ index switch op has utilies for getCaseBlock(int i) and getDefaultBlock(), so these have been added. Optional builder args have been added for the default case and each switch case. The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents this. The same paradigm is used for get_case_block(i: int), but this is unavoidable. --- mlir/python/mlir/dialects/scf.py | 76 +++++++++++++++++++ mlir/test/python/dialects/scf.py | 126 ++++++++++++++++++++++++++++++- 2 files changed, 198 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 678ceeebac204..59ccbce147be3 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -6,12 +6,14 @@ from ._scf_ops_gen import * from ._scf_ops_gen import _Dialect from .arith import constant +import builtins try: from ..ir import * from ._ods_common import ( get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, + get_op_result_or_op_results as _get_op_result_or_op_results, _cext as _ods_cext, ) except ImportError as e: @@ -254,3 +256,77 @@ def for_( yield iv, iter_args[0], for_op.results[0] else: yield iv + + +@_ods_cext.register_operation(_Dialect, replace=True) +class IndexSwitchOp(IndexSwitchOp): + __doc__ = IndexSwitchOp.__doc__ + + def __init__( + self, + results_, + arg, + cases, + case_body_builder=None, + default_body_builder=None, + loc=None, + ip=None, + ): + cases = DenseI64ArrayAttr.get(cases) + super().__init__( + results_, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip + ) + for region in self.regions: + region.blocks.append() + + if default_body_builder is not None: + with InsertionPoint(self.default_block): + default_body_builder(self) + + if case_body_builder is not None: + for i, case in enumerate(cases): + with InsertionPoint(self.case_block(i)): + case_body_builder(self, i, self.cases[i]) + + @builtins.property + def default_region(self) -> Region: + return self.regions[0] + + @builtins.property + def default_block(self) -> Block: + return self.default_region.blocks[0] + + @builtins.property + def case_regions(self) -> Sequence[Region]: + return [self.regions[1 + i] for i in range(len(self.cases))] + + def case_region(self, i: int) -> Region: + return self.case_regions[i] + + @builtins.property + def case_blocks(self) -> Sequence[Block]: + return [region.blocks[0] for region in self.case_regions] + + def case_block(self, i: int) -> Block: + return self.case_regions[i].blocks[0] + + +def index_switch( + results_, + arg, + cases, + case_body_builder=None, + default_body_builder=None, + loc=None, + ip=None, +) -> Union[OpResult, OpResultList, IndexSwitchOp]: + op = IndexSwitchOp( + results_=results_, + arg=arg, + cases=cases, + case_body_builder=case_body_builder, + default_body_builder=default_body_builder, + loc=loc, + ip=ip, + ) + return _get_op_result_or_op_results(op) diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 62d11d5e189c8..11d207b4a5e07 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -1,10 +1,14 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -from mlir.dialects import arith -from mlir.dialects import func -from mlir.dialects import memref -from mlir.dialects import scf +from mlir.extras import types as T +from mlir.dialects import ( + arith, + func, + memref, + scf, + cf, +) from mlir.passmanager import PassManager @@ -355,3 +359,117 @@ def simple_if_else(cond): # CHECK: scf.yield %[[TWO]], %[[THREE]] # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 # CHECK: return + + +@constructAndPrintInModule +def testIndexSwitch(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + switch_op = scf.IndexSwitchOp([i32], index, range(3)) + + assert switch_op.regions[0] == switch_op.default_region + assert switch_op.regions[1] == switch_op.case_regions[0] + assert switch_op.regions[1] == switch_op.case_region(0) + assert len(switch_op.case_regions) == 3 + assert len(switch_op.regions) == 4 + + with InsertionPoint(switch_op.default_block): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + for i, block in enumerate(switch_op.case_blocks): + with InsertionPoint(block): + scf.yield_([arith.constant(i32, i)]) + + func.return_([switch_op.results[0]]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } + + +@constructAndPrintInModule +def testIndexSwitchWithBodyBuilders(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + + def default_body_builder(switch_op): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + def case_body_builder(switch_op, case_index: int, case_value: int): + scf.yield_([arith.constant(i32, case_value)]) + + result = scf.index_switch( + results_=[i32], + arg=index, + cases=range(3), + case_body_builder=case_body_builder, + default_body_builder=default_body_builder, + ) + + func.return_([result]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } From 854729950e2fb816e4ac7055b160ee7eb49b98a6 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 11 Nov 2025 07:48:28 -0800 Subject: [PATCH 2/6] Update region slicing after #167466 --- mlir/python/mlir/dialects/scf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 59ccbce147be3..d0dcda1fbc1e9 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -298,7 +298,7 @@ def default_block(self) -> Block: @builtins.property def case_regions(self) -> Sequence[Region]: - return [self.regions[1 + i] for i in range(len(self.cases))] + return self.regions[1:] def case_region(self, i: int) -> Region: return self.case_regions[i] From 12dc35771c22ca78a7c08b260305bed8b699992f Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 11 Nov 2025 08:02:14 -0800 Subject: [PATCH 3/6] Update test with new IndexSwitchOp interface --- mlir/test/python/ir/operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 1bdd345d98c05..5585ee7c82e04 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -1213,7 +1213,7 @@ def testIndexSwitch(): def index_switch(index): c1 = arith.constant(i32, 1) switch_op = scf.IndexSwitchOp( - results_=[i32], arg=index, cases=range(3), num_caseRegions=3 + results_=[i32], arg=index, cases=range(3) ) assert len(switch_op.regions) == 4 From ee73bdd8972f4b64622f406d0fa86ee5db130911 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 11 Nov 2025 08:11:40 -0800 Subject: [PATCH 4/6] Format --- mlir/test/python/ir/operation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 5585ee7c82e04..9dcc1d4c1eb08 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -1212,9 +1212,7 @@ def testIndexSwitch(): @func.FuncOp.from_py_func(T.index()) def index_switch(index): c1 = arith.constant(i32, 1) - switch_op = scf.IndexSwitchOp( - results_=[i32], arg=index, cases=range(3) - ) + switch_op = scf.IndexSwitchOp(results_=[i32], arg=index, cases=range(3)) assert len(switch_op.regions) == 4 assert len(switch_op.regions[2:]) == 2 From 32cb7996a0d66e1c01e28e543ccb0a7ea17dc3f4 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 11 Nov 2025 09:11:58 -0800 Subject: [PATCH 5/6] Address nits --- mlir/python/mlir/dialects/scf.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index d0dcda1fbc1e9..9e22df3dd50a9 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -6,7 +6,6 @@ from ._scf_ops_gen import * from ._scf_ops_gen import _Dialect from .arith import constant -import builtins try: from ..ir import * @@ -264,7 +263,7 @@ class IndexSwitchOp(IndexSwitchOp): def __init__( self, - results_, + results, arg, cases, case_body_builder=None, @@ -274,7 +273,7 @@ def __init__( ): cases = DenseI64ArrayAttr.get(cases) super().__init__( - results_, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip + results, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip ) for region in self.regions: region.blocks.append() @@ -288,22 +287,22 @@ def __init__( with InsertionPoint(self.case_block(i)): case_body_builder(self, i, self.cases[i]) - @builtins.property + @property def default_region(self) -> Region: return self.regions[0] - @builtins.property + @property def default_block(self) -> Block: return self.default_region.blocks[0] - @builtins.property + @property def case_regions(self) -> Sequence[Region]: return self.regions[1:] def case_region(self, i: int) -> Region: return self.case_regions[i] - @builtins.property + @property def case_blocks(self) -> Sequence[Block]: return [region.blocks[0] for region in self.case_regions] @@ -312,7 +311,7 @@ def case_block(self, i: int) -> Block: def index_switch( - results_, + results, arg, cases, case_body_builder=None, @@ -321,7 +320,7 @@ def index_switch( ip=None, ) -> Union[OpResult, OpResultList, IndexSwitchOp]: op = IndexSwitchOp( - results_=results_, + results=results, arg=arg, cases=cases, case_body_builder=case_body_builder, From d82b0a1a17b3e0bfa569405651c40100d2954c74 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 11 Nov 2025 09:24:32 -0800 Subject: [PATCH 6/6] Update tests to use non-suffixed kwarg --- mlir/test/python/dialects/scf.py | 2 +- mlir/test/python/ir/operation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 11d207b4a5e07..0c0c9b986562b 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -436,7 +436,7 @@ def case_body_builder(switch_op, case_index: int, case_value: int): scf.yield_([arith.constant(i32, case_value)]) result = scf.index_switch( - results_=[i32], + results=[i32], arg=index, cases=range(3), case_body_builder=case_body_builder, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 9dcc1d4c1eb08..66ba5d28e49b2 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -1212,7 +1212,7 @@ def testIndexSwitch(): @func.FuncOp.from_py_func(T.index()) def index_switch(index): c1 = arith.constant(i32, 1) - switch_op = scf.IndexSwitchOp(results_=[i32], arg=index, cases=range(3)) + switch_op = scf.IndexSwitchOp(results=[i32], arg=index, cases=range(3)) assert len(switch_op.regions) == 4 assert len(switch_op.regions[2:]) == 2