Skip to content

Commit

Permalink
[MLIR][mlir-spirv-cpu-runner] A pass to emulate a call to kernel in LLVM
Browse files Browse the repository at this point in the history
This patch introduces a pass for running
`mlir-spirv-cpu-runner` - LowerHostCodeToLLVMPass.

This pass emulates `gpu.launch_func` call in LLVM dialect and lowers
the host module code to LLVM. It removes the `gpu.module`, creates a
sequence of global variables that are later linked to the varables
in the kernel module, as well as a series of copies to/from
them to emulate the memory transfer to/from the host or to/from the
device sides. It also converts the remaining Standard dialect into
LLVM dialect, emitting C wrappers.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86112
  • Loading branch information
georgemitenkov authored and antiagainst committed Oct 26, 2020
1 parent efa9aaa commit cae4067
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
];
}

def LowerHostCodeToLLVM : Pass<"lower-host-to-llvm", "ModuleOp"> {
let summary = "Lowers the host module code and `gpu.launch_func` to LLVM";
let constructor = "mlir::createLowerHostCodeToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
}

//===----------------------------------------------------------------------===//
// GPUToNVVM
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class ModuleOp;
template <typename T>
class OperationPass;

/// Creates a pass to emulate `gpu.launch_func` call in LLVM dialect and lower
/// the host module code to LLVM.
///
/// This transformation creates a sequence of global variables that are later
/// linked to the varables in the kernel module, and a series of copies to/from
/// them to emulate the memory transfer from the host or to the device sides. It
/// also converts the remaining Standard dialect into LLVM dialect, emitting C
/// wrappers.
std::unique_ptr<OperationPass<ModuleOp>> createLowerHostCodeToLLVMPass();

/// Creates a pass to convert SPIR-V operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertSPIRVToLLVMPass();

Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_conversion_library(MLIRSPIRVToLLVM
ConvertLaunchFuncToLLVMCalls.cpp
ConvertSPIRVToLLVM.cpp
ConvertSPIRVToLLVMPass.cpp

Expand All @@ -10,6 +11,7 @@ add_mlir_conversion_library(MLIRSPIRVToLLVM
intrinsics_gen

LINK_LIBS PUBLIC
MLIRGPU
MLIRSPIRV
MLIRLLVMIR
MLIRStandardToLLVM
Expand Down
307 changes: 307 additions & 0 deletions mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements passes to convert `gpu.launch_func` op into a sequence
// of LLVM calls that emulate the host and device sides.
//
//===----------------------------------------------------------------------===//

#include "../PassDetail.h"
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

static constexpr const char kSPIRVModule[] = "__spv__";

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Returns the string name of the `DescriptorSet` decoration.
static std::string descriptorSetName() {
return llvm::convertToSnakeFromCamelCase(
stringifyDecoration(spirv::Decoration::DescriptorSet));
}

/// Returns the string name of the `Binding` decoration.
static std::string bindingName() {
return llvm::convertToSnakeFromCamelCase(
stringifyDecoration(spirv::Decoration::Binding));
}

/// Calculates the index of the kernel's operand that is represented by the
/// given global variable with the `bind` attribute. We assume that the index of
/// each kernel's operand is mapped to (descriptorSet, binding) by the map:
/// i -> (0, i)
/// which is implemented under `LowerABIAttributesPass`.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName());
return binding.getInt();
}

/// Copies the given number of bytes from src to dst pointers.
static void copy(Location loc, Value dst, Value src, Value size,
OpBuilder &builder) {
MLIRContext *context = builder.getContext();
auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context);
Value isVolatile = builder.create<LLVM::ConstantOp>(
loc, llvmI1Type, builder.getBoolAttr(false));
builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
}

/// Encodes the binding and descriptor set numbers into a new symbolic name.
/// The name is specified by
/// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
/// binding numbers.
static std::string
createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
StringRef kernelModuleName) {
IntegerAttr descriptorSet =
op.getAttrOfType<IntegerAttr>(descriptorSetName());
IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName());
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
kernelModuleName.str(), op.sym_name().str(),
std::to_string(descriptorSet.getInt()),
std::to_string(binding.getInt()));
}

/// Returns true if the given global variable has both a descriptor set number
/// and a binding number.
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
IntegerAttr descriptorSet =
op.getAttrOfType<IntegerAttr>(descriptorSetName());
IntegerAttr binding = op.getAttrOfType<IntegerAttr>(bindingName());
return descriptorSet && binding;
}

/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
/// arguments from the given SPIR-V module. We assume that the module contains a
/// single entry point function. Hence, all `spv.globalVariable`s with a bind
/// attribute are kernel arguments.
static LogicalResult getKernelGlobalVariables(
spirv::ModuleOp module,
DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
auto entryPoints = module.getOps<spirv::EntryPointOp>();
if (!llvm::hasSingleElement(entryPoints)) {
return module.emitError(
"The module must contain exactly one entry point function");
}
auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
for (auto globalOp : globalVariables) {
if (hasDescriptorSetAndBinding(globalOp))
globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
}
return success();
}

/// Encodes the SPIR-V module's symbolic name into the name of the entry point
/// function.
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
StringRef spvModuleName = module.sym_name().getValue();
// We already know that the module contains exactly one entry point function
// based on `getKernelGlobalVariables()` call. Update this function's name
// to:
// {spv_module_name}_{function_name}
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
StringRef funcName = entryPoint.fn();
auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName);
std::string newFuncName = spvModuleName.str() + "_" + funcName.str();
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
return failure();
SymbolTable::setSymbolName(funcOp, newFuncName);
return success();
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//

namespace {

/// Structure to group information about the variables being copied.
struct CopyInfo {
Value dst;
Value src;
Value size;
};

/// This pattern emulates a call to the kernel in LLVM dialect. For that, we
/// copy the data to the global variable (emulating device side), call the
/// kernel as a normal void LLVM function, and copy the data back (emulating the
/// host side).
class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
gpu::LaunchFuncOp launchOp = cast<gpu::LaunchFuncOp>(op);
MLIRContext *context = rewriter.getContext();
auto module = launchOp.getParentOfType<ModuleOp>();

// Get the SPIR-V module that represents the gpu kernel module. The module
// is named:
// __spv__{kernel_module_name}
// based on GPU to SPIR-V conversion.
StringRef kernelModuleName = launchOp.getKernelModuleName();
std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName);
if (!spvModule) {
return launchOp.emitOpError("SPIR-V kernel module '")
<< spvModuleName << "' is not found";
}

// Declare kernel function in the main module so that it later can be linked
// with its definition from the kernel module. We know that the kernel
// function would have no arguments and the data is passed via global
// variables. The name of the kernel will be
// {spv_module_name}_{kernel_function_name}
// to avoid symbolic name conflicts.
StringRef kernelFuncName = launchOp.getKernelName();
std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName);
if (!kernelFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), newKernelFuncName,
LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context),
ArrayRef<LLVM::LLVMType>(),
/*isVarArg=*/false));
rewriter.setInsertionPoint(launchOp);
}

// Get all global variables associated with the kernel operands.
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
return failure();

// Traverse kernel operands that were converted to MemRefDescriptors. For
// each operand, create a global variable and copy data from operand to it.
Location loc = launchOp.getLoc();
SmallVector<CopyInfo, 4> copyInfo;
auto numKernelOperands = launchOp.getNumKernelOperands();
auto kernelOperands = operands.take_back(numKernelOperands);
for (auto operand : llvm::enumerate(kernelOperands)) {
// Check if the kernel's opernad is a ranked memref.
auto memRefType = launchOp.getKernelOperand(operand.index())
.getType()
.dyn_cast<MemRefType>();
if (!memRefType)
return failure();

// Calculate the size of the memref and get the pointer to the allocated
// buffer.
SmallVector<Value, 4> sizes;
getMemRefDescriptorSizes(loc, memRefType, operand.value(), rewriter,
sizes);
Value size = getCumulativeSizeInBytes(loc, memRefType.getElementType(),
sizes, rewriter);
MemRefDescriptor descriptor(operand.value());
Value src = descriptor.allocatedPtr(rewriter, loc);

// Get the global variable in the SPIR-V module that is associated with
// the kernel operand. Construct its new name and create a corresponding
// LLVM dialect global variable.
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
auto pointeeType =
spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
auto dstGlobalType = typeConverter.convertType(pointeeType);
if (!dstGlobalType)
return failure();
std::string name =
createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
// Check if this variable has already been created.
auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
if (!dstGlobal) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
dstGlobal = rewriter.create<LLVM::GlobalOp>(
loc, dstGlobalType.cast<LLVM::LLVMType>(),
/*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute());
rewriter.setInsertionPoint(launchOp);
}

// Copy the data from src operand pointer to dst global variable. Save
// src, dst and size so that we can copy data back after emulating the
// kernel call.
Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
copy(loc, dst, src, size, rewriter);

CopyInfo info;
info.dst = dst;
info.src = src;
info.size = size;
copyInfo.push_back(info);
}
// Create a call to the kernel and copy the data back.
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
ArrayRef<Value>());
for (CopyInfo info : copyInfo)
copy(loc, info.src, info.dst, info.size, rewriter);
return success();
}
};

class LowerHostCodeToLLVM
: public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
public:
void runOnOperation() override {
ModuleOp module = getOperation();

// Erase the GPU module.
for (auto gpuModule :
llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
gpuModule.erase();

// Specify options to lower Standard to LLVM and pull in the conversion
// patterns.
LowerToLLVMOptions options = {
/*useBarePtrCallConv=*/false,
/*emitCWrappers=*/true,
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
auto *context = module.getContext();
OwningRewritePatternList patterns;
LLVMTypeConverter typeConverter(context, options);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
patterns.insert<GPULaunchLowering>(typeConverter);

// Pull in SPIR-V type conversion patterns to convert SPIR-V global
// variable's type to LLVM dialect type.
populateSPIRVToLLVMTypeConversion(typeConverter);

ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyPartialConversion(module, target, patterns)))
signalPassFailure();

// Finally, modify the kernel function in SPIR-V modules to avoid symbolic
// conflicts.
for (auto spvModule : module.getOps<spirv::ModuleOp>())
encodeKernelName(spvModule);
}
};
} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::createLowerHostCodeToLLVMPass() {
return std::make_unique<LowerHostCodeToLLVM>();
}
Loading

0 comments on commit cae4067

Please sign in to comment.