Skip to content

Conversation

FranciscoThiesen
Copy link
Contributor

Problem

RemoveDeadValues can legally drop dead function arguments on private func.func callees. But call-sites to such functions aren't fixed if the call operation keeps its call arguments in a segmented operand group (i.ie, uses AttrSizedOperandSegments), unless the call op implements getArgOperandsMutable and the RDV pass actually uses it.

Fix

When RDV decides to drop callee function args, it should, for each call-site that implements CallOpInterface, shrink the call's argument segment via getArgOperandsMutable() using the same dead-arg indices. This keeps both the flat operand list and the operand_segment_sizes attribute in sync (that's what MutableOperandRange does when bound to the segment).

Note

This change is a no-op for:

  • call ops without segment operands (they still get their flat operands erased via the generic path)
  • call ops whose calle args weren't dropped (public, external, non-func-func, unresolved symbol, etc)
  • llvm.call/llvm.invoke (RDV doesn't drop llvm.func args

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2025

@llvm/pr-subscribers-llvm-adt
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Francisco Geiman Thiesen (FranciscoThiesen)

Changes

Problem

RemoveDeadValues can legally drop dead function arguments on private func.func callees. But call-sites to such functions aren't fixed if the call operation keeps its call arguments in a segmented operand group (i.ie, uses AttrSizedOperandSegments), unless the call op implements getArgOperandsMutable and the RDV pass actually uses it.

Fix

When RDV decides to drop callee function args, it should, for each call-site that implements CallOpInterface, shrink the call's argument segment via getArgOperandsMutable() using the same dead-arg indices. This keeps both the flat operand list and the operand_segment_sizes attribute in sync (that's what MutableOperandRange does when bound to the segment).

Note

This change is a no-op for:

  • call ops without segment operands (they still get their flat operands erased via the generic path)
  • call ops whose calle args weren't dropped (public, external, non-func-func, unresolved symbol, etc)
  • llvm.call/llvm.invoke (RDV doesn't drop llvm.func args

Full diff: https://github.com/llvm/llvm-project/pull/160415.diff

4 Files Affected:

  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+50-14)
  • (added) mlir/test/Transforms/remove-dead-values-call-segments.mlir (+23)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+44)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+43)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0e84b6dd17f29..0655adaad5f5f 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -306,19 +306,17 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
       nonLiveSet.insert(arg);
     }
 
-  // Do (2).
+  // Do (2). (Skip creating generic operand cleanup entries for call ops.
+  // Call arguments will be removed in the call-site specific segment-aware
+  // cleanup, avoiding generic eraseOperands bitvector mechanics.)
   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
-    // The number of operands in the call op may not match the number of
-    // arguments in the func op.
-    BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
-    SmallVector<OpOperand *> callOpOperands =
-        operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
-    for (int index : nonLiveArgs.set_bits())
-      nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
-    cl.operands.push_back({callOp, nonLiveCallOperands});
+    // Push an empty operand cleanup entry so that call-site specific logic in
+    // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
+    // intentionally all false to avoid generic erasure.
+    cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)});
   }
 
   // Do (3).
@@ -746,6 +744,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
 
   // 3. Functions
   LDBG() << "Cleaning up " << list.functions.size() << " functions";
+  // Record which function arguments were erased so we can shrink call-site
+  // argument segments for CallOpInterface operations (e.g. ops using
+  // AttrSizedOperandSegments) in the next phase.
+  DenseMap<Operation *, BitVector> erasedFuncArgs;
   for (auto &f : list.functions) {
     LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
     LDBG() << "  Erasing " << f.nonLiveArgs.count() << " non-live arguments";
@@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     // Some functions may not allow erasing arguments or results. These calls
     // return failure in such cases without modifying the function, so it's okay
     // to proceed.
-    (void)f.funcOp.eraseArguments(f.nonLiveArgs);
+    if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
+      // Record only if we actually erased something.
+      if (f.nonLiveArgs.any())
+        erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
+    }
     (void)f.funcOp.eraseResults(f.nonLiveRets);
   }
 
   // 4. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperationToCleanup &o : list.operands) {
-    if (o.op->getNumOperands() > 0) {
-      LDBG() << "Erasing " << o.nonLive.count()
-             << " non-live operands from operation: "
-             << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
+    if (auto call = dyn_cast<CallOpInterface>(o.op)) {
+      if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
+        Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
+        auto it = erasedFuncArgs.find(callee);
+        if (it != erasedFuncArgs.end()) {
+          const BitVector &deadArgIdxs = it->second;
+          MutableOperandRange args = call.getArgOperandsMutable();
+          // First, erase the call arguments corresponding to erased callee args.
+          for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
+            if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
+              args.erase(i);
+          }
+          // If this operand cleanup entry also has a generic nonLive bitvector,
+          // clear bits for call arguments we already erased above to avoid
+          // double-erasing (which could impact other segments of ops with
+          // AttrSizedOperandSegments).
+          if (o.nonLive.any()) {
+            // Map the argument logical index to the operand number(s) recorded.
+            SmallVector<OpOperand *> callOperands =
+                operandsToOpOperands(call.getArgOperands());
+            for (int argIdx : deadArgIdxs.set_bits()) {
+              if (argIdx < static_cast<int>(callOperands.size())) {
+                unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
+                if (operandNumber < o.nonLive.size())
+                  o.nonLive.reset(operandNumber);
+              }
+            }
+          }
+        }
+      }
+    }
+    // Only perform generic operand erasure for non-call ops; for call ops we
+    // already handled argument removals via the segment-aware path above.
+    if (!isa<CallOpInterface>(o.op) && o.nonLive.any()) {
       o.op->eraseOperands(o.nonLive);
     }
   }
diff --git a/mlir/test/Transforms/remove-dead-values-call-segments.mlir b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
new file mode 100644
index 0000000000000..fed9cabbd2ee8
--- /dev/null
+++ b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN
+
+// -----
+// Private callee: both args become dead after internal DCE; RDV drops callee
+// args and shrinks the *args* segment on the call-site to zero; sizes kept in
+// sync.
+
+module {
+  func.func private @callee(%x: i32, %y: i32) {
+    %u = arith.addi %x, %x : i32   // %y is dead
+    return
+  }
+
+  func.func @caller(%a: i32, %b: i32) {
+    // args segment initially has 2 operands.
+    "test.call_with_segments"(%a, %b) { callee = @callee,
+      operandSegmentSizes = array<i32: 0, 2, 0> } : (i32, i32) -> ()
+    return
+  }
+}
+
+// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> ()
+//       ^ args shrank from 2 -> 0
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 987e8f3654ce8..5016ab6b94cdb 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
   results.add(&dialectCanonicalizationPattern);
 }
+
+//===----------------------------------------------------------------------===//
+// TestCallWithSegmentsOp
+//===----------------------------------------------------------------------===//
+// The op `test.call_with_segments` models a call-like operation whose operands
+// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
+// Only the middle segment represents the actual call arguments. The op uses
+// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
+// the generated `operandSegmentSizes` attribute. We provide custom helpers to
+// expose the logical call arguments as both a read-only range and a mutable
+// range bound to the proper segment so that insertion/erasure updates the
+// attribute automatically.
+
+// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
+static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
+
+Operation::operand_range CallWithSegmentsOp::getArgOperands() {
+  // Leverage generated getters for segment sizes: slice between prefix and
+  // suffix using current operand list.
+  return getOperation()->getOperands().slice(getPrefix().size(),
+                                             getArgs().size());
+}
+
+MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
+  Operation *op = getOperation();
+
+  // Obtain the canonical segment size attribute name for this op.
+  auto segName =
+      CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
+  auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
+  assert(sizesAttr && "missing operandSegmentSizes attribute on op");
+
+  // Compute the start and length of the args segment from the prefix size and
+  // args size stored in the attribute.
+  auto sizes = sizesAttr.asArrayRef();
+  unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
+  unsigned len = static_cast<unsigned>(sizes[1]);    // args size
+
+  NamedAttribute segNamed(segName, sizesAttr);
+  MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
+                                              segNamed};
+
+  return MutableOperandRange(op, start, len, {binding});
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d9bbb3261febc..6ea27187655ee 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3746,4 +3746,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
   }];
 }
 
+def CallWithSegmentsOp : TEST_Op<"call_with_segments",
+    [AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<CallOpInterface>]> {
+  let summary = "test call op with segmented args";
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<AnyType>:$prefix,   // non-arg segment (e.g., 'in')
+    Variadic<AnyType>:$args,     // <-- the call *arguments* segment
+    Variadic<AnyType>:$suffix    // non-arg segment (e.g., 'out')
+  );
+  let results = (outs);
+  let assemblyFormat = [{
+    $callee `(` $prefix `:` type($prefix) `)`
+            `(` $args `:` type($args) `)`
+            `(` $suffix `:` type($suffix) `)` attr-dict
+  }];
+
+  // Provide stub implementations for the ArgAndResultAttrsOpInterface.
+  let extraClassDeclaration = [{
+    ::mlir::ArrayAttr getArgAttrsAttr() { return {}; }
+    ::mlir::ArrayAttr getResAttrsAttr() { return {}; }
+    void setArgAttrsAttr(::mlir::ArrayAttr) {}
+    void setResAttrsAttr(::mlir::ArrayAttr) {}
+    ::mlir::Attribute removeArgAttrsAttr() { return {}; }
+    ::mlir::Attribute removeResAttrsAttr() { return {}; }
+  }];
+
+  let extraClassDefinition = [{
+    ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
+      if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"))
+        return ::mlir::CallInterfaceCallable(sym);
+      return ::mlir::CallInterfaceCallable();
+    }
+    void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+      if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>())
+        (*this)->setAttr("callee", sym);
+      else
+        (*this)->removeAttr("callee");
+    }
+  }];
+}
+
+
 #endif // TEST_OPS

Copy link

github-actions bot commented Sep 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you split ADT changes to a separate PR and add unit tests?

@FranciscoThiesen
Copy link
Contributor Author

Can you split ADT changes to a separate PR and add unit tests?

I've split the ADT work into (#160726), could you please review @kuhar ?

kuhar added a commit that referenced this pull request Sep 25, 2025
This feature is needed for #160415 , kuhar suggested that I split that
PR into 2 so that the ADT work is checking in first.

---------

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
@FranciscoThiesen
Copy link
Contributor Author

@joker-eph @kuhar ADT functionality was merged, can I get a re-review here? d8a2965

@kuhar kuhar removed their request for review September 25, 2025 22:49
@FranciscoThiesen
Copy link
Contributor Author

@joker-eph can you help me merge?

@joker-eph joker-eph merged commit 3e746bd into llvm:main Sep 26, 2025
9 checks passed
Copy link

@FranciscoThiesen Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…160726)

This feature is needed for llvm#160415 , kuhar suggested that I split that
PR into 2 so that the ADT work is checking in first.

---------

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
## Problem

`RemoveDeadValues` can legally drop dead function arguments on private
`func.func` callees. But call-sites to such functions aren't fixed if
the call operation keeps its call arguments in a **segmented operand
group** (i.ie, uses `AttrSizedOperandSegments`), unless the call op
implements `getArgOperandsMutable` and the RDV pass actually uses it.

## Fix
When RDV decides to drop callee function args, it should, for each
call-site that implements `CallOpInterface`, **shrink the call's
argument segment** via `getArgOperandsMutable()` using the same dead-arg
indices. This keeps both the flat operand list and the
`operand_segment_sizes` attribute in sync (that's what
`MutableOperandRange` does when bound to the segment).

## Note
This change is a no-op for:
* call ops without segment operands (they still get their flat operands
erased via the generic path)
* call ops whose calle args weren't dropped (public, external,
non-`func-func`, unresolved symbol, etc)
* `llvm.call`/`llvm.invoke` (RDV doesn't drop `llvm.func` args

---------

Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:adt mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants