Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows the user to specify a module ID to emit unique functions #1694

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,11 @@ class FrontendGenImpl {
*/
func::FuncOp importGraph(const onnx::GraphProto &graph) {
const std::string &name = "main_graph";
auto mainFunc = func::FuncOp::create(UnknownLoc(), name,
const std::string &moduleId = options_.moduleId;

std::string moduleSuffix = moduleId.empty() ? "" : ("_" + moduleId);

auto mainFunc = func::FuncOp::create(UnknownLoc(), name + moduleSuffix,
/*type=*/builder_.getFunctionType({}, {}), /*attrs=*/{});
module_.push_back(mainFunc);
// Create and set insertion point to entry block.
Expand Down
2 changes: 2 additions & 0 deletions src/Builder/FrontendDialectTransformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ struct ImportOptions {
// Directory to look for external data if any tensor has external
// data location. If empty then external data is disabled.
std::string externalDataDir = "";
// The module ID
std::string moduleId = "";
};

/*!
Expand Down
5 changes: 5 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ llvm::cl::opt<std::string> customEnvFlags("customEnvFlags",
llvm::cl::value_desc("option env var"), llvm::cl::init("ONNX_MLIR_FLAGS"),
llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<std::string> moduleId("moduleId",
llvm::cl::desc(
"Optional module ID that's added as a suffix to all emitted functions"),
llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<std::string> mtriple("mtriple",
llvm::cl::desc("Override target triple for module"),
llvm::cl::value_desc("LLVM target triple"), llvm::cl::cat(OnnxMlirOptions),
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ extern llvm::cl::opt<bool> preserveMLIR;
extern llvm::cl::opt<bool> useOnnxModelTypes;
extern llvm::cl::opt<int> repeatOnnxTransform;
extern llvm::cl::opt<std::string> shapeInformation;
extern llvm::cl::opt<std::string> moduleId;
extern llvm::cl::opt<onnx_mlir::OptLevel> OptimizationLevel;
extern llvm::cl::opt<std::string> customEnvFlags;
extern llvm::cl::opt<std::string> mtriple;
Expand Down
27 changes: 19 additions & 8 deletions src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Support/FileUtilities.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/MC/TargetRegistry.h"
Expand All @@ -25,6 +26,7 @@
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Target/TargetMachine.h"
#include <string>

#include "ExternalUtil.hpp"

Expand Down Expand Up @@ -297,22 +299,24 @@ static void tailorLLVMIR(llvm::Module &llvmModule) {
llvm::MDString::get(ctx, PRODUCT_ID));
#endif

// Annotate functions to be accessible from DLL on Windows.
std::string moduleSuffix = moduleId.empty() ? "" : ("_" + moduleId);

// Annotate functions to be accessible from DLL on Windows.
#ifdef _WIN32
SmallVector<StringRef, 4> exportedFuncs;
SmallVector<std::string, 4> exportedFuncs;
// Signature functions.
exportedFuncs.emplace_back(StringRef("omInputSignature"));
exportedFuncs.emplace_back(StringRef("omOutputSignature"));
exportedFuncs.emplace_back(StringRef("omQueryEntryPoints"));
exportedFuncs.emplace_back("omInputSignature" + moduleSuffix);
exportedFuncs.emplace_back("omOutputSignature" + moduleSuffix);
exportedFuncs.emplace_back("omQueryEntryPoints" + moduleSuffix);
// Entry point funtions.
if (llvm::GlobalVariable *GV =
llvmModule.getNamedGlobal(StringRef("_entry_point_arrays"))) {
llvmModule.getNamedGlobal("_entry_point_arrays" + moduleSuffix)) {
if (GV->isConstant() && GV->hasDefinitiveInitializer()) {
llvm::Constant *initializer = GV->getInitializer();
llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(initializer->getType());
for (uint64_t i = 0; i < AT->getNumElements() - 1; ++i) {
llvm::GlobalVariable *entryGV = llvmModule.getNamedGlobal(
StringRef("_entry_point_" + std::to_string(i)));
"_entry_point_" + std::to_string(i) + moduleSuffix);
if (entryGV->isConstant()) {
llvm::ConstantDataSequential *entry =
dyn_cast<llvm::ConstantDataSequential>(entryGV->getInitializer());
Expand All @@ -321,11 +325,12 @@ static void tailorLLVMIR(llvm::Module &llvmModule) {
}
}
}
for (StringRef funcName : exportedFuncs)
for (StringRef funcName : exportedFuncs) {
if (llvm::GlobalValue *GV = llvmModule.getNamedValue(funcName)) {
GV->setDSOLocal(true);
GV->setDLLStorageClass(llvm::GlobalValue::DLLExportStorageClass);
}
}
#endif
}

Expand Down Expand Up @@ -642,6 +647,7 @@ int processInputFile(StringRef inputFilename, mlir::MLIRContext &context,
options.invokeOnnxVersionConverter = invokeOnnxVersionConverter;
options.shapeInformation = shapeInformation;
options.externalDataDir = dirName(inputFilename);
options.moduleId = moduleId;
return ImportFrontendModelFile(
inputFilename, context, module, errorMessage, options);
} else if (inputIsMLIR)
Expand All @@ -657,6 +663,7 @@ int processInputArray(const void *onnxBuffer, int bufferSize,
options.useOnnxModelTypes = useOnnxModelTypes;
options.invokeOnnxVersionConverter = invokeOnnxVersionConverter;
options.shapeInformation = shapeInformation;
options.moduleId = moduleId;
return ImportFrontendModelArray(
onnxBuffer, bufferSize, context, module, errorMessage, options);
}
Expand Down Expand Up @@ -838,6 +845,10 @@ static int setupModule(mlir::OwningOpRef<ModuleOp> &module,
moduleOp.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
StringAttr::get(&context, getDataLayout(loc)));

if (!moduleId.empty()) {
moduleOp.setAttr("onnx-mlir.moduleId", StringAttr::get(&context, moduleId));
}

// Set the module target accelerators.
SmallVector<Attribute, 2> accelsAttr;
for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) {
Expand Down
18 changes: 14 additions & 4 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
#include <string>

#include "onnx/onnx_pb.h"

Expand Down Expand Up @@ -206,6 +209,12 @@ void genSignatureFunction(ModuleOp &module,
OpBuilder b(context);
MultiDialectBuilder<LLVMBuilder> create(b, loc);

std::string moduleSuffix;
if (auto moduleIdAttr =
module->getAttrOfType<StringAttr>("onnx-mlir.moduleId")) {
moduleSuffix = "_" + moduleIdAttr.str();
}

// Common information.
Type i8Type = IntegerType::get(context, 8);
Type i32Type = IntegerType::get(context, 32);
Expand All @@ -222,8 +231,8 @@ void genSignatureFunction(ModuleOp &module,
b.setInsertionPointToEnd(module.getBody());
auto arrayType = LLVM::LLVMArrayType::get(i8PtrTy, entryGlobalOps.size() + 1);
LLVM::GlobalOp entryArrayOp = create.llvm.globalOp(arrayType,
/*isConstant=*/true, LLVM::Linkage::Internal, "_entry_point_arrays",
Attribute());
/*isConstant=*/true, LLVM::Linkage::Internal,
"_entry_point_arrays" + moduleSuffix, Attribute());
{ // Fill the initializer with pointers to entry point constants.
Region &region = entryArrayOp.getInitializerRegion();
Block *block = b.createBlock(&region);
Expand Down Expand Up @@ -259,7 +268,7 @@ void genSignatureFunction(ModuleOp &module,
Type llvmFnType =
LLVM::LLVMFunctionType::get(i8PtrPtrTy, {i64PtrTy}, false);
LLVM::LLVMFuncOp funcOp =
create.llvm.func("omQueryEntryPoints", llvmFnType);
create.llvm.func("omQueryEntryPoints" + moduleSuffix, llvmFnType);
// Emit the body of the function.
Block *entryBlock = funcOp.addEntryBlock();
OpBuilder::InsertionGuard bodyGuard(b);
Expand Down Expand Up @@ -296,7 +305,8 @@ void genSignatureFunction(ModuleOp &module,
b.setInsertionPointToEnd(module.getBody());
// 1. Emit the function type.
Type llvmFnType = LLVM::LLVMFunctionType::get(i8PtrTy, {i8PtrTy}, false);
LLVM::LLVMFuncOp funcOp = create.llvm.func(funcNames[i], llvmFnType);
LLVM::LLVMFuncOp funcOp =
create.llvm.func(funcNames[i] + moduleSuffix, llvmFnType);

// 2. Emit the body of the function.
Block *entryBlock = funcOp.addEntryBlock();
Expand Down
12 changes: 9 additions & 3 deletions src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,12 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
// Common information.
Type i8Type = IntegerType::get(context, 8);

std::string moduleSuffix;
if (auto moduleIdAttr =
module->getAttrOfType<StringAttr>("onnx-mlir.moduleId")) {
moduleSuffix = "_" + moduleIdAttr.str();
}

// A helper function to emit a global constant operation storing a string.
auto emitGlobalOp = [&context, &i8Type, &create](
std::string name, std::string value) {
Expand Down Expand Up @@ -568,19 +574,19 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
b.setInsertionPointToStart(module.getBody());
// Global constants for entry point names.
std::string entryVarName =
"_entry_point_" + std::to_string(KRNL_ENTRY_POINT_ID);
"_entry_point_" + std::to_string(KRNL_ENTRY_POINT_ID) + moduleSuffix;
KRNL_ENTRY_POINT_ID++;
LLVM::GlobalOp entryGlobalOp =
emitGlobalOp(entryVarName, terminatedEntryPointName);
entryGlobalOps.emplace_back(entryGlobalOp);

// Global constants for input signatures.
std::string inSigVarName = entryVarName + "_in_sig";
std::string inSigVarName = entryVarName + "_in_sig" + moduleSuffix;
LLVM::GlobalOp inSigGlobalOp = emitGlobalOp(inSigVarName, inSignature);
inSigGlobalOps.emplace_back(inSigGlobalOp);

// Global constants for output signatures.
std::string outSigVarName = entryVarName + "_out_sig";
std::string outSigVarName = entryVarName + "_out_sig" + moduleSuffix;
LLVM::GlobalOp outSigGlobalOp = emitGlobalOp(outSigVarName, outSignature);
outSigGlobalOps.emplace_back(outSigGlobalOp);
}
Expand Down