Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions mlir/python/mlir/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
get_op_result_or_op_results as _get_op_result_or_op_results,
_cext as _ods_cext,
)
except ImportError as e:
Expand Down Expand Up @@ -254,3 +255,77 @@ def for_(
yield iv, iter_args[0], for_op.results[0]
else:
yield iv


@_ods_cext.register_operation(_Dialect, replace=True)
class IndexSwitchOp(IndexSwitchOp):
__doc__ = IndexSwitchOp.__doc__

def __init__(
self,
results,
arg,
cases,
case_body_builder=None,
default_body_builder=None,
loc=None,
ip=None,
):
cases = DenseI64ArrayAttr.get(cases)
super().__init__(
results, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
)
for region in self.regions:
region.blocks.append()

if default_body_builder is not None:
with InsertionPoint(self.default_block):
default_body_builder(self)

if case_body_builder is not None:
for i, case in enumerate(cases):
with InsertionPoint(self.case_block(i)):
case_body_builder(self, i, self.cases[i])

@property
def default_region(self) -> Region:
return self.regions[0]

@property
def default_block(self) -> Block:
return self.default_region.blocks[0]

@property
def case_regions(self) -> Sequence[Region]:
return self.regions[1:]

def case_region(self, i: int) -> Region:
return self.case_regions[i]

@property
def case_blocks(self) -> Sequence[Block]:
return [region.blocks[0] for region in self.case_regions]

def case_block(self, i: int) -> Block:
return self.case_regions[i].blocks[0]


def index_switch(
results,
arg,
cases,
case_body_builder=None,
default_body_builder=None,
loc=None,
ip=None,
) -> Union[OpResult, OpResultList, IndexSwitchOp]:
op = IndexSwitchOp(
results=results,
arg=arg,
cases=cases,
case_body_builder=case_body_builder,
default_body_builder=default_body_builder,
loc=loc,
ip=ip,
)
return _get_op_result_or_op_results(op)
126 changes: 122 additions & 4 deletions mlir/test/python/dialects/scf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import memref
from mlir.dialects import scf
from mlir.extras import types as T
from mlir.dialects import (
arith,
func,
memref,
scf,
cf,
)
from mlir.passmanager import PassManager


Expand Down Expand Up @@ -355,3 +359,117 @@ def simple_if_else(cond):
# CHECK: scf.yield %[[TWO]], %[[THREE]]
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
# CHECK: return


@constructAndPrintInModule
def testIndexSwitch():
i32 = T.i32()

@func.FuncOp.from_py_func(T.index(), results=[i32])
def index_switch(index):
c1 = arith.constant(i32, 1)
c0 = arith.constant(i32, 0)
value = arith.constant(i32, 5)
switch_op = scf.IndexSwitchOp([i32], index, range(3))

assert switch_op.regions[0] == switch_op.default_region
assert switch_op.regions[1] == switch_op.case_regions[0]
assert switch_op.regions[1] == switch_op.case_region(0)
assert len(switch_op.case_regions) == 3
assert len(switch_op.regions) == 4

with InsertionPoint(switch_op.default_block):
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
scf.yield_([c1])

for i, block in enumerate(switch_op.case_blocks):
with InsertionPoint(block):
scf.yield_([arith.constant(i32, i)])

func.return_([switch_op.results[0]])

return index_switch


# CHECK-LABEL: func.func @index_switch(
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
# CHECK: case 0 {
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
# CHECK: scf.yield %[[CONSTANT_3]] : i32
# CHECK: }
# CHECK: case 1 {
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
# CHECK: scf.yield %[[CONSTANT_4]] : i32
# CHECK: }
# CHECK: case 2 {
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
# CHECK: scf.yield %[[CONSTANT_5]] : i32
# CHECK: }
# CHECK: default {
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
# CHECK: scf.yield %[[CONSTANT_0]] : i32
# CHECK: }
# CHECK: return %[[INDEX_SWITCH_0]] : i32
# CHECK: }


@constructAndPrintInModule
def testIndexSwitchWithBodyBuilders():
i32 = T.i32()

@func.FuncOp.from_py_func(T.index(), results=[i32])
def index_switch(index):
c1 = arith.constant(i32, 1)
c0 = arith.constant(i32, 0)
value = arith.constant(i32, 5)

def default_body_builder(switch_op):
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
scf.yield_([c1])

def case_body_builder(switch_op, case_index: int, case_value: int):
scf.yield_([arith.constant(i32, case_value)])

result = scf.index_switch(
results=[i32],
arg=index,
cases=range(3),
case_body_builder=case_body_builder,
default_body_builder=default_body_builder,
)

func.return_([result])

return index_switch


# CHECK-LABEL: func.func @index_switch(
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
# CHECK: case 0 {
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
# CHECK: scf.yield %[[CONSTANT_3]] : i32
# CHECK: }
# CHECK: case 1 {
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
# CHECK: scf.yield %[[CONSTANT_4]] : i32
# CHECK: }
# CHECK: case 2 {
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
# CHECK: scf.yield %[[CONSTANT_5]] : i32
# CHECK: }
# CHECK: default {
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
# CHECK: scf.yield %[[CONSTANT_0]] : i32
# CHECK: }
# CHECK: return %[[INDEX_SWITCH_0]] : i32
# CHECK: }
4 changes: 1 addition & 3 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,9 +1212,7 @@ def testIndexSwitch():
@func.FuncOp.from_py_func(T.index())
def index_switch(index):
c1 = arith.constant(i32, 1)
switch_op = scf.IndexSwitchOp(
results_=[i32], arg=index, cases=range(3), num_caseRegions=3
)
switch_op = scf.IndexSwitchOp(results=[i32], arg=index, cases=range(3))

assert len(switch_op.regions) == 4
assert len(switch_op.regions[2:]) == 2
Expand Down
Loading