Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,24 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
mlir_configure_python_dev_packages()
endif()

#-------------------------------------------------------------------------------
# MLIR Pattern Catalog Generator Configuration
# Requires:
# RTTI to be enabled (set with -DLLVM_ENABLE_RTTI=ON)
# When enabled, causes all rewriter patterns to dump their type names and the
# names of affected operations, which can be used to build a search index
# mapping operations to patterns.
#-------------------------------------------------------------------------------

set(MLIR_ENABLE_CATALOG_GENERATOR 0 CACHE BOOL
"Enables construction of a catalog of rewrite patterns.")

if (MLIR_ENABLE_CATALOG_GENERATOR)
message(STATUS "Enabling MLIR pattern catalog generator")
add_definitions(-DMLIR_ENABLE_CATALOG_GENERATOR)
add_definitions(-DLLVM_ENABLE_RTTI)
endif()

set(CMAKE_INCLUDE_CURRENT_DIR ON)

include_directories(BEFORE
Expand Down Expand Up @@ -322,3 +340,4 @@ endif()
if(MLIR_STANDALONE_BUILD)
llvm_distribution_add_targets()
endif()

1 change: 1 addition & 0 deletions mlir/cmake/modules/MLIRConfig.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(MLIR_IRDL_TO_CPP_EXE "@MLIR_CONFIG_IRDL_TO_CPP_EXE@")
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
set(MLIR_ENABLE_CATALOG_GENERATOR "@MLIR_ENABLE_CATALOG_GENERATOR@")

set_property(GLOBAL PROPERTY MLIR_ALL_LIBS "@MLIR_ALL_LIBS@")
set_property(GLOBAL PROPERTY MLIR_DIALECT_LIBS "@MLIR_DIALECT_LIBS@")
Expand Down
74 changes: 74 additions & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,80 @@ class RewriterBase : public OpBuilder {
RewriterBase::Listener *rewriteListener;
};

struct CatalogingListener : public RewriterBase::ForwardingListener {
CatalogingListener(OpBuilder::Listener *listener,
const std::string &patternName, raw_ostream &os,
std::mutex &writeMutex)
: RewriterBase::ForwardingListener(listener), patternName(patternName),
os(os), writeMutex(writeMutex) {}

void notifyOperationInserted(Operation *op, InsertPoint previous) override {
{
std::lock_guard<std::mutex> lock(writeMutex);
os << patternName << " | notifyOperationInserted"
<< " | " << op->getName() << "\n";
os.flush();
}
ForwardingListener::notifyOperationInserted(op, previous);
}

void notifyOperationModified(Operation *op) override {
{
std::lock_guard<std::mutex> lock(writeMutex);
os << patternName << " | notifyOperationModified"
<< " | " << op->getName() << "\n";
os.flush();
}
ForwardingListener::notifyOperationModified(op);
}

void notifyOperationReplaced(Operation *op, Operation *newOp) override {
{
std::lock_guard<std::mutex> lock(writeMutex);
os << patternName << " | notifyOperationReplaced (with op)"
<< " | " << op->getName() << " | " << newOp->getName() << "\n";
os.flush();
}
ForwardingListener::notifyOperationReplaced(op, newOp);
}

void notifyOperationReplaced(Operation *op,
ValueRange replacement) override {
{
std::lock_guard<std::mutex> lock(writeMutex);
os << patternName << " | notifyOperationReplaced (with values)"
<< " | " << op->getName() << "\n";
os.flush();
}
ForwardingListener::notifyOperationReplaced(op, replacement);
}

void notifyOperationErased(Operation *op) override {
{
std::lock_guard<std::mutex> lock(writeMutex);
os << patternName << " | notifyOperationErased"
<< " | " << op->getName() << "\n";
os.flush();
}
ForwardingListener::notifyOperationErased(op);
}

void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
{
std::lock_guard<std::mutex> lock(writeMutex);
os << patternName << " | notifyPatternBegin"
<< " | " << op->getName() << "\n";
os.flush();
}
ForwardingListener::notifyPatternBegin(pattern, op);
}

private:
const std::string &patternName;
raw_ostream &os;
std::mutex &writeMutex;
};

/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow
Expand Down
49 changes: 47 additions & 2 deletions mlir/lib/Rewrite/PatternApplicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@
#include "ByteCode.h"
#include "llvm/Support/Debug.h"

#ifdef MLIR_ENABLE_CATALOG_GENERATOR
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cxxabi.h>
#include <mutex>
#endif

#define DEBUG_TYPE "pattern-application"

#ifdef MLIR_ENABLE_CATALOG_GENERATOR
static std::mutex catalogWriteMutex;
#endif

using namespace mlir;
using namespace mlir::detail;

Expand Down Expand Up @@ -152,6 +163,16 @@ LogicalResult PatternApplicator::matchAndRewrite(
unsigned anyIt = 0, anyE = anyOpPatterns.size();
unsigned pdlIt = 0, pdlE = pdlMatches.size();
LogicalResult result = failure();
#ifdef MLIR_ENABLE_CATALOG_GENERATOR
std::error_code ec;
llvm::raw_fd_ostream catalogOs("pattern_catalog.txt", ec,
llvm::sys::fs::OF_Append);
if (ec) {
op->emitError("Failed to open pattern catalog file: " + ec.message());
return failure();
}
#endif

do {
// Find the next pattern with the highest benefit.
const Pattern *bestPattern = nullptr;
Expand Down Expand Up @@ -206,14 +227,38 @@ LogicalResult PatternApplicator::matchAndRewrite(
} else {
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
<< bestPattern->getDebugName() << "\"\n");

const auto *pattern =
static_cast<const RewritePattern *>(bestPattern);
result = pattern->matchAndRewrite(op, rewriter);

#ifdef MLIR_ENABLE_CATALOG_GENERATOR
OpBuilder::Listener *oldListener = rewriter.getListener();
int status;
const char *mangledPatternName = typeid(*pattern).name();
char *demangled = abi::__cxa_demangle(mangledPatternName, nullptr,
nullptr, &status);
std::string demangledPatternName;
if (status == 0 && demangled) {
demangledPatternName = demangled;
free(demangled);
} else {
// Fallback in case demangling fails.
demangledPatternName = mangledPatternName;
}

RewriterBase::CatalogingListener *catalogingListener =
new RewriterBase::CatalogingListener(
oldListener, demangledPatternName, catalogOs,
catalogWriteMutex);
rewriter.setListener(catalogingListener);
#endif
result = pattern->matchAndRewrite(op, rewriter);
LLVM_DEBUG(llvm::dbgs()
<< "\"" << bestPattern->getDebugName() << "\" result "
<< succeeded(result) << "\n");
#ifdef MLIR_ENABLE_CATALOG_GENERATOR
rewriter.setListener(oldListener);
delete catalogingListener;
#endif
}

// Process the result of the pattern application.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ static LogicalResult processBuffer(raw_ostream &os,
// Create a context just for the current buffer. Disable threading on creation
// since we'll inject the thread-pool separately.
MLIRContext context(registry, MLIRContext::Threading::DISABLED);

if (threadPool)
context.setThreadPool(*threadPool);

Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ expand_template(
"@MLIR_ENABLE_SPIRV_CPU_RUNNER@": "0",
"@MLIR_ENABLE_VULKAN_RUNNER@": "0",
"@MLIR_ENABLE_BINDINGS_PYTHON@": "0",
"@MLIR_ENABLE_CATALOG_GENERATOR@": "0",
"@MLIR_RUN_AMX_TESTS@": "0",
"@MLIR_RUN_ARM_SVE_TESTS@": "0",
"@MLIR_RUN_ARM_SME_TESTS@": "0",
Expand Down