Skip to content

Conversation

@Men-cotton
Copy link
Contributor

Fixes: #64639

Handle nested/invalid symbol references during GPU kernel outlining to avoid crashes and emit proper diagnostics.
Add a test to cover the new behavior.

CC: @fabianmcg

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Men-cotton (Men-cotton)

Changes

Fixes: #64639

Handle nested/invalid symbol references during GPU kernel outlining to avoid crashes and emit proper diagnostics.
Add a test to cover the new behavior.

CC: @fabianmcg


Full diff: https://github.com/llvm/llvm-project/pull/169843.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp (+24-7)
  • (added) mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir (+29)
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 97adad64d78c4..dace8fa38dc6d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -371,7 +371,11 @@ class GpuKernelOutliningPass
         // Create nested module and insert outlinedFunc. The module will
         // originally get the same name as the function, but may be renamed on
         // insertion into the parent module.
-        auto kernelModule = createKernelModule(op, outlinedFunc, symbolTable);
+        auto kernelModuleOrFailure =
+            createKernelModule(op, outlinedFunc, symbolTable);
+        if (failed(kernelModuleOrFailure))
+          return WalkResult::interrupt();
+        auto kernelModule = *kernelModuleOrFailure;
         symbolTable.insert(kernelModule, insertPt);
 
         // Potentially changes signature, pulling in constants.
@@ -392,9 +396,9 @@ class GpuKernelOutliningPass
 
 private:
   /// Returns a gpu.module containing kernelFunc and all callees (recursive).
-  gpu::GPUModuleOp createKernelModule(gpu::LaunchOp gpuLaunchOp,
-                                      gpu::GPUFuncOp kernelFunc,
-                                      const SymbolTable &parentSymbolTable) {
+  FailureOr<gpu::GPUModuleOp>
+  createKernelModule(gpu::LaunchOp gpuLaunchOp, gpu::GPUFuncOp kernelFunc,
+                     const SymbolTable &parentSymbolTable) {
     // TODO: This code cannot use an OpBuilder because it must be inserted into
     // a SymbolTable by the caller. SymbolTable needs to be refactored to
     // prevent manual building of Ops with symbols in code using SymbolTables
@@ -431,12 +435,25 @@ class GpuKernelOutliningPass
       if (std::optional<SymbolTable::UseRange> symbolUses =
               SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
         for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
-          StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
+          SymbolRefAttr symbolRef = symbolUse.getSymbolRef();
+          StringAttr symbolName = symbolRef.getLeafReference();
           if (symbolTable.lookup(symbolName))
             continue;
 
-          Operation *symbolDefClone =
-              parentSymbolTable.lookup(symbolName)->clone();
+          Operation *symbolDef =
+              SymbolTable::lookupSymbolIn(parentSymbolTable.getOp(), symbolRef);
+          if (!symbolDef) {
+            if (isa<FlatSymbolRefAttr>(symbolRef)) {
+              return symbolUse.getUser()->emitOpError(
+                  "failed to outline gpu kernel: symbol '" +
+                  symbolName.getValue() + "' not found");
+            }
+            return symbolUse.getUser()->emitOpError(
+                       "failed to outline gpu kernel: "
+                       "found invalid symbol reference: ")
+                   << symbolRef;
+          }
+          Operation *symbolDefClone = symbolDef->clone();
           symbolDefWorklist.push_back(symbolDefClone);
           symbolTable.insert(symbolDefClone);
         }
diff --git a/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir
new file mode 100644
index 0000000000000..5cd290fc396c3
--- /dev/null
+++ b/mlir/test/Dialect/GPU/outlining-invalid-symbol.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -gpu-kernel-outlining -verify-diagnostics -split-input-file %s
+
+module attributes {gpu.container_module} {
+  func.func @kernel_crash() {
+    %c1 = arith.constant 1 : index
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+               threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+      // expected-error@+1 {{failed to outline gpu kernel: symbol 'unknown_func' not found}}
+      "test.op"() {symbol = @unknown_func} : () -> ()
+      gpu.terminator
+    }
+    return
+  }
+}
+
+// -----
+
+module attributes {gpu.container_module} {
+  func.func @kernel_invalid_ref() {
+    %c1 = arith.constant 1 : index
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+               threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
+      // expected-error@+1 {{failed to outline gpu kernel: found invalid symbol reference: @nested::@ref}}
+      "test.op"() {symbol = @nested::@ref} : () -> ()
+      gpu.terminator
+    }
+    return
+  }
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] --gpu-kernel-outlining crashed with assertion failure "cast<Ty>() argument of incompatible type!"

2 participants