diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 678ceeebac204..9e22df3dd50a9 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -12,6 +12,7 @@ from ._ods_common import ( get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, + get_op_result_or_op_results as _get_op_result_or_op_results, _cext as _ods_cext, ) except ImportError as e: @@ -254,3 +255,77 @@ def for_( yield iv, iter_args[0], for_op.results[0] else: yield iv + + +@_ods_cext.register_operation(_Dialect, replace=True) +class IndexSwitchOp(IndexSwitchOp): + __doc__ = IndexSwitchOp.__doc__ + + def __init__( + self, + results, + arg, + cases, + case_body_builder=None, + default_body_builder=None, + loc=None, + ip=None, + ): + cases = DenseI64ArrayAttr.get(cases) + super().__init__( + results, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip + ) + for region in self.regions: + region.blocks.append() + + if default_body_builder is not None: + with InsertionPoint(self.default_block): + default_body_builder(self) + + if case_body_builder is not None: + for i, case in enumerate(cases): + with InsertionPoint(self.case_block(i)): + case_body_builder(self, i, self.cases[i]) + + @property + def default_region(self) -> Region: + return self.regions[0] + + @property + def default_block(self) -> Block: + return self.default_region.blocks[0] + + @property + def case_regions(self) -> Sequence[Region]: + return self.regions[1:] + + def case_region(self, i: int) -> Region: + return self.case_regions[i] + + @property + def case_blocks(self) -> Sequence[Block]: + return [region.blocks[0] for region in self.case_regions] + + def case_block(self, i: int) -> Block: + return self.case_regions[i].blocks[0] + + +def index_switch( + results, + arg, + cases, + case_body_builder=None, + default_body_builder=None, + loc=None, + ip=None, +) -> Union[OpResult, OpResultList, IndexSwitchOp]: + op = IndexSwitchOp( + results=results, + arg=arg, + cases=cases, + case_body_builder=case_body_builder, + default_body_builder=default_body_builder, + loc=loc, + ip=ip, + ) + return _get_op_result_or_op_results(op) diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 62d11d5e189c8..0c0c9b986562b 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -1,10 +1,14 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -from mlir.dialects import arith -from mlir.dialects import func -from mlir.dialects import memref -from mlir.dialects import scf +from mlir.extras import types as T +from mlir.dialects import ( + arith, + func, + memref, + scf, + cf, +) from mlir.passmanager import PassManager @@ -355,3 +359,117 @@ def simple_if_else(cond): # CHECK: scf.yield %[[TWO]], %[[THREE]] # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 # CHECK: return + + +@constructAndPrintInModule +def testIndexSwitch(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + switch_op = scf.IndexSwitchOp([i32], index, range(3)) + + assert switch_op.regions[0] == switch_op.default_region + assert switch_op.regions[1] == switch_op.case_regions[0] + assert switch_op.regions[1] == switch_op.case_region(0) + assert len(switch_op.case_regions) == 3 + assert len(switch_op.regions) == 4 + + with InsertionPoint(switch_op.default_block): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + for i, block in enumerate(switch_op.case_blocks): + with InsertionPoint(block): + scf.yield_([arith.constant(i32, i)]) + + func.return_([switch_op.results[0]]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } + + +@constructAndPrintInModule +def testIndexSwitchWithBodyBuilders(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + + def default_body_builder(switch_op): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + def case_body_builder(switch_op, case_index: int, case_value: int): + scf.yield_([arith.constant(i32, case_value)]) + + result = scf.index_switch( + results=[i32], + arg=index, + cases=range(3), + case_body_builder=case_body_builder, + default_body_builder=default_body_builder, + ) + + func.return_([result]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 1bdd345d98c05..66ba5d28e49b2 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -1212,9 +1212,7 @@ def testIndexSwitch(): @func.FuncOp.from_py_func(T.index()) def index_switch(index): c1 = arith.constant(i32, 1) - switch_op = scf.IndexSwitchOp( - results_=[i32], arg=index, cases=range(3), num_caseRegions=3 - ) + switch_op = scf.IndexSwitchOp(results=[i32], arg=index, cases=range(3)) assert len(switch_op.regions) == 4 assert len(switch_op.regions[2:]) == 2