diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index 2f95531455b2b..8464b633a2625 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include "mlir/IR/BuiltinTypes.h" @@ -930,18 +931,12 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern { } }; -/// Return "true" if the two values are owned by the same operation or block. -static bool haveSameOwner(Value a, Value b) { - void *aOwner, *bOwner; - if (auto arg = dyn_cast(a)) - aOwner = arg.getOwner(); - else - aOwner = a.getDefiningOp(); - if (auto arg = dyn_cast(b)) - bOwner = arg.getOwner(); - else - bOwner = b.getDefiningOp(); - return aOwner == bOwner; +/// Return the "owner" of a value: the parent block for block arguments, the +/// defining op for op results. +static void *getOwnerOfValue(Value value) { + if (auto arg = dyn_cast(value)) + return arg.getOwner(); + return value.getDefiningOp(); } /// Get the block argument or op result number of the given value. @@ -1006,39 +1001,58 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern { return getArgOrResultNumber(a) < getArgOrResultNumber(b); }); - // Check every distinct pair of successor inputs for duplicates. Replace - // `input2` with `input1` if they are duplicates. + // Group inputs by their operand "signature" to find duplicates. Two + // successor inputs are duplicates if each predecessor (region branch point) + // forwards the same value for both. Let n = number of successor inputs and + // k = number of predecessors per input. Instead of comparing every pair of + // inputs (O(n² * k)), we build a signature for each input and group them + // via a std::map. + // + // A signature is a sorted list of (predecessor, forwarded value) pairs. + // Within each group, all but the first (canonical) input are replaced with + // the canonical one. + using SigEntry = std::pair; + using Signature = SmallVector; + auto sigEntryLess = [](const SigEntry &a, const SigEntry &b) { + if (a.first != b.first) + return a.first < b.first; + return a.second.getAsOpaquePointer() < b.second.getAsOpaquePointer(); + }; + // The map key is (signature, owner). Two inputs are duplicates only if they + // have the same signature AND the same owner (block or defining op). This + // ensures we track one canonical per owner group. + using MapKey = std::pair; + auto mapKeyLess = [&](const MapKey &a, const MapKey &b) { + if (a.second != b.second) + return a.second < b.second; + return std::lexicographical_compare(a.first.begin(), a.first.end(), + b.first.begin(), b.first.end(), + sigEntryLess); + }; + std::map signatureToCanonical( + mapKeyLess); bool changed = false; - unsigned numInputs = inputs.size(); - for (auto i : llvm::seq(0, numInputs)) { - Value input1 = inputs[i]; - for (auto j : llvm::seq(i + 1, numInputs)) { - Value input2 = inputs[j]; - // Nothing to do if input2 is already dead. - if (input2.use_empty()) + // Total complexity: O(n * k * max(log k, log n)). For each input, sorting + // the signature costs O(k log k) and the std::map lookup costs O(k log n). + for (Value input : inputs) { + // Gather the predecessor value for each predecessor (region branch + // point) and sort them to form this input's signature. + Signature sig; + for (OpOperand *operand : inputsToOperands[input]) + sig.emplace_back(operand->getOwner(), operand->get()); + llvm::sort(sig, sigEntryLess); + + void *owner = getOwnerOfValue(input); + + auto [it, inserted] = signatureToCanonical.try_emplace( + MapKey{std::move(sig), owner}, input); + if (!inserted) { + Value canonical = it->second; + // Nothing to do if input is already dead. + if (input.use_empty()) continue; - // Replace only values that belong to the same block / operation. - // This implies that the two values are either both block arguments or - // both op results. - if (!haveSameOwner(input1, input2)) - continue; - - // Gather the predecessor value for each predecessor (region branch - // point). The two inputs are duplicates if each predecessor forwards - // the same value. - llvm::SmallDenseMap operands1, operands2; - for (OpOperand *operand : inputsToOperands[input1]) { - assert(!operands1.contains(operand->getOwner())); - operands1[operand->getOwner()] = operand->get(); - } - for (OpOperand *operand : inputsToOperands[input2]) { - assert(!operands2.contains(operand->getOwner())); - operands2[operand->getOwner()] = operand->get(); - } - if (operands1 == operands2) { - rewriter.replaceAllUsesWith(input2, input1); - changed = true; - } + rewriter.replaceAllUsesWith(input, canonical); + changed = true; } } return success(changed);