diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 8989ec7e1d95d..5ce32b14185ed 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -19,6 +19,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/DebugStringHelper.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SCCIterator.h" @@ -364,6 +365,31 @@ static void collectCallOps(iterator_range blocks, //===----------------------------------------------------------------------===// // Inliner //===----------------------------------------------------------------------===// + +#ifndef NDEBUG +static std::string getNodeName(CallOpInterface op) { + if (auto sym = op.getCallableForCallee().dyn_cast()) + return debugString(op); + return "_unnamed_callee_"; +} +#endif + +/// Return true if the specified `inlineHistoryID` indicates an inline history +/// that already includes `node`. +static bool inlineHistoryIncludes( + CallGraphNode *node, Optional inlineHistoryID, + MutableArrayRef>> + inlineHistory) { + while (inlineHistoryID.has_value()) { + assert(inlineHistoryID.value() < inlineHistory.size() && + "Invalid inline history ID"); + if (inlineHistory[inlineHistoryID.value()].first == node) + return true; + inlineHistoryID = inlineHistory[inlineHistoryID.value()].second; + } + return false; +} + namespace { /// This class provides a specialization of the main inlining interface. struct Inliner : public InlinerInterface { @@ -454,23 +480,43 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, } } + // When inlining a callee produces new call sites, we want to keep track of + // the fact that they were inlined from the callee. This allows us to avoid + // infinite inlining. + using InlineHistoryT = Optional; + SmallVector, 8> inlineHistory; + std::vector callHistory(calls.size(), InlineHistoryT{}); + + LLVM_DEBUG({ + llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + for (unsigned i = 0, e = calls.size(); i < e; ++i) + llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; + llvm::dbgs() << "}\n"; + }); + // Try to inline each of the call operations. Don't cache the end iterator // here as more calls may be added during inlining. bool inlinedAnyCalls = false; - for (unsigned i = 0; i != calls.size(); ++i) { + for (unsigned i = 0; i < calls.size(); ++i) { if (deadNodes.contains(calls[i].sourceNode)) continue; ResolvedCall it = calls[i]; - bool doInline = shouldInline(it); + + InlineHistoryT inlineHistoryID = callHistory[i]; + bool inHistory = + inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory); + bool doInline = !inHistory && shouldInline(it); CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << call << "\n"; + llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; else - llvm::dbgs() << "* Not inlining call: " << call << "\n"; + llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; }); if (!doInline) continue; + + unsigned prevSize = calls.size(); Region *targetRegion = it.targetNode->getCallableRegion(); // If this is the last call to the target node and the node is discardable, @@ -486,6 +532,29 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, } inlinedAnyCalls = true; + // Create a inline history entry for this inlined call, so that we remember + // that new callsites came about due to inlining Callee. + InlineHistoryT newInlineHistoryID{inlineHistory.size()}; + inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID)); + + auto historyToString = [](InlineHistoryT h) { + return h.has_value() ? std::to_string(h.value()) : "root"; + }; + (void)historyToString; + LLVM_DEBUG(llvm::dbgs() + << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << historyToString(inlineHistoryID) + << "]\n"); + + for (unsigned k = prevSize; k != calls.size(); ++k) { + callHistory.push_back(newInlineHistoryID); + LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " + << historyToString(inlineHistoryID) << "\n"); + } + // If the inlining was successful, Merge the new uses into the source node. useList.dropCallUses(it.sourceNode, call.getOperation(), cg); useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); diff --git a/mlir/test/Transforms/inlining-recursive.mlir b/mlir/test/Transforms/inlining-recursive.mlir new file mode 100644 index 0000000000000..a02fe69133ad8 --- /dev/null +++ b/mlir/test/Transforms/inlining-recursive.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s +// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s + +// CHECK-LABEL: func.func @foo0 +func.func @foo0(%arg0 : i32) -> i32 { + // CHECK: call @foo1 + // CHECK: } + %0 = arith.constant 0 : i32 + %1 = arith.cmpi eq, %arg0, %0 : i32 + cf.cond_br %1, ^exit, ^tail +^exit: + return %0 : i32 +^tail: + %3 = call @foo1(%arg0) : (i32) -> i32 + return %3 : i32 +} + +// CHECK-LABEL: func.func @foo1 +func.func @foo1(%arg0 : i32) -> i32 { + // CHECK: call @foo1 + %0 = arith.constant 1 : i32 + %1 = arith.subi %arg0, %0 : i32 + %2 = call @foo0(%1) : (i32) -> i32 + return %2 : i32 +}