-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][GPU] Generalize gpu.printf to not need gpu.module #161266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][GPU] Generalize gpu.printf to not need gpu.module #161266
Conversation
In order to make the gpu.printf => [various LLVM calls] passes less order-dependent and to allow downstreams that don't use gpu.module to use gpu.printf, allow the lowerigs for such prints to target the neraest `builtin.module` if a `gpu.module` cannot be found.
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesIn order to make the gpu.printf => [various LLVM calls] passes less order-dependent and to allow downstreams that don't use gpu.module to use gpu.printf, allow the lowerigs for such prints to target the neraest Full diff: https://github.com/llvm/llvm-project/pull/161266.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a73afbcb6474b..78bdbbfc61836 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -20,20 +20,20 @@
using namespace mlir;
-LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
- Location loc, OpBuilder &b,
- StringRef name,
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
+ OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
- if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ if (!(ret = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name)))) {
OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
return ret;
}
-static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+static SmallString<16> getUniqueSymbolName(Operation *moduleOp,
StringRef prefix) {
// Get a unique global name.
unsigned stringNumber = 0;
@@ -41,15 +41,16 @@ static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
do {
stringConstName.clear();
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
+ } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
return stringConstName;
}
-LLVM::GlobalOp
-mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
- gpu::GPUModuleOp moduleOp, Type llvmI8,
- StringRef namePrefix, StringRef str,
- uint64_t alignment, unsigned addrSpace) {
+LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+ Operation *moduleOp, Type llvmI8,
+ StringRef namePrefix,
+ StringRef str,
+ uint64_t alignment,
+ unsigned addrSpace) {
llvm::SmallString<20> nullTermStr(str);
nullTermStr.push_back('\0'); // Null terminate for C
auto globalType =
@@ -57,7 +58,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
StringAttr attr = b.getStringAttr(nullTermStr);
// Try to find existing global.
- for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
@@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
return LLVM::GlobalOp::create(b, loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
@@ -398,8 +399,15 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
- // the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // the device code, not the host code.
+ Operation *moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // However, if the `gpu.module` is already lowered or for compilers that don't
+ // use `gpu.module`, fall back to `builtin.module`.
+ if (!moduleOp)
+ moduleOp = gpuPrintfOp->getParentOfType<ModuleOp>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
auto ocklBegin =
getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
@@ -499,7 +507,14 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ Operation *moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // However, if the `gpu.module` is already lowered or for compilers that don't
+ // use `gpu.module`, fall back to `builtin.module`.
+ if (!moduleOp)
+ moduleOp = gpuPrintfOp->getParentOfType<ModuleOp>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
auto printfType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
@@ -544,7 +559,14 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ Operation *moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // However, if the `gpu.module` is already lowered or for compilers that don't
+ // use `gpu.module`, fall back to `builtin.module`.
+ if (!moduleOp)
+ moduleOp = gpuPrintfOp->getParentOfType<ModuleOp>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
// Create a valid global location removing any metadata attached to the
// location as debug info metadata inside of a function cannot be used outside
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e17b06379988c..5eceb96e5234b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -19,14 +19,14 @@ namespace mlir {
//===----------------------------------------------------------------------===//
/// Find or create an external function declaration in the given module.
-LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc,
OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type);
/// Create a global that contains the given string. If a global with the same
/// string already exists in the module, return that global.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
- gpu::GPUModuleOp moduleOp, Type llvmI8,
+ Operation *moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment = 0,
unsigned addrSpace = 0);
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
index 2dc6a5ab2a86c..e23e4eb5ed3f2 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl='runtime=HIP' -split-input-file | FileCheck %s
+// CHECK-LABEL: gpu.module @test_module
gpu.module @test_module {
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
|
@llvm/pr-subscribers-mlir-gpu Author: Krzysztof Drewniak (krzysz00) ChangesIn order to make the gpu.printf => [various LLVM calls] passes less order-dependent and to allow downstreams that don't use gpu.module to use gpu.printf, allow the lowerigs for such prints to target the neraest Full diff: https://github.com/llvm/llvm-project/pull/161266.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a73afbcb6474b..78bdbbfc61836 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -20,20 +20,20 @@
using namespace mlir;
-LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
- Location loc, OpBuilder &b,
- StringRef name,
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
+ OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
- if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ if (!(ret = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name)))) {
OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
return ret;
}
-static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+static SmallString<16> getUniqueSymbolName(Operation *moduleOp,
StringRef prefix) {
// Get a unique global name.
unsigned stringNumber = 0;
@@ -41,15 +41,16 @@ static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
do {
stringConstName.clear();
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
+ } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
return stringConstName;
}
-LLVM::GlobalOp
-mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
- gpu::GPUModuleOp moduleOp, Type llvmI8,
- StringRef namePrefix, StringRef str,
- uint64_t alignment, unsigned addrSpace) {
+LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+ Operation *moduleOp, Type llvmI8,
+ StringRef namePrefix,
+ StringRef str,
+ uint64_t alignment,
+ unsigned addrSpace) {
llvm::SmallString<20> nullTermStr(str);
nullTermStr.push_back('\0'); // Null terminate for C
auto globalType =
@@ -57,7 +58,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
StringAttr attr = b.getStringAttr(nullTermStr);
// Try to find existing global.
- for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
@@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
return LLVM::GlobalOp::create(b, loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
@@ -398,8 +399,15 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
- // the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // the device code, not the host code.
+ Operation *moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // However, if the `gpu.module` is already lowered or for compilers that don't
+ // use `gpu.module`, fall back to `builtin.module`.
+ if (!moduleOp)
+ moduleOp = gpuPrintfOp->getParentOfType<ModuleOp>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
auto ocklBegin =
getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
@@ -499,7 +507,14 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ Operation *moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // However, if the `gpu.module` is already lowered or for compilers that don't
+ // use `gpu.module`, fall back to `builtin.module`.
+ if (!moduleOp)
+ moduleOp = gpuPrintfOp->getParentOfType<ModuleOp>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
auto printfType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
@@ -544,7 +559,14 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
- auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ Operation *moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+ // However, if the `gpu.module` is already lowered or for compilers that don't
+ // use `gpu.module`, fall back to `builtin.module`.
+ if (!moduleOp)
+ moduleOp = gpuPrintfOp->getParentOfType<ModuleOp>();
+ if (!moduleOp)
+ return rewriter.notifyMatchFailure(gpuPrintfOp,
+ "Couldn't find a parent module");
// Create a valid global location removing any metadata attached to the
// location as debug info metadata inside of a function cannot be used outside
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e17b06379988c..5eceb96e5234b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -19,14 +19,14 @@ namespace mlir {
//===----------------------------------------------------------------------===//
/// Find or create an external function declaration in the given module.
-LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc,
OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type);
/// Create a global that contains the given string. If a global with the same
/// string already exists in the module, return that global.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
- gpu::GPUModuleOp moduleOp, Type llvmI8,
+ Operation *moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment = 0,
unsigned addrSpace = 0);
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
index 2dc6a5ab2a86c..e23e4eb5ed3f2 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl='runtime=HIP' -split-input-file | FileCheck %s
+// CHECK-LABEL: gpu.module @test_module
gpu.module @test_module {
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
|
In order to make the gpu.printf => [various LLVM calls] passes less order-dependent and to allow downstreams that don't use gpu.module to use gpu.printf, allow the flowerings for such prints to target the nearest `SymbolTable` instead.
In order to make the gpu.printf => [various LLVM calls] passes less order-dependent and to allow downstreams that don't use gpu.module to use gpu.printf, allow the flowerings for such prints to target the nearest
SymbolTable
instead.