Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 33 additions & 31 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,45 @@

using namespace mlir;

LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
Location loc, OpBuilder &b,
StringRef name,
LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
return ret;
auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, name));
if (existing)
return existing;

OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}

static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
static SmallString<16> getUniqueSymbolName(Operation *moduleOp,
StringRef prefix) {
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
} while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
return stringConstName;
}

LLVM::GlobalOp
mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment, unsigned addrSpace) {
LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
Operation *moduleOp, Type llvmI8,
StringRef namePrefix,
StringRef str,
uint64_t alignment,
unsigned addrSpace) {
llvm::SmallString<20> nullTermStr(str);
nullTermStr.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
StringAttr attr = b.getStringAttr(nullTermStr);

// Try to find existing global.
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
Expand All @@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,

// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
return LLVM::GlobalOp::create(b, loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
Expand Down Expand Up @@ -396,10 +397,11 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();

Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
if (!moduleOp)
return rewriter.notifyMatchFailure(gpuPrintfOp,
"Couldn't find a parent module");

auto ocklBegin =
getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
Expand Down Expand Up @@ -496,10 +498,10 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
mlir::Type ptrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);

// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
if (!moduleOp)
return rewriter.notifyMatchFailure(gpuPrintfOp,
"Couldn't find a parent module");

auto printfType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
Expand Down Expand Up @@ -541,10 +543,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());

// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
if (!moduleOp)
return rewriter.notifyMatchFailure(gpuPrintfOp,
"Couldn't find a parent module");

// Create a valid global location removing any metadata attached to the
// location as debug info metadata inside of a function cannot be used outside
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ namespace mlir {
// Helper Functions
//===----------------------------------------------------------------------===//

/// Note that these functions don't take a `SymbolTable` because GPU module
/// lowerings can have name collisions as an intermediate state.

/// Find or create an external function declaration in the given module.
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc,
OpBuilder &b, StringRef name,
LLVM::LLVMFunctionType type);

/// Create a global that contains the given string. If a global with the same
/// string already exists in the module, return that global.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
gpu::GPUModuleOp moduleOp, Type llvmI8,
Operation *moduleOp, Type llvmI8,
StringRef namePrefix, StringRef str,
uint64_t alignment = 0,
unsigned addrSpace = 0);
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl='runtime=HIP' -split-input-file | FileCheck %s

// CHECK-LABEL: gpu.module @test_module
gpu.module @test_module {
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
Expand Down Expand Up @@ -40,3 +41,38 @@ gpu.module @test_module {
gpu.return
}
}

// -----

// The bulitin.module we're targetting is wrapped in a fake gpu.module
// because the convert-gpu-to-rocdl pass only runs an `gpu.module` ops,
// even though the printf patterns could run in other contexts.

// CHECK-LABEL: gpu.module @fake_gpu_module_for_test
// CHECK-LABEL: builtin.module @test_module
gpu.module @fake_gpu_module_for_test {
builtin.module @test_module {
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
// CHECK-DAG: llvm.func @__ockl_printf_append_args(i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64
// CHECK-DAG: llvm.func @__ockl_printf_append_string_n(i64, !llvm.ptr, i64, i32) -> i64
// CHECK-DAG: llvm.func @__ockl_printf_begin(i64) -> i64

// CHECK-LABEL: llvm.func @test_printf
// CHECK: (%[[ARG0:.*]]: i32)
llvm.func @test_printf(%arg0: i32) {
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8>
// CHECK-NEXT: %[[FORMATLEN:.*]] = llvm.mlir.constant(11 : i64) : i64
// CHECK-NEXT: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[DESC1:.*]] = llvm.call @__ockl_printf_append_string_n(%[[DESC0]], %[[FORMATSTART]], %[[FORMATLEN]], %[[ISNTLAST]]) : (i64, !llvm.ptr, i64, i32) -> i64
// CHECK-NEXT: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ARG0_64:.*]] = llvm.zext %[[ARG0]] : i32 to i64
// CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[ISLAST]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64
gpu.printf "Hello: %d\n", %arg0 : i32
llvm.return
}
}
}