Skip to content

Commit

Permalink
[MLIR][SCF] Simplify scf.if by swapping regions if condition is a not
Browse files Browse the repository at this point in the history
Given an if of the form, simplify it by eliminating the not and swapping the regions

scf.if not(c) {
  yield origTrue
} else {
  yield origFalse
}

becomes

scf.if c {
  yield origFalse
} else {
  yield origTrue
}

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116990
  • Loading branch information
wsmoses committed Jan 11, 2022
1 parent 37a1291 commit 5443d2e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SCF/SCFOps.td
Expand Up @@ -411,7 +411,7 @@ def IfOp : SCF_Op<"if",
void getNumRegionInvocations(ArrayRef<Attribute> operands,
SmallVectorImpl<int64_t> &countPerRegion);
}];

let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down
26 changes: 25 additions & 1 deletion mlir/lib/Dialect/SCF/SCF.cpp
Expand Up @@ -13,10 +13,10 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"

using namespace mlir;
using namespace mlir::scf;

Expand Down Expand Up @@ -1199,6 +1199,30 @@ void IfOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
}
}

LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// if (!c) then A() else B() -> if c then B() else A()
if (getElseRegion().empty())
return failure();

arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
if (!xorStmt)
return failure();

if (!matchPattern(xorStmt.getRhs(), m_One()))
return failure();

getConditionMutable().assign(xorStmt.getLhs());
Block *thenBlock = &getThenRegion().front();
// It would be nicer to use iplist::swap, but that has no implemented
// callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
getElseRegion().getBlocks());
getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
getThenRegion().getBlocks(), thenBlock);
return success();
}

namespace {
// Pattern to remove unused IfOp results.
struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Expand Up @@ -447,6 +447,29 @@ func @merge_nested_if(%arg0: i1, %arg1: i1) {

// -----

// CHECK-LABEL: func @if_condition_swap
// CHECK-NEXT: %{{.*}} = scf.if %arg0 -> (index) {
// CHECK-NEXT: %[[i1:.+]] = "test.origFalse"() : () -> index
// CHECK-NEXT: scf.yield %[[i1]] : index
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[i2:.+]] = "test.origTrue"() : () -> index
// CHECK-NEXT: scf.yield %[[i2]] : index
// CHECK-NEXT: }
func @if_condition_swap(%cond: i1) -> index {
%true = arith.constant true
%not = arith.xori %cond, %true : i1
%0 = scf.if %not -> (index) {
%1 = "test.origTrue"() : () -> index
scf.yield %1 : index
} else {
%1 = "test.origFalse"() : () -> index
scf.yield %1 : index
}
return %0 : index
}

// -----

// CHECK-LABEL: @remove_zero_iteration_loop
func @remove_zero_iteration_loop() {
%c42 = arith.constant 42 : index
Expand Down

0 comments on commit 5443d2e

Please sign in to comment.