Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
let constructor = "mlir::createForLoopSpecializationPass()";
}

def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
let summary = "Replace usages of if condition with true/false constants in "
"the conditional regions";
let dependentDialects = ["arith::ArithDialect"];
}

def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
let summary = "Fuse adjacent parallel loops";
let constructor = "mlir::createParallelLoopFusionPass()";
Expand Down
64 changes: 62 additions & 2 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2453,6 +2453,65 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
};

/// Allow the true region of an if to assume the condition is true
/// and vice versa. For example:
///
/// scf.if %cmp {
/// print(%cmp)
/// }
///
/// becomes
///
/// scf.if %cmp {
/// print(true)
/// }
///
struct ConditionPropagation : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
// Early exit if the condition is constant since replacing a constant
// in the body with another constant isn't a simplification.
if (matchPattern(op.getCondition(), m_Constant()))
return failure();

bool changed = false;
mlir::Type i1Ty = rewriter.getI1Type();

// These variables serve to prevent creating duplicate constants
// and hold constant true or false values.
Value constantTrue = nullptr;
Value constantFalse = nullptr;

for (OpOperand &use :
llvm::make_early_inc_range(op.getCondition().getUses())) {
if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
changed = true;

if (!constantTrue)
constantTrue = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));

rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
} else if (op.getElseRegion().isAncestor(
use.getOwner()->getParentRegion())) {
changed = true;

if (!constantFalse)
constantFalse = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));

rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
}
}

return success(changed);
}
};

/// Remove any statements from an if that are equivalent to the condition
/// or its negation. For example:
///
Expand Down Expand Up @@ -2835,8 +2894,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {

void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ add_mlir_dialect_library(MLIRSCFTransforms
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
IfConditionPropagation.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
Expand Down
98 changes: 0 additions & 98 deletions mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp

This file was deleted.

35 changes: 35 additions & 0 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,41 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {

// -----

// CHECK-LABEL: @cond_prop
func.func @cond_prop(%arg0 : i1) -> index {
%res = scf.if %arg0 -> index {
%res1 = scf.if %arg0 -> index {
%v1 = "test.get_some_value1"() : () -> index
scf.yield %v1 : index
} else {
%v2 = "test.get_some_value2"() : () -> index
scf.yield %v2 : index
}
scf.yield %res1 : index
} else {
%res2 = scf.if %arg0 -> index {
%v3 = "test.get_some_value3"() : () -> index
scf.yield %v3 : index
} else {
%v4 = "test.get_some_value4"() : () -> index
scf.yield %v4 : index
}
scf.yield %res2 : index
}
return %res : index
}
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
// CHECK-NEXT: scf.yield %[[c1]] : index
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
// CHECK-NEXT: scf.yield %[[c4]] : index
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]] : index
// CHECK-NEXT:}

// -----

// CHECK-LABEL: @replace_if_with_cond1
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = arith.constant true
Expand Down
34 changes: 0 additions & 34 deletions mlir/test/Dialect/SCF/if-cond-prop.mlir

This file was deleted.