Skip to content

Commit

Permalink
ConvertLaunchFuncToCudaCalls: use LLVM dialect globals
Browse files Browse the repository at this point in the history
This conversion has been using a stack-allocated array of i8 to store the
null-terminated kernel name in order to pass it to the CUDA wrappers expecting
a C string because the LLVM dialect was missing support for globals.  Now that
the suport is introduced, use a global instead.

Refactor global string construction from GenerateCubinAccessors into a common
utility function living in the LLVM namespace.

PiperOrigin-RevId: 264382489
  • Loading branch information
ftynse authored and tensorflower-gardener committed Aug 20, 2019
1 parent 0d82a29 commit 006fcce
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 61 deletions.
11 changes: 10 additions & 1 deletion mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
Expand Up @@ -17,15 +17,23 @@
#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_

#include "mlir/Support/LLVM.h"
#include <functional>
#include <memory>
#include <string>
#include <vector>

namespace mlir {

class ModulePassBase;
class FuncOp;
class Location;
class ModulePassBase;
class OpBuilder;
class Value;

namespace LLVM {
class LLVMDialect;
}

using OwnedCubin = std::unique_ptr<std::vector<char>>;
using CubinGenerator = std::function<OwnedCubin(const std::string &, FuncOp &)>;
Expand Down Expand Up @@ -53,6 +61,7 @@ std::unique_ptr<ModulePassBase> createConvertGpuLaunchFuncToCudaCallsPass();
/// Creates a pass to augment a module with getter functions for all contained
/// cubins as encoded via the 'nvvm.cubin' attribute.
std::unique_ptr<ModulePassBase> createGenerateCubinAccessorPass();

} // namespace mlir

#endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Expand Up @@ -179,6 +179,13 @@ class LLVMDialect : public Dialect {
std::unique_ptr<detail::LLVMDialectImpl> impl;
};

/// Create an LLVM global containing the string "value" at the module containing
/// surrounding the insertion point of builder. Obtain the address of that
/// global and use it to compute the address of the first character in the
/// string (operations inserted at the builder insertion point).
Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
StringRef value, LLVM::LLVMDialect *llvmDialect);

} // end namespace LLVM
} // end namespace mlir

Expand Down
54 changes: 20 additions & 34 deletions mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
Expand Up @@ -39,6 +39,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

Expand Down Expand Up @@ -253,43 +254,28 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
return array;
}

// Generates LLVM IR that produces a value representing the name of the
// given kernel function. The generated IR consists essentially of the
// following:
// Generates an LLVM IR dialect global that contains the name of the given
// kernel function as a C string, and returns a pointer to its beginning.
// The code is essentially:
//
// %0 = alloca(strlen(name) + 1)
// %0[0] = constant name[0]
// ...
// %0[n] = constant name[n]
// %0[n+1] = 0
// llvm.global constant @kernel_name("function_name\00")
// func(...) {
// %0 = llvm.addressof @kernel_name
// %1 = llvm.constant (0 : index)
// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
// }
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
FuncOp kernelFunction, Location &loc, OpBuilder &builder) {
// TODO(herhut): Make this a constant once this is supported.
auto kernelNameSize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
auto kernelName = builder.create<LLVM::AllocaOp>(
loc, getPointerType(), kernelNameSize, /*alignment=*/1);
for (auto byte : llvm::enumerate(kernelFunction.getName())) {
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
ArrayRef<Value *>{index});
auto value = builder.create<LLVM::ConstantOp>(
loc, getInt8Type(),
builder.getIntegerAttr(builder.getIntegerType(8), byte.value()));
builder.create<LLVM::StoreOp>(loc, value, gep);
}
// Add trailing zero to terminate string.
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(kernelFunction.getName().size()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
ArrayRef<Value *>{index});
auto value = builder.create<LLVM::ConstantOp>(
loc, getInt8Type(), builder.getIntegerAttr(builder.getIntegerType(8), 0));
builder.create<LLVM::StoreOp>(loc, value, gep);
return kernelName;
// Make sure the trailing zero is included in the constant.
std::vector<char> kernelName(kernelFunction.getName().begin(),
kernelFunction.getName().end());
kernelName.push_back('\0');

std::string globalName =
llvm::formatv("{0}_kernel_name", kernelFunction.getName());
return LLVM::createGlobalString(
loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
llvmDialect);
}

// Emits LLVM IR to launch a kernel function. Expects the module that contains
Expand Down
27 changes: 9 additions & 18 deletions mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
Expand Up @@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -63,35 +64,25 @@ class GpuGenerateCubinAccessorsPass
auto module = orig.getParentOfType<ModuleOp>();
assert(module && "function must belong to a module");

// Create a global at the top of the module.
OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin());
auto type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(llvmDialect), blob.getValue().size());
nameBuffer.append(kCubinStorageSuffix);
auto cubinGlobalString = moduleBuilder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, StringRef(nameBuffer), blob);

// Insert the getter function just after the original function.
OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin());
moduleBuilder.setInsertionPoint(orig.getOperation()->getNextNode());
auto getterType = moduleBuilder.getFunctionType(
llvm::None, LLVM::LLVMType::getInt8PtrTy(llvmDialect));
// Drop the storage suffix before appending the getter suffix.
nameBuffer.resize(orig.getName().size());
nameBuffer.append(kCubinGetterSuffix);
auto result = moduleBuilder.create<FuncOp>(
loc, StringRef(nameBuffer), getterType, ArrayRef<NamedAttribute>());
Block *entryBlock = result.addEntryBlock();

// Drop the getter suffix before appending the storage suffix.
nameBuffer.resize(orig.getName().size());
nameBuffer.append(kCubinStorageSuffix);

// Obtain the address of the first character of the global string containing
// the cubin and return from the getter (addressof will return [? x i8]*).
// the cubin and return from the getter.
OpBuilder builder(entryBlock);
Value *cubinGlobalStringPtr =
builder.create<LLVM::AddressOfOp>(loc, cubinGlobalString);
Value *cst0 = builder.create<LLVM::ConstantOp>(
loc, getIndexType(), builder.getIntegerAttr(builder.getIndexType(), 0));
Value *startPtr = builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), cubinGlobalStringPtr,
ArrayRef<Value *>({cst0, cst0}));
Value *startPtr = LLVM::createGlobalString(
loc, builder, StringRef(nameBuffer), blob.getValue(), llvmDialect);
builder.create<LLVM::ReturnOp>(loc, startPtr);

// Store the name of the getter on the function for easier lookup.
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Expand Up @@ -1397,3 +1397,34 @@ LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
return dialect->impl->voidTy;
}

//===----------------------------------------------------------------------===//
// Utility functions.
//===----------------------------------------------------------------------===//

Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
LLVM::LLVMDialect *llvmDialect) {
assert(builder.getInsertionBlock() &&
builder.getInsertionBlock()->getParentOp() &&
"expected builder to point to a block constained in an op");
auto module =
builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
assert(module && "builder points to an op outside of a module");

// Create the global at the entry of the module.
OpBuilder moduleBuilder(module.getBodyRegion());
auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
value.size());
auto global = moduleBuilder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, name, builder.getStringAttr(value));

// Get the pointer to the first character in the global string.
Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value *cst0 = builder.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
ArrayRef<Value *>({cst0, cst0}));
}
18 changes: 10 additions & 8 deletions mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s --launch-func-to-cuda | FileCheck %s

// CHECK: llvm.global constant @[[kernel_name:.*]]("kernel\00")

func @cubin_getter() -> !llvm<"i8*">

func @kernel(!llvm.float, !llvm<"float*">)
Expand All @@ -11,15 +13,15 @@ func @foo() {
%1 = "op"() : () -> (!llvm<"float*">)
%cst = constant 8 : index

// CHECK: %5 = llvm.alloca %4 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
// CHECK: %6 = llvm.call @mcuModuleLoad(%5, %3) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
// CHECK: %32 = llvm.alloca %31 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
// CHECK: %33 = llvm.call @mcuModuleGetFunction(%32, %7, %9) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
// CHECK: %34 = llvm.call @mcuGetStreamHelper() : () -> !llvm<"i8*">
// CHECK: %48 = llvm.call @mcuLaunchKernel(%35, %c8, %c8, %c8, %c8, %c8, %c8, %2, %34, %38, %47) : (!llvm<"i8*">, index, index, index, index, index, index, !llvm.i32, !llvm<"i8*">, !llvm<"i8**">, !llvm<"i8**">) -> !llvm.i32
// CHECK: %49 = llvm.call @mcuStreamSynchronize(%34) : (!llvm<"i8*">) -> !llvm.i32
// CHECK: [[module_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
// CHECK: llvm.call @mcuModuleLoad([[module_ptr]], {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
// CHECK: [[func_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
// CHECK: llvm.call @mcuModuleGetFunction([[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
// CHECK: llvm.call @mcuGetStreamHelper
// CHECK: llvm.call @mcuLaunchKernel
// CHECK: llvm.call @mcuStreamSynchronize
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel }
: (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()

return
}
}

0 comments on commit 006fcce

Please sign in to comment.