From 63d403b153bbc23c3451f14ba63b410fc253a559 Mon Sep 17 00:00:00 2001 From: mencotton Date: Fri, 28 Nov 2025 02:40:59 +0900 Subject: [PATCH] [mlir][gpu] Avoid kernel outlining crash on invalid symbol refs --- .../GPU/Transforms/KernelOutlining.cpp | 31 ++++++++++++++----- .../Dialect/GPU/outlining-invalid-symbol.mlir | 29 +++++++++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 97adad64d78c4..dace8fa38dc6d 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -371,7 +371,11 @@ class GpuKernelOutliningPass // Create nested module and insert outlinedFunc. The module will // originally get the same name as the function, but may be renamed on // insertion into the parent module. - auto kernelModule = createKernelModule(op, outlinedFunc, symbolTable); + auto kernelModuleOrFailure = + createKernelModule(op, outlinedFunc, symbolTable); + if (failed(kernelModuleOrFailure)) + return WalkResult::interrupt(); + auto kernelModule = *kernelModuleOrFailure; symbolTable.insert(kernelModule, insertPt); // Potentially changes signature, pulling in constants. @@ -392,9 +396,9 @@ class GpuKernelOutliningPass private: /// Returns a gpu.module containing kernelFunc and all callees (recursive). - gpu::GPUModuleOp createKernelModule(gpu::LaunchOp gpuLaunchOp, - gpu::GPUFuncOp kernelFunc, - const SymbolTable &parentSymbolTable) { + FailureOr + createKernelModule(gpu::LaunchOp gpuLaunchOp, gpu::GPUFuncOp kernelFunc, + const SymbolTable &parentSymbolTable) { // TODO: This code cannot use an OpBuilder because it must be inserted into // a SymbolTable by the caller. SymbolTable needs to be refactored to // prevent manual building of Ops with symbols in code using SymbolTables @@ -431,12 +435,25 @@ class GpuKernelOutliningPass if (std::optional symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference(); + SymbolRefAttr symbolRef = symbolUse.getSymbolRef(); + StringAttr symbolName = symbolRef.getLeafReference(); if (symbolTable.lookup(symbolName)) continue; - Operation *symbolDefClone = - parentSymbolTable.lookup(symbolName)->clone(); + Operation *symbolDef = + SymbolTable::lookupSymbolIn(parentSymbolTable.getOp(), symbolRef); + if (!symbolDef) { + if (isa(symbolRef)) { + return symbolUse.getUser()->emitOpError( + "failed to outline gpu kernel: symbol '" + + symbolName.getValue() + "' not found"); + } + return symbolUse.getUser()->emitOpError( + "failed to outline gpu kernel: " + "found invalid symbol reference: ") + << symbolRef; + } + Operation *symbolDefClone = symbolDef->clone(); symbolDefWorklist.push_back(symbolDefClone); symbolTable.insert(symbolDefClone); } diff --git a/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir new file mode 100644 index 0000000000000..5cd290fc396c3 --- /dev/null +++ b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -gpu-kernel-outlining -verify-diagnostics -split-input-file %s + +module attributes {gpu.container_module} { + func.func @kernel_crash() { + %c1 = arith.constant 1 : index + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) { + // expected-error@+1 {{failed to outline gpu kernel: symbol 'unknown_func' not found}} + "test.op"() {symbol = @unknown_func} : () -> () + gpu.terminator + } + return + } +} + +// ----- + +module attributes {gpu.container_module} { + func.func @kernel_invalid_ref() { + %c1 = arith.constant 1 : index + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) { + // expected-error@+1 {{failed to outline gpu kernel: found invalid symbol reference: @nested::@ref}} + "test.op"() {symbol = @nested::@ref} : () -> () + gpu.terminator + } + return + } +}