diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index d4978ca768747..97adad64d78c4 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -431,8 +431,7 @@ class GpuKernelOutliningPass if (std::optional symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - StringRef symbolName = - cast(symbolUse.getSymbolRef()).getValue(); + StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference(); if (symbolTable.lookup(symbolName)) continue; diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir index 04901182a80f5..e14521ab9bb5c 100644 --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -634,3 +634,29 @@ func.func @testNoAttributes() { } return } + +// ----- + +// This test tests nested `gpu.launch`. + +// CHECK-LABEL: func.func @nested_launch( +// CHECK-SAME: %[[ARG0:.*]]: index) { +// CHECK: gpu.launch_func @nested_launch_kernel_0::@nested_launch_kernel blocks in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) threads in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) args(%[[ARG0]] : index) +// CHECK: gpu.module @nested_launch_kernel +// CHECK: gpu.func @nested_launch_kernel() kernel +// CHECK: "some_op" +// CHECK: gpu.module @nested_launch_kernel_0 +// CHECK: gpu.func @nested_launch_kernel(%[[VAL_0:.*]]: index) kernel +// CHECK: gpu.launch_func @nested_launch_kernel::@nested_launch_kernel blocks in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) threads in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) +func.func @nested_launch(%sz : index) { + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz) + threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) { + gpu.launch blocks(%bx1, %by1, %bz1) in (%grid_x1 = %sz, %grid_y1 = %sz, %grid_z1 = %sz) + threads(%tx1, %ty1, %tz1) in (%block_x1 = %sz, %block_y1 = %sz, %block_z1 = %sz) { + "some_op"(%bx1, %tx1) : (index, index) -> () + gpu.terminator + } + gpu.terminator + } + return +}