Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
13faa33
add cpuruntime dialect
Menooker May 14, 2024
161848e
format
Menooker May 14, 2024
447ef12
add dependency
Menooker May 14, 2024
a73dcc1
fix new MLIR
Menooker May 14, 2024
1cfede8
add
Menooker May 15, 2024
57ba92e
Merge remote-tracking branch 'origin/main' into yijie/cpuruntime
Menooker May 15, 2024
4d25de6
Merge branch 'yijie/cpuruntime' into yijie/pipeline
Menooker May 15, 2024
475faf8
update
Menooker May 15, 2024
0ac087d
fix
Menooker May 15, 2024
74b0d34
remove at exit
Menooker May 16, 2024
2cebba9
fix lint
Menooker May 16, 2024
d1b35a1
Merge branch 'yijie/cpuruntime' into yijie/pipeline
Menooker May 16, 2024
34d10ea
Add kmp_* wrapper for gomp environment
Menooker May 16, 2024
55c1043
Merge remote-tracking branch 'origin' into yijie/pipeline
Menooker May 16, 2024
e1490bb
Merge branch 'yijie/fake_omp' into yijie/pipeline
Menooker May 16, 2024
80a597f
fix
Menooker May 16, 2024
0b4332b
fix
Menooker May 16, 2024
c43f481
Merge branch 'main' into yijie/fake_omp
Menooker May 23, 2024
b1c79a2
add wraper
Menooker May 23, 2024
382171b
fix lint
Menooker May 23, 2024
ef75da8
Merge branch 'yijie/fake_omp' of https://github.com/intel/graph-compi…
Menooker May 23, 2024
f1fd0ae
fix
Menooker May 23, 2024
a773ea6
f
Menooker May 23, 2024
84933c2
fix
Menooker May 23, 2024
4cca4df
add reference
Menooker May 23, 2024
678cef9
enable const cache
Menooker May 24, 2024
c12156c
reduce size
Menooker May 24, 2024
e24b1df
rename
Menooker May 28, 2024
34064f3
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Menooker May 28, 2024
70c5e97
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Menooker May 28, 2024
1e06c98
fix license.py
Menooker May 28, 2024
3f656b7
Merge branch 'yijie/fake_omp' into yijie/pipeline
Menooker May 28, 2024
24cee01
Merge branch 'yijie/pipeline' into yijie/mainfunc_wrapper
Menooker May 28, 2024
7c32bc5
fix
Menooker May 28, 2024
4540fb6
fix lint
Menooker May 28, 2024
381677a
fix comments
Menooker May 28, 2024
fdfc53e
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Menooker May 29, 2024
60042e1
Merge branch 'yijie/pipeline' into yijie/mainfunc_wrapper
Menooker May 29, 2024
b54b310
fix
Menooker May 29, 2024
9d04cd2
format
Menooker May 29, 2024
206c3f3
cleanup
Menooker May 30, 2024
bc9a7ad
refine options
Menooker May 30, 2024
bc5c9de
fmt
Menooker May 30, 2024
60fb17d
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Menooker May 30, 2024
8fda8fa
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Menooker Jun 11, 2024
529c403
rebase
Menooker Jun 11, 2024
58d6639
fix comments
Menooker Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions include/gc/ExecutionEngine/Driver/Driver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//===-- Driver.h - The top-level MLIR compiler driver -----------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H
#define GC_EXECUTIONENGINE_DRIVER_DRIVER_H

#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include <memory>
#include <string_view>

namespace mlir {
class DialectRegistry;
namespace gc {

const DialectRegistry &initCompilerAndGetDialects();

// the pointers to XXXMemRefType
using GeneralMemrefPtr = void *;
using JitModuleFuncT = void (*)(void **);

struct DriverOptions {
/// the optimization level for the LLVM-JIT
llvm::CodeGenOptLevel jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
/// whether to run the MLIR transformation passes
bool runTransforms = true;
/// todo: target machine, etc.
};

class JitModule {
public:
static llvm::Expected<std::shared_ptr<JitModule>>
create(Operation *op, const DriverOptions &options = {});

/// args should be an array of XXXMemrefType*
void call(GeneralMemrefPtr *args, std::size_t numArgs) {
// Silly code, MLIR execution engine requires pointers of real args as
// inputs
llvm::SmallVector<void *, 32> realargs;
realargs.reserve(numArgs);
for (size_t i = 0; i < numArgs; i++) {
realargs.push_back(&args[i]);
}
compute(realargs.data());
}

/// directly call compute(). args should be an array of void*. args[i] should
/// be a pointer to the real data. For passing memref, users need to 1) create
/// a pointer to XXXMemrefType 2) store the pointer to pointer to
/// XXXMemrefType in args[i]
void callRaw(void **args) { compute(args); }

JitModule(std::unique_ptr<ExecutionEngine> engine, JitModuleFuncT compute);
~JitModule();

private:
std::unique_ptr<ExecutionEngine> engine;
JitModuleFuncT compute;
};

} // namespace gc
} // namespace mlir

#endif
4 changes: 3 additions & 1 deletion lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect)

add_mlir_dialect_library(MLIRCPURuntimeDialect
CPURuntimeDialect.cpp
CPURuntimeOps.cpp
Expand All @@ -10,5 +12,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect
MLIRCPURuntimePassesIncGen

LINK_LIBS PUBLIC
MLIRFuncDialect
${MLIR_LINK_COMPONENTS}
)
4 changes: 3 additions & 1 deletion lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect)

add_mlir_dialect_library(MLIRCPURuntimeTransforms
CPURuntimeToLLVM.cpp

Expand All @@ -8,7 +10,7 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms
MLIRCPURuntimePassesIncGen

LINK_LIBS PUBLIC
MLIRFuncDialect
${MLIR_LINK_COMPONENTS}
MLIRCPURuntimeDialect
)

Expand Down
1 change: 1 addition & 0 deletions lib/gc/ExecutionEngine/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(CPURuntime)
add_subdirectory(Driver)
41 changes: 41 additions & 0 deletions lib/gc/ExecutionEngine/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
if(GC_DEV_LINK_LLVM_DYLIB)
set(LLVM_LINK_COMPONENTS
LLVM
)
get_property(dialect_libs GLOBAL PROPERTY GC_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY GC_PASS_LIBS)
set(MLIR_LINK_COMPONENTS
MLIR
MLIRExecutionEngineShared
)
else()
set(LLVM_LINK_COMPONENTS
Core
Support
nativecodegen
native
)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(MLIR_LINK_COMPONENTS
MLIRBuiltinToLLVMIRTranslation
MLIRExecutionEngine
MLIRLLVMDialect
MLIRLLVMToLLVMIRTranslation
MLIRToLLVMIRTranslationRegistration
)
endif()

add_mlir_library(GCJitWrapper
Driver.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include

LINK_LIBS PUBLIC
${MLIR_LINK_COMPONENTS}
${dialect_libs}
${conversion_libs}
GCPasses
)

82 changes: 82 additions & 0 deletions lib/gc/ExecutionEngine/Driver/Driver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===-- Driver.cpp - Top-level MLIR compiler driver -------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "gc/ExecutionEngine/Driver/Driver.h"
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
#include "gc/Transforms/Passes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "string.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"

namespace mlir {
namespace gc {

static DialectRegistry initDialects() {
mlir::registerAllPasses();
mlir::gc::registerGraphCompilerPasses();
mlir::cpuruntime::registerCPURuntimePasses();
mlir::DialectRegistry registry;
registry.insert<mlir::cpuruntime::CPURuntimeDialect>();
mlir::registerAllDialects(registry);
mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry);
registry.insert<mlir::onednn_graph::OneDNNGraphDialect>();
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
mlir::registerAllToLLVMIRTranslations(registry);
return registry;
}

const DialectRegistry &initCompilerAndGetDialects() {
static DialectRegistry reg = initDialects();
return reg;
}

static const char defaultComputeName[] = "_mlir_ciface_compute";

llvm::Expected<std::shared_ptr<JitModule>>
JitModule::create(Operation *op, const DriverOptions &options) {
if (options.runTransforms) {
mlir::PassManager pm{op->getContext()};
populateCPUPipeline(pm);
if (auto result = pm.run(op); failed(result)) {
return llvm::make_error<llvm::StringError>(
"MLIR pass error", llvm::inconvertibleErrorCode());
}
}
ExecutionEngineOptions exeOptions;
exeOptions.jitCodeGenOptLevel = options.jitCodeGenOptLevel;
std::unique_ptr<llvm::TargetMachine> tm = nullptr;
auto exec = ExecutionEngine::create(op, exeOptions, std::move(tm));
if (!exec) {
return exec.takeError();
}
auto &engine = *exec;
JitModuleFuncT compute;
{
auto expectCompute = engine->lookupPacked(defaultComputeName);
if (!expectCompute) {
return expectCompute.takeError();
}
compute = *expectCompute;
}
return std::make_shared<JitModule>(std::move(engine), compute);
}

JitModule::JitModule(std::unique_ptr<ExecutionEngine> engine,
JitModuleFuncT compute)
: engine{std::move(engine)}, compute{compute} {}
JitModule::~JitModule() = default;

} // namespace gc
} // namespace mlir
1 change: 1 addition & 0 deletions lib/gc/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
Expand Down
1 change: 1 addition & 0 deletions test/mlir/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ function(add_mlir_unittest test_dirname)
endfunction()

add_subdirectory(Example)
add_subdirectory(ExecutionEngine)

7 changes: 7 additions & 0 deletions test/mlir/unittests/ExecutionEngine/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_mlir_unittest(GCExecutionEngineTests
JitWrapper.cpp
)
target_link_libraries(GCExecutionEngineTests
PRIVATE
GCJitWrapper
GCCpuRuntime)
70 changes: 70 additions & 0 deletions test/mlir/unittests/ExecutionEngine/JitWrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//===-- JitWrapper.cpp - Wrapper for JIT ------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "gc/ExecutionEngine/Driver/Driver.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "gtest/gtest.h"
#include <memory>

using namespace mlir;

static const char code1[] = R"mlir(
module {
llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32
func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } {
%out = tensor.empty() : tensor<128xf32>
%2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32>
return %2 : tensor<128xf32>
}
}
)mlir";

extern "C" {
extern int gc_runtime_keep_alive;
}

TEST(ExecutionEngine, JitWrapper) {
gc_runtime_keep_alive = 0;
MLIRContext ctx{gc::initCompilerAndGetDialects()};
std::unique_ptr<llvm::MemoryBuffer> ir_buffer =
llvm::MemoryBuffer::getMemBuffer(code1);
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc());
mlir::OwningOpRef<mlir::ModuleOp> module =
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &ctx);
ASSERT_TRUE(module);
auto jited = gc::JitModule::create(module.get());
bool jit_success = static_cast<bool>(jited);
if (!jit_success) {
auto err = jited.takeError();
llvm::errs() << err;
llvm::consumeError(std::move(err));
}
ASSERT_TRUE(jit_success);
OwningMemRef<float, 1> bufA{
{128}, {128}, [](float &ptr, ArrayRef<int64_t>) { ptr = 1.0f; }};
OwningMemRef<float, 1> bufB{
{128}, {128}, [](float &ptr, ArrayRef<int64_t> idx) { ptr = idx[0]; }};
OwningMemRef<float, 1> bufC{{128}, {128}};
void *args[] = {&*bufA, &*bufB, &*bufC};
jited.get()->call(args, 3);
for (int i = 0; i < 128; i++) {
ASSERT_EQ(bufC[{i}], 1.0f + i);
}
}