Skip to content

Commit c774840

Browse files
committed
[mlir] Update the CallGraph for nested symbol references, and simplify CallableOpInterface
Summary: This enables tracking calls that cross symbol table boundaries. It also simplifies some of the implementation details of CallableOpInterface, i.e. there can only be one region within the callable operation. Depends On D72042 Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D72043
1 parent 6fca03f commit c774840

File tree

7 files changed

+82
-100
lines changed

7 files changed

+82
-100
lines changed

mlir/include/mlir/Analysis/CallInterfaces.td

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,29 +54,23 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
5454
be a target for a call-like operation (those providing the CallOpInterface
5555
above). These operations may be traditional functional operation
5656
`func @foo(...)`, as well as function producing operations
57-
`%foo = dialect.create_function(...)`. These operations may produce multiple
58-
callable regions, or subroutines.
57+
`%foo = dialect.create_function(...)`. These operations may only contain a
58+
single region, or subroutine.
5959
}];
6060

6161
let methods = [
6262
InterfaceMethod<[{
63-
Returns a region on the current operation that the given callable refers
64-
to. This may return null in the case of an external callable object,
65-
e.g. an external function.
63+
Returns the region on the current operation that is callable. This may
64+
return null in the case of an external callable object, e.g. an external
65+
function.
6666
}],
67-
"Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable)
67+
"Region *", "getCallableRegion"
6868
>,
6969
InterfaceMethod<[{
70-
Returns all of the callable regions of this operation.
71-
}],
72-
"void", "getCallableRegions",
73-
(ins "SmallVectorImpl<Region *> &":$callables)
74-
>,
75-
InterfaceMethod<[{
76-
Returns the results types that the given callable region produces when
70+
Returns the results types that the callable region produces when
7771
executed.
7872
}],
79-
"ArrayRef<Type>", "getCallableResults", (ins "Region *":$callable)
73+
"ArrayRef<Type>", "getCallableResults"
8074
>,
8175
];
8276
}

mlir/include/mlir/IR/Function.h

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -122,26 +122,13 @@ class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
122122
// CallableOpInterface
123123
//===--------------------------------------------------------------------===//
124124

125-
/// Returns a region on the current operation that the given callable refers
126-
/// to. This may return null in the case of an external callable object, e.g.
127-
/// an external function.
128-
Region *getCallableRegion(CallInterfaceCallable callable) {
129-
assert(callable.get<SymbolRefAttr>().getLeafReference() == getName());
130-
return isExternal() ? nullptr : &getBody();
131-
}
132-
133-
/// Returns all of the callable regions of this operation.
134-
void getCallableRegions(SmallVectorImpl<Region *> &callables) {
135-
if (!isExternal())
136-
callables.push_back(&getBody());
137-
}
125+
/// Returns the region on the current operation that is callable. This may
126+
/// return null in the case of an external callable object, e.g. an external
127+
/// function.
128+
Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
138129

139-
/// Returns the results types that the given callable region produces when
140-
/// executed.
141-
ArrayRef<Type> getCallableResults(Region *region) {
142-
assert(!isExternal() && region == &getBody() && "invalid callable");
143-
return getType().getResults();
144-
}
130+
/// Returns the results types that the callable region produces when executed.
131+
ArrayRef<Type> getCallableResults() { return getType().getResults(); }
145132

146133
private:
147134
// This trait needs access to the hooks defined below.

mlir/lib/Analysis/CallGraph.cpp

Lines changed: 25 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -74,67 +74,38 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
7474
/// Recursively compute the callgraph edges for the given operation. Computed
7575
/// edges are placed into the given callgraph object.
7676
static void computeCallGraph(Operation *op, CallGraph &cg,
77-
CallGraphNode *parentNode);
78-
79-
/// Compute the set of callgraph nodes that are created by regions nested within
80-
/// 'op'.
81-
static void computeCallables(Operation *op, CallGraph &cg,
82-
CallGraphNode *parentNode) {
83-
if (op->getNumRegions() == 0)
77+
CallGraphNode *parentNode, bool resolveCalls) {
78+
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
79+
// If there is no parent node, we ignore this operation. Even if this
80+
// operation was a call, there would be no callgraph node to attribute it
81+
// to.
82+
if (!resolveCalls || !parentNode)
83+
return;
84+
parentNode->addCallEdge(
85+
cg.resolveCallable(call.getCallableForCallee(), op));
8486
return;
85-
if (auto callableOp = dyn_cast<CallableOpInterface>(op)) {
86-
SmallVector<Region *, 1> callables;
87-
callableOp.getCallableRegions(callables);
88-
for (auto *callableRegion : callables)
89-
cg.getOrAddNode(callableRegion, parentNode);
9087
}
91-
}
9288

93-
/// Recursively compute the callgraph edges within the given region. Computed
94-
/// edges are placed into the given callgraph object.
95-
static void computeCallGraph(Region &region, CallGraph &cg,
96-
CallGraphNode *parentNode) {
97-
// Iterate over the nested operations twice:
98-
/// One to fully create nodes in the for each callable region of a nested
99-
/// operation;
100-
for (auto &block : region)
101-
for (auto &nested : block)
102-
computeCallables(&nested, cg, parentNode);
103-
104-
/// And another to recursively compute the callgraph.
105-
for (auto &block : region)
106-
for (auto &nested : block)
107-
computeCallGraph(&nested, cg, parentNode);
108-
}
109-
110-
/// Recursively compute the callgraph edges for the given operation. Computed
111-
/// edges are placed into the given callgraph object.
112-
static void computeCallGraph(Operation *op, CallGraph &cg,
113-
CallGraphNode *parentNode) {
11489
// Compute the callgraph nodes and edges for each of the nested operations.
115-
auto isCallable = isa<CallableOpInterface>(op);
116-
for (auto &region : op->getRegions()) {
117-
// Check to see if this region is a callable node, if so this is the parent
118-
// node of the nested region.
119-
CallGraphNode *nestedParentNode;
120-
if (!isCallable || !(nestedParentNode = cg.lookupNode(&region)))
121-
nestedParentNode = parentNode;
122-
computeCallGraph(region, cg, nestedParentNode);
90+
if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
91+
if (auto *callableRegion = callable.getCallableRegion())
92+
parentNode = cg.getOrAddNode(callableRegion, parentNode);
93+
else
94+
return;
12395
}
12496

125-
// If there is no parent node, we ignore this operation. Even if this
126-
// operation was a call, there would be no callgraph node to attribute it to.
127-
if (!parentNode)
128-
return;
129-
130-
// If this is a call operation, resolve the callee.
131-
if (auto call = dyn_cast<CallOpInterface>(op))
132-
parentNode->addCallEdge(
133-
cg.resolveCallable(call.getCallableForCallee(), op));
97+
for (Region &region : op->getRegions())
98+
for (Block &block : region)
99+
for (Operation &nested : block)
100+
computeCallGraph(&nested, cg, parentNode, resolveCalls);
134101
}
135102

136103
CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
137-
computeCallGraph(op, *this, /*parentNode=*/nullptr);
104+
// Make two passes over the graph, one to compute the callables and one to
105+
// resolve the calls. We split these up as we may have nested callable objects
106+
// that need to be reserved before the calls.
107+
computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
108+
computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
138109
}
139110

140111
/// Get or add a call graph node for the given region.
@@ -175,17 +146,15 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
175146
// Get the callee operation from the callable.
176147
Operation *callee;
177148
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
178-
// TODO(riverriddle) Support nested references.
179-
callee = SymbolTable::lookupNearestSymbolFrom(from,
180-
symbolRef.getRootReference());
149+
callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
181150
else
182151
callee = callable.get<Value>().getDefiningOp();
183152

184153
// If the callee is non-null and is a valid callable object, try to get the
185154
// called region from it.
186155
if (callee && callee->getNumRegions()) {
187156
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
188-
if (auto *node = lookupNode(callableOp.getCallableRegion(callable)))
157+
if (auto *node = lookupNode(callableOp.getCallableRegion()))
189158
return node;
190159
}
191160
}

mlir/lib/Transforms/Inliner.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,15 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
8686
while (!worklist.empty()) {
8787
for (Operation &op : *worklist.pop_back_val()) {
8888
if (auto call = dyn_cast<CallOpInterface>(op)) {
89-
CallGraphNode *node =
90-
cg.resolveCallable(call.getCallableForCallee(), &op);
89+
CallInterfaceCallable callable = call.getCallableForCallee();
90+
91+
// TODO(riverriddle) Support inlining nested call references.
92+
if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
93+
if (!symRef.isa<FlatSymbolRefAttr>())
94+
continue;
95+
}
96+
97+
CallGraphNode *node = cg.resolveCallable(callable, &op);
9198
if (!node->isExternal())
9299
calls.emplace_back(call, node);
93100
continue;
@@ -274,6 +281,15 @@ struct InlinerPass : public OperationPass<InlinerPass> {
274281
CallGraph &cg = getAnalysis<CallGraph>();
275282
auto *context = &getContext();
276283

284+
// The inliner should only be run on operations that define a symbol table,
285+
// as the callgraph will need to resolve references.
286+
Operation *op = getOperation();
287+
if (!op->hasTrait<OpTrait::SymbolTable>()) {
288+
op->emitOpError() << " was scheduled to run under the inliner, but does "
289+
"not define a symbol table";
290+
return signalPassFailure();
291+
}
292+
277293
// Collect a set of canonicalization patterns to use when simplifying
278294
// callable regions within an SCC.
279295
OwningRewritePatternList canonPatterns;

mlir/lib/Transforms/Utils/InliningUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
284284
if (src->empty())
285285
return failure();
286286
auto *entryBlock = &src->front();
287-
ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
287+
ArrayRef<Type> callableResultTypes = callable.getCallableResults();
288288

289289
// Make sure that the number of arguments and results matchup between the call
290290
// and the region.

mlir/test/Analysis/test-callgraph.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-print-callgraph 2>&1 | FileCheck %s --dump-input-on-failure
1+
// RUN: mlir-opt %s -test-print-callgraph -split-input-file 2>&1 | FileCheck %s --dump-input-on-failure
22

33
// CHECK-LABEL: Testing : "simple"
44
module attributes {test.name = "simple"} {
@@ -50,3 +50,22 @@ module attributes {test.name = "simple"} {
5050
return
5151
}
5252
}
53+
54+
// -----
55+
56+
// CHECK-LABEL: Testing : "nested"
57+
module attributes {test.name = "nested"} {
58+
module @nested_module {
59+
// CHECK: Node{{.*}}func_a
60+
func @func_a() {
61+
return
62+
}
63+
}
64+
65+
// CHECK: Node{{.*}}func_b
66+
// CHECK: Call-Edge{{.*}}func_a
67+
func @func_b() {
68+
"test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> ()
69+
return
70+
}
71+
}

mlir/test/lib/TestDialect/TestOps.td

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
230230

231231
def ConversionCallOp : TEST_Op<"conversion_call_op",
232232
[CallOpInterface]> {
233-
let arguments = (ins Variadic<AnyType>:$inputs, FlatSymbolRefAttr:$callee);
233+
let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
234234
let results = (outs Variadic<AnyType>);
235235

236236
let extraClassDeclaration = [{
@@ -239,7 +239,7 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
239239

240240
/// Return the callee of this operation.
241241
CallInterfaceCallable getCallableForCallee() {
242-
return getAttrOfType<FlatSymbolRefAttr>("callee");
242+
return getAttrOfType<SymbolRefAttr>("callee");
243243
}
244244
}];
245245
}
@@ -250,11 +250,8 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
250250
let results = (outs FunctionType);
251251

252252
let extraClassDeclaration = [{
253-
Region *getCallableRegion(CallInterfaceCallable) { return &body(); }
254-
void getCallableRegions(SmallVectorImpl<Region *> &callables) {
255-
callables.push_back(&body());
256-
}
257-
ArrayRef<Type> getCallableResults(Region *) {
253+
Region *getCallableRegion() { return &body(); }
254+
ArrayRef<Type> getCallableResults() {
258255
return getType().cast<FunctionType>().getResults();
259256
}
260257
}];

0 commit comments

Comments
 (0)