diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 0e84b6dd17f29..fe74f3d4632a6 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -88,6 +88,8 @@ struct FunctionToCleanUp { struct OperationToCleanup { Operation *op; BitVector nonLive; + Operation *callee = + nullptr; // Optional: For CallOpInterface ops, stores the callee function }; struct BlockArgsToCleanup { @@ -306,19 +308,19 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, nonLiveSet.insert(arg); } - // Do (2). + // Do (2). (Skip creating generic operand cleanup entries for call ops. + // Call arguments will be removed in the call-site specific segment-aware + // cleanup, avoiding generic eraseOperands bitvector mechanics.) SymbolTable::UseRange uses = *funcOp.getSymbolUses(module); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa(callOp) && "expected a call-like user"); - // The number of operands in the call op may not match the number of - // arguments in the func op. - BitVector nonLiveCallOperands(callOp->getNumOperands(), false); - SmallVector callOpOperands = - operandsToOpOperands(cast(callOp).getArgOperands()); - for (int index : nonLiveArgs.set_bits()) - nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber()); - cl.operands.push_back({callOp, nonLiveCallOperands}); + // Push an empty operand cleanup entry so that call-site specific logic in + // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is + // intentionally all false to avoid generic erasure. + // Store the funcOp as the callee to avoid expensive symbol lookup later. + cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false), + funcOp.getOperation()}); } // Do (3). @@ -746,6 +748,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { // 3. Functions LDBG() << "Cleaning up " << list.functions.size() << " functions"; + // Record which function arguments were erased so we can shrink call-site + // argument segments for CallOpInterface operations (e.g. ops using + // AttrSizedOperandSegments) in the next phase. + DenseMap erasedFuncArgs; for (auto &f : list.functions) { LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName(); LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments"; @@ -754,17 +760,52 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. - (void)f.funcOp.eraseArguments(f.nonLiveArgs); + if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) { + // Record only if we actually erased something. + if (f.nonLiveArgs.any()) + erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs); + } (void)f.funcOp.eraseResults(f.nonLiveRets); } // 4. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { - if (o.op->getNumOperands() > 0) { - LDBG() << "Erasing " << o.nonLive.count() - << " non-live operands from operation: " - << OpWithFlags(o.op, OpPrintingFlags().skipRegions()); + // Handle call-specific cleanup only when we have a cached callee reference. + // This avoids expensive symbol lookup and is defensive against future + // changes. + bool handledAsCall = false; + if (o.callee && isa(o.op)) { + auto call = cast(o.op); + auto it = erasedFuncArgs.find(o.callee); + if (it != erasedFuncArgs.end()) { + const BitVector &deadArgIdxs = it->second; + MutableOperandRange args = call.getArgOperandsMutable(); + // First, erase the call arguments corresponding to erased callee + // args. We iterate backwards to preserve indices. + for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits())) + args.erase(argIdx); + // If this operand cleanup entry also has a generic nonLive bitvector, + // clear bits for call arguments we already erased above to avoid + // double-erasing (which could impact other segments of ops with + // AttrSizedOperandSegments). + if (o.nonLive.any()) { + // Map the argument logical index to the operand number(s) recorded. + int operandOffset = call.getArgOperands().getBeginOperandIndex(); + for (int argIdx : deadArgIdxs.set_bits()) { + int operandNumber = operandOffset + argIdx; + if (operandNumber < static_cast(o.nonLive.size())) + o.nonLive.reset(operandNumber); + } + } + handledAsCall = true; + } + } + // Perform generic operand erasure for: + // - Non-call operations + // - Call operations without cached callee (where handledAsCall is false) + // But skip call operations that were already handled via segment-aware path + if (!handledAsCall && o.nonLive.any()) { o.op->eraseOperands(o.nonLive); } } diff --git a/mlir/test/Transforms/remove-dead-values-call-segments.mlir b/mlir/test/Transforms/remove-dead-values-call-segments.mlir new file mode 100644 index 0000000000000..fed9cabbd2ee8 --- /dev/null +++ b/mlir/test/Transforms/remove-dead-values-call-segments.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN + +// ----- +// Private callee: both args become dead after internal DCE; RDV drops callee +// args and shrinks the *args* segment on the call-site to zero; sizes kept in +// sync. + +module { + func.func private @callee(%x: i32, %y: i32) { + %u = arith.addi %x, %x : i32 // %y is dead + return + } + + func.func @caller(%a: i32, %b: i32) { + // args segment initially has 2 operands. + "test.call_with_segments"(%a, %b) { callee = @callee, + operandSegmentSizes = array } : (i32, i32) -> () + return + } +} + +// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array}> : () -> () +// ^ args shrank from 2 -> 0 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 987e8f3654ce8..21d75f58b0a3a 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(&dialectCanonicalizationPattern); } + +//===----------------------------------------------------------------------===// +// TestCallWithSegmentsOp +//===----------------------------------------------------------------------===// +// The op `test.call_with_segments` models a call-like operation whose operands +// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`. +// Only the middle segment represents the actual call arguments. The op uses +// the AttrSizedOperandSegments trait, so we can derive segment boundaries from +// the generated `operandSegmentSizes` attribute. We provide custom helpers to +// expose the logical call arguments as both a read-only range and a mutable +// range bound to the proper segment so that insertion/erasure updates the +// attribute automatically. + +// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix]. +static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1; + +Operation::operand_range CallWithSegmentsOp::getArgOperands() { + // Leverage generated getters for segment sizes: slice between prefix and + // suffix using current operand list. + return getOperation()->getOperands().slice(getPrefix().size(), + getArgs().size()); +} + +MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() { + Operation *op = getOperation(); + + // Obtain the canonical segment size attribute name for this op. + auto segName = + CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName()); + auto sizesAttr = op->getAttrOfType(segName); + assert(sizesAttr && "missing operandSegmentSizes attribute on op"); + + // Compute the start and length of the args segment from the prefix size and + // args size stored in the attribute. + auto sizes = sizesAttr.asArrayRef(); + unsigned start = static_cast(sizes[0]); // prefix size + unsigned len = static_cast(sizes[1]); // args size + + NamedAttribute segNamed(segName, sizesAttr); + MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex, + segNamed}; + + return MutableOperandRange(op, start, len, {binding}); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index d9bbb3261febc..6ea27187655ee 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3746,4 +3746,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> { }]; } +def CallWithSegmentsOp : TEST_Op<"call_with_segments", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "test call op with segmented args"; + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$prefix, // non-arg segment (e.g., 'in') + Variadic:$args, // <-- the call *arguments* segment + Variadic:$suffix // non-arg segment (e.g., 'out') + ); + let results = (outs); + let assemblyFormat = [{ + $callee `(` $prefix `:` type($prefix) `)` + `(` $args `:` type($args) `)` + `(` $suffix `:` type($suffix) `)` attr-dict + }]; + + // Provide stub implementations for the ArgAndResultAttrsOpInterface. + let extraClassDeclaration = [{ + ::mlir::ArrayAttr getArgAttrsAttr() { return {}; } + ::mlir::ArrayAttr getResAttrsAttr() { return {}; } + void setArgAttrsAttr(::mlir::ArrayAttr) {} + void setResAttrsAttr(::mlir::ArrayAttr) {} + ::mlir::Attribute removeArgAttrsAttr() { return {}; } + ::mlir::Attribute removeResAttrsAttr() { return {}; } + }]; + + let extraClassDefinition = [{ + ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() { + if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee")) + return ::mlir::CallInterfaceCallable(sym); + return ::mlir::CallInterfaceCallable(); + } + void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) { + if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>()) + (*this)->setAttr("callee", sym); + else + (*this)->removeAttr("callee"); + } + }]; +} + + #endif // TEST_OPS