diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index a73afbcb6474b..2285d2695db4e 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -20,20 +20,20 @@ using namespace mlir; -LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, - Location loc, OpBuilder &b, - StringRef name, +LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc, + OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type) { - LLVM::LLVMFuncOp ret; - if (!(ret = moduleOp.template lookupSymbol(name))) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(moduleOp.getBody()); - ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); - } - return ret; + auto existing = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(moduleOp, name)); + if (existing) + return existing; + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); + return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } -static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, +static SmallString<16> getUniqueSymbolName(Operation *moduleOp, StringRef prefix) { // Get a unique global name. unsigned stringNumber = 0; @@ -41,15 +41,16 @@ static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, do { stringConstName.clear(); (prefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); + } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName)); return stringConstName; } -LLVM::GlobalOp -mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, - gpu::GPUModuleOp moduleOp, Type llvmI8, - StringRef namePrefix, StringRef str, - uint64_t alignment, unsigned addrSpace) { +LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, + Operation *moduleOp, Type llvmI8, + StringRef namePrefix, + StringRef str, + uint64_t alignment, + unsigned addrSpace) { llvm::SmallString<20> nullTermStr(str); nullTermStr.push_back('\0'); // Null terminate for C auto globalType = @@ -57,7 +58,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, StringAttr attr = b.getStringAttr(nullTermStr); // Try to find existing global. - for (auto globalOp : moduleOp.getOps()) + for (auto globalOp : moduleOp->getRegion(0).getOps()) if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && globalOp.getValueAttr() == attr && globalOp.getAlignment().value_or(0) == alignment && @@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, // Not found: create new global. OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(moduleOp.getBody()); + b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); return LLVM::GlobalOp::create(b, loc, globalType, /*isConstant=*/true, LLVM::Linkage::Internal, @@ -396,10 +397,11 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); - // Note: this is the GPUModule op, not the ModuleOp that surrounds it - // This ensures that global constants and declarations are placed within - // the device code, not the host code - auto moduleOp = gpuPrintfOp->getParentOfType(); + + Operation *moduleOp = gpuPrintfOp->getParentWithTrait(); + if (!moduleOp) + return rewriter.notifyMatchFailure(gpuPrintfOp, + "Couldn't find a parent module"); auto ocklBegin = getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin", @@ -496,10 +498,10 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); - // Note: this is the GPUModule op, not the ModuleOp that surrounds it - // This ensures that global constants and declarations are placed within - // the device code, not the host code - auto moduleOp = gpuPrintfOp->getParentOfType(); + Operation *moduleOp = gpuPrintfOp->getParentWithTrait(); + if (!moduleOp) + return rewriter.notifyMatchFailure(gpuPrintfOp, + "Couldn't find a parent module"); auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, @@ -541,10 +543,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - // Note: this is the GPUModule op, not the ModuleOp that surrounds it - // This ensures that global constants and declarations are placed within - // the device code, not the host code - auto moduleOp = gpuPrintfOp->getParentOfType(); + Operation *moduleOp = gpuPrintfOp->getParentWithTrait(); + if (!moduleOp) + return rewriter.notifyMatchFailure(gpuPrintfOp, + "Couldn't find a parent module"); // Create a valid global location removing any metadata attached to the // location as debug info metadata inside of a function cannot be used outside diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index e17b06379988c..66d3bb40a8f5a 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -18,15 +18,18 @@ namespace mlir { // Helper Functions //===----------------------------------------------------------------------===// +/// Note that these functions don't take a `SymbolTable` because GPU module +/// lowerings can have name collisions as an intermediate state. + /// Find or create an external function declaration in the given module. -LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, +LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type); /// Create a global that contains the given string. If a global with the same /// string already exists in the module, return that global. LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, - gpu::GPUModuleOp moduleOp, Type llvmI8, + Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0); diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir index 2dc6a5ab2a86c..32da31202b688 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -convert-gpu-to-rocdl='runtime=HIP' -split-input-file | FileCheck %s +// CHECK-LABEL: gpu.module @test_module gpu.module @test_module { // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00") // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00") @@ -40,3 +41,38 @@ gpu.module @test_module { gpu.return } } + +// ----- + +// The bulitin.module we're targetting is wrapped in a fake gpu.module +// because the convert-gpu-to-rocdl pass only runs an `gpu.module` ops, +// even though the printf patterns could run in other contexts. + +// CHECK-LABEL: gpu.module @fake_gpu_module_for_test +// CHECK-LABEL: builtin.module @test_module +gpu.module @fake_gpu_module_for_test { +builtin.module @test_module { + // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00") + // CHECK-DAG: llvm.func @__ockl_printf_append_args(i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64 + // CHECK-DAG: llvm.func @__ockl_printf_append_string_n(i64, !llvm.ptr, i64, i32) -> i64 + // CHECK-DAG: llvm.func @__ockl_printf_begin(i64) -> i64 + + // CHECK-LABEL: llvm.func @test_printf + // CHECK: (%[[ARG0:.*]]: i32) + llvm.func @test_printf(%arg0: i32) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64 + // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr + // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8> + // CHECK-NEXT: %[[FORMATLEN:.*]] = llvm.mlir.constant(11 : i64) : i64 + // CHECK-NEXT: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[DESC1:.*]] = llvm.call @__ockl_printf_append_string_n(%[[DESC0]], %[[FORMATSTART]], %[[FORMATLEN]], %[[ISNTLAST]]) : (i64, !llvm.ptr, i64, i32) -> i64 + // CHECK-NEXT: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ARG0_64:.*]] = llvm.zext %[[ARG0]] : i32 to i64 + // CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[ISLAST]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64 + gpu.printf "Hello: %d\n", %arg0 : i32 + llvm.return + } +} +}