Skip to content

Commit

Permalink
Implement Pass and Dialect plugins for mlir-opt
Browse files Browse the repository at this point in the history
Implementation of Pass and Dialect Plugins that mirrors LLVM Pass Plugin
implementation from the new pass manager.

Currently the implementation only supports using the pass-pipeline option
for adding passes. This restriction is imposed by the `PassPipelineCLParser`
variable in mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:114 that loads the
parse options statically before parsing the cmd line args.

```
mlir-opt stanalone-plugin.mlir --load-dialect-plugin=lib/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)"
```

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D147053
  • Loading branch information
xblang-project authored and joker-eph committed Apr 7, 2023
1 parent 713e815 commit 5e2afe5
Show file tree
Hide file tree
Showing 22 changed files with 591 additions and 5 deletions.
1 change: 1 addition & 0 deletions mlir/examples/standalone/CMakeLists.txt
Expand Up @@ -52,4 +52,5 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
endif()
add_subdirectory(test)
add_subdirectory(standalone-opt)
add_subdirectory(standalone-plugin)
add_subdirectory(standalone-translate)
4 changes: 4 additions & 0 deletions mlir/examples/standalone/include/Standalone/CMakeLists.txt
@@ -1,3 +1,7 @@
add_mlir_dialect(StandaloneOps standalone)
add_mlir_doc(StandaloneDialect StandaloneDialect Standalone/ -gen-dialect-doc)
add_mlir_doc(StandaloneOps StandaloneOps Standalone/ -gen-op-doc)

set(LLVM_TARGET_DEFINITIONS StandalonePasses.td)
mlir_tablegen(StandalonePasses.h.inc --gen-pass-decls)
add_public_tablegen_target(MLIRStandalonePassesIncGen)
26 changes: 26 additions & 0 deletions mlir/examples/standalone/include/Standalone/StandalonePasses.h
@@ -0,0 +1,26 @@
//===- StandalonePasses.h - Standalone passes ------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef STANDALONE_STANDALONEPASSES_H
#define STANDALONE_STANDALONEPASSES_H

#include "Standalone/StandaloneDialect.h"
#include "Standalone/StandaloneOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace standalone {
#define GEN_PASS_DECL
#include "Standalone/StandalonePasses.h.inc"

#define GEN_PASS_REGISTRATION
#include "Standalone/StandalonePasses.h.inc"
} // namespace standalone
} // namespace mlir

#endif
30 changes: 30 additions & 0 deletions mlir/examples/standalone/include/Standalone/StandalonePasses.td
@@ -0,0 +1,30 @@
//===- StandalonePsss.td - Standalone dialect passes -------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef STANDALONE_PASS
#define STANDALONE_PASS

include "mlir/Pass/PassBase.td"

def StandaloneSwitchBarFoo: Pass<"standalone-switch-bar-foo", "::mlir::ModuleOp"> {
let summary = "Switches the name of a FuncOp named `bar` to `foo` and folds.";
let description = [{
Switches the name of a FuncOp named `bar` to `foo` and folds.
```
func.func @bar() {
return
}
// Gets transformed to:
func.func @foo() {
return
}
```
}];
}

#endif // STANDALONE_PASS
3 changes: 3 additions & 0 deletions mlir/examples/standalone/lib/Standalone/CMakeLists.txt
Expand Up @@ -2,14 +2,17 @@ add_mlir_dialect_library(MLIRStandalone
StandaloneTypes.cpp
StandaloneDialect.cpp
StandaloneOps.cpp
StandalonePasses.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/Standalone

DEPENDS
MLIRStandaloneOpsIncGen
MLIRStandalonePassesIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRInferTypeOpInterface
MLIRFuncDialect
)
48 changes: 48 additions & 0 deletions mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
@@ -0,0 +1,48 @@
//===- StandalonePasses.cpp - Standalone passes -----------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "Standalone/StandalonePasses.h"

namespace mlir::standalone {
#define GEN_PASS_DEF_STANDALONESWITCHBARFOO
#include "Standalone/StandalonePasses.h.inc"

namespace {
class StandaloneSwitchBarFooRewriter : public OpRewritePattern<func::FuncOp> {
public:
using OpRewritePattern<func::FuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
if (op.getSymName() == "bar") {
rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
return success();
}
return failure();
}
};

class StandaloneSwitchBarFoo
: public impl::StandaloneSwitchBarFooBase<StandaloneSwitchBarFoo> {
public:
using impl::StandaloneSwitchBarFooBase<
StandaloneSwitchBarFoo>::StandaloneSwitchBarFooBase;
void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<StandaloneSwitchBarFooRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
signalPassFailure();
}
};
} // namespace
} // namespace mlir::standalone
2 changes: 2 additions & 0 deletions mlir/examples/standalone/standalone-opt/standalone-opt.cpp
Expand Up @@ -21,9 +21,11 @@
#include "llvm/Support/ToolOutputFile.h"

#include "Standalone/StandaloneDialect.h"
#include "Standalone/StandalonePasses.h"

int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::standalone::registerPasses();
// TODO: Register standalone passes here.

mlir::DialectRegistry registry;
Expand Down
22 changes: 22 additions & 0 deletions mlir/examples/standalone/standalone-plugin/CMakeLists.txt
@@ -0,0 +1,22 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(LIBS
MLIRIR
MLIRPass
MLIRPluginsLib
MLIRStandalone
MLIRTransformUtils
)

add_mlir_dialect_library(StandalonePlugin
SHARED
standalone-plugin.cpp

DEPENDS
MLIRStandalone
)

llvm_update_compile_flags(StandalonePlugin)
target_link_libraries(StandalonePlugin PRIVATE ${LIBS})

mlir_check_all_link_libraries(StandalonePlugin)
39 changes: 39 additions & 0 deletions mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp
@@ -0,0 +1,39 @@
//===- standalone-plugin.cpp ------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Tools/Plugins/DialectPlugin.h"

#include "Standalone/StandaloneDialect.h"
#include "Standalone/StandalonePasses.h"

using namespace mlir;

/// Dialect plugin registration mechanism.
/// Observe that it also allows to register passes.
/// Necessary symbol to register the dialect plugin.
extern "C" LLVM_ATTRIBUTE_WEAK DialectPluginLibraryInfo
mlirGetDialectPluginInfo() {
return {MLIR_PLUGIN_API_VERSION, "Standalone", LLVM_VERSION_STRING,
[](DialectRegistry *registry) {
registry->insert<mlir::standalone::StandaloneDialect>();
mlir::standalone::registerPasses();
}};
}

/// Pass plugin registration mechanism.
/// Necessary symbol to register the pass plugin.
extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() {
return {MLIR_PLUGIN_API_VERSION, "StandalonePasses", LLVM_VERSION_STRING,
[]() { mlir::standalone::registerPasses(); }};
}
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s --load-pass-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s

module {
// CHECK-LABEL: func @foo()
func.func @bar() {
return
}

// CHECK-LABEL: func @abar()
func.func @abar() {
return
}
}
13 changes: 13 additions & 0 deletions mlir/examples/standalone/test/Standalone/standalone-plugin.mlir
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s --load-dialect-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s

module {
// CHECK-LABEL: func @foo()
func.func @bar() {
return
}

// CHECK-LABEL: func @standalone_types(%arg0: !standalone.custom<"10">)
func.func @standalone_types(%arg0: !standalone.custom<"10">) {
return
}
}
4 changes: 4 additions & 0 deletions mlir/examples/standalone/test/lit.cfg.py
Expand Up @@ -44,12 +44,16 @@
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin')
config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib')

config.substitutions.append(('%standalone_libs', config.standalone_libs_dir))

# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)

tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir]
tools = [
'mlir-opt',
'standalone-capi-test',
'standalone-opt',
'standalone-translate',
Expand Down
106 changes: 106 additions & 0 deletions mlir/include/mlir/Tools/Plugins/DialectPlugin.h
@@ -0,0 +1,106 @@
//===- mlir/Tools/Plugins/DialectPlugin.h - Public Plugin API -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This defines the public entry point for dialect plugins.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H
#define MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H

#include "mlir/IR/DialectRegistry.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/Error.h"
#include <cstdint>
#include <string>

namespace mlir {
extern "C" {
/// Information about the plugin required to load its dialects & passes
///
/// This struct defines the core interface for dialect plugins and is supposed
/// to be filled out by plugin implementors. MLIR-side users of a plugin are
/// expected to use the \c DialectPlugin class below to interface with it.
struct DialectPluginLibraryInfo {
/// The API version understood by this plugin, usually
/// \c MLIR_PLUGIN_API_VERSION
uint32_t apiVersion;
/// A meaningful name of the plugin.
const char *pluginName;
/// The version of the plugin.
const char *pluginVersion;

/// The callback for registering dialect plugin with a \c DialectRegistry
/// instance
void (*registerDialectRegistryCallbacks)(DialectRegistry *);
};
}

/// A loaded dialect plugin.
///
/// An instance of this class wraps a loaded dialect plugin and gives access to
/// its interface defined by the \c DialectPluginLibraryInfo it exposes.
class DialectPlugin {
public:
/// Attempts to load a dialect plugin from a given file.
///
/// \returns Returns an error if either the library cannot be found or loaded,
/// there is no public entry point, or the plugin implements the wrong API
/// version.
static llvm::Expected<DialectPlugin> load(const std::string &filename);

/// Get the filename of the loaded plugin.
StringRef getFilename() const { return filename; }

/// Get the plugin name
StringRef getPluginName() const { return info.pluginName; }

/// Get the plugin version
StringRef getPluginVersion() const { return info.pluginVersion; }

/// Get the plugin API version
uint32_t getAPIVersion() const { return info.apiVersion; }

/// Invoke the DialectRegistry callback registration
void
registerDialectRegistryCallbacks(DialectRegistry &dialectRegistry) const {
info.registerDialectRegistryCallbacks(&dialectRegistry);
}

private:
DialectPlugin(const std::string &filename,
const llvm::sys::DynamicLibrary &library)
: filename(filename), library(library), info() {}

std::string filename;
llvm::sys::DynamicLibrary library;
DialectPluginLibraryInfo info;
};
} // namespace mlir

/// The public entry point for a dialect plugin.
///
/// When a plugin is loaded by the driver, it will call this entry point to
/// obtain information about this plugin and about how to register its dialects.
/// This function needs to be implemented by the plugin, see the example below:
///
/// ```
/// extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
/// mlirGetDialectPluginInfo() {
/// return {
/// MLIR_PLUGIN_API_VERSION, "MyPlugin", "v0.1", [](DialectRegistry) { ... }
/// };
/// }
/// ```
extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
mlirGetDialectPluginInfo();

#endif /* MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H */

0 comments on commit 5e2afe5

Please sign in to comment.