|
22 | 22 | #include "mlir/Interfaces/CallInterfaces.h"
|
23 | 23 | #include "mlir/Interfaces/ControlFlowInterfaces.h"
|
24 | 24 | #include "mlir/Support/LLVM.h"
|
| 25 | +#include "llvm/ADT/ScopeExit.h" |
25 | 26 | #include "llvm/Support/Casting.h"
|
26 | 27 | #include "llvm/Support/Debug.h"
|
27 | 28 | #include "llvm/Support/DebugLog.h"
|
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
|
159 | 160 | LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
|
160 | 161 | << OpWithFlags(top, OpPrintingFlags().skipRegions());
|
161 | 162 | analysisScope = top;
|
| 163 | + hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>(); |
162 | 164 | auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
|
163 | 165 | LDBG() << "[init] Processing symbol table op: "
|
164 | 166 | << OpWithFlags(symTable, OpPrintingFlags().skipRegions());
|
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
|
260 | 262 | return failure();
|
261 | 263 | }
|
262 | 264 | // Recurse on nested operations.
|
263 |
| - for (Region ®ion : op->getRegions()) { |
264 |
| - LDBG() << "[init] Recursing into region of op: " |
265 |
| - << OpWithFlags(op, OpPrintingFlags().skipRegions()); |
266 |
| - for (Operation &nestedOp : region.getOps()) { |
267 |
| - LDBG() << "[init] Recursing into nested op: " |
268 |
| - << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); |
269 |
| - if (failed(initializeRecursively(&nestedOp))) |
270 |
| - return failure(); |
| 265 | + if (op->getNumRegions()) { |
| 266 | + // If we haven't seen a symbol table yet, check if the current operation |
| 267 | + // has one. If so, update the flag to allow for resolving callables in |
| 268 | + // nested regions. |
| 269 | + bool savedHasSymbolTable = hasSymbolTable; |
| 270 | + auto restoreHasSymbolTable = |
| 271 | + llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; }); |
| 272 | + if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>()) |
| 273 | + hasSymbolTable = true; |
| 274 | + |
| 275 | + for (Region ®ion : op->getRegions()) { |
| 276 | + LDBG() << "[init] Recursing into region of op: " |
| 277 | + << OpWithFlags(op, OpPrintingFlags().skipRegions()); |
| 278 | + for (Operation &nestedOp : region.getOps()) { |
| 279 | + LDBG() << "[init] Recursing into nested op: " |
| 280 | + << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); |
| 281 | + if (failed(initializeRecursively(&nestedOp))) |
| 282 | + return failure(); |
| 283 | + } |
271 | 284 | }
|
272 | 285 | }
|
273 | 286 | LDBG() << "[init] Finished initializeRecursively for op: "
|
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
|
388 | 401 | void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
|
389 | 402 | LDBG() << "visitCallOperation: "
|
390 | 403 | << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
|
391 |
| - Operation *callableOp = call.resolveCallableInTable(&symbolTable); |
| 404 | + |
| 405 | + Operation *callableOp = nullptr; |
| 406 | + if (hasSymbolTable) |
| 407 | + callableOp = call.resolveCallableInTable(&symbolTable); |
| 408 | + else |
| 409 | + LDBG() |
| 410 | + << "No symbol table present in analysis scope, can't resolve callable"; |
392 | 411 |
|
393 | 412 | // A call to a externally-defined callable has unknown predecessors.
|
394 | 413 | const auto isExternalCallable = [this](Operation *op) {
|
|
0 commit comments