diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h index 6b486ba2d1ac6..cd25151da4c07 100644 --- a/mlir/include/mlir/Analysis/CallGraph.h +++ b/mlir/include/mlir/Analysis/CallGraph.h @@ -23,6 +23,7 @@ #include "llvm/ADT/SetVector.h" namespace mlir { +class CallOpInterface; struct CallInterfaceCallable; class Operation; class Region; @@ -188,11 +189,8 @@ class CallGraph { } /// Resolve the callable for given callee to a node in the callgraph, or the - /// external node if a valid node was not resolved. 'from' provides an anchor - /// for symbol table lookups, and is only required if the callable is a symbol - /// reference. - CallGraphNode *resolveCallable(CallInterfaceCallable callable, - Operation *from = nullptr) const; + /// external node if a valid node was not resolved. + CallGraphNode *resolveCallable(CallOpInterface call) const; /// An iterator over the nodes of the graph. using iterator = NodeIterator; diff --git a/mlir/include/mlir/Analysis/CallInterfaces.h b/mlir/include/mlir/Analysis/CallInterfaces.h index e0d9ba2fcbdc1..ebefc88e21a95 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.h +++ b/mlir/include/mlir/Analysis/CallInterfaces.h @@ -14,11 +14,10 @@ #ifndef MLIR_ANALYSIS_CALLINTERFACES_H #define MLIR_ANALYSIS_CALLINTERFACES_H -#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" #include "llvm/ADT/PointerUnion.h" namespace mlir { - /// A callable is either a symbol, or an SSA value, that is referenced by a /// call-like operation. This represents the destination of the call. struct CallInterfaceCallable : public PointerUnion { diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td index 7dd58a48fa1da..2bc59c224b4f7 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.td +++ b/mlir/include/mlir/Analysis/CallInterfaces.td @@ -44,6 +44,18 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { }], "Operation::operand_range", "getArgOperands" >, + InterfaceMethod<[{ + Resolve the callable operation for given callee to a + CallableOpInterface, or nullptr if a valid callable was not resolved. + }], + "Operation *", "resolveCallable", (ins), [{ + // If the callable isn't a value, lookup the symbol reference. + CallInterfaceCallable callable = op.getCallableForCallee(); + if (auto symbolRef = callable.dyn_cast()) + return SymbolTable::lookupNearestSymbolFrom(op, symbolRef); + return callable.get().getDefiningOp(); + }] + >, ]; } diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index 0f3bfbb653c8e..d61b2359e6911 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -79,10 +79,8 @@ static void computeCallGraph(Operation *op, CallGraph &cg, // If there is no parent node, we ignore this operation. Even if this // operation was a call, there would be no callgraph node to attribute it // to. - if (!resolveCalls || !parentNode) - return; - parentNode->addCallEdge( - cg.resolveCallable(call.getCallableForCallee(), op)); + if (resolveCalls && parentNode) + parentNode->addCallEdge(cg.resolveCallable(call)); return; } @@ -141,23 +139,11 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const { /// Resolve the callable for given callee to a node in the callgraph, or the /// external node if a valid node was not resolved. -CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, - Operation *from) const { - // Get the callee operation from the callable. - Operation *callee; - if (auto symbolRef = callable.dyn_cast()) - callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef); - else - callee = callable.get().getDefiningOp(); - - // If the callee is non-null and is a valid callable object, try to get the - // called region from it. - if (callee && callee->getNumRegions()) { - if (auto callableOp = dyn_cast_or_null(callee)) { - if (auto *node = lookupNode(callableOp.getCallableRegion())) - return node; - } - } +CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const { + Operation *callable = call.resolveCallable(); + if (auto callableOp = dyn_cast_or_null(callable)) + if (auto *node = lookupNode(callableOp.getCallableRegion())) + return node; // If we don't have a valid direct region, this is an external call. return getExternalNode(); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 3b68553666576..b6fcf8bc3941c 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -86,15 +86,14 @@ static void collectCallOps(iterator_range blocks, while (!worklist.empty()) { for (Operation &op : *worklist.pop_back_val()) { if (auto call = dyn_cast(op)) { - CallInterfaceCallable callable = call.getCallableForCallee(); - // TODO(riverriddle) Support inlining nested call references. + CallInterfaceCallable callable = call.getCallableForCallee(); if (SymbolRefAttr symRef = callable.dyn_cast()) { if (!symRef.isa()) continue; } - CallGraphNode *node = cg.resolveCallable(callable, &op); + CallGraphNode *node = cg.resolveCallable(call); if (!node->isExternal()) calls.emplace_back(call, node); continue;