Skip to content

Commit 13ae9ea

Browse files
authored
[MLIR] Avoid resolving callable outside the analysis scope in DeadCodeAnalysis (#155088)
We are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions. The callable would be discarded right after the lookup (we check the analysis scope), so avoiding the lookup is NFC. For the DataFlow solver, we're looking at the top-level operation, and if it isn't a SymbolTable we disable the interprocedural optimization in the solver config directly. This strategy isn't NFC but seems reasonnable and does not encounter any change in behavior in practice in tree. Fix #154948
1 parent 5d8d98c commit 13ae9ea

File tree

5 files changed

+62
-21
lines changed

5 files changed

+62
-21
lines changed

mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
229229
/// considered an external callable.
230230
Operation *analysisScope;
231231

232+
/// Whether the analysis scope has a symbol table. This is used to avoid
233+
/// resolving callables outside the analysis scope.
234+
/// It is updated when recursing into a region in case where the top-level
235+
/// operation does not have a symbol table, but one is encountered in a nested
236+
/// region.
237+
bool hasSymbolTable = false;
238+
232239
/// A symbol table used for O(1) symbol lookups during simplification.
233240
SymbolTableCollection symbolTable;
234241
};

mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Interfaces/CallInterfaces.h"
2323
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2424
#include "mlir/Support/LLVM.h"
25+
#include "llvm/ADT/ScopeExit.h"
2526
#include "llvm/Support/Casting.h"
2627
#include "llvm/Support/Debug.h"
2728
#include "llvm/Support/DebugLog.h"
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
159160
LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
160161
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
161162
analysisScope = top;
163+
hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>();
162164
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
163165
LDBG() << "[init] Processing symbol table op: "
164166
<< OpWithFlags(symTable, OpPrintingFlags().skipRegions());
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
260262
return failure();
261263
}
262264
// Recurse on nested operations.
263-
for (Region &region : op->getRegions()) {
264-
LDBG() << "[init] Recursing into region of op: "
265-
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
266-
for (Operation &nestedOp : region.getOps()) {
267-
LDBG() << "[init] Recursing into nested op: "
268-
<< OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
269-
if (failed(initializeRecursively(&nestedOp)))
270-
return failure();
265+
if (op->getNumRegions()) {
266+
// If we haven't seen a symbol table yet, check if the current operation
267+
// has one. If so, update the flag to allow for resolving callables in
268+
// nested regions.
269+
bool savedHasSymbolTable = hasSymbolTable;
270+
auto restoreHasSymbolTable =
271+
llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; });
272+
if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>())
273+
hasSymbolTable = true;
274+
275+
for (Region &region : op->getRegions()) {
276+
LDBG() << "[init] Recursing into region of op: "
277+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
278+
for (Operation &nestedOp : region.getOps()) {
279+
LDBG() << "[init] Recursing into nested op: "
280+
<< OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
281+
if (failed(initializeRecursively(&nestedOp)))
282+
return failure();
283+
}
271284
}
272285
}
273286
LDBG() << "[init] Finished initializeRecursively for op: "
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
388401
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
389402
LDBG() << "visitCallOperation: "
390403
<< OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
391-
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
404+
405+
Operation *callableOp = nullptr;
406+
if (hasSymbolTable)
407+
callableOp = call.resolveCallableInTable(&symbolTable);
408+
else
409+
LDBG()
410+
<< "No symbol table present in analysis scope, can't resolve callable";
392411

393412
// A call to a externally-defined callable has unknown predecessors.
394413
const auto isExternalCallable = [this](Operation *op) {

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
6464
AbstractDenseLattice *after) {
6565
// Allow for customizing the behavior of calls to external symbols, including
6666
// when the analysis is explicitly marked as non-interprocedural.
67-
auto callable =
68-
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
69-
if (!getSolverConfig().isInterprocedural() ||
70-
(callable && !callable.getCallableRegion())) {
67+
auto isExternalCallable = [&]() {
68+
auto callable =
69+
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
70+
return callable && !callable.getCallableRegion();
71+
};
72+
if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
7173
return visitCallControlFlowTransfer(
7274
call, CallControlFlowAction::ExternalCallee, before, after);
7375
}
@@ -290,19 +292,23 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
290292
void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
291293
CallOpInterface call, const AbstractDenseLattice &after,
292294
AbstractDenseLattice *before) {
295+
// If the solver is not interprocedural, let the hook handle it as an external
296+
// callee.
297+
if (!getSolverConfig().isInterprocedural())
298+
return visitCallControlFlowTransfer(
299+
call, CallControlFlowAction::ExternalCallee, after, before);
300+
293301
// Find the callee.
294302
Operation *callee = call.resolveCallableInTable(&symbolTable);
295303

296304
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
297305
// No region means the callee is only declared in this module.
298306
// If that is the case or if the solver is not interprocedural,
299307
// let the hook handle it.
300-
if (!getSolverConfig().isInterprocedural() ||
301-
(callable && (!callable.getCallableRegion() ||
302-
callable.getCallableRegion()->empty()))) {
308+
if (callable &&
309+
(!callable.getCallableRegion() || callable.getCallableRegion()->empty()))
303310
return visitCallControlFlowTransfer(
304311
call, CallControlFlowAction::ExternalCallee, after, before);
305-
}
306312

307313
if (!callable)
308314
return setToExitState(before);

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,12 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
228228
ArrayRef<AbstractSparseLattice *> resultLattices) {
229229
// If the call operation is to an external function, attempt to infer the
230230
// results from the call arguments.
231-
auto callable =
232-
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
233-
if (!getSolverConfig().isInterprocedural() ||
234-
(callable && !callable.getCallableRegion())) {
231+
auto isExternalCallable = [&]() {
232+
auto callable =
233+
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
234+
return callable && !callable.getCallableRegion();
235+
};
236+
if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
235237
visitExternalCallImpl(call, operandLattices, resultLattices);
236238
return success();
237239
}

mlir/lib/Analysis/DataFlowFramework.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Analysis/DataFlowFramework.h"
1010
#include "mlir/IR/Location.h"
1111
#include "mlir/IR/Operation.h"
12+
#include "mlir/IR/SymbolTable.h"
1213
#include "mlir/IR/Value.h"
1314
#include "llvm/ADT/ScopeExit.h"
1415
#include "llvm/ADT/iterator.h"
@@ -109,6 +110,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
109110
isRunning = true;
110111
auto guard = llvm::make_scope_exit([&]() { isRunning = false; });
111112

113+
bool isInterprocedural = config.isInterprocedural();
114+
auto restoreInterprocedural = llvm::make_scope_exit(
115+
[&]() { config.setInterprocedural(isInterprocedural); });
116+
if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>())
117+
config.setInterprocedural(false);
118+
112119
// Initialize equivalent lattice anchors.
113120
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
114121
analysis.initializeEquivalentLatticeAnchor(top);

0 commit comments

Comments
 (0)