From 19a9c64d49906845faed6e81effed356a0020ec1 Mon Sep 17 00:00:00 2001 From: Francisco Geiman Thiesen Date: Fri, 12 Sep 2025 22:19:46 +0000 Subject: [PATCH 1/6] Adding changes to RDV +small repro case for dialect with callOp and the AttrSizedOperandSegments trait --- mlir/lib/Transforms/RemoveDeadValues.cpp | 64 +++++++++++++++---- .../remove-dead-values-call-segments.mlir | 23 +++++++ mlir/test/lib/Dialect/Test/TestDialect.cpp | 44 +++++++++++++ mlir/test/lib/Dialect/Test/TestOps.td | 43 +++++++++++++ 4 files changed, 160 insertions(+), 14 deletions(-) create mode 100644 mlir/test/Transforms/remove-dead-values-call-segments.mlir diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 0e84b6dd17f29..0655adaad5f5f 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -306,19 +306,17 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, nonLiveSet.insert(arg); } - // Do (2). + // Do (2). (Skip creating generic operand cleanup entries for call ops. + // Call arguments will be removed in the call-site specific segment-aware + // cleanup, avoiding generic eraseOperands bitvector mechanics.) SymbolTable::UseRange uses = *funcOp.getSymbolUses(module); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa(callOp) && "expected a call-like user"); - // The number of operands in the call op may not match the number of - // arguments in the func op. - BitVector nonLiveCallOperands(callOp->getNumOperands(), false); - SmallVector callOpOperands = - operandsToOpOperands(cast(callOp).getArgOperands()); - for (int index : nonLiveArgs.set_bits()) - nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber()); - cl.operands.push_back({callOp, nonLiveCallOperands}); + // Push an empty operand cleanup entry so that call-site specific logic in + // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is + // intentionally all false to avoid generic erasure. + cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)}); } // Do (3). @@ -746,6 +744,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { // 3. Functions LDBG() << "Cleaning up " << list.functions.size() << " functions"; + // Record which function arguments were erased so we can shrink call-site + // argument segments for CallOpInterface operations (e.g. ops using + // AttrSizedOperandSegments) in the next phase. + DenseMap erasedFuncArgs; for (auto &f : list.functions) { LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName(); LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments"; @@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. - (void)f.funcOp.eraseArguments(f.nonLiveArgs); + if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) { + // Record only if we actually erased something. + if (f.nonLiveArgs.any()) + erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs); + } (void)f.funcOp.eraseResults(f.nonLiveRets); } // 4. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { - if (o.op->getNumOperands() > 0) { - LDBG() << "Erasing " << o.nonLive.count() - << " non-live operands from operation: " - << OpWithFlags(o.op, OpPrintingFlags().skipRegions()); + if (auto call = dyn_cast(o.op)) { + if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast()) { + Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym); + auto it = erasedFuncArgs.find(callee); + if (it != erasedFuncArgs.end()) { + const BitVector &deadArgIdxs = it->second; + MutableOperandRange args = call.getArgOperandsMutable(); + // First, erase the call arguments corresponding to erased callee args. + for (int i = static_cast(args.size()) - 1; i >= 0; --i) { + if (i < static_cast(deadArgIdxs.size()) && deadArgIdxs.test(i)) + args.erase(i); + } + // If this operand cleanup entry also has a generic nonLive bitvector, + // clear bits for call arguments we already erased above to avoid + // double-erasing (which could impact other segments of ops with + // AttrSizedOperandSegments). + if (o.nonLive.any()) { + // Map the argument logical index to the operand number(s) recorded. + SmallVector callOperands = + operandsToOpOperands(call.getArgOperands()); + for (int argIdx : deadArgIdxs.set_bits()) { + if (argIdx < static_cast(callOperands.size())) { + unsigned operandNumber = callOperands[argIdx]->getOperandNumber(); + if (operandNumber < o.nonLive.size()) + o.nonLive.reset(operandNumber); + } + } + } + } + } + } + // Only perform generic operand erasure for non-call ops; for call ops we + // already handled argument removals via the segment-aware path above. + if (!isa(o.op) && o.nonLive.any()) { o.op->eraseOperands(o.nonLive); } } diff --git a/mlir/test/Transforms/remove-dead-values-call-segments.mlir b/mlir/test/Transforms/remove-dead-values-call-segments.mlir new file mode 100644 index 0000000000000..fed9cabbd2ee8 --- /dev/null +++ b/mlir/test/Transforms/remove-dead-values-call-segments.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN + +// ----- +// Private callee: both args become dead after internal DCE; RDV drops callee +// args and shrinks the *args* segment on the call-site to zero; sizes kept in +// sync. + +module { + func.func private @callee(%x: i32, %y: i32) { + %u = arith.addi %x, %x : i32 // %y is dead + return + } + + func.func @caller(%a: i32, %b: i32) { + // args segment initially has 2 operands. + "test.call_with_segments"(%a, %b) { callee = @callee, + operandSegmentSizes = array } : (i32, i32) -> () + return + } +} + +// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array}> : () -> () +// ^ args shrank from 2 -> 0 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 987e8f3654ce8..5016ab6b94cdb 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(&dialectCanonicalizationPattern); } + +//===----------------------------------------------------------------------===// +// TestCallWithSegmentsOp +//===----------------------------------------------------------------------===// +// The op `test.call_with_segments` models a call-like operation whose operands +// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`. +// Only the middle segment represents the actual call arguments. The op uses +// the AttrSizedOperandSegments trait, so we can derive segment boundaries from +// the generated `operandSegmentSizes` attribute. We provide custom helpers to +// expose the logical call arguments as both a read-only range and a mutable +// range bound to the proper segment so that insertion/erasure updates the +// attribute automatically. + +// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix]. +static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1; + +Operation::operand_range CallWithSegmentsOp::getArgOperands() { + // Leverage generated getters for segment sizes: slice between prefix and + // suffix using current operand list. + return getOperation()->getOperands().slice(getPrefix().size(), + getArgs().size()); +} + +MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() { + Operation *op = getOperation(); + + // Obtain the canonical segment size attribute name for this op. + auto segName = + CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName()); + auto sizesAttr = op->getAttrOfType(segName); + assert(sizesAttr && "missing operandSegmentSizes attribute on op"); + + // Compute the start and length of the args segment from the prefix size and + // args size stored in the attribute. + auto sizes = sizesAttr.asArrayRef(); + unsigned start = static_cast(sizes[0]); // prefix size + unsigned len = static_cast(sizes[1]); // args size + + NamedAttribute segNamed(segName, sizesAttr); + MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex, + segNamed}; + + return MutableOperandRange(op, start, len, {binding}); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 5564264ed8b0b..a459385129909 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3745,4 +3745,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> { }]; } +def CallWithSegmentsOp : TEST_Op<"call_with_segments", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "test call op with segmented args"; + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$prefix, // non-arg segment (e.g., 'in') + Variadic:$args, // <-- the call *arguments* segment + Variadic:$suffix // non-arg segment (e.g., 'out') + ); + let results = (outs); + let assemblyFormat = [{ + $callee `(` $prefix `:` type($prefix) `)` + `(` $args `:` type($args) `)` + `(` $suffix `:` type($suffix) `)` attr-dict + }]; + + // Provide stub implementations for the ArgAndResultAttrsOpInterface. + let extraClassDeclaration = [{ + ::mlir::ArrayAttr getArgAttrsAttr() { return {}; } + ::mlir::ArrayAttr getResAttrsAttr() { return {}; } + void setArgAttrsAttr(::mlir::ArrayAttr) {} + void setResAttrsAttr(::mlir::ArrayAttr) {} + ::mlir::Attribute removeArgAttrsAttr() { return {}; } + ::mlir::Attribute removeResAttrsAttr() { return {}; } + }]; + + let extraClassDefinition = [{ + ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() { + if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee")) + return ::mlir::CallInterfaceCallable(sym); + return ::mlir::CallInterfaceCallable(); + } + void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) { + if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>()) + (*this)->setAttr("callee", sym); + else + (*this)->removeAttr("callee"); + } + }]; +} + + #endif // TEST_OPS From 78ae4f1905e567dd66aebb488864e29194e69e84 Mon Sep 17 00:00:00 2001 From: Francisco Geiman Thiesen Date: Wed, 24 Sep 2025 14:56:42 -0600 Subject: [PATCH 2/6] Update mlir/lib/Transforms/RemoveDeadValues.cpp with joker-eph suggestion Co-authored-by: Mehdi Amini --- mlir/lib/Transforms/RemoveDeadValues.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 0655adaad5f5f..03c02859366b2 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -785,14 +785,11 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { // AttrSizedOperandSegments). if (o.nonLive.any()) { // Map the argument logical index to the operand number(s) recorded. - SmallVector callOperands = - operandsToOpOperands(call.getArgOperands()); + int operandOffset = call.getArgOperands().getBeginOperandIndex(); for (int argIdx : deadArgIdxs.set_bits()) { - if (argIdx < static_cast(callOperands.size())) { - unsigned operandNumber = callOperands[argIdx]->getOperandNumber(); - if (operandNumber < o.nonLive.size()) - o.nonLive.reset(operandNumber); - } + int operandNumber = operandOffset + argIdx; + if (operandNumber < o.nonLive.size()) + o.nonLive.reset(operandNumber); } } } From e3d5d543d9c0580acee8081897054018687ac5ae Mon Sep 17 00:00:00 2001 From: Francisco Geiman Thiesen Date: Wed, 24 Sep 2025 15:49:27 -0600 Subject: [PATCH 3/6] Avoiding the expensive symbol look-up --- mlir/lib/Transforms/RemoveDeadValues.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 03c02859366b2..18d75a93195a2 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -88,6 +88,7 @@ struct FunctionToCleanUp { struct OperationToCleanup { Operation *op; BitVector nonLive; + Operation *callee = nullptr; // Optional: For CallOpInterface ops, stores the callee function }; struct BlockArgsToCleanup { @@ -316,7 +317,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // Push an empty operand cleanup entry so that call-site specific logic in // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is // intentionally all false to avoid generic erasure. - cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)}); + // Store the funcOp as the callee to avoid expensive symbol lookup later. + cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false), funcOp.getOperation()}); } // Do (3). @@ -768,9 +770,9 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { if (auto call = dyn_cast(o.op)) { - if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast()) { - Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym); - auto it = erasedFuncArgs.find(callee); + // Use the stored callee reference if available, avoiding expensive symbol lookup + if (o.callee) { + auto it = erasedFuncArgs.find(o.callee); if (it != erasedFuncArgs.end()) { const BitVector &deadArgIdxs = it->second; MutableOperandRange args = call.getArgOperandsMutable(); @@ -788,7 +790,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { int operandOffset = call.getArgOperands().getBeginOperandIndex(); for (int argIdx : deadArgIdxs.set_bits()) { int operandNumber = operandOffset + argIdx; - if (operandNumber < o.nonLive.size()) + if (operandNumber < static_cast(o.nonLive.size())) o.nonLive.reset(operandNumber); } } From 6ea2caca51d16d1e296de53751b39870fa30c611 Mon Sep 17 00:00:00 2001 From: Francisco Geiman Thiesen Date: Wed, 24 Sep 2025 16:26:13 -0600 Subject: [PATCH 4/6] Clang formatting --- mlir/lib/Transforms/RemoveDeadValues.cpp | 12 ++++++++---- mlir/test/lib/Dialect/Test/TestDialect.cpp | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 18d75a93195a2..01b5522572769 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -88,7 +88,8 @@ struct FunctionToCleanUp { struct OperationToCleanup { Operation *op; BitVector nonLive; - Operation *callee = nullptr; // Optional: For CallOpInterface ops, stores the callee function + Operation *callee = + nullptr; // Optional: For CallOpInterface ops, stores the callee function }; struct BlockArgsToCleanup { @@ -318,7 +319,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is // intentionally all false to avoid generic erasure. // Store the funcOp as the callee to avoid expensive symbol lookup later. - cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false), funcOp.getOperation()}); + cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false), + funcOp.getOperation()}); } // Do (3). @@ -770,13 +772,15 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { if (auto call = dyn_cast(o.op)) { - // Use the stored callee reference if available, avoiding expensive symbol lookup + // Use the stored callee reference if available, avoiding expensive symbol + // lookup if (o.callee) { auto it = erasedFuncArgs.find(o.callee); if (it != erasedFuncArgs.end()) { const BitVector &deadArgIdxs = it->second; MutableOperandRange args = call.getArgOperandsMutable(); - // First, erase the call arguments corresponding to erased callee args. + // First, erase the call arguments corresponding to erased callee + // args. for (int i = static_cast(args.size()) - 1; i >= 0; --i) { if (i < static_cast(deadArgIdxs.size()) && deadArgIdxs.test(i)) args.erase(i); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 5016ab6b94cdb..21d75f58b0a3a 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -467,7 +467,7 @@ MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() { // args size stored in the attribute. auto sizes = sizesAttr.asArrayRef(); unsigned start = static_cast(sizes[0]); // prefix size - unsigned len = static_cast(sizes[1]); // args size + unsigned len = static_cast(sizes[1]); // args size NamedAttribute segNamed(segName, sizesAttr); MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex, From a71642f30851f62c71d71ffd4ce777ea0fd3b08c Mon Sep 17 00:00:00 2001 From: Francisco Geiman Thiesen Date: Wed, 24 Sep 2025 19:55:50 -0700 Subject: [PATCH 5/6] Making assumption explicit --- mlir/lib/Transforms/RemoveDeadValues.cpp | 61 +++++++++++++----------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 01b5522572769..3f4cb7e22fa6e 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -771,39 +771,42 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { // 4. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { - if (auto call = dyn_cast(o.op)) { - // Use the stored callee reference if available, avoiding expensive symbol - // lookup - if (o.callee) { - auto it = erasedFuncArgs.find(o.callee); - if (it != erasedFuncArgs.end()) { - const BitVector &deadArgIdxs = it->second; - MutableOperandRange args = call.getArgOperandsMutable(); - // First, erase the call arguments corresponding to erased callee - // args. - for (int i = static_cast(args.size()) - 1; i >= 0; --i) { - if (i < static_cast(deadArgIdxs.size()) && deadArgIdxs.test(i)) - args.erase(i); - } - // If this operand cleanup entry also has a generic nonLive bitvector, - // clear bits for call arguments we already erased above to avoid - // double-erasing (which could impact other segments of ops with - // AttrSizedOperandSegments). - if (o.nonLive.any()) { - // Map the argument logical index to the operand number(s) recorded. - int operandOffset = call.getArgOperands().getBeginOperandIndex(); - for (int argIdx : deadArgIdxs.set_bits()) { - int operandNumber = operandOffset + argIdx; - if (operandNumber < static_cast(o.nonLive.size())) - o.nonLive.reset(operandNumber); - } + // Handle call-specific cleanup only when we have a cached callee reference. + // This avoids expensive symbol lookup and is defensive against future changes. + bool handledAsCall = false; + if (o.callee && isa(o.op)) { + auto call = cast(o.op); + auto it = erasedFuncArgs.find(o.callee); + if (it != erasedFuncArgs.end()) { + const BitVector &deadArgIdxs = it->second; + MutableOperandRange args = call.getArgOperandsMutable(); + // First, erase the call arguments corresponding to erased callee + // args. + for (int i = static_cast(args.size()) - 1; i >= 0; --i) { + if (i < static_cast(deadArgIdxs.size()) && deadArgIdxs.test(i)) + args.erase(i); + } + // If this operand cleanup entry also has a generic nonLive bitvector, + // clear bits for call arguments we already erased above to avoid + // double-erasing (which could impact other segments of ops with + // AttrSizedOperandSegments). + if (o.nonLive.any()) { + // Map the argument logical index to the operand number(s) recorded. + int operandOffset = call.getArgOperands().getBeginOperandIndex(); + for (int argIdx : deadArgIdxs.set_bits()) { + int operandNumber = operandOffset + argIdx; + if (operandNumber < static_cast(o.nonLive.size())) + o.nonLive.reset(operandNumber); } } + handledAsCall = true; } } - // Only perform generic operand erasure for non-call ops; for call ops we - // already handled argument removals via the segment-aware path above. - if (!isa(o.op) && o.nonLive.any()) { + // Perform generic operand erasure for: + // - Non-call operations + // - Call operations without cached callee (where handledAsCall is false) + // But skip call operations that were already handled via segment-aware path + if (!handledAsCall && o.nonLive.any()) { o.op->eraseOperands(o.nonLive); } } From d2a3dbede2e18e58183621df5e39584dc1bf7e2c Mon Sep 17 00:00:00 2001 From: Francisco Geiman Thiesen Date: Wed, 24 Sep 2025 21:26:41 -0700 Subject: [PATCH 6/6] Adding bidirectional iterator support to const_set_bits_iterator_impl --- llvm/include/llvm/ADT/BitVector.h | 32 +++++++++++++++++++++--- mlir/lib/Transforms/RemoveDeadValues.cpp | 11 ++++---- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/llvm/include/llvm/ADT/BitVector.h b/llvm/include/llvm/ADT/BitVector.h index 72da2343fae13..a6e2a397c661e 100644 --- a/llvm/include/llvm/ADT/BitVector.h +++ b/llvm/include/llvm/ADT/BitVector.h @@ -40,12 +40,25 @@ template class const_set_bits_iterator_impl { Current = Parent.find_next(Current); } + void retreat() { + // For bidirectional iteration to work with reverse_iterator, + // we need to handle the case where Current might be at end (-1) + // or at a position where we need to find the previous set bit. + if (Current == -1) { + // If we're at the end, go to the last set bit + Current = Parent.find_last(); + } else { + // Otherwise find the previous set bit before Current + Current = Parent.find_prev(Current); + } + } + public: - using iterator_category = std::forward_iterator_tag; + using iterator_category = std::bidirectional_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = int; - using pointer = value_type*; - using reference = value_type&; + using value_type = unsigned; + using pointer = const value_type*; + using reference = value_type; const_set_bits_iterator_impl(const BitVectorT &Parent, int Current) : Parent(Parent), Current(Current) {} @@ -64,6 +77,17 @@ template class const_set_bits_iterator_impl { return *this; } + const_set_bits_iterator_impl operator--(int) { + auto Prev = *this; + retreat(); + return Prev; + } + + const_set_bits_iterator_impl &operator--() { + retreat(); + return *this; + } + unsigned operator*() const { return Current; } bool operator==(const const_set_bits_iterator_impl &Other) const { diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 3f4cb7e22fa6e..fe74f3d4632a6 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -772,7 +772,8 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { // Handle call-specific cleanup only when we have a cached callee reference. - // This avoids expensive symbol lookup and is defensive against future changes. + // This avoids expensive symbol lookup and is defensive against future + // changes. bool handledAsCall = false; if (o.callee && isa(o.op)) { auto call = cast(o.op); @@ -781,11 +782,9 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { const BitVector &deadArgIdxs = it->second; MutableOperandRange args = call.getArgOperandsMutable(); // First, erase the call arguments corresponding to erased callee - // args. - for (int i = static_cast(args.size()) - 1; i >= 0; --i) { - if (i < static_cast(deadArgIdxs.size()) && deadArgIdxs.test(i)) - args.erase(i); - } + // args. We iterate backwards to preserve indices. + for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits())) + args.erase(argIdx); // If this operand cleanup entry also has a generic nonLive bitvector, // clear bits for call arguments we already erased above to avoid // double-erasing (which could impact other segments of ops with