diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index e025cf6d831f49..cd2b8a0b450555 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -713,8 +713,9 @@ struct ForOpIterArgsFolder : public OpRewritePattern { } }; -/// Rewriting pattern that erases loops that are known not to iterate and -/// replaces single-iteration loops with their bodies. +/// Rewriting pattern that erases loops that are known not to iterate, replaces +/// single-iteration loops with their bodies, and removes empty loops that +/// iterate at least once and only return values defined outside of the loop. struct SimplifyTrivialLoops : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -756,7 +757,19 @@ struct SimplifyTrivialLoops : public OpRewritePattern { return success(); } - return failure(); + // Now we are left with loops that have more than 1 iterations. + Block &block = op.getRegion().front(); + if (!llvm::hasSingleElement(block)) + return failure(); + // If the loop is empty, iterates at least once, and only returns values + // defined outside of the loop, remove it and replace it with yield values. + auto yieldOp = cast(block.getTerminator()); + auto yieldOperands = yieldOp.getOperands(); + if (llvm::any_of(yieldOperands, + [&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) + return failure(); + rewriter.replaceOp(op, yieldOperands); + return success(); } }; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 955f6bbac7b614..6edf9ff9ddff74 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -363,6 +363,26 @@ func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) { // ----- +// Test that an empty loop which iterates at least once and only returns +// values defined outside of the loop is folded away. +func @for_yields_4() -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %a = arith.constant 3 : i32 + %b = arith.constant 4 : i32 + %r = scf.for %i = %c0 to %c2 step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r : i32 +} + +// CHECK-LABEL: func @for_yields_4 +// CHECK-NEXT: %[[b:.*]] = arith.constant 4 : i32 +// CHECK-NEXT: return %[[b]] : i32 + +// ----- + // CHECK-LABEL: @replace_true_if func @replace_true_if() { %true = arith.constant true