[SCF] Added canonicalizer for recursively dead uses of iter_args.#191085
[SCF] Added canonicalizer for recursively dead uses of iter_args.#191085vzakhari wants to merge 1 commit into
Conversation
This pattern may appear after Mem2Reg, which conservatively returns live values of the memory slots from loops. If those values are not used, we can get rid of the loops' results and corresponding iter_args. Co-authored-by: Claude Opus 4.6
|
@llvm/pr-subscribers-mlir Author: Slava Zakharin (vzakhari) ChangesThis pattern may appear after Mem2Reg, which conservatively Co-authored-by: Claude Opus 4.6 Full diff: https://github.com/llvm/llvm-project/pull/191085.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..db50b2c1d3314 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1002,11 +1002,153 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
return failure();
}
};
+
+/// Remove iter_arg/result pairs from scf.for when the result is unused and
+/// the corresponding iter_arg block argument is "effectively unused" --
+/// meaning it has no uses, or its only uses are as init operands for nested
+/// scf.for iter_args whose block arguments are also effectively unused.
+///
+/// This handles cases that the generic RemoveDeadRegionBranchOpSuccessorInputs
+/// pattern cannot, specifically when inner loop results are used by outer
+/// loop yields creating cross-loop use chains that appear live but are
+/// semantically dead.
+///
+/// Example:
+/// %r = scf.for %i = %lb to %ub step %s iter_args(%a = %init) -> (f32) {
+/// %inner = scf.for %j = %lb to %ub step %s
+/// iter_args(%b = %a) -> (f32) {
+/// // %b is unused in the body
+/// scf.yield %val : f32
+/// }
+/// scf.yield %inner : f32
+/// }
+/// // %r is unused
+///
+/// After canonicalization:
+/// scf.for %i = %lb to %ub step %s {
+/// scf.for %j = %lb to %ub step %s {
+/// // body without iter_args
+/// }
+/// }
+struct ForOpUnusedIterArgElimination : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ /// Check if a block argument is effectively unused. A block argument is
+ /// effectively unused if it has no uses, or all its uses are init operands
+ /// for nested scf.for iter_args where: (a) the inner block arg is also
+ /// effectively unused, and (b) the inner for's result at that position is
+ /// only used as yield operands at positions we are removing from parentFor.
+ static bool isBlockArgEffectivelyUnused(Value blockArg, ForOp parentFor,
+ const BitVector &parentCandidates) {
+ if (blockArg.use_empty())
+ return true;
+
+ for (OpOperand &use : blockArg.getUses()) {
+ auto innerFor = dyn_cast<ForOp>(use.getOwner());
+ if (!innerFor)
+ return false;
+
+ unsigned opNum = use.getOperandNumber();
+ if (opNum < innerFor.getNumControlOperands())
+ return false;
+
+ unsigned innerIdx = opNum - innerFor.getNumControlOperands();
+ Value innerResult = innerFor.getResult(innerIdx);
+
+ // Inner result must only be used at yield positions of parentFor
+ // that are candidates for removal.
+ for (OpOperand &resUse : innerResult.getUses()) {
+ auto yieldOp = dyn_cast<YieldOp>(resUse.getOwner());
+ if (!yieldOp || yieldOp->getParentOp() != parentFor.getOperation())
+ return false;
+ if (!parentCandidates.test(resUse.getOperandNumber()))
+ return false;
+ }
+
+ // Build candidate positions for the inner for: innerIdx is a candidate
+ // because its result will become unused after the parent transformation.
+ BitVector innerCandidates(innerFor.getNumResults(), false);
+ innerCandidates.set(innerIdx);
+ for (unsigned j = 0; j < innerFor.getNumResults(); ++j) {
+ if (innerFor.getResult(j).use_empty())
+ innerCandidates.set(j);
+ }
+
+ Value innerBlockArg = innerFor.getRegionIterArg(innerIdx);
+ if (!isBlockArgEffectivelyUnused(innerBlockArg, innerFor,
+ innerCandidates))
+ return false;
+ }
+ return true;
+ }
+
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ unsigned numResults = forOp.getNumResults();
+ if (numResults == 0)
+ return failure();
+
+ // Step 1: Find candidate positions (result unused).
+ BitVector candidates(numResults, false);
+ for (unsigned i = 0; i < numResults; ++i) {
+ if (forOp.getResult(i).use_empty())
+ candidates.set(i);
+ }
+ if (candidates.none())
+ return failure();
+
+ // Step 2: For each candidate, verify the block arg is effectively unused
+ // but not trivially unused (the generic patterns handle the trivial case).
+ BitVector toRemove(numResults, false);
+ for (unsigned i : candidates.set_bits()) {
+ Value blockArg = forOp.getRegionIterArg(i);
+ if (blockArg.use_empty())
+ continue;
+ if (isBlockArgEffectivelyUnused(blockArg, forOp, candidates))
+ toRemove.set(i);
+ }
+ if (toRemove.none())
+ return failure();
+
+ // Step 3: Replace block arg uses with init values. This is safe because
+ // the block arg is effectively unused and the init value dominates the
+ // body.
+ for (unsigned i : toRemove.set_bits()) {
+ Value blockArg = forOp.getRegionIterArg(i);
+ Value initVal = forOp.getInitArgs()[i];
+ rewriter.replaceAllUsesWith(blockArg, initVal);
+ }
+
+ // Step 4: Erase yield operands for removed positions.
+ auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->eraseOperands(toRemove); });
+
+ // Step 5: Erase block arguments for removed positions.
+ BitVector blockArgsToErase(forOp.getBody()->getNumArguments(), false);
+ for (unsigned i : toRemove.set_bits())
+ blockArgsToErase.set(i + forOp.getNumInductionVars());
+ rewriter.modifyOpInPlace(
+ forOp, [&]() { forOp.getBody()->eraseArguments(blockArgsToErase); });
+
+ // Step 6: Erase init operands for removed positions.
+ BitVector initOperandsToErase(forOp->getNumOperands(), false);
+ for (unsigned i : toRemove.set_bits())
+ initOperandsToErase.set(i + forOp.getNumControlOperands());
+ rewriter.modifyOpInPlace(
+ forOp, [&]() { forOp->eraseOperands(initOperandsToErase); });
+
+ // Step 7: Erase results for removed positions.
+ rewriter.eraseOpResults(forOp, toRemove);
+
+ return success();
+ }
+};
} // namespace
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpTensorCastFolder>(context);
+ results.add<ForOpTensorCastFolder, ForOpUnusedIterArgElimination>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, ForOp::getOperationName());
populateRegionBranchOpInterfaceInliningPattern(
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c324d34942bf8..145a3417eec04 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2360,3 +2360,151 @@ func.func @fold_tensor_cast_into_forall_non_sequential_writes(
// %0#0 contains %arg1 data; %0#1 contains %arg0 data.
return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
}
+
+// -----
+
+// Test: nested for loops with effectively-dead iter_args.
+// The outer result is unused, the outer block arg feeds the inner iter_arg
+// whose block arg is also unused. Both iter_args should be removed.
+
+// CHECK-LABEL: func @nested_for_effectively_dead_iter_args
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: memref.store %[[CST]], %[[MEM]][]
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @nested_for_effectively_dead_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %inner = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner : f32
+ }
+ return
+}
+
+// -----
+
+// Test: three levels of nesting with effectively-dead iter_args.
+
+// CHECK-LABEL: func @triple_nested_effectively_dead_iter_args
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @triple_nested_effectively_dead_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %mid = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ %inner = scf.for %k = %lb to %ub step %step iter_args(%c = %b) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner : f32
+ }
+ scf.yield %mid : f32
+ }
+ return
+}
+
+// -----
+
+// Negative test: block arg is used in the loop body.
+// The iter_arg must NOT be removed even though the result is unused.
+
+// CHECK-LABEL: func @iter_arg_used_in_body
+// CHECK: scf.for {{.*}} iter_args
+// CHECK: memref.store
+// CHECK: scf.yield
+// CHECK: }
+func.func @iter_arg_used_in_body(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %init = arith.constant 0.0 : f32
+ %r = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ memref.store %a, %mem[] : memref<f32>
+ %next = arith.addf %a, %a : f32
+ scf.yield %next : f32
+ }
+ return
+}
+
+// -----
+
+// Test: two chained for loops where the second loop's result is unused and
+// its iter_arg chain through an inner loop is effectively dead.
+
+// CHECK-LABEL: func @chained_for_effectively_dead
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @chained_for_effectively_dead(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %first = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %inner1 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner1 : f32
+ }
+ %second = scf.for %i = %lb to %ub step %step iter_args(%a = %first) -> (f32) {
+ %inner2 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner2 : f32
+ }
+ return
+}
+
+// -----
+
+// Test: 2-level loop nest with two iter_args, both effectively unused.
+
+// CHECK-LABEL: func @nested_for_two_iter_args
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @nested_for_two_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %c0 = arith.constant 0.0 : f32
+ %c1 = arith.constant 1.0 : f32
+ %r:2 = scf.for %i = %lb to %ub step %step
+ iter_args(%a = %c0, %b = %c1) -> (f32, f32) {
+ %inner:2 = scf.for %j = %lb to %ub step %step
+ iter_args(%x = %a, %y = %b) -> (f32, f32) {
+ memref.store %c1, %mem[] : memref<f32>
+ scf.yield %c1, %c0 : f32, f32
+ }
+ scf.yield %inner#0, %inner#1 : f32, f32
+ }
+ return
+}
diff --git a/mlir/test/Transforms/mem2reg-with-canonicalization.mlir b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
new file mode 100644
index 0000000000000..44e315cf620d3
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg))' | FileCheck %s --check-prefix=MEM2REG
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg,canonicalize))' | FileCheck %s --check-prefix=CANON
+
+// Two loop nests share the same alloca, causing the first loop's result
+// to chain into the second loop's init -- demonstrating the cross-loop
+// use chain that the generic canonicalization patterns cannot handle.
+
+// MEM2REG-LABEL: func.func @redundant_iter_args
+// MEM2REG: %[[POISON:.*]] = ub.poison : f32
+// MEM2REG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// MEM2REG: %[[R1:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[POISON]]) -> (f32) {
+// MEM2REG: %[[R1I:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+// MEM2REG: memref.store %[[CST]],
+// MEM2REG: scf.yield %[[CST]] : f32
+// MEM2REG: }
+// MEM2REG: scf.yield %[[R1I]] : f32
+// MEM2REG: }
+// MEM2REG: scf.for {{.*}} iter_args(%{{.*}} = %[[R1]]) -> (f32) {
+// MEM2REG: scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+// MEM2REG: memref.store %[[CST]],
+// MEM2REG: scf.yield %[[CST]] : f32
+// MEM2REG: }
+// MEM2REG: }
+
+// CANON-LABEL: func.func @redundant_iter_args
+// CANON-SAME: (%[[N:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CANON: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON-NOT: iter_args
+// CANON: memref.store %[[CST]], %[[MEM]][] : memref<f32>
+// CANON: }
+// CANON: }
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON-NOT: iter_args
+// CANON: memref.store %[[CST]], %[[MEM]][] : memref<f32>
+// CANON: }
+// CANON: }
+// CANON: return
+
+func.func @redundant_iter_args(%n: index, %mem: memref<f32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 1.0 : f32
+ %tmp = memref.alloca() : memref<f32>
+ scf.for %i = %c0 to %n step %c1 {
+ scf.for %j = %c0 to %n step %c1 {
+ memref.store %cst, %tmp[] : memref<f32>
+ %v = memref.load %tmp[] : memref<f32>
+ memref.store %v, %mem[] : memref<f32>
+ }
+ }
+ scf.for %i = %c0 to %n step %c1 {
+ scf.for %j = %c0 to %n step %c1 {
+ memref.store %cst, %tmp[] : memref<f32>
+ %v = memref.load %tmp[] : memref<f32>
+ memref.store %v, %mem[] : memref<f32>
+ }
+ }
+ return
+}
|
|
@llvm/pr-subscribers-mlir-scf Author: Slava Zakharin (vzakhari) ChangesThis pattern may appear after Mem2Reg, which conservatively Co-authored-by: Claude Opus 4.6 Full diff: https://github.com/llvm/llvm-project/pull/191085.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..db50b2c1d3314 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1002,11 +1002,153 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
return failure();
}
};
+
+/// Remove iter_arg/result pairs from scf.for when the result is unused and
+/// the corresponding iter_arg block argument is "effectively unused" --
+/// meaning it has no uses, or its only uses are as init operands for nested
+/// scf.for iter_args whose block arguments are also effectively unused.
+///
+/// This handles cases that the generic RemoveDeadRegionBranchOpSuccessorInputs
+/// pattern cannot, specifically when inner loop results are used by outer
+/// loop yields creating cross-loop use chains that appear live but are
+/// semantically dead.
+///
+/// Example:
+/// %r = scf.for %i = %lb to %ub step %s iter_args(%a = %init) -> (f32) {
+/// %inner = scf.for %j = %lb to %ub step %s
+/// iter_args(%b = %a) -> (f32) {
+/// // %b is unused in the body
+/// scf.yield %val : f32
+/// }
+/// scf.yield %inner : f32
+/// }
+/// // %r is unused
+///
+/// After canonicalization:
+/// scf.for %i = %lb to %ub step %s {
+/// scf.for %j = %lb to %ub step %s {
+/// // body without iter_args
+/// }
+/// }
+struct ForOpUnusedIterArgElimination : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ /// Check if a block argument is effectively unused. A block argument is
+ /// effectively unused if it has no uses, or all its uses are init operands
+ /// for nested scf.for iter_args where: (a) the inner block arg is also
+ /// effectively unused, and (b) the inner for's result at that position is
+ /// only used as yield operands at positions we are removing from parentFor.
+ static bool isBlockArgEffectivelyUnused(Value blockArg, ForOp parentFor,
+ const BitVector &parentCandidates) {
+ if (blockArg.use_empty())
+ return true;
+
+ for (OpOperand &use : blockArg.getUses()) {
+ auto innerFor = dyn_cast<ForOp>(use.getOwner());
+ if (!innerFor)
+ return false;
+
+ unsigned opNum = use.getOperandNumber();
+ if (opNum < innerFor.getNumControlOperands())
+ return false;
+
+ unsigned innerIdx = opNum - innerFor.getNumControlOperands();
+ Value innerResult = innerFor.getResult(innerIdx);
+
+ // Inner result must only be used at yield positions of parentFor
+ // that are candidates for removal.
+ for (OpOperand &resUse : innerResult.getUses()) {
+ auto yieldOp = dyn_cast<YieldOp>(resUse.getOwner());
+ if (!yieldOp || yieldOp->getParentOp() != parentFor.getOperation())
+ return false;
+ if (!parentCandidates.test(resUse.getOperandNumber()))
+ return false;
+ }
+
+ // Build candidate positions for the inner for: innerIdx is a candidate
+ // because its result will become unused after the parent transformation.
+ BitVector innerCandidates(innerFor.getNumResults(), false);
+ innerCandidates.set(innerIdx);
+ for (unsigned j = 0; j < innerFor.getNumResults(); ++j) {
+ if (innerFor.getResult(j).use_empty())
+ innerCandidates.set(j);
+ }
+
+ Value innerBlockArg = innerFor.getRegionIterArg(innerIdx);
+ if (!isBlockArgEffectivelyUnused(innerBlockArg, innerFor,
+ innerCandidates))
+ return false;
+ }
+ return true;
+ }
+
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ unsigned numResults = forOp.getNumResults();
+ if (numResults == 0)
+ return failure();
+
+ // Step 1: Find candidate positions (result unused).
+ BitVector candidates(numResults, false);
+ for (unsigned i = 0; i < numResults; ++i) {
+ if (forOp.getResult(i).use_empty())
+ candidates.set(i);
+ }
+ if (candidates.none())
+ return failure();
+
+ // Step 2: For each candidate, verify the block arg is effectively unused
+ // but not trivially unused (the generic patterns handle the trivial case).
+ BitVector toRemove(numResults, false);
+ for (unsigned i : candidates.set_bits()) {
+ Value blockArg = forOp.getRegionIterArg(i);
+ if (blockArg.use_empty())
+ continue;
+ if (isBlockArgEffectivelyUnused(blockArg, forOp, candidates))
+ toRemove.set(i);
+ }
+ if (toRemove.none())
+ return failure();
+
+ // Step 3: Replace block arg uses with init values. This is safe because
+ // the block arg is effectively unused and the init value dominates the
+ // body.
+ for (unsigned i : toRemove.set_bits()) {
+ Value blockArg = forOp.getRegionIterArg(i);
+ Value initVal = forOp.getInitArgs()[i];
+ rewriter.replaceAllUsesWith(blockArg, initVal);
+ }
+
+ // Step 4: Erase yield operands for removed positions.
+ auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->eraseOperands(toRemove); });
+
+ // Step 5: Erase block arguments for removed positions.
+ BitVector blockArgsToErase(forOp.getBody()->getNumArguments(), false);
+ for (unsigned i : toRemove.set_bits())
+ blockArgsToErase.set(i + forOp.getNumInductionVars());
+ rewriter.modifyOpInPlace(
+ forOp, [&]() { forOp.getBody()->eraseArguments(blockArgsToErase); });
+
+ // Step 6: Erase init operands for removed positions.
+ BitVector initOperandsToErase(forOp->getNumOperands(), false);
+ for (unsigned i : toRemove.set_bits())
+ initOperandsToErase.set(i + forOp.getNumControlOperands());
+ rewriter.modifyOpInPlace(
+ forOp, [&]() { forOp->eraseOperands(initOperandsToErase); });
+
+ // Step 7: Erase results for removed positions.
+ rewriter.eraseOpResults(forOp, toRemove);
+
+ return success();
+ }
+};
} // namespace
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpTensorCastFolder>(context);
+ results.add<ForOpTensorCastFolder, ForOpUnusedIterArgElimination>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, ForOp::getOperationName());
populateRegionBranchOpInterfaceInliningPattern(
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c324d34942bf8..145a3417eec04 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2360,3 +2360,151 @@ func.func @fold_tensor_cast_into_forall_non_sequential_writes(
// %0#0 contains %arg1 data; %0#1 contains %arg0 data.
return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
}
+
+// -----
+
+// Test: nested for loops with effectively-dead iter_args.
+// The outer result is unused, the outer block arg feeds the inner iter_arg
+// whose block arg is also unused. Both iter_args should be removed.
+
+// CHECK-LABEL: func @nested_for_effectively_dead_iter_args
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: memref.store %[[CST]], %[[MEM]][]
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @nested_for_effectively_dead_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %inner = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner : f32
+ }
+ return
+}
+
+// -----
+
+// Test: three levels of nesting with effectively-dead iter_args.
+
+// CHECK-LABEL: func @triple_nested_effectively_dead_iter_args
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @triple_nested_effectively_dead_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %mid = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ %inner = scf.for %k = %lb to %ub step %step iter_args(%c = %b) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner : f32
+ }
+ scf.yield %mid : f32
+ }
+ return
+}
+
+// -----
+
+// Negative test: block arg is used in the loop body.
+// The iter_arg must NOT be removed even though the result is unused.
+
+// CHECK-LABEL: func @iter_arg_used_in_body
+// CHECK: scf.for {{.*}} iter_args
+// CHECK: memref.store
+// CHECK: scf.yield
+// CHECK: }
+func.func @iter_arg_used_in_body(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %init = arith.constant 0.0 : f32
+ %r = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ memref.store %a, %mem[] : memref<f32>
+ %next = arith.addf %a, %a : f32
+ scf.yield %next : f32
+ }
+ return
+}
+
+// -----
+
+// Test: two chained for loops where the second loop's result is unused and
+// its iter_arg chain through an inner loop is effectively dead.
+
+// CHECK-LABEL: func @chained_for_effectively_dead
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @chained_for_effectively_dead(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %cst = arith.constant 1.0 : f32
+ %init = arith.constant 0.0 : f32
+ %first = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+ %inner1 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner1 : f32
+ }
+ %second = scf.for %i = %lb to %ub step %step iter_args(%a = %first) -> (f32) {
+ %inner2 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+ memref.store %cst, %mem[] : memref<f32>
+ scf.yield %cst : f32
+ }
+ scf.yield %inner2 : f32
+ }
+ return
+}
+
+// -----
+
+// Test: 2-level loop nest with two iter_args, both effectively unused.
+
+// CHECK-LABEL: func @nested_for_two_iter_args
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: memref.store
+// CHECK: }
+// CHECK: }
+// CHECK: return
+func.func @nested_for_two_iter_args(
+ %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+ %c0 = arith.constant 0.0 : f32
+ %c1 = arith.constant 1.0 : f32
+ %r:2 = scf.for %i = %lb to %ub step %step
+ iter_args(%a = %c0, %b = %c1) -> (f32, f32) {
+ %inner:2 = scf.for %j = %lb to %ub step %step
+ iter_args(%x = %a, %y = %b) -> (f32, f32) {
+ memref.store %c1, %mem[] : memref<f32>
+ scf.yield %c1, %c0 : f32, f32
+ }
+ scf.yield %inner#0, %inner#1 : f32, f32
+ }
+ return
+}
diff --git a/mlir/test/Transforms/mem2reg-with-canonicalization.mlir b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
new file mode 100644
index 0000000000000..44e315cf620d3
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg))' | FileCheck %s --check-prefix=MEM2REG
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg,canonicalize))' | FileCheck %s --check-prefix=CANON
+
+// Two loop nests share the same alloca, causing the first loop's result
+// to chain into the second loop's init -- demonstrating the cross-loop
+// use chain that the generic canonicalization patterns cannot handle.
+
+// MEM2REG-LABEL: func.func @redundant_iter_args
+// MEM2REG: %[[POISON:.*]] = ub.poison : f32
+// MEM2REG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// MEM2REG: %[[R1:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[POISON]]) -> (f32) {
+// MEM2REG: %[[R1I:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+// MEM2REG: memref.store %[[CST]],
+// MEM2REG: scf.yield %[[CST]] : f32
+// MEM2REG: }
+// MEM2REG: scf.yield %[[R1I]] : f32
+// MEM2REG: }
+// MEM2REG: scf.for {{.*}} iter_args(%{{.*}} = %[[R1]]) -> (f32) {
+// MEM2REG: scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+// MEM2REG: memref.store %[[CST]],
+// MEM2REG: scf.yield %[[CST]] : f32
+// MEM2REG: }
+// MEM2REG: }
+
+// CANON-LABEL: func.func @redundant_iter_args
+// CANON-SAME: (%[[N:.*]]: index, %[[MEM:.*]]: memref<f32>)
+// CANON: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON-NOT: iter_args
+// CANON: memref.store %[[CST]], %[[MEM]][] : memref<f32>
+// CANON: }
+// CANON: }
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON: scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+// CANON-NOT: iter_args
+// CANON: memref.store %[[CST]], %[[MEM]][] : memref<f32>
+// CANON: }
+// CANON: }
+// CANON: return
+
+func.func @redundant_iter_args(%n: index, %mem: memref<f32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 1.0 : f32
+ %tmp = memref.alloca() : memref<f32>
+ scf.for %i = %c0 to %n step %c1 {
+ scf.for %j = %c0 to %n step %c1 {
+ memref.store %cst, %tmp[] : memref<f32>
+ %v = memref.load %tmp[] : memref<f32>
+ memref.store %v, %mem[] : memref<f32>
+ }
+ }
+ scf.for %i = %c0 to %n step %c1 {
+ scf.for %j = %c0 to %n step %c1 {
+ memref.store %cst, %tmp[] : memref<f32>
+ %v = memref.load %tmp[] : memref<f32>
+ memref.store %v, %mem[] : memref<f32>
+ }
+ }
+ return
+}
|
| return true; | ||
|
|
||
| for (OpOperand &use : blockArg.getUses()) { | ||
| auto innerFor = dyn_cast<ForOp>(use.getOwner()); |
There was a problem hiding this comment.
That seems completely ad-hoc to me: we should rely on inside-out processing to simplify first the inner ops and then the outer ones. Hardcoding ForOp makes this all not composable.
joker-eph
left a comment
There was a problem hiding this comment.
I believe the intent is that populateRegionBranchOpInterfaceCanonicalizationPatterns should handle this. We need to investigate why it isn't the case.
I actually had a look, the issue is that we can only remove trivially dead iter_args. Not ones that are transitively dead. Right now this kind of IR simplifications is in scope for the |
|
I just tried |
|
That said this makes me think there is an opportunity for a "isRecursivelyTriviallyDead()" utility to check on whether a value is actually transitively used by a side-effecting operation, and if not we could replace it with a poison or something like that. |
|
Thank you both for the review and the insights! I did look into the RegionBranchOpInterface canonicalizers and they did not handle it. I missed the fact that |
|
This is a redundant change. |
This pattern may appear after Mem2Reg, which conservatively
returns live values of the memory slots from loops.
If those values are not used, we can get rid of the loops'
results and corresponding iter_args.
Co-authored-by: Claude Opus 4.6