diff --git a/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp index 54fa15722febe..db06d2beed94a 100644 --- a/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp +++ b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp @@ -33,7 +33,7 @@ class MLProgramPipelineGlobals llvm::DenseMap> storeSymbolsMap; }; -// Traverses upwards searchign for the operation mapped by the symbol. +// Traverses upwards searching for the operation mapped by the symbol. static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) { for (auto *op = baseOp; op; op = op->getParentOp()) { auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol); @@ -57,6 +57,9 @@ LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { auto symbol = mlir::dyn_cast(callable); auto *func = getFromSymbol(op, symbol); + // If the callee cannot be resolved, we cannot safely analyze the IR. + if (!func) + return WalkResult::interrupt(); callableMap[symbol] = func; } return WalkResult::advance(); @@ -95,8 +98,13 @@ LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { llvm::DenseSet storeSymbols; for (size_t i = 0; i < work.size(); ++i) { - callableMap[work[i]]->walk([&](CallOpInterface call) { - auto symbol = dyn_cast(call.getCallableForCallee()); + // Defensive: symbols in `work` should always be in `callableMap` since + // buildGlobalMap interrupted on any unresolvable callee, but use find to + // avoid inserting null entries via operator[]. + auto it = callableMap.find(work[i]); + assert(it != callableMap.end() && "Expected callable in callableMap"); + it->second->walk([&](CallOpInterface call) { + auto symbol = cast(call.getCallableForCallee()); if (visited.insert(symbol).second) work.push_back(symbol); }); diff --git a/mlir/test/Dialect/MLProgram/pipeline-globals.mlir b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir index a5c9b3e890558..a7b52efe08ca3 100644 --- a/mlir/test/Dialect/MLProgram/pipeline-globals.mlir +++ b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir @@ -219,6 +219,28 @@ func.func @call_indirect_load() { // ----- +// Calling via a CallOpInterface op whose callee symbol is not defined in this +// module should not crash - the pass should bail out gracefully. +// See https://github.com/llvm/llvm-project/issues/109649 + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @call_with_unresolvable_callee +func.func @call_with_unresolvable_callee(%arg0: memref) { + // Both loads must be preserved; the pass conservatively bails out when it + // encounters a call whose callee symbol cannot be resolved. + // CHECK: ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + // @callee is not defined anywhere in this module. + test.call_and_store @callee(%arg0), %arg0 {store_before_call = false} : (memref, memref) -> () + // CHECK: ml_program.global_load @global_variable + %1 = ml_program.global_load @global_variable : tensor<4xi32> + func.return +} + +// ----- + // CHECK-LABEL: @global_variable ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>