Skip to content

[SCF] Added canonicalizer for recursively dead uses of iter_args.#191085

Closed
vzakhari wants to merge 1 commit into
llvm:mainfrom
vzakhari:scf_canon_dead_iter_args
Closed

[SCF] Added canonicalizer for recursively dead uses of iter_args.#191085
vzakhari wants to merge 1 commit into
llvm:mainfrom
vzakhari:scf_canon_dead_iter_args

Conversation

@vzakhari
Copy link
Copy Markdown
Contributor

@vzakhari vzakhari commented Apr 9, 2026

This pattern may appear after Mem2Reg, which conservatively
returns live values of the memory slots from loops.
If those values are not used, we can get rid of the loops'
results and corresponding iter_args.

Co-authored-by: Claude Opus 4.6

This pattern may appear after Mem2Reg, which conservatively
returns live values of the memory slots from loops.
If those values are not used, we can get rid of the loops'
results and corresponding iter_args.

Co-authored-by: Claude Opus 4.6
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 9, 2026

@llvm/pr-subscribers-mlir

Author: Slava Zakharin (vzakhari)

Changes

This pattern may appear after Mem2Reg, which conservatively
returns live values of the memory slots from loops.
If those values are not used, we can get rid of the loops'
results and corresponding iter_args.

Co-authored-by: Claude Opus 4.6


Full diff: https://github.com/llvm/llvm-project/pull/191085.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+143-1)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+148)
  • (added) mlir/test/Transforms/mem2reg-with-canonicalization.mlir (+62)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..db50b2c1d3314 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1002,11 +1002,153 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
     return failure();
   }
 };
+
+/// Remove iter_arg/result pairs from scf.for when the result is unused and
+/// the corresponding iter_arg block argument is "effectively unused" --
+/// meaning it has no uses, or its only uses are as init operands for nested
+/// scf.for iter_args whose block arguments are also effectively unused.
+///
+/// This handles cases that the generic RemoveDeadRegionBranchOpSuccessorInputs
+/// pattern cannot, specifically when inner loop results are used by outer
+/// loop yields creating cross-loop use chains that appear live but are
+/// semantically dead.
+///
+/// Example:
+///   %r = scf.for %i = %lb to %ub step %s iter_args(%a = %init) -> (f32) {
+///     %inner = scf.for %j = %lb to %ub step %s
+///         iter_args(%b = %a) -> (f32) {
+///       // %b is unused in the body
+///       scf.yield %val : f32
+///     }
+///     scf.yield %inner : f32
+///   }
+///   // %r is unused
+///
+/// After canonicalization:
+///   scf.for %i = %lb to %ub step %s {
+///     scf.for %j = %lb to %ub step %s {
+///       // body without iter_args
+///     }
+///   }
+struct ForOpUnusedIterArgElimination : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  /// Check if a block argument is effectively unused. A block argument is
+  /// effectively unused if it has no uses, or all its uses are init operands
+  /// for nested scf.for iter_args where: (a) the inner block arg is also
+  /// effectively unused, and (b) the inner for's result at that position is
+  /// only used as yield operands at positions we are removing from parentFor.
+  static bool isBlockArgEffectivelyUnused(Value blockArg, ForOp parentFor,
+                                          const BitVector &parentCandidates) {
+    if (blockArg.use_empty())
+      return true;
+
+    for (OpOperand &use : blockArg.getUses()) {
+      auto innerFor = dyn_cast<ForOp>(use.getOwner());
+      if (!innerFor)
+        return false;
+
+      unsigned opNum = use.getOperandNumber();
+      if (opNum < innerFor.getNumControlOperands())
+        return false;
+
+      unsigned innerIdx = opNum - innerFor.getNumControlOperands();
+      Value innerResult = innerFor.getResult(innerIdx);
+
+      // Inner result must only be used at yield positions of parentFor
+      // that are candidates for removal.
+      for (OpOperand &resUse : innerResult.getUses()) {
+        auto yieldOp = dyn_cast<YieldOp>(resUse.getOwner());
+        if (!yieldOp || yieldOp->getParentOp() != parentFor.getOperation())
+          return false;
+        if (!parentCandidates.test(resUse.getOperandNumber()))
+          return false;
+      }
+
+      // Build candidate positions for the inner for: innerIdx is a candidate
+      // because its result will become unused after the parent transformation.
+      BitVector innerCandidates(innerFor.getNumResults(), false);
+      innerCandidates.set(innerIdx);
+      for (unsigned j = 0; j < innerFor.getNumResults(); ++j) {
+        if (innerFor.getResult(j).use_empty())
+          innerCandidates.set(j);
+      }
+
+      Value innerBlockArg = innerFor.getRegionIterArg(innerIdx);
+      if (!isBlockArgEffectivelyUnused(innerBlockArg, innerFor,
+                                       innerCandidates))
+        return false;
+    }
+    return true;
+  }
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    unsigned numResults = forOp.getNumResults();
+    if (numResults == 0)
+      return failure();
+
+    // Step 1: Find candidate positions (result unused).
+    BitVector candidates(numResults, false);
+    for (unsigned i = 0; i < numResults; ++i) {
+      if (forOp.getResult(i).use_empty())
+        candidates.set(i);
+    }
+    if (candidates.none())
+      return failure();
+
+    // Step 2: For each candidate, verify the block arg is effectively unused
+    // but not trivially unused (the generic patterns handle the trivial case).
+    BitVector toRemove(numResults, false);
+    for (unsigned i : candidates.set_bits()) {
+      Value blockArg = forOp.getRegionIterArg(i);
+      if (blockArg.use_empty())
+        continue;
+      if (isBlockArgEffectivelyUnused(blockArg, forOp, candidates))
+        toRemove.set(i);
+    }
+    if (toRemove.none())
+      return failure();
+
+    // Step 3: Replace block arg uses with init values. This is safe because
+    // the block arg is effectively unused and the init value dominates the
+    // body.
+    for (unsigned i : toRemove.set_bits()) {
+      Value blockArg = forOp.getRegionIterArg(i);
+      Value initVal = forOp.getInitArgs()[i];
+      rewriter.replaceAllUsesWith(blockArg, initVal);
+    }
+
+    // Step 4: Erase yield operands for removed positions.
+    auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+    rewriter.modifyOpInPlace(yieldOp,
+                             [&]() { yieldOp->eraseOperands(toRemove); });
+
+    // Step 5: Erase block arguments for removed positions.
+    BitVector blockArgsToErase(forOp.getBody()->getNumArguments(), false);
+    for (unsigned i : toRemove.set_bits())
+      blockArgsToErase.set(i + forOp.getNumInductionVars());
+    rewriter.modifyOpInPlace(
+        forOp, [&]() { forOp.getBody()->eraseArguments(blockArgsToErase); });
+
+    // Step 6: Erase init operands for removed positions.
+    BitVector initOperandsToErase(forOp->getNumOperands(), false);
+    for (unsigned i : toRemove.set_bits())
+      initOperandsToErase.set(i + forOp.getNumControlOperands());
+    rewriter.modifyOpInPlace(
+        forOp, [&]() { forOp->eraseOperands(initOperandsToErase); });
+
+    // Step 7: Erase results for removed positions.
+    rewriter.eraseOpResults(forOp, toRemove);
+
+    return success();
+  }
+};
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpTensorCastFolder>(context);
+  results.add<ForOpTensorCastFolder, ForOpUnusedIterArgElimination>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, ForOp::getOperationName());
   populateRegionBranchOpInterfaceInliningPattern(
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c324d34942bf8..145a3417eec04 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2360,3 +2360,151 @@ func.func @fold_tensor_cast_into_forall_non_sequential_writes(
   // %0#0 contains %arg1 data; %0#1 contains %arg0 data.
   return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
 }
+
+// -----
+
+// Test: nested for loops with effectively-dead iter_args.
+// The outer result is unused, the outer block arg feeds the inner iter_arg
+// whose block arg is also unused. Both iter_args should be removed.
+
+// CHECK-LABEL: func @nested_for_effectively_dead_iter_args
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CHECK:   %[[CST:.*]] = arith.constant
+//       CHECK:   scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:     scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:       memref.store %[[CST]], %[[MEM]][]
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @nested_for_effectively_dead_iter_args(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %cst = arith.constant 1.0 : f32
+  %init = arith.constant 0.0 : f32
+  %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    %inner = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      memref.store %cst, %mem[] : memref<f32>
+      scf.yield %cst : f32
+    }
+    scf.yield %inner : f32
+  }
+  return
+}
+
+// -----
+
+// Test: three levels of nesting with effectively-dead iter_args.
+
+// CHECK-LABEL: func @triple_nested_effectively_dead_iter_args
+//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:         memref.store
+//       CHECK:       }
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @triple_nested_effectively_dead_iter_args(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %cst = arith.constant 1.0 : f32
+  %init = arith.constant 0.0 : f32
+  %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    %mid = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      %inner = scf.for %k = %lb to %ub step %step iter_args(%c = %b) -> (f32) {
+        memref.store %cst, %mem[] : memref<f32>
+        scf.yield %cst : f32
+      }
+      scf.yield %inner : f32
+    }
+    scf.yield %mid : f32
+  }
+  return
+}
+
+// -----
+
+// Negative test: block arg is used in the loop body.
+// The iter_arg must NOT be removed even though the result is unused.
+
+// CHECK-LABEL: func @iter_arg_used_in_body
+//       CHECK:   scf.for {{.*}} iter_args
+//       CHECK:     memref.store
+//       CHECK:     scf.yield
+//       CHECK:   }
+func.func @iter_arg_used_in_body(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %init = arith.constant 0.0 : f32
+  %r = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    memref.store %a, %mem[] : memref<f32>
+    %next = arith.addf %a, %a : f32
+    scf.yield %next : f32
+  }
+  return
+}
+
+// -----
+
+// Test: two chained for loops where the second loop's result is unused and
+// its iter_arg chain through an inner loop is effectively dead.
+
+// CHECK-LABEL: func @chained_for_effectively_dead
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CHECK:   %[[CST:.*]] = arith.constant
+//       CHECK:   scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:     scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:       memref.store
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:     scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:       memref.store
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @chained_for_effectively_dead(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %cst = arith.constant 1.0 : f32
+  %init = arith.constant 0.0 : f32
+  %first = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    %inner1 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      memref.store %cst, %mem[] : memref<f32>
+      scf.yield %cst : f32
+    }
+    scf.yield %inner1 : f32
+  }
+  %second = scf.for %i = %lb to %ub step %step iter_args(%a = %first) -> (f32) {
+    %inner2 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      memref.store %cst, %mem[] : memref<f32>
+      scf.yield %cst : f32
+    }
+    scf.yield %inner2 : f32
+  }
+  return
+}
+
+// -----
+
+// Test: 2-level loop nest with two iter_args, both effectively unused.
+
+// CHECK-LABEL: func @nested_for_two_iter_args
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CHECK:   scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:     scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:       memref.store
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @nested_for_two_iter_args(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %c0 = arith.constant 0.0 : f32
+  %c1 = arith.constant 1.0 : f32
+  %r:2 = scf.for %i = %lb to %ub step %step
+      iter_args(%a = %c0, %b = %c1) -> (f32, f32) {
+    %inner:2 = scf.for %j = %lb to %ub step %step
+        iter_args(%x = %a, %y = %b) -> (f32, f32) {
+      memref.store %c1, %mem[] : memref<f32>
+      scf.yield %c1, %c0 : f32, f32
+    }
+    scf.yield %inner#0, %inner#1 : f32, f32
+  }
+  return
+}
diff --git a/mlir/test/Transforms/mem2reg-with-canonicalization.mlir b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
new file mode 100644
index 0000000000000..44e315cf620d3
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg))' | FileCheck %s --check-prefix=MEM2REG
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg,canonicalize))' | FileCheck %s --check-prefix=CANON
+
+// Two loop nests share the same alloca, causing the first loop's result
+// to chain into the second loop's init -- demonstrating the cross-loop
+// use chain that the generic canonicalization patterns cannot handle.
+
+// MEM2REG-LABEL: func.func @redundant_iter_args
+//       MEM2REG:   %[[POISON:.*]] = ub.poison : f32
+//       MEM2REG:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//       MEM2REG:   %[[R1:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[POISON]]) -> (f32) {
+//       MEM2REG:     %[[R1I:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+//       MEM2REG:       memref.store %[[CST]],
+//       MEM2REG:       scf.yield %[[CST]] : f32
+//       MEM2REG:     }
+//       MEM2REG:     scf.yield %[[R1I]] : f32
+//       MEM2REG:   }
+//       MEM2REG:   scf.for {{.*}} iter_args(%{{.*}} = %[[R1]]) -> (f32) {
+//       MEM2REG:     scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+//       MEM2REG:       memref.store %[[CST]],
+//       MEM2REG:       scf.yield %[[CST]] : f32
+//       MEM2REG:     }
+//       MEM2REG:   }
+
+// CANON-LABEL: func.func @redundant_iter_args
+//  CANON-SAME:   (%[[N:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CANON:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//       CANON:   scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//       CANON:     scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//   CANON-NOT:       iter_args
+//       CANON:       memref.store %[[CST]], %[[MEM]][] : memref<f32>
+//       CANON:     }
+//       CANON:   }
+//       CANON:   scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//       CANON:     scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//   CANON-NOT:       iter_args
+//       CANON:       memref.store %[[CST]], %[[MEM]][] : memref<f32>
+//       CANON:     }
+//       CANON:   }
+//       CANON:   return
+
+func.func @redundant_iter_args(%n: index, %mem: memref<f32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 1.0 : f32
+  %tmp = memref.alloca() : memref<f32>
+  scf.for %i = %c0 to %n step %c1 {
+    scf.for %j = %c0 to %n step %c1 {
+      memref.store %cst, %tmp[] : memref<f32>
+      %v = memref.load %tmp[] : memref<f32>
+      memref.store %v, %mem[] : memref<f32>
+    }
+  }
+  scf.for %i = %c0 to %n step %c1 {
+    scf.for %j = %c0 to %n step %c1 {
+      memref.store %cst, %tmp[] : memref<f32>
+      %v = memref.load %tmp[] : memref<f32>
+      memref.store %v, %mem[] : memref<f32>
+    }
+  }
+  return
+}

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 9, 2026

@llvm/pr-subscribers-mlir-scf

Author: Slava Zakharin (vzakhari)

Changes

This pattern may appear after Mem2Reg, which conservatively
returns live values of the memory slots from loops.
If those values are not used, we can get rid of the loops'
results and corresponding iter_args.

Co-authored-by: Claude Opus 4.6


Full diff: https://github.com/llvm/llvm-project/pull/191085.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+143-1)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+148)
  • (added) mlir/test/Transforms/mem2reg-with-canonicalization.mlir (+62)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..db50b2c1d3314 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1002,11 +1002,153 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
     return failure();
   }
 };
+
+/// Remove iter_arg/result pairs from scf.for when the result is unused and
+/// the corresponding iter_arg block argument is "effectively unused" --
+/// meaning it has no uses, or its only uses are as init operands for nested
+/// scf.for iter_args whose block arguments are also effectively unused.
+///
+/// This handles cases that the generic RemoveDeadRegionBranchOpSuccessorInputs
+/// pattern cannot, specifically when inner loop results are used by outer
+/// loop yields creating cross-loop use chains that appear live but are
+/// semantically dead.
+///
+/// Example:
+///   %r = scf.for %i = %lb to %ub step %s iter_args(%a = %init) -> (f32) {
+///     %inner = scf.for %j = %lb to %ub step %s
+///         iter_args(%b = %a) -> (f32) {
+///       // %b is unused in the body
+///       scf.yield %val : f32
+///     }
+///     scf.yield %inner : f32
+///   }
+///   // %r is unused
+///
+/// After canonicalization:
+///   scf.for %i = %lb to %ub step %s {
+///     scf.for %j = %lb to %ub step %s {
+///       // body without iter_args
+///     }
+///   }
+struct ForOpUnusedIterArgElimination : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  /// Check if a block argument is effectively unused. A block argument is
+  /// effectively unused if it has no uses, or all its uses are init operands
+  /// for nested scf.for iter_args where: (a) the inner block arg is also
+  /// effectively unused, and (b) the inner for's result at that position is
+  /// only used as yield operands at positions we are removing from parentFor.
+  static bool isBlockArgEffectivelyUnused(Value blockArg, ForOp parentFor,
+                                          const BitVector &parentCandidates) {
+    if (blockArg.use_empty())
+      return true;
+
+    for (OpOperand &use : blockArg.getUses()) {
+      auto innerFor = dyn_cast<ForOp>(use.getOwner());
+      if (!innerFor)
+        return false;
+
+      unsigned opNum = use.getOperandNumber();
+      if (opNum < innerFor.getNumControlOperands())
+        return false;
+
+      unsigned innerIdx = opNum - innerFor.getNumControlOperands();
+      Value innerResult = innerFor.getResult(innerIdx);
+
+      // Inner result must only be used at yield positions of parentFor
+      // that are candidates for removal.
+      for (OpOperand &resUse : innerResult.getUses()) {
+        auto yieldOp = dyn_cast<YieldOp>(resUse.getOwner());
+        if (!yieldOp || yieldOp->getParentOp() != parentFor.getOperation())
+          return false;
+        if (!parentCandidates.test(resUse.getOperandNumber()))
+          return false;
+      }
+
+      // Build candidate positions for the inner for: innerIdx is a candidate
+      // because its result will become unused after the parent transformation.
+      BitVector innerCandidates(innerFor.getNumResults(), false);
+      innerCandidates.set(innerIdx);
+      for (unsigned j = 0; j < innerFor.getNumResults(); ++j) {
+        if (innerFor.getResult(j).use_empty())
+          innerCandidates.set(j);
+      }
+
+      Value innerBlockArg = innerFor.getRegionIterArg(innerIdx);
+      if (!isBlockArgEffectivelyUnused(innerBlockArg, innerFor,
+                                       innerCandidates))
+        return false;
+    }
+    return true;
+  }
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    unsigned numResults = forOp.getNumResults();
+    if (numResults == 0)
+      return failure();
+
+    // Step 1: Find candidate positions (result unused).
+    BitVector candidates(numResults, false);
+    for (unsigned i = 0; i < numResults; ++i) {
+      if (forOp.getResult(i).use_empty())
+        candidates.set(i);
+    }
+    if (candidates.none())
+      return failure();
+
+    // Step 2: For each candidate, verify the block arg is effectively unused
+    // but not trivially unused (the generic patterns handle the trivial case).
+    BitVector toRemove(numResults, false);
+    for (unsigned i : candidates.set_bits()) {
+      Value blockArg = forOp.getRegionIterArg(i);
+      if (blockArg.use_empty())
+        continue;
+      if (isBlockArgEffectivelyUnused(blockArg, forOp, candidates))
+        toRemove.set(i);
+    }
+    if (toRemove.none())
+      return failure();
+
+    // Step 3: Replace block arg uses with init values. This is safe because
+    // the block arg is effectively unused and the init value dominates the
+    // body.
+    for (unsigned i : toRemove.set_bits()) {
+      Value blockArg = forOp.getRegionIterArg(i);
+      Value initVal = forOp.getInitArgs()[i];
+      rewriter.replaceAllUsesWith(blockArg, initVal);
+    }
+
+    // Step 4: Erase yield operands for removed positions.
+    auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+    rewriter.modifyOpInPlace(yieldOp,
+                             [&]() { yieldOp->eraseOperands(toRemove); });
+
+    // Step 5: Erase block arguments for removed positions.
+    BitVector blockArgsToErase(forOp.getBody()->getNumArguments(), false);
+    for (unsigned i : toRemove.set_bits())
+      blockArgsToErase.set(i + forOp.getNumInductionVars());
+    rewriter.modifyOpInPlace(
+        forOp, [&]() { forOp.getBody()->eraseArguments(blockArgsToErase); });
+
+    // Step 6: Erase init operands for removed positions.
+    BitVector initOperandsToErase(forOp->getNumOperands(), false);
+    for (unsigned i : toRemove.set_bits())
+      initOperandsToErase.set(i + forOp.getNumControlOperands());
+    rewriter.modifyOpInPlace(
+        forOp, [&]() { forOp->eraseOperands(initOperandsToErase); });
+
+    // Step 7: Erase results for removed positions.
+    rewriter.eraseOpResults(forOp, toRemove);
+
+    return success();
+  }
+};
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpTensorCastFolder>(context);
+  results.add<ForOpTensorCastFolder, ForOpUnusedIterArgElimination>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, ForOp::getOperationName());
   populateRegionBranchOpInterfaceInliningPattern(
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c324d34942bf8..145a3417eec04 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2360,3 +2360,151 @@ func.func @fold_tensor_cast_into_forall_non_sequential_writes(
   // %0#0 contains %arg1 data; %0#1 contains %arg0 data.
   return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
 }
+
+// -----
+
+// Test: nested for loops with effectively-dead iter_args.
+// The outer result is unused, the outer block arg feeds the inner iter_arg
+// whose block arg is also unused. Both iter_args should be removed.
+
+// CHECK-LABEL: func @nested_for_effectively_dead_iter_args
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CHECK:   %[[CST:.*]] = arith.constant
+//       CHECK:   scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:     scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:       memref.store %[[CST]], %[[MEM]][]
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @nested_for_effectively_dead_iter_args(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %cst = arith.constant 1.0 : f32
+  %init = arith.constant 0.0 : f32
+  %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    %inner = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      memref.store %cst, %mem[] : memref<f32>
+      scf.yield %cst : f32
+    }
+    scf.yield %inner : f32
+  }
+  return
+}
+
+// -----
+
+// Test: three levels of nesting with effectively-dead iter_args.
+
+// CHECK-LABEL: func @triple_nested_effectively_dead_iter_args
+//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//       CHECK:         memref.store
+//       CHECK:       }
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @triple_nested_effectively_dead_iter_args(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %cst = arith.constant 1.0 : f32
+  %init = arith.constant 0.0 : f32
+  %outer = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    %mid = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      %inner = scf.for %k = %lb to %ub step %step iter_args(%c = %b) -> (f32) {
+        memref.store %cst, %mem[] : memref<f32>
+        scf.yield %cst : f32
+      }
+      scf.yield %inner : f32
+    }
+    scf.yield %mid : f32
+  }
+  return
+}
+
+// -----
+
+// Negative test: block arg is used in the loop body.
+// The iter_arg must NOT be removed even though the result is unused.
+
+// CHECK-LABEL: func @iter_arg_used_in_body
+//       CHECK:   scf.for {{.*}} iter_args
+//       CHECK:     memref.store
+//       CHECK:     scf.yield
+//       CHECK:   }
+func.func @iter_arg_used_in_body(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %init = arith.constant 0.0 : f32
+  %r = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    memref.store %a, %mem[] : memref<f32>
+    %next = arith.addf %a, %a : f32
+    scf.yield %next : f32
+  }
+  return
+}
+
+// -----
+
+// Test: two chained for loops where the second loop's result is unused and
+// its iter_arg chain through an inner loop is effectively dead.
+
+// CHECK-LABEL: func @chained_for_effectively_dead
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CHECK:   %[[CST:.*]] = arith.constant
+//       CHECK:   scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:     scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:       memref.store
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:     scf.for %{{.*}} = %[[LB]] to %{{.*}} step %[[STEP]] {
+//       CHECK:       memref.store
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @chained_for_effectively_dead(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %cst = arith.constant 1.0 : f32
+  %init = arith.constant 0.0 : f32
+  %first = scf.for %i = %lb to %ub step %step iter_args(%a = %init) -> (f32) {
+    %inner1 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      memref.store %cst, %mem[] : memref<f32>
+      scf.yield %cst : f32
+    }
+    scf.yield %inner1 : f32
+  }
+  %second = scf.for %i = %lb to %ub step %step iter_args(%a = %first) -> (f32) {
+    %inner2 = scf.for %j = %lb to %ub step %step iter_args(%b = %a) -> (f32) {
+      memref.store %cst, %mem[] : memref<f32>
+      scf.yield %cst : f32
+    }
+    scf.yield %inner2 : f32
+  }
+  return
+}
+
+// -----
+
+// Test: 2-level loop nest with two iter_args, both effectively unused.
+
+// CHECK-LABEL: func @nested_for_two_iter_args
+//  CHECK-SAME:   (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CHECK:   scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:     scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+//       CHECK:       memref.store
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   return
+func.func @nested_for_two_iter_args(
+    %lb: index, %ub: index, %step: index, %mem: memref<f32>) {
+  %c0 = arith.constant 0.0 : f32
+  %c1 = arith.constant 1.0 : f32
+  %r:2 = scf.for %i = %lb to %ub step %step
+      iter_args(%a = %c0, %b = %c1) -> (f32, f32) {
+    %inner:2 = scf.for %j = %lb to %ub step %step
+        iter_args(%x = %a, %y = %b) -> (f32, f32) {
+      memref.store %c1, %mem[] : memref<f32>
+      scf.yield %c1, %c0 : f32, f32
+    }
+    scf.yield %inner#0, %inner#1 : f32, f32
+  }
+  return
+}
diff --git a/mlir/test/Transforms/mem2reg-with-canonicalization.mlir b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
new file mode 100644
index 0000000000000..44e315cf620d3
--- /dev/null
+++ b/mlir/test/Transforms/mem2reg-with-canonicalization.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg))' | FileCheck %s --check-prefix=MEM2REG
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(any(mem2reg,canonicalize))' | FileCheck %s --check-prefix=CANON
+
+// Two loop nests share the same alloca, causing the first loop's result
+// to chain into the second loop's init -- demonstrating the cross-loop
+// use chain that the generic canonicalization patterns cannot handle.
+
+// MEM2REG-LABEL: func.func @redundant_iter_args
+//       MEM2REG:   %[[POISON:.*]] = ub.poison : f32
+//       MEM2REG:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//       MEM2REG:   %[[R1:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[POISON]]) -> (f32) {
+//       MEM2REG:     %[[R1I:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+//       MEM2REG:       memref.store %[[CST]],
+//       MEM2REG:       scf.yield %[[CST]] : f32
+//       MEM2REG:     }
+//       MEM2REG:     scf.yield %[[R1I]] : f32
+//       MEM2REG:   }
+//       MEM2REG:   scf.for {{.*}} iter_args(%{{.*}} = %[[R1]]) -> (f32) {
+//       MEM2REG:     scf.for {{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+//       MEM2REG:       memref.store %[[CST]],
+//       MEM2REG:       scf.yield %[[CST]] : f32
+//       MEM2REG:     }
+//       MEM2REG:   }
+
+// CANON-LABEL: func.func @redundant_iter_args
+//  CANON-SAME:   (%[[N:.*]]: index, %[[MEM:.*]]: memref<f32>)
+//       CANON:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//       CANON:   scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//       CANON:     scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//   CANON-NOT:       iter_args
+//       CANON:       memref.store %[[CST]], %[[MEM]][] : memref<f32>
+//       CANON:     }
+//       CANON:   }
+//       CANON:   scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//       CANON:     scf.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
+//   CANON-NOT:       iter_args
+//       CANON:       memref.store %[[CST]], %[[MEM]][] : memref<f32>
+//       CANON:     }
+//       CANON:   }
+//       CANON:   return
+
+func.func @redundant_iter_args(%n: index, %mem: memref<f32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 1.0 : f32
+  %tmp = memref.alloca() : memref<f32>
+  scf.for %i = %c0 to %n step %c1 {
+    scf.for %j = %c0 to %n step %c1 {
+      memref.store %cst, %tmp[] : memref<f32>
+      %v = memref.load %tmp[] : memref<f32>
+      memref.store %v, %mem[] : memref<f32>
+    }
+  }
+  scf.for %i = %c0 to %n step %c1 {
+    scf.for %j = %c0 to %n step %c1 {
+      memref.store %cst, %tmp[] : memref<f32>
+      %v = memref.load %tmp[] : memref<f32>
+      memref.store %v, %mem[] : memref<f32>
+    }
+  }
+  return
+}

return true;

for (OpOperand &use : blockArg.getUses()) {
auto innerFor = dyn_cast<ForOp>(use.getOwner());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems completely ad-hoc to me: we should rely on inside-out processing to simplify first the inner ops and then the outer ones. Hardcoding ForOp makes this all not composable.

Copy link
Copy Markdown
Contributor

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the intent is that populateRegionBranchOpInterfaceCanonicalizationPatterns should handle this. We need to investigate why it isn't the case.

@joker-eph
Copy link
Copy Markdown
Contributor

I believe the intent is that populateRegionBranchOpInterfaceCanonicalizationPatterns should handle this. We need to investigate why it isn't the case.

I actually had a look, the issue is that we can only remove trivially dead iter_args. Not ones that are transitively dead.

Right now this kind of IR simplifications is in scope for the remove-dead-values pass.

@matthias-springer
Copy link
Copy Markdown
Member

I just tried -remove-dead-values with the @chained_for_effectively_dead test case from this PR. It removes all iter_args.

@joker-eph
Copy link
Copy Markdown
Contributor

That said this makes me think there is an opportunity for a "isRecursivelyTriviallyDead()" utility to check on whether a value is actually transitively used by a side-effecting operation, and if not we could replace it with a poison or something like that.

@vzakhari
Copy link
Copy Markdown
Contributor Author

vzakhari commented Apr 9, 2026

Thank you both for the review and the insights! I did look into the RegionBranchOpInterface canonicalizers and they did not handle it. I missed the fact that remove-dead-values should handle this. Thank you!

@vzakhari
Copy link
Copy Markdown
Contributor Author

vzakhari commented Apr 9, 2026

This is a redundant change.

@vzakhari vzakhari closed this Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants