Skip to content

Commit 5657f93

Browse files
author
Butygin
committed
[mlir] Canonicalize IfOp with trivial then and else bodies to list of SelectOp's
* Do we need a threshold on maximum number of Yeild arguments processed (maximum number of SelectOp's to be generated)? * Had to modify some old IfOp tests to not get optimized by this pattern Differential Revision: https://reviews.llvm.org/D98592
1 parent 319d093 commit 5657f93

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,11 +934,49 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
934934
return success();
935935
}
936936
};
937+
938+
struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
939+
using OpRewritePattern<IfOp>::OpRewritePattern;
940+
941+
LogicalResult matchAndRewrite(IfOp op,
942+
PatternRewriter &rewriter) const override {
943+
if (op->getNumResults() == 0)
944+
return failure();
945+
946+
if (!llvm::hasSingleElement(op.thenRegion().front()) ||
947+
!llvm::hasSingleElement(op.elseRegion().front()))
948+
return failure();
949+
950+
auto cond = op.condition();
951+
auto thenYieldArgs =
952+
cast<scf::YieldOp>(op.thenRegion().front().getTerminator())
953+
.getOperands();
954+
auto elseYieldArgs =
955+
cast<scf::YieldOp>(op.elseRegion().front().getTerminator())
956+
.getOperands();
957+
SmallVector<Value> results(op->getNumResults());
958+
assert(thenYieldArgs.size() == results.size());
959+
assert(elseYieldArgs.size() == results.size());
960+
for (auto it : llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
961+
Value trueVal = std::get<0>(it.value());
962+
Value falseVal = std::get<1>(it.value());
963+
if (trueVal == falseVal)
964+
results[it.index()] = trueVal;
965+
else
966+
results[it.index()] =
967+
rewriter.create<SelectOp>(op.getLoc(), cond, trueVal, falseVal);
968+
}
969+
970+
rewriter.replaceOp(op, results);
971+
return success();
972+
}
973+
};
937974
} // namespace
938975

939976
void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
940977
MLIRContext *context) {
941-
results.insert<RemoveUnusedResults, RemoveStaticCondition>(context);
978+
results.insert<RemoveUnusedResults, RemoveStaticCondition,
979+
ConvertTrivialIfToSelect>(context);
942980
}
943981

944982
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
3535

3636
// -----
3737

38+
func private @side_effect()
3839
func @one_unused(%cond: i1) -> (index) {
3940
%c0 = constant 0 : index
4041
%c1 = constant 1 : index
4142
%0, %1 = scf.if %cond -> (index, index) {
43+
call @side_effect() : () -> ()
4244
scf.yield %c0, %c1 : index, index
4345
} else {
4446
scf.yield %c0, %c1 : index, index
@@ -49,6 +51,7 @@ func @one_unused(%cond: i1) -> (index) {
4951
// CHECK-LABEL: func @one_unused
5052
// CHECK: [[C0:%.*]] = constant 1 : index
5153
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
54+
// CHECK: call @side_effect() : () -> ()
5255
// CHECK: scf.yield [[C0]] : index
5356
// CHECK: } else
5457
// CHECK: scf.yield [[C0]] : index
@@ -57,11 +60,13 @@ func @one_unused(%cond: i1) -> (index) {
5760

5861
// -----
5962

63+
func private @side_effect()
6064
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
6165
%c0 = constant 0 : index
6266
%c1 = constant 1 : index
6367
%0, %1 = scf.if %cond1 -> (index, index) {
6468
%2, %3 = scf.if %cond2 -> (index, index) {
69+
call @side_effect() : () -> ()
6570
scf.yield %c0, %c1 : index, index
6671
} else {
6772
scf.yield %c0, %c1 : index, index
@@ -77,6 +82,7 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
7782
// CHECK: [[C0:%.*]] = constant 1 : index
7883
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
7984
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
85+
// CHECK: call @side_effect() : () -> ()
8086
// CHECK: scf.yield [[C0]] : index
8187
// CHECK: } else
8288
// CHECK: scf.yield [[C0]] : index
@@ -113,6 +119,96 @@ func @all_unused(%cond: i1) {
113119

114120
// -----
115121

122+
func @empty_if1(%cond: i1) {
123+
scf.if %cond {
124+
scf.yield
125+
}
126+
return
127+
}
128+
129+
// CHECK-LABEL: func @empty_if1
130+
// CHECK-NOT: scf.if
131+
// CHECK: return
132+
133+
// -----
134+
135+
func @empty_if2(%cond: i1) {
136+
scf.if %cond {
137+
scf.yield
138+
} else {
139+
scf.yield
140+
}
141+
return
142+
}
143+
144+
// CHECK-LABEL: func @empty_if2
145+
// CHECK-NOT: scf.if
146+
// CHECK: return
147+
148+
// -----
149+
150+
func @to_select1(%cond: i1) -> index {
151+
%c0 = constant 0 : index
152+
%c1 = constant 1 : index
153+
%0 = scf.if %cond -> index {
154+
scf.yield %c0 : index
155+
} else {
156+
scf.yield %c1 : index
157+
}
158+
return %0 : index
159+
}
160+
161+
// CHECK-LABEL: func @to_select1
162+
// CHECK: [[C0:%.*]] = constant 0 : index
163+
// CHECK: [[C1:%.*]] = constant 1 : index
164+
// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]]
165+
// CHECK: return [[V0]] : index
166+
167+
// -----
168+
169+
func @to_select_same_val(%cond: i1) -> (index, index) {
170+
%c0 = constant 0 : index
171+
%c1 = constant 1 : index
172+
%0, %1 = scf.if %cond -> (index, index) {
173+
scf.yield %c0, %c1 : index, index
174+
} else {
175+
scf.yield %c1, %c1 : index, index
176+
}
177+
return %0, %1 : index, index
178+
}
179+
180+
// CHECK-LABEL: func @to_select_same_val
181+
// CHECK: [[C0:%.*]] = constant 0 : index
182+
// CHECK: [[C1:%.*]] = constant 1 : index
183+
// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]]
184+
// CHECK: return [[V0]], [[C1]] : index, index
185+
186+
// -----
187+
188+
func @to_select2(%cond: i1) -> (index, index) {
189+
%c0 = constant 0 : index
190+
%c1 = constant 1 : index
191+
%c2 = constant 2 : index
192+
%c3 = constant 3 : index
193+
%0, %1 = scf.if %cond -> (index, index) {
194+
scf.yield %c0, %c1 : index, index
195+
} else {
196+
scf.yield %c2, %c3 : index, index
197+
}
198+
return %0, %1 : index, index
199+
}
200+
201+
// CHECK-LABEL: func @to_select2
202+
// CHECK: [[C0:%.*]] = constant 0 : index
203+
// CHECK: [[C1:%.*]] = constant 1 : index
204+
// CHECK: [[C2:%.*]] = constant 2 : index
205+
// CHECK: [[C3:%.*]] = constant 3 : index
206+
// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C2]]
207+
// CHECK: [[V1:%.*]] = select {{.*}}, [[C1]], [[C3]]
208+
// CHECK: return [[V0]], [[V1]] : index
209+
210+
// -----
211+
116212
func private @make_i32() -> i32
117213

118214
func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {

0 commit comments

Comments
 (0)