Skip to content

Commit

Permalink
[mlir] JitRunner: add a config option to register symbols with Execut…
Browse files Browse the repository at this point in the history
…ionEngine at runtime

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D90264
  • Loading branch information
ezhulenev committed Oct 27, 2020
1 parent 50dfa19 commit f6c9f6e
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 55 deletions.
43 changes: 28 additions & 15 deletions mlir/include/mlir/ExecutionEngine/JitRunner.h
Expand Up @@ -18,29 +18,42 @@
#ifndef MLIR_SUPPORT_JITRUNNER_H_
#define MLIR_SUPPORT_JITRUNNER_H_

#include "mlir/IR/Module.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Module.h"
#include "llvm/ExecutionEngine/Orc/Core.h"

namespace mlir {
namespace llvm {
class Module;
class LLVMContext;

using TranslationCallback = llvm::function_ref<std::unique_ptr<llvm::Module>(
ModuleOp, llvm::LLVMContext &)>;
namespace orc {
class MangleAndInterner;
} // namespace orc
} // namespace llvm

namespace mlir {

class ModuleOp;
struct LogicalResult;

struct JitRunnerConfig {
/// MLIR transformer applied after parsing the input into MLIR IR and before
/// passing the MLIR module to the ExecutionEngine.
llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;

/// A custom function that is passed to ExecutionEngine. It processes MLIR
/// module and creates LLVM IR module.
llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
llvm::LLVMContext &)>
llvmModuleBuilder = nullptr;

/// A callback to register symbols with ExecutionEngine at runtime.
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
runtimesymbolMap = nullptr;
};

// Entry point for all CPU runners. Expects the common argc/argv arguments for
// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
/// passing the MLIR module to the ExecutionEngine.
/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
/// It processes MLIR module and creates LLVM IR module.
int JitRunnerMain(
int argc, char **argv,
llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
TranslationCallback llvmModuleBuilder = nullptr);
// standard C++ main functions.
int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {});

} // namespace mlir

Expand Down
85 changes: 49 additions & 36 deletions mlir/lib/ExecutionEngine/JitRunner.cpp
Expand Up @@ -92,6 +92,23 @@ struct Options {
"object-filename",
llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
};

struct CompileAndExecuteConfig {
/// LLVM module transformer that is passed to ExecutionEngine.
llvm::function_ref<llvm::Error(llvm::Module *)> transformer;

/// A custom function that is passed to ExecutionEngine. It processes MLIR
/// module and creates LLVM IR module.
llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
llvm::LLVMContext &)>
llvmModuleBuilder;

/// A custom function that is passed to ExecutinEngine to register symbols at
/// runtime.
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
runtimeSymbolMap;
};

} // end anonymous namespace

static OwningModuleRef parseMLIRInput(StringRef inputFilename,
Expand Down Expand Up @@ -131,23 +148,25 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
}

// JIT-compile the given module and run "entryPoint" with "args" as arguments.
static Error
compileAndExecute(Options &options, ModuleOp module,
TranslationCallback llvmModuleBuilder, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer,
void **args) {
static Error compileAndExecute(Options &options, ModuleOp module,
StringRef entryPoint,
CompileAndExecuteConfig config, void **args) {
Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
if (auto clOptLevel = getCommandLineOptLevel(options))
jitCodeGenOptLevel =
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
options.clSharedLibs.end());
auto expectedEngine = mlir::ExecutionEngine::create(
module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
libs);
if (!expectedEngine)
return expectedEngine.takeError();

auto engine = std::move(*expectedEngine);
if (config.runtimeSymbolMap)
engine->registerSymbols(config.runtimeSymbolMap);

auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
Expand All @@ -163,16 +182,14 @@ compileAndExecute(Options &options, ModuleOp module,
return Error::success();
}

static Error compileAndExecuteVoidFunction(
Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
StringRef entryPoint,
CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.empty())
return make_string_error("entry point not found");
void *empty = nullptr;
return compileAndExecute(options, module, llvmModuleBuilder, entryPoint,
transformer, &empty);
return compileAndExecute(options, module, entryPoint, config, &empty);
}

template <typename Type>
Expand All @@ -196,10 +213,9 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
return Error::success();
}
template <typename Type>
Error compileAndExecuteSingleReturnFunction(
Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
StringRef entryPoint,
CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal())
return make_string_error("entry point not found");
Expand All @@ -215,8 +231,8 @@ Error compileAndExecuteSingleReturnFunction(
void *data;
} data;
data.data = &res;
if (auto error = compileAndExecute(options, module, llvmModuleBuilder,
entryPoint, transformer, (void **)&data))
if (auto error = compileAndExecute(options, module, entryPoint, config,
(void **)&data))
return error;

// Intentional printing of the output so we can test.
Expand All @@ -226,15 +242,8 @@ Error compileAndExecuteSingleReturnFunction(
}

/// Entry point for all CPU runners. Expects the common argc/argv arguments for
/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
/// passing the MLIR module to the ExecutionEngine.
/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
/// It processes MLIR module and creates LLVM IR module.
int mlir::JitRunnerMain(
int argc, char **argv,
function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
TranslationCallback llvmModuleBuilder) {
/// standard C++ main functions.
int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
// Create the options struct containing the command line options for the
// runner. This must come before the command line options are parsed.
Options options;
Expand Down Expand Up @@ -274,8 +283,8 @@ int mlir::JitRunnerMain(
return 1;
}

if (mlirTransformer)
if (failed(mlirTransformer(m.get())))
if (config.mlirTransformer)
if (failed(config.mlirTransformer(m.get())))
return EXIT_FAILURE;

auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
Expand All @@ -292,10 +301,14 @@ int mlir::JitRunnerMain(
auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);

CompileAndExecuteConfig compileAndExecuteConfig;
compileAndExecuteConfig.transformer = transformer;
compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;

// Get the function used to compile and execute the module.
using CompileAndExecuteFnT =
Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
std::function<llvm::Error(llvm::Module *)>);
Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
auto compileAndExecuteFn =
StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
.Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
Expand All @@ -304,11 +317,11 @@ int mlir::JitRunnerMain(
.Case("void", compileAndExecuteVoidFunction)
.Default(nullptr);

Error error =
compileAndExecuteFn
? compileAndExecuteFn(options, m.get(), llvmModuleBuilder,
options.mainFuncName.getValue(), transformer)
: make_string_error("unsupported function type");
Error error = compileAndExecuteFn
? compileAndExecuteFn(options, m.get(),
options.mainFuncName.getValue(),
compileAndExecuteConfig)
: make_string_error("unsupported function type");

int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
Expand Up @@ -24,5 +24,5 @@ int main(int argc, char **argv) {
llvm::InitializeNativeTargetAsmPrinter();
mlir::initializeLLVMPasses();

return mlir::JitRunnerMain(argc, argv, nullptr);
return mlir::JitRunnerMain(argc, argv);
}
6 changes: 5 additions & 1 deletion mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
Expand Up @@ -136,5 +136,9 @@ int main(int argc, char **argv) {
LLVMInitializeNVPTXAsmPrinter();

mlir::initializeLLVMPasses();
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);

mlir::JitRunnerConfig jitRunnerConfig;
jitRunnerConfig.mlirTransformer = &runMLIRPasses;

return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
}
6 changes: 5 additions & 1 deletion mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
Expand Up @@ -86,5 +86,9 @@ int main(int argc, char **argv) {
llvm::InitializeNativeTargetAsmPrinter();
mlir::initializeLLVMPasses();

return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule);
mlir::JitRunnerConfig jitRunnerConfig;
jitRunnerConfig.mlirTransformer = &runMLIRPasses;
jitRunnerConfig.llvmModuleBuilder = &convertMLIRModule;

return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
}
5 changes: 4 additions & 1 deletion mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
Expand Up @@ -58,5 +58,8 @@ int main(int argc, char **argv) {
llvm::InitializeNativeTargetAsmPrinter();
mlir::initializeLLVMPasses();

return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
mlir::JitRunnerConfig jitRunnerConfig;
jitRunnerConfig.mlirTransformer = &runMLIRPasses;

return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
}

0 comments on commit f6c9f6e

Please sign in to comment.