diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 44493b75b8a8c..bd30d94e1ccb4 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -202,6 +202,24 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) mlir_configure_python_dev_packages() endif() +#------------------------------------------------------------------------------- +# MLIR Pattern Catalog Generator Configuration +# Requires: +# RTTI to be enabled (set with -DLLVM_ENABLE_RTTI=ON) +# When enabled, causes all rewriter patterns to dump their type names and the +# names of affected operations, which can be used to build a search index +# mapping operations to patterns. +#------------------------------------------------------------------------------- + +set(MLIR_ENABLE_CATALOG_GENERATOR 0 CACHE BOOL + "Enables construction of a catalog of rewrite patterns.") + +if (MLIR_ENABLE_CATALOG_GENERATOR) + message(STATUS "Enabling MLIR pattern catalog generator") + add_definitions(-DMLIR_ENABLE_CATALOG_GENERATOR) + add_definitions(-DLLVM_ENABLE_RTTI) +endif() + set(CMAKE_INCLUDE_CURRENT_DIR ON) include_directories(BEFORE @@ -322,3 +340,4 @@ endif() if(MLIR_STANDALONE_BUILD) llvm_distribution_add_targets() endif() + diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in index 71f3e028b1e88..f4ae70a22b3d2 100644 --- a/mlir/cmake/modules/MLIRConfig.cmake.in +++ b/mlir/cmake/modules/MLIRConfig.cmake.in @@ -16,6 +16,7 @@ set(MLIR_IRDL_TO_CPP_EXE "@MLIR_CONFIG_IRDL_TO_CPP_EXE@") set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@") set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@") set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@") +set(MLIR_ENABLE_CATALOG_GENERATOR "@MLIR_ENABLE_CATALOG_GENERATOR@") set_property(GLOBAL PROPERTY MLIR_ALL_LIBS "@MLIR_ALL_LIBS@") set_property(GLOBAL PROPERTY MLIR_DIALECT_LIBS "@MLIR_DIALECT_LIBS@") diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 10cfe851765dc..141b3c6806ed8 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -475,6 +475,80 @@ class RewriterBase : public OpBuilder { RewriterBase::Listener *rewriteListener; }; + struct CatalogingListener : public RewriterBase::ForwardingListener { + CatalogingListener(OpBuilder::Listener *listener, + const std::string &patternName, raw_ostream &os, + std::mutex &writeMutex) + : RewriterBase::ForwardingListener(listener), patternName(patternName), + os(os), writeMutex(writeMutex) {} + + void notifyOperationInserted(Operation *op, InsertPoint previous) override { + { + std::lock_guard lock(writeMutex); + os << patternName << " | notifyOperationInserted" + << " | " << op->getName() << "\n"; + os.flush(); + } + ForwardingListener::notifyOperationInserted(op, previous); + } + + void notifyOperationModified(Operation *op) override { + { + std::lock_guard lock(writeMutex); + os << patternName << " | notifyOperationModified" + << " | " << op->getName() << "\n"; + os.flush(); + } + ForwardingListener::notifyOperationModified(op); + } + + void notifyOperationReplaced(Operation *op, Operation *newOp) override { + { + std::lock_guard lock(writeMutex); + os << patternName << " | notifyOperationReplaced (with op)" + << " | " << op->getName() << " | " << newOp->getName() << "\n"; + os.flush(); + } + ForwardingListener::notifyOperationReplaced(op, newOp); + } + + void notifyOperationReplaced(Operation *op, + ValueRange replacement) override { + { + std::lock_guard lock(writeMutex); + os << patternName << " | notifyOperationReplaced (with values)" + << " | " << op->getName() << "\n"; + os.flush(); + } + ForwardingListener::notifyOperationReplaced(op, replacement); + } + + void notifyOperationErased(Operation *op) override { + { + std::lock_guard lock(writeMutex); + os << patternName << " | notifyOperationErased" + << " | " << op->getName() << "\n"; + os.flush(); + } + ForwardingListener::notifyOperationErased(op); + } + + void notifyPatternBegin(const Pattern &pattern, Operation *op) override { + { + std::lock_guard lock(writeMutex); + os << patternName << " | notifyPatternBegin" + << " | " << op->getName() << "\n"; + os.flush(); + } + ForwardingListener::notifyPatternBegin(pattern, op); + } + + private: + const std::string &patternName; + raw_ostream &os; + std::mutex &writeMutex; + }; + /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller /// is responsible for creating or updating the operation transferring flow diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index 4a12183492fd4..c66aaf267881f 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -15,8 +15,19 @@ #include "ByteCode.h" #include "llvm/Support/Debug.h" +#ifdef MLIR_ENABLE_CATALOG_GENERATOR +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#endif + #define DEBUG_TYPE "pattern-application" +#ifdef MLIR_ENABLE_CATALOG_GENERATOR +static std::mutex catalogWriteMutex; +#endif + using namespace mlir; using namespace mlir::detail; @@ -152,6 +163,16 @@ LogicalResult PatternApplicator::matchAndRewrite( unsigned anyIt = 0, anyE = anyOpPatterns.size(); unsigned pdlIt = 0, pdlE = pdlMatches.size(); LogicalResult result = failure(); +#ifdef MLIR_ENABLE_CATALOG_GENERATOR + std::error_code ec; + llvm::raw_fd_ostream catalogOs("pattern_catalog.txt", ec, + llvm::sys::fs::OF_Append); + if (ec) { + op->emitError("Failed to open pattern catalog file: " + ec.message()); + return failure(); + } +#endif + do { // Find the next pattern with the highest benefit. const Pattern *bestPattern = nullptr; @@ -206,14 +227,38 @@ LogicalResult PatternApplicator::matchAndRewrite( } else { LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" << bestPattern->getDebugName() << "\"\n"); - const auto *pattern = static_cast(bestPattern); - result = pattern->matchAndRewrite(op, rewriter); +#ifdef MLIR_ENABLE_CATALOG_GENERATOR + OpBuilder::Listener *oldListener = rewriter.getListener(); + int status; + const char *mangledPatternName = typeid(*pattern).name(); + char *demangled = abi::__cxa_demangle(mangledPatternName, nullptr, + nullptr, &status); + std::string demangledPatternName; + if (status == 0 && demangled) { + demangledPatternName = demangled; + free(demangled); + } else { + // Fallback in case demangling fails. + demangledPatternName = mangledPatternName; + } + + RewriterBase::CatalogingListener *catalogingListener = + new RewriterBase::CatalogingListener( + oldListener, demangledPatternName, catalogOs, + catalogWriteMutex); + rewriter.setListener(catalogingListener); +#endif + result = pattern->matchAndRewrite(op, rewriter); LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName() << "\" result " << succeeded(result) << "\n"); +#ifdef MLIR_ENABLE_CATALOG_GENERATOR + rewriter.setListener(oldListener); + delete catalogingListener; +#endif } // Process the result of the pattern application. diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 31e0caa768113..3f491932f3989 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -527,6 +527,7 @@ static LogicalResult processBuffer(raw_ostream &os, // Create a context just for the current buffer. Disable threading on creation // since we'll inject the thread-pool separately. MLIRContext context(registry, MLIRContext::Threading::DISABLED); + if (threadPool) context.setThreadPool(*threadPool); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index 23d89f41a3a45..281b2566304b0 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -44,6 +44,7 @@ expand_template( "@MLIR_ENABLE_SPIRV_CPU_RUNNER@": "0", "@MLIR_ENABLE_VULKAN_RUNNER@": "0", "@MLIR_ENABLE_BINDINGS_PYTHON@": "0", + "@MLIR_ENABLE_CATALOG_GENERATOR@": "0", "@MLIR_RUN_AMX_TESTS@": "0", "@MLIR_RUN_ARM_SVE_TESTS@": "0", "@MLIR_RUN_ARM_SME_TESTS@": "0",