-
Notifications
You must be signed in to change notification settings - Fork 15.2k
Allowing RDV to call getArgOperandsMutable()
#160415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allowing RDV to call getArgOperandsMutable()
#160415
Conversation
…he AttrSizedOperandSegments trait
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-adt @llvm/pr-subscribers-mlir-core Author: Francisco Geiman Thiesen (FranciscoThiesen) ChangesProblem
FixWhen RDV decides to drop callee function args, it should, for each call-site that implements NoteThis change is a no-op for:
Full diff: https://github.com/llvm/llvm-project/pull/160415.diff 4 Files Affected:
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 0e84b6dd17f29..0655adaad5f5f 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -306,19 +306,17 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
nonLiveSet.insert(arg);
}
- // Do (2).
+ // Do (2). (Skip creating generic operand cleanup entries for call ops.
+ // Call arguments will be removed in the call-site specific segment-aware
+ // cleanup, avoiding generic eraseOperands bitvector mechanics.)
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- // The number of operands in the call op may not match the number of
- // arguments in the func op.
- BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
- SmallVector<OpOperand *> callOpOperands =
- operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
- for (int index : nonLiveArgs.set_bits())
- nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
- cl.operands.push_back({callOp, nonLiveCallOperands});
+ // Push an empty operand cleanup entry so that call-site specific logic in
+ // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
+ // intentionally all false to avoid generic erasure.
+ cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)});
}
// Do (3).
@@ -746,6 +744,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
// 3. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
+ // Record which function arguments were erased so we can shrink call-site
+ // argument segments for CallOpInterface operations (e.g. ops using
+ // AttrSizedOperandSegments) in the next phase.
+ DenseMap<Operation *, BitVector> erasedFuncArgs;
for (auto &f : list.functions) {
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
@@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
- (void)f.funcOp.eraseArguments(f.nonLiveArgs);
+ if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
+ // Record only if we actually erased something.
+ if (f.nonLiveArgs.any())
+ erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
+ }
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
// 4. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0) {
- LDBG() << "Erasing " << o.nonLive.count()
- << " non-live operands from operation: "
- << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
+ if (auto call = dyn_cast<CallOpInterface>(o.op)) {
+ if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
+ Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
+ auto it = erasedFuncArgs.find(callee);
+ if (it != erasedFuncArgs.end()) {
+ const BitVector &deadArgIdxs = it->second;
+ MutableOperandRange args = call.getArgOperandsMutable();
+ // First, erase the call arguments corresponding to erased callee args.
+ for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
+ if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
+ args.erase(i);
+ }
+ // If this operand cleanup entry also has a generic nonLive bitvector,
+ // clear bits for call arguments we already erased above to avoid
+ // double-erasing (which could impact other segments of ops with
+ // AttrSizedOperandSegments).
+ if (o.nonLive.any()) {
+ // Map the argument logical index to the operand number(s) recorded.
+ SmallVector<OpOperand *> callOperands =
+ operandsToOpOperands(call.getArgOperands());
+ for (int argIdx : deadArgIdxs.set_bits()) {
+ if (argIdx < static_cast<int>(callOperands.size())) {
+ unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
+ if (operandNumber < o.nonLive.size())
+ o.nonLive.reset(operandNumber);
+ }
+ }
+ }
+ }
+ }
+ }
+ // Only perform generic operand erasure for non-call ops; for call ops we
+ // already handled argument removals via the segment-aware path above.
+ if (!isa<CallOpInterface>(o.op) && o.nonLive.any()) {
o.op->eraseOperands(o.nonLive);
}
}
diff --git a/mlir/test/Transforms/remove-dead-values-call-segments.mlir b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
new file mode 100644
index 0000000000000..fed9cabbd2ee8
--- /dev/null
+++ b/mlir/test/Transforms/remove-dead-values-call-segments.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN
+
+// -----
+// Private callee: both args become dead after internal DCE; RDV drops callee
+// args and shrinks the *args* segment on the call-site to zero; sizes kept in
+// sync.
+
+module {
+ func.func private @callee(%x: i32, %y: i32) {
+ %u = arith.addi %x, %x : i32 // %y is dead
+ return
+ }
+
+ func.func @caller(%a: i32, %b: i32) {
+ // args segment initially has 2 operands.
+ "test.call_with_segments"(%a, %b) { callee = @callee,
+ operandSegmentSizes = array<i32: 0, 2, 0> } : (i32, i32) -> ()
+ return
+ }
+}
+
+// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> ()
+// ^ args shrank from 2 -> 0
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 987e8f3654ce8..5016ab6b94cdb 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add(&dialectCanonicalizationPattern);
}
+
+//===----------------------------------------------------------------------===//
+// TestCallWithSegmentsOp
+//===----------------------------------------------------------------------===//
+// The op `test.call_with_segments` models a call-like operation whose operands
+// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
+// Only the middle segment represents the actual call arguments. The op uses
+// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
+// the generated `operandSegmentSizes` attribute. We provide custom helpers to
+// expose the logical call arguments as both a read-only range and a mutable
+// range bound to the proper segment so that insertion/erasure updates the
+// attribute automatically.
+
+// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
+static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
+
+Operation::operand_range CallWithSegmentsOp::getArgOperands() {
+ // Leverage generated getters for segment sizes: slice between prefix and
+ // suffix using current operand list.
+ return getOperation()->getOperands().slice(getPrefix().size(),
+ getArgs().size());
+}
+
+MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
+ Operation *op = getOperation();
+
+ // Obtain the canonical segment size attribute name for this op.
+ auto segName =
+ CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
+ auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
+ assert(sizesAttr && "missing operandSegmentSizes attribute on op");
+
+ // Compute the start and length of the args segment from the prefix size and
+ // args size stored in the attribute.
+ auto sizes = sizesAttr.asArrayRef();
+ unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
+ unsigned len = static_cast<unsigned>(sizes[1]); // args size
+
+ NamedAttribute segNamed(segName, sizesAttr);
+ MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
+ segNamed};
+
+ return MutableOperandRange(op, start, len, {binding});
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d9bbb3261febc..6ea27187655ee 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3746,4 +3746,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
}];
}
+def CallWithSegmentsOp : TEST_Op<"call_with_segments",
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<CallOpInterface>]> {
+ let summary = "test call op with segmented args";
+ let arguments = (ins
+ FlatSymbolRefAttr:$callee,
+ Variadic<AnyType>:$prefix, // non-arg segment (e.g., 'in')
+ Variadic<AnyType>:$args, // <-- the call *arguments* segment
+ Variadic<AnyType>:$suffix // non-arg segment (e.g., 'out')
+ );
+ let results = (outs);
+ let assemblyFormat = [{
+ $callee `(` $prefix `:` type($prefix) `)`
+ `(` $args `:` type($args) `)`
+ `(` $suffix `:` type($suffix) `)` attr-dict
+ }];
+
+ // Provide stub implementations for the ArgAndResultAttrsOpInterface.
+ let extraClassDeclaration = [{
+ ::mlir::ArrayAttr getArgAttrsAttr() { return {}; }
+ ::mlir::ArrayAttr getResAttrsAttr() { return {}; }
+ void setArgAttrsAttr(::mlir::ArrayAttr) {}
+ void setResAttrsAttr(::mlir::ArrayAttr) {}
+ ::mlir::Attribute removeArgAttrsAttr() { return {}; }
+ ::mlir::Attribute removeResAttrsAttr() { return {}; }
+ }];
+
+ let extraClassDefinition = [{
+ ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
+ if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"))
+ return ::mlir::CallInterfaceCallable(sym);
+ return ::mlir::CallInterfaceCallable();
+ }
+ void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+ if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>())
+ (*this)->setAttr("callee", sym);
+ else
+ (*this)->removeAttr("callee");
+ }
+ }];
+}
+
+
#endif // TEST_OPS
|
…tion Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you split ADT changes to a separate PR and add unit tests?
This feature is needed for #160415 , kuhar suggested that I split that PR into 2 so that the ADT work is checking in first. --------- Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
@joker-eph @kuhar ADT functionality was merged, can I get a re-review here? d8a2965 |
@joker-eph can you help me merge? |
@FranciscoThiesen Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
…160726) This feature is needed for llvm#160415 , kuhar suggested that I split that PR into 2 so that the ADT work is checking in first. --------- Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
## Problem `RemoveDeadValues` can legally drop dead function arguments on private `func.func` callees. But call-sites to such functions aren't fixed if the call operation keeps its call arguments in a **segmented operand group** (i.ie, uses `AttrSizedOperandSegments`), unless the call op implements `getArgOperandsMutable` and the RDV pass actually uses it. ## Fix When RDV decides to drop callee function args, it should, for each call-site that implements `CallOpInterface`, **shrink the call's argument segment** via `getArgOperandsMutable()` using the same dead-arg indices. This keeps both the flat operand list and the `operand_segment_sizes` attribute in sync (that's what `MutableOperandRange` does when bound to the segment). ## Note This change is a no-op for: * call ops without segment operands (they still get their flat operands erased via the generic path) * call ops whose calle args weren't dropped (public, external, non-`func-func`, unresolved symbol, etc) * `llvm.call`/`llvm.invoke` (RDV doesn't drop `llvm.func` args --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Problem
RemoveDeadValues
can legally drop dead function arguments on privatefunc.func
callees. But call-sites to such functions aren't fixed if the call operation keeps its call arguments in a segmented operand group (i.ie, usesAttrSizedOperandSegments
), unless the call op implementsgetArgOperandsMutable
and the RDV pass actually uses it.Fix
When RDV decides to drop callee function args, it should, for each call-site that implements
CallOpInterface
, shrink the call's argument segment viagetArgOperandsMutable()
using the same dead-arg indices. This keeps both the flat operand list and theoperand_segment_sizes
attribute in sync (that's whatMutableOperandRange
does when bound to the segment).Note
This change is a no-op for:
func-func
, unresolved symbol, etc)llvm.call
/llvm.invoke
(RDV doesn't dropllvm.func
args