diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index dcc068e4097c5..3de8677aefe90 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -133,6 +133,8 @@ set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL "Statically link the nvptxlibrary instead of calling ptxas as a subprocess \ for compiling PTX to cubin") +set(MLIR_ENABLE_PDL_IN_PATTERNMATCH 1 CACHE BOOL "Enable PDL in PatternMatch") + option(MLIR_INCLUDE_TESTS "Generate build targets for the MLIR unit tests." ${LLVM_INCLUDE_TESTS}) @@ -178,10 +180,9 @@ include_directories( ${MLIR_INCLUDE_DIR}) # Adding tools/mlir-tblgen here as calling add_tablegen sets some variables like # MLIR_TABLEGEN_EXE in PARENT_SCOPE which gets lost if that folder is included # from another directory like tools -add_subdirectory(tools/mlir-tblgen) add_subdirectory(tools/mlir-linalg-ods-gen) add_subdirectory(tools/mlir-pdll) - +add_subdirectory(tools/mlir-tblgen) set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "") set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "") set(MLIR_LINALG_ODS_YAML_GEN_TABLEGEN_EXE "${MLIR_LINALG_ODS_YAML_GEN_TABLEGEN_EXE}" CACHE INTERNAL "") diff --git a/mlir/examples/minimal-opt/README.md b/mlir/examples/minimal-opt/README.md index b8a455f7a7966..1bc54b8367cc5 100644 --- a/mlir/examples/minimal-opt/README.md +++ b/mlir/examples/minimal-opt/README.md @@ -14,10 +14,10 @@ Below are some example measurements taken at the time of the LLVM 17 release, using clang-14 on a X86 Ubuntu and [bloaty](https://github.com/google/bloaty). | | Base | Os | Oz | Os LTO | Oz LTO | -| :-----------------------------: | ------ | ------ | ------ | ------ | ------ | -| `mlir-cat` | 1018kB | 836KB | 879KB | 697KB | 649KB | -| `mlir-minimal-opt` | 1.54MB | 1.25MB | 1.29MB | 1.10MB | 1.00MB | -| `mlir-minimal-opt-canonicalize` | 2.24MB | 1.81MB | 1.86MB | 1.62MB | 1.48MB | +| :------------------------------: | ------ | ------ | ------ | ------ | ------ | +| `mlir-cat` | 1024KB | 840KB | 885KB | 706KB | 657KB | +| `mlir-minimal-opt` | 1.62MB | 1.32MB | 1.36MB | 1.17MB | 1.07MB | +| `mlir-minimal-opt-canonicalize` | 1.83MB | 1.40MB | 1.45MB | 1.25MB | 1.14MB | Base configuration: @@ -32,6 +32,7 @@ cmake ../llvm/ -G Ninja \ -DCMAKE_CXX_COMPILER=clang++ \ -DLLVM_ENABLE_LLD=ON \ -DLLVM_ENABLE_BACKTRACES=OFF \ + -DMLIR_ENABLE_PDL_IN_PATTERNMATCH=OFF \ -DCMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=-Wl,-icf=all ``` diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake index efa77b2e5ce5d..e152a36c0ce0c 100644 --- a/mlir/include/mlir/Config/mlir-config.h.cmake +++ b/mlir/include/mlir/Config/mlir-config.h.cmake @@ -26,4 +26,7 @@ numeric seed that is passed to the random number generator. */ #cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED} +/* If set, enables PDL usage. */ +#cmakedefine01 MLIR_ENABLE_PDL_IN_PATTERNMATCH + #endif diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index 74f9c977b7028..e228229302cff 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -15,6 +15,7 @@ #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index a28b27e4e1581..4603953cb40fa 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -29,6 +29,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" // Pull in all enum type definitions and utility function declarations. diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc new file mode 100644 index 0000000000000..a215da8cb6431 --- /dev/null +++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc @@ -0,0 +1,995 @@ +//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_PDLPATTERNMATCH_H +#define MLIR_IR_PDLPATTERNMATCH_H + +#include "mlir/Config/mlir-config.h" + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" + +namespace mlir { +//===----------------------------------------------------------------------===// +// PDL Patterns +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// PDLValue + +/// Storage type of byte-code interpreter values. These are passed to constraint +/// functions as arguments. +class PDLValue { +public: + /// The underlying kind of a PDL value. + enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange }; + + /// Construct a new PDL value. + PDLValue(const PDLValue &other) = default; + PDLValue(std::nullptr_t = nullptr) {} + PDLValue(Attribute value) + : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {} + PDLValue(Operation *value) : value(value), kind(Kind::Operation) {} + PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {} + PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {} + PDLValue(Value value) + : value(value.getAsOpaquePointer()), kind(Kind::Value) {} + PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {} + + /// Returns true if the type of the held value is `T`. + template + bool isa() const { + assert(value && "isa<> used on a null value"); + return kind == getKindOf(); + } + + /// Attempt to dynamically cast this value to type `T`, returns null if this + /// value is not an instance of `T`. + template ::value, T, std::optional>> + ResultT dyn_cast() const { + return isa() ? castImpl() : ResultT(); + } + + /// Cast this value to type `T`, asserts if this value is not an instance of + /// `T`. + template + T cast() const { + assert(isa() && "expected value to be of type `T`"); + return castImpl(); + } + + /// Get an opaque pointer to the value. + const void *getAsOpaquePointer() const { return value; } + + /// Return if this value is null or not. + explicit operator bool() const { return value; } + + /// Return the kind of this value. + Kind getKind() const { return kind; } + + /// Print this value to the provided output stream. + void print(raw_ostream &os) const; + + /// Print the specified value kind to an output stream. + static void print(raw_ostream &os, Kind kind); + +private: + /// Find the index of a given type in a range of other types. + template + struct index_of_t; + template + struct index_of_t : std::integral_constant {}; + template + struct index_of_t + : std::integral_constant::value> {}; + + /// Return the kind used for the given T. + template + static Kind getKindOf() { + return static_cast(index_of_t::value); + } + + /// The internal implementation of `cast`, that returns the underlying value + /// as the given type `T`. + template + std::enable_if_t::value, T> + castImpl() const { + return T::getFromOpaquePointer(value); + } + template + std::enable_if_t::value, T> + castImpl() const { + return *reinterpret_cast(const_cast(value)); + } + template + std::enable_if_t::value, T> castImpl() const { + return reinterpret_cast(const_cast(value)); + } + + /// The internal opaque representation of a PDLValue. + const void *value{nullptr}; + /// The kind of the opaque value. + Kind kind{Kind::Attribute}; +}; + +inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { + value.print(os); + return os; +} + +inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { + PDLValue::print(os, kind); + return os; +} + +//===----------------------------------------------------------------------===// +// PDLResultList + +/// The class represents a list of PDL results, returned by a native rewrite +/// method. It provides the mechanism with which to pass PDLValues back to the +/// PDL bytecode. +class PDLResultList { +public: + /// Push a new Attribute value onto the result list. + void push_back(Attribute value) { results.push_back(value); } + + /// Push a new Operation onto the result list. + void push_back(Operation *value) { results.push_back(value); } + + /// Push a new Type onto the result list. + void push_back(Type value) { results.push_back(value); } + + /// Push a new TypeRange onto the result list. + void push_back(TypeRange value) { + // The lifetime of a TypeRange can't be guaranteed, so we'll need to + // allocate a storage for it. + llvm::OwningArrayRef storage(value.size()); + llvm::copy(value, storage.begin()); + allocatedTypeRanges.emplace_back(std::move(storage)); + typeRanges.push_back(allocatedTypeRanges.back()); + results.push_back(&typeRanges.back()); + } + void push_back(ValueTypeRange value) { + typeRanges.push_back(value); + results.push_back(&typeRanges.back()); + } + void push_back(ValueTypeRange value) { + typeRanges.push_back(value); + results.push_back(&typeRanges.back()); + } + + /// Push a new Value onto the result list. + void push_back(Value value) { results.push_back(value); } + + /// Push a new ValueRange onto the result list. + void push_back(ValueRange value) { + // The lifetime of a ValueRange can't be guaranteed, so we'll need to + // allocate a storage for it. + llvm::OwningArrayRef storage(value.size()); + llvm::copy(value, storage.begin()); + allocatedValueRanges.emplace_back(std::move(storage)); + valueRanges.push_back(allocatedValueRanges.back()); + results.push_back(&valueRanges.back()); + } + void push_back(OperandRange value) { + valueRanges.push_back(value); + results.push_back(&valueRanges.back()); + } + void push_back(ResultRange value) { + valueRanges.push_back(value); + results.push_back(&valueRanges.back()); + } + +protected: + /// Create a new result list with the expected number of results. + PDLResultList(unsigned maxNumResults) { + // For now just reserve enough space for all of the results. We could do + // separate counts per range type, but it isn't really worth it unless there + // are a "large" number of results. + typeRanges.reserve(maxNumResults); + valueRanges.reserve(maxNumResults); + } + + /// The PDL results held by this list. + SmallVector results; + /// Memory used to store ranges held by the list. + SmallVector typeRanges; + SmallVector valueRanges; + /// Memory allocated to store ranges in the result list whose lifetime was + /// generated in the native function. + SmallVector> allocatedTypeRanges; + SmallVector> allocatedValueRanges; +}; + +//===----------------------------------------------------------------------===// +// PDLPatternConfig + +/// An individual configuration for a pattern, which can be accessed by native +/// functions via the PDLPatternConfigSet. This allows for injecting additional +/// configuration into PDL patterns that is specific to certain compilation +/// flows. +class PDLPatternConfig { +public: + virtual ~PDLPatternConfig() = default; + + /// Hooks that are invoked at the beginning and end of a rewrite of a matched + /// pattern. These can be used to setup any specific state necessary for the + /// rewrite. + virtual void notifyRewriteBegin(PatternRewriter &rewriter) {} + virtual void notifyRewriteEnd(PatternRewriter &rewriter) {} + + /// Return the TypeID that represents this configuration. + TypeID getTypeID() const { return id; } + +protected: + PDLPatternConfig(TypeID id) : id(id) {} + +private: + TypeID id; +}; + +/// This class provides a base class for users implementing a type of pattern +/// configuration. +template +class PDLPatternConfigBase : public PDLPatternConfig { +public: + /// Support LLVM style casting. + static bool classof(const PDLPatternConfig *config) { + return config->getTypeID() == getConfigID(); + } + + /// Return the type id used for this configuration. + static TypeID getConfigID() { return TypeID::get(); } + +protected: + PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {} +}; + +/// This class contains a set of configurations for a specific pattern. +/// Configurations are uniqued by TypeID, meaning that only one configuration of +/// each type is allowed. +class PDLPatternConfigSet { +public: + PDLPatternConfigSet() = default; + + /// Construct a set with the given configurations. + template + PDLPatternConfigSet(ConfigsT &&...configs) { + (addConfig(std::forward(configs)), ...); + } + + /// Get the configuration defined by the given type. Asserts that the + /// configuration of the provided type exists. + template + const T &get() const { + const T *config = tryGet(); + assert(config && "configuration not found"); + return *config; + } + + /// Get the configuration defined by the given type, returns nullptr if the + /// configuration does not exist. + template + const T *tryGet() const { + for (const auto &configIt : configs) + if (const T *config = dyn_cast(configIt.get())) + return config; + return nullptr; + } + + /// Notify the configurations within this set at the beginning or end of a + /// rewrite of a matched pattern. + void notifyRewriteBegin(PatternRewriter &rewriter) { + for (const auto &config : configs) + config->notifyRewriteBegin(rewriter); + } + void notifyRewriteEnd(PatternRewriter &rewriter) { + for (const auto &config : configs) + config->notifyRewriteEnd(rewriter); + } + +protected: + /// Add a configuration to the set. + template + void addConfig(T &&config) { + assert(!tryGet>() && "configuration already exists"); + configs.emplace_back( + std::make_unique>(std::forward(config))); + } + + /// The set of configurations for this pattern. This uses a vector instead of + /// a map with the expectation that the number of configurations per set is + /// small (<= 1). + SmallVector> configs; +}; + +//===----------------------------------------------------------------------===// +// PDLPatternModule + +/// A generic PDL pattern constraint function. This function applies a +/// constraint to a given set of opaque PDLValue entities. Returns success if +/// the constraint successfully held, failure otherwise. +using PDLConstraintFunction = + std::function)>; +/// A native PDL rewrite function. This function performs a rewrite on the +/// given set of values. Any results from this rewrite that should be passed +/// back to PDL should be added to the provided result list. This method is only +/// invoked when the corresponding match was successful. Returns failure if an +/// invariant of the rewrite was broken (certain rewriters may recover from +/// partial pattern application). +using PDLRewriteFunction = std::function)>; + +namespace detail { +namespace pdl_function_builder { +/// A utility variable that always resolves to false. This is useful for static +/// asserts that are always false, but only should fire in certain templated +/// constructs. For example, if a templated function should never be called, the +/// function could be defined as: +/// +/// template +/// void foo() { +/// static_assert(always_false, "This function should never be called"); +/// } +/// +template +constexpr bool always_false = false; + +//===----------------------------------------------------------------------===// +// PDL Function Builder: Type Processing +//===----------------------------------------------------------------------===// + +/// This struct provides a convenient way to determine how to process a given +/// type as either a PDL parameter, or a result value. This allows for +/// supporting complex types in constraint and rewrite functions, without +/// requiring the user to hand-write the necessary glue code themselves. +/// Specializations of this class should implement the following methods to +/// enable support as a PDL argument or result type: +/// +/// static LogicalResult verifyAsArg( +/// function_ref errorFn, PDLValue pdlValue, +/// size_t argIdx); +/// +/// * This method verifies that the given PDLValue is valid for use as a +/// value of `T`. +/// +/// static T processAsArg(PDLValue pdlValue); +/// +/// * This method processes the given PDLValue as a value of `T`. +/// +/// static void processAsResult(PatternRewriter &, PDLResultList &results, +/// const T &value); +/// +/// * This method processes the given value of `T` as the result of a +/// function invocation. The method should package the value into an +/// appropriate form and append it to the given result list. +/// +/// If the type `T` is based on a higher order value, consider using +/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify +/// the implementation. +/// +template +struct ProcessPDLValue; + +/// This struct provides a simplified model for processing types that are based +/// on another type, e.g. APInt is based on the handling for IntegerAttr. This +/// allows for building the necessary processing functions on top of the base +/// value instead of a PDLValue. Derived users should implement the following +/// (which subsume the ProcessPDLValue variants): +/// +/// static LogicalResult verifyAsArg( +/// function_ref errorFn, +/// const BaseT &baseValue, size_t argIdx); +/// +/// * This method verifies that the given PDLValue is valid for use as a +/// value of `T`. +/// +/// static T processAsArg(BaseT baseValue); +/// +/// * This method processes the given base value as a value of `T`. +/// +template +struct ProcessPDLValueBasedOn { + static LogicalResult + verifyAsArg(function_ref errorFn, + PDLValue pdlValue, size_t argIdx) { + // Verify the base class before continuing. + if (failed(ProcessPDLValue::verifyAsArg(errorFn, pdlValue, argIdx))) + return failure(); + return ProcessPDLValue::verifyAsArg( + errorFn, ProcessPDLValue::processAsArg(pdlValue), argIdx); + } + static T processAsArg(PDLValue pdlValue) { + return ProcessPDLValue::processAsArg( + ProcessPDLValue::processAsArg(pdlValue)); + } + + /// Explicitly add the expected parent API to ensure the parent class + /// implements the necessary API (and doesn't implicitly inherit it from + /// somewhere else). + static LogicalResult + verifyAsArg(function_ref errorFn, BaseT value, + size_t argIdx) { + return success(); + } + static T processAsArg(BaseT baseValue); +}; + +/// This struct provides a simplified model for processing types that have +/// "builtin" PDLValue support: +/// * Attribute, Operation *, Type, TypeRange, ValueRange +template +struct ProcessBuiltinPDLValue { + static LogicalResult + verifyAsArg(function_ref errorFn, + PDLValue pdlValue, size_t argIdx) { + if (pdlValue) + return success(); + return errorFn("expected a non-null value for argument " + Twine(argIdx) + + " of type: " + llvm::getTypeName()); + } + + static T processAsArg(PDLValue pdlValue) { return pdlValue.cast(); } + static void processAsResult(PatternRewriter &, PDLResultList &results, + T value) { + results.push_back(value); + } +}; + +/// This struct provides a simplified model for processing types that inherit +/// from builtin PDLValue types. For example, derived attributes like +/// IntegerAttr, derived types like IntegerType, derived operations like +/// ModuleOp, Interfaces, etc. +template +struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn { + static LogicalResult + verifyAsArg(function_ref errorFn, + BaseT baseValue, size_t argIdx) { + return TypeSwitch(baseValue) + .Case([&](T) { return success(); }) + .Default([&](BaseT) { + return errorFn("expected argument " + Twine(argIdx) + + " to be of type: " + llvm::getTypeName()); + }); + } + using ProcessPDLValueBasedOn::verifyAsArg; + + static T processAsArg(BaseT baseValue) { + return baseValue.template cast(); + } + using ProcessPDLValueBasedOn::processAsArg; + + static void processAsResult(PatternRewriter &, PDLResultList &results, + T value) { + results.push_back(value); + } +}; + +//===----------------------------------------------------------------------===// +// Attribute + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; +template +struct ProcessPDLValue::value>> + : public ProcessDerivedPDLValue {}; + +/// Handling for various Attribute value types. +template <> +struct ProcessPDLValue + : public ProcessPDLValueBasedOn { + static StringRef processAsArg(StringAttr value) { return value.getValue(); } + using ProcessPDLValueBasedOn::processAsArg; + + static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, + StringRef value) { + results.push_back(rewriter.getStringAttr(value)); + } +}; +template <> +struct ProcessPDLValue + : public ProcessPDLValueBasedOn { + template + static std::string processAsArg(T value) { + static_assert(always_false, + "`std::string` arguments require a string copy, use " + "`StringRef` for string-like arguments instead"); + return {}; + } + static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, + StringRef value) { + results.push_back(rewriter.getStringAttr(value)); + } +}; + +//===----------------------------------------------------------------------===// +// Operation + +template <> +struct ProcessPDLValue + : public ProcessBuiltinPDLValue {}; +template +struct ProcessPDLValue::value>> + : public ProcessDerivedPDLValue { + static T processAsArg(Operation *value) { return cast(value); } +}; + +//===----------------------------------------------------------------------===// +// Type + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; +template +struct ProcessPDLValue::value>> + : public ProcessDerivedPDLValue {}; + +//===----------------------------------------------------------------------===// +// TypeRange + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; +template <> +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + ValueTypeRange types) { + results.push_back(types); + } +}; +template <> +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + ValueTypeRange types) { + results.push_back(types); + } +}; +template +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + SmallVector values) { + results.push_back(TypeRange(values)); + } +}; + +//===----------------------------------------------------------------------===// +// Value + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; + +//===----------------------------------------------------------------------===// +// ValueRange + +template <> +struct ProcessPDLValue : public ProcessBuiltinPDLValue { +}; +template <> +struct ProcessPDLValue { + static void processAsResult(PatternRewriter &, PDLResultList &results, + OperandRange values) { + results.push_back(values); + } +}; +template <> +struct ProcessPDLValue { + static void processAsResult(PatternRewriter &, PDLResultList &results, + ResultRange values) { + results.push_back(values); + } +}; +template +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + SmallVector values) { + results.push_back(ValueRange(values)); + } +}; + +//===----------------------------------------------------------------------===// +// PDL Function Builder: Argument Handling +//===----------------------------------------------------------------------===// + +/// Validate the given PDLValues match the constraints defined by the argument +/// types of the given function. In the case of failure, a match failure +/// diagnostic is emitted. +/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL +/// does not currently preserve Constraint application ordering. +template +LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef values, + std::index_sequence) { + using FnTraitsT = llvm::function_traits; + + auto errorFn = [&](const Twine &msg) { + return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg); + }; + return success( + (succeeded(ProcessPDLValue>:: + verifyAsArg(errorFn, values[I], I)) && + ...)); +} + +/// Assert that the given PDLValues match the constraints defined by the +/// arguments of the given function. In the case of failure, a fatal error +/// is emitted. +template +void assertArgs(PatternRewriter &rewriter, ArrayRef values, + std::index_sequence) { + // We only want to do verification in debug builds, same as with `assert`. +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + using FnTraitsT = llvm::function_traits; + auto errorFn = [&](const Twine &msg) -> LogicalResult { + llvm::report_fatal_error(msg); + }; + (void)errorFn; + assert((succeeded(ProcessPDLValue>:: + verifyAsArg(errorFn, values[I], I)) && + ...)); +#endif + (void)values; +} + +//===----------------------------------------------------------------------===// +// PDL Function Builder: Results Handling +//===----------------------------------------------------------------------===// + +/// Store a single result within the result list. +template +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, T &&value) { + ProcessPDLValue::processAsResult(rewriter, results, + std::forward(value)); + return success(); +} + +/// Store a std::pair<> as individual results within the result list. +template +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + std::pair &&pair) { + if (failed(processResults(rewriter, results, std::move(pair.first))) || + failed(processResults(rewriter, results, std::move(pair.second)))) + return failure(); + return success(); +} + +/// Store a std::tuple<> as individual results within the result list. +template +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + std::tuple &&tuple) { + auto applyFn = [&](auto &&...args) { + return (succeeded(processResults(rewriter, results, std::move(args))) && + ...); + }; + return success(std::apply(applyFn, std::move(tuple))); +} + +/// Handle LogicalResult propagation. +inline LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + LogicalResult &&result) { + return result; +} +template +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + FailureOr &&result) { + if (failed(result)) + return failure(); + return processResults(rewriter, results, std::move(*result)); +} + +//===----------------------------------------------------------------------===// +// PDL Constraint Builder +//===----------------------------------------------------------------------===// + +/// Process the arguments of a native constraint and invoke it. +template > +typename FnTraitsT::result_t +processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, + ArrayRef values, + std::index_sequence) { + return fn( + rewriter, + (ProcessPDLValue>::processAsArg( + values[I]))...); +} + +/// Build a constraint function from the given function `ConstraintFnT`. This +/// allows for enabling the user to define simpler, more direct constraint +/// functions without needing to handle the low-level PDL goop. +/// +/// If the constraint function is already in the correct form, we just forward +/// it directly. +template +std::enable_if_t< + std::is_convertible::value, + PDLConstraintFunction> +buildConstraintFn(ConstraintFnT &&constraintFn) { + return std::forward(constraintFn); +} +/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form +/// we desire. +template +std::enable_if_t< + !std::is_convertible::value, + PDLConstraintFunction> +buildConstraintFn(ConstraintFnT &&constraintFn) { + return [constraintFn = std::forward(constraintFn)]( + PatternRewriter &rewriter, + ArrayRef values) -> LogicalResult { + auto argIndices = std::make_index_sequence< + llvm::function_traits::num_args - 1>(); + if (failed(verifyAsArgs(rewriter, values, argIndices))) + return failure(); + return processArgsAndInvokeConstraint(constraintFn, rewriter, values, + argIndices); + }; +} + +//===----------------------------------------------------------------------===// +// PDL Rewrite Builder +//===----------------------------------------------------------------------===// + +/// Process the arguments of a native rewrite and invoke it. +/// This overload handles the case of no return values. +template > +std::enable_if_t::value, + LogicalResult> +processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, + PDLResultList &, ArrayRef values, + std::index_sequence) { + fn(rewriter, + (ProcessPDLValue>::processAsArg( + values[I]))...); + return success(); +} +/// This overload handles the case of return values, which need to be packaged +/// into the result list. +template > +std::enable_if_t::value, + LogicalResult> +processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, + PDLResultList &results, ArrayRef values, + std::index_sequence) { + return processResults( + rewriter, results, + fn(rewriter, (ProcessPDLValue>:: + processAsArg(values[I]))...)); + (void)values; +} + +/// Build a rewrite function from the given function `RewriteFnT`. This +/// allows for enabling the user to define simpler, more direct rewrite +/// functions without needing to handle the low-level PDL goop. +/// +/// If the rewrite function is already in the correct form, we just forward +/// it directly. +template +std::enable_if_t::value, + PDLRewriteFunction> +buildRewriteFn(RewriteFnT &&rewriteFn) { + return std::forward(rewriteFn); +} +/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form +/// we desire. +template +std::enable_if_t::value, + PDLRewriteFunction> +buildRewriteFn(RewriteFnT &&rewriteFn) { + return [rewriteFn = std::forward(rewriteFn)]( + PatternRewriter &rewriter, PDLResultList &results, + ArrayRef values) { + auto argIndices = + std::make_index_sequence::num_args - + 1>(); + assertArgs(rewriter, values, argIndices); + return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, + argIndices); + }; +} + +} // namespace pdl_function_builder +} // namespace detail + +//===----------------------------------------------------------------------===// +// PDLPatternModule + +/// This class contains all of the necessary data for a set of PDL patterns, or +/// pattern rewrites specified in the form of the PDL dialect. This PDL module +/// contained by this pattern may contain any number of `pdl.pattern` +/// operations. +class PDLPatternModule { +public: + PDLPatternModule() = default; + + /// Construct a PDL pattern with the given module and configurations. + PDLPatternModule(OwningOpRef module) + : pdlModule(std::move(module)) {} + template + PDLPatternModule(OwningOpRef module, ConfigsT &&...patternConfigs) + : PDLPatternModule(std::move(module)) { + auto configSet = std::make_unique( + std::forward(patternConfigs)...); + attachConfigToPatterns(*pdlModule, *configSet); + configs.emplace_back(std::move(configSet)); + } + + /// Merge the state in `other` into this pattern module. + void mergeIn(PDLPatternModule &&other); + + /// Return the internal PDL module of this pattern. + ModuleOp getModule() { return pdlModule.get(); } + + /// Return the MLIR context of this pattern. + MLIRContext *getContext() { return getModule()->getContext(); } + + //===--------------------------------------------------------------------===// + // Function Registry + + /// Register a constraint function with PDL. A constraint function may be + /// specified in one of two ways: + /// + /// * `LogicalResult (PatternRewriter &, ArrayRef)` + /// + /// In this overload the arguments of the constraint function are passed via + /// the low-level PDLValue form. + /// + /// * `LogicalResult (PatternRewriter &, ValueTs... values)` + /// + /// In this form the arguments of the constraint function are passed via the + /// expected high level C++ type. In this form, the framework will + /// automatically unwrap PDLValues and convert them to the expected ValueTs. + /// For example, if the constraint function accepts a `Operation *`, the + /// framework will automatically cast the input PDLValue. In the case of a + /// `StringRef`, the framework will automatically unwrap the argument as a + /// StringAttr and pass the underlying string value. To see the full list of + /// supported types, or to see how to add handling for custom types, view + /// the definition of `ProcessPDLValue` above. + void registerConstraintFunction(StringRef name, + PDLConstraintFunction constraintFn); + template + void registerConstraintFunction(StringRef name, + ConstraintFnT &&constraintFn) { + registerConstraintFunction(name, + detail::pdl_function_builder::buildConstraintFn( + std::forward(constraintFn))); + } + + /// Register a rewrite function with PDL. A rewrite function may be specified + /// in one of two ways: + /// + /// * `void (PatternRewriter &, PDLResultList &, ArrayRef)` + /// + /// In this overload the arguments of the constraint function are passed via + /// the low-level PDLValue form, and the results are manually appended to + /// the given result list. + /// + /// * `ResultT (PatternRewriter &, ValueTs... values)` + /// + /// In this form the arguments and result of the rewrite function are passed + /// via the expected high level C++ type. In this form, the framework will + /// automatically unwrap the PDLValues arguments and convert them to the + /// expected ValueTs. It will also automatically handle the processing and + /// packaging of the result value to the result list. For example, if the + /// rewrite function takes a `Operation *`, the framework will automatically + /// cast the input PDLValue. In the case of a `StringRef`, the framework + /// will automatically unwrap the argument as a StringAttr and pass the + /// underlying string value. In the reverse case, if the rewrite returns a + /// StringRef or std::string, it will automatically package this as a + /// StringAttr and append it to the result list. To see the full list of + /// supported types, or to see how to add handling for custom types, view + /// the definition of `ProcessPDLValue` above. + void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); + template + void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) { + registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( + std::forward(rewriteFn))); + } + + /// Return the set of the registered constraint functions. + const llvm::StringMap &getConstraintFunctions() const { + return constraintFunctions; + } + llvm::StringMap takeConstraintFunctions() { + return constraintFunctions; + } + /// Return the set of the registered rewrite functions. + const llvm::StringMap &getRewriteFunctions() const { + return rewriteFunctions; + } + llvm::StringMap takeRewriteFunctions() { + return rewriteFunctions; + } + + /// Return the set of the registered pattern configs. + SmallVector> takeConfigs() { + return std::move(configs); + } + DenseMap takeConfigMap() { + return std::move(configMap); + } + + /// Clear out the patterns and functions within this module. + void clear() { + pdlModule = nullptr; + constraintFunctions.clear(); + rewriteFunctions.clear(); + } + +private: + /// Attach the given pattern config set to the patterns defined within the + /// given module. + void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet); + + /// The module containing the `pdl.pattern` operations. + OwningOpRef pdlModule; + + /// The set of configuration sets referenced by patterns within `pdlModule`. + SmallVector> configs; + DenseMap configMap; + + /// The external functions referenced from within the PDL module. + llvm::StringMap constraintFunctions; + llvm::StringMap rewriteFunctions; +}; +} // namespace mlir + +#else + +namespace mlir { +// Stubs for when PDL in pattern rewrites is not enabled. + +class PDLValue { +public: + template + T dyn_cast() const { + return nullptr; + } +}; +class PDLResultList {}; +using PDLConstraintFunction = + std::function)>; +using PDLRewriteFunction = std::function)>; + +class PDLPatternModule { +public: + PDLPatternModule() = default; + + PDLPatternModule(OwningOpRef /*module*/) {} + MLIRContext *getContext() { + llvm_unreachable("Error: PDL for rewrites when PDL is not enabled"); + } + void mergeIn(PDLPatternModule &&other) {} + void clear() {} + template + void registerConstraintFunction(StringRef name, + ConstraintFnT &&constraintFn) {} + void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {} + template + void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {} + const llvm::StringMap &getConstraintFunctions() const { + return constraintFunctions; + } + +private: + llvm::StringMap constraintFunctions; +}; + +} // namespace mlir +#endif + +#endif // MLIR_IR_PDLPATTERNMATCH_H diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 6625ef553eba2..9b4fa65bff49e 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -735,932 +735,12 @@ class PatternRewriter : public RewriterBase { virtual bool canRecoverFromRewriteFailure() const { return false; } }; -//===----------------------------------------------------------------------===// -// PDL Patterns -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// PDLValue - -/// Storage type of byte-code interpreter values. These are passed to constraint -/// functions as arguments. -class PDLValue { -public: - /// The underlying kind of a PDL value. - enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange }; - - /// Construct a new PDL value. - PDLValue(const PDLValue &other) = default; - PDLValue(std::nullptr_t = nullptr) {} - PDLValue(Attribute value) - : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {} - PDLValue(Operation *value) : value(value), kind(Kind::Operation) {} - PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {} - PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {} - PDLValue(Value value) - : value(value.getAsOpaquePointer()), kind(Kind::Value) {} - PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {} - - /// Returns true if the type of the held value is `T`. - template - bool isa() const { - assert(value && "isa<> used on a null value"); - return kind == getKindOf(); - } - - /// Attempt to dynamically cast this value to type `T`, returns null if this - /// value is not an instance of `T`. - template ::value, T, std::optional>> - ResultT dyn_cast() const { - return isa() ? castImpl() : ResultT(); - } - - /// Cast this value to type `T`, asserts if this value is not an instance of - /// `T`. - template - T cast() const { - assert(isa() && "expected value to be of type `T`"); - return castImpl(); - } - - /// Get an opaque pointer to the value. - const void *getAsOpaquePointer() const { return value; } - - /// Return if this value is null or not. - explicit operator bool() const { return value; } - - /// Return the kind of this value. - Kind getKind() const { return kind; } - - /// Print this value to the provided output stream. - void print(raw_ostream &os) const; - - /// Print the specified value kind to an output stream. - static void print(raw_ostream &os, Kind kind); - -private: - /// Find the index of a given type in a range of other types. - template - struct index_of_t; - template - struct index_of_t : std::integral_constant {}; - template - struct index_of_t - : std::integral_constant::value> {}; - - /// Return the kind used for the given T. - template - static Kind getKindOf() { - return static_cast(index_of_t::value); - } - - /// The internal implementation of `cast`, that returns the underlying value - /// as the given type `T`. - template - std::enable_if_t::value, T> - castImpl() const { - return T::getFromOpaquePointer(value); - } - template - std::enable_if_t::value, T> - castImpl() const { - return *reinterpret_cast(const_cast(value)); - } - template - std::enable_if_t::value, T> castImpl() const { - return reinterpret_cast(const_cast(value)); - } - - /// The internal opaque representation of a PDLValue. - const void *value{nullptr}; - /// The kind of the opaque value. - Kind kind{Kind::Attribute}; -}; - -inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { - value.print(os); - return os; -} - -inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { - PDLValue::print(os, kind); - return os; -} - -//===----------------------------------------------------------------------===// -// PDLResultList - -/// The class represents a list of PDL results, returned by a native rewrite -/// method. It provides the mechanism with which to pass PDLValues back to the -/// PDL bytecode. -class PDLResultList { -public: - /// Push a new Attribute value onto the result list. - void push_back(Attribute value) { results.push_back(value); } - - /// Push a new Operation onto the result list. - void push_back(Operation *value) { results.push_back(value); } - - /// Push a new Type onto the result list. - void push_back(Type value) { results.push_back(value); } - - /// Push a new TypeRange onto the result list. - void push_back(TypeRange value) { - // The lifetime of a TypeRange can't be guaranteed, so we'll need to - // allocate a storage for it. - llvm::OwningArrayRef storage(value.size()); - llvm::copy(value, storage.begin()); - allocatedTypeRanges.emplace_back(std::move(storage)); - typeRanges.push_back(allocatedTypeRanges.back()); - results.push_back(&typeRanges.back()); - } - void push_back(ValueTypeRange value) { - typeRanges.push_back(value); - results.push_back(&typeRanges.back()); - } - void push_back(ValueTypeRange value) { - typeRanges.push_back(value); - results.push_back(&typeRanges.back()); - } - - /// Push a new Value onto the result list. - void push_back(Value value) { results.push_back(value); } - - /// Push a new ValueRange onto the result list. - void push_back(ValueRange value) { - // The lifetime of a ValueRange can't be guaranteed, so we'll need to - // allocate a storage for it. - llvm::OwningArrayRef storage(value.size()); - llvm::copy(value, storage.begin()); - allocatedValueRanges.emplace_back(std::move(storage)); - valueRanges.push_back(allocatedValueRanges.back()); - results.push_back(&valueRanges.back()); - } - void push_back(OperandRange value) { - valueRanges.push_back(value); - results.push_back(&valueRanges.back()); - } - void push_back(ResultRange value) { - valueRanges.push_back(value); - results.push_back(&valueRanges.back()); - } - -protected: - /// Create a new result list with the expected number of results. - PDLResultList(unsigned maxNumResults) { - // For now just reserve enough space for all of the results. We could do - // separate counts per range type, but it isn't really worth it unless there - // are a "large" number of results. - typeRanges.reserve(maxNumResults); - valueRanges.reserve(maxNumResults); - } - - /// The PDL results held by this list. - SmallVector results; - /// Memory used to store ranges held by the list. - SmallVector typeRanges; - SmallVector valueRanges; - /// Memory allocated to store ranges in the result list whose lifetime was - /// generated in the native function. - SmallVector> allocatedTypeRanges; - SmallVector> allocatedValueRanges; -}; - -//===----------------------------------------------------------------------===// -// PDLPatternConfig - -/// An individual configuration for a pattern, which can be accessed by native -/// functions via the PDLPatternConfigSet. This allows for injecting additional -/// configuration into PDL patterns that is specific to certain compilation -/// flows. -class PDLPatternConfig { -public: - virtual ~PDLPatternConfig() = default; - - /// Hooks that are invoked at the beginning and end of a rewrite of a matched - /// pattern. These can be used to setup any specific state necessary for the - /// rewrite. - virtual void notifyRewriteBegin(PatternRewriter &rewriter) {} - virtual void notifyRewriteEnd(PatternRewriter &rewriter) {} - - /// Return the TypeID that represents this configuration. - TypeID getTypeID() const { return id; } - -protected: - PDLPatternConfig(TypeID id) : id(id) {} - -private: - TypeID id; -}; - -/// This class provides a base class for users implementing a type of pattern -/// configuration. -template -class PDLPatternConfigBase : public PDLPatternConfig { -public: - /// Support LLVM style casting. - static bool classof(const PDLPatternConfig *config) { - return config->getTypeID() == getConfigID(); - } - - /// Return the type id used for this configuration. - static TypeID getConfigID() { return TypeID::get(); } - -protected: - PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {} -}; - -/// This class contains a set of configurations for a specific pattern. -/// Configurations are uniqued by TypeID, meaning that only one configuration of -/// each type is allowed. -class PDLPatternConfigSet { -public: - PDLPatternConfigSet() = default; - - /// Construct a set with the given configurations. - template - PDLPatternConfigSet(ConfigsT &&...configs) { - (addConfig(std::forward(configs)), ...); - } - - /// Get the configuration defined by the given type. Asserts that the - /// configuration of the provided type exists. - template - const T &get() const { - const T *config = tryGet(); - assert(config && "configuration not found"); - return *config; - } - - /// Get the configuration defined by the given type, returns nullptr if the - /// configuration does not exist. - template - const T *tryGet() const { - for (const auto &configIt : configs) - if (const T *config = dyn_cast(configIt.get())) - return config; - return nullptr; - } - - /// Notify the configurations within this set at the beginning or end of a - /// rewrite of a matched pattern. - void notifyRewriteBegin(PatternRewriter &rewriter) { - for (const auto &config : configs) - config->notifyRewriteBegin(rewriter); - } - void notifyRewriteEnd(PatternRewriter &rewriter) { - for (const auto &config : configs) - config->notifyRewriteEnd(rewriter); - } - -protected: - /// Add a configuration to the set. - template - void addConfig(T &&config) { - assert(!tryGet>() && "configuration already exists"); - configs.emplace_back( - std::make_unique>(std::forward(config))); - } - - /// The set of configurations for this pattern. This uses a vector instead of - /// a map with the expectation that the number of configurations per set is - /// small (<= 1). - SmallVector> configs; -}; - -//===----------------------------------------------------------------------===// -// PDLPatternModule - -/// A generic PDL pattern constraint function. This function applies a -/// constraint to a given set of opaque PDLValue entities. Returns success if -/// the constraint successfully held, failure otherwise. -using PDLConstraintFunction = - std::function)>; -/// A native PDL rewrite function. This function performs a rewrite on the -/// given set of values. Any results from this rewrite that should be passed -/// back to PDL should be added to the provided result list. This method is only -/// invoked when the corresponding match was successful. Returns failure if an -/// invariant of the rewrite was broken (certain rewriters may recover from -/// partial pattern application). -using PDLRewriteFunction = std::function)>; - -namespace detail { -namespace pdl_function_builder { -/// A utility variable that always resolves to false. This is useful for static -/// asserts that are always false, but only should fire in certain templated -/// constructs. For example, if a templated function should never be called, the -/// function could be defined as: -/// -/// template -/// void foo() { -/// static_assert(always_false, "This function should never be called"); -/// } -/// -template -constexpr bool always_false = false; - -//===----------------------------------------------------------------------===// -// PDL Function Builder: Type Processing -//===----------------------------------------------------------------------===// - -/// This struct provides a convenient way to determine how to process a given -/// type as either a PDL parameter, or a result value. This allows for -/// supporting complex types in constraint and rewrite functions, without -/// requiring the user to hand-write the necessary glue code themselves. -/// Specializations of this class should implement the following methods to -/// enable support as a PDL argument or result type: -/// -/// static LogicalResult verifyAsArg( -/// function_ref errorFn, PDLValue pdlValue, -/// size_t argIdx); -/// -/// * This method verifies that the given PDLValue is valid for use as a -/// value of `T`. -/// -/// static T processAsArg(PDLValue pdlValue); -/// -/// * This method processes the given PDLValue as a value of `T`. -/// -/// static void processAsResult(PatternRewriter &, PDLResultList &results, -/// const T &value); -/// -/// * This method processes the given value of `T` as the result of a -/// function invocation. The method should package the value into an -/// appropriate form and append it to the given result list. -/// -/// If the type `T` is based on a higher order value, consider using -/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify -/// the implementation. -/// -template -struct ProcessPDLValue; - -/// This struct provides a simplified model for processing types that are based -/// on another type, e.g. APInt is based on the handling for IntegerAttr. This -/// allows for building the necessary processing functions on top of the base -/// value instead of a PDLValue. Derived users should implement the following -/// (which subsume the ProcessPDLValue variants): -/// -/// static LogicalResult verifyAsArg( -/// function_ref errorFn, -/// const BaseT &baseValue, size_t argIdx); -/// -/// * This method verifies that the given PDLValue is valid for use as a -/// value of `T`. -/// -/// static T processAsArg(BaseT baseValue); -/// -/// * This method processes the given base value as a value of `T`. -/// -template -struct ProcessPDLValueBasedOn { - static LogicalResult - verifyAsArg(function_ref errorFn, - PDLValue pdlValue, size_t argIdx) { - // Verify the base class before continuing. - if (failed(ProcessPDLValue::verifyAsArg(errorFn, pdlValue, argIdx))) - return failure(); - return ProcessPDLValue::verifyAsArg( - errorFn, ProcessPDLValue::processAsArg(pdlValue), argIdx); - } - static T processAsArg(PDLValue pdlValue) { - return ProcessPDLValue::processAsArg( - ProcessPDLValue::processAsArg(pdlValue)); - } - - /// Explicitly add the expected parent API to ensure the parent class - /// implements the necessary API (and doesn't implicitly inherit it from - /// somewhere else). - static LogicalResult - verifyAsArg(function_ref errorFn, BaseT value, - size_t argIdx) { - return success(); - } - static T processAsArg(BaseT baseValue); -}; - -/// This struct provides a simplified model for processing types that have -/// "builtin" PDLValue support: -/// * Attribute, Operation *, Type, TypeRange, ValueRange -template -struct ProcessBuiltinPDLValue { - static LogicalResult - verifyAsArg(function_ref errorFn, - PDLValue pdlValue, size_t argIdx) { - if (pdlValue) - return success(); - return errorFn("expected a non-null value for argument " + Twine(argIdx) + - " of type: " + llvm::getTypeName()); - } - - static T processAsArg(PDLValue pdlValue) { return pdlValue.cast(); } - static void processAsResult(PatternRewriter &, PDLResultList &results, - T value) { - results.push_back(value); - } -}; - -/// This struct provides a simplified model for processing types that inherit -/// from builtin PDLValue types. For example, derived attributes like -/// IntegerAttr, derived types like IntegerType, derived operations like -/// ModuleOp, Interfaces, etc. -template -struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn { - static LogicalResult - verifyAsArg(function_ref errorFn, - BaseT baseValue, size_t argIdx) { - return TypeSwitch(baseValue) - .Case([&](T) { return success(); }) - .Default([&](BaseT) { - return errorFn("expected argument " + Twine(argIdx) + - " to be of type: " + llvm::getTypeName()); - }); - } - using ProcessPDLValueBasedOn::verifyAsArg; - - static T processAsArg(BaseT baseValue) { - return baseValue.template cast(); - } - using ProcessPDLValueBasedOn::processAsArg; - - static void processAsResult(PatternRewriter &, PDLResultList &results, - T value) { - results.push_back(value); - } -}; - -//===----------------------------------------------------------------------===// -// Attribute - -template <> -struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; -template -struct ProcessPDLValue::value>> - : public ProcessDerivedPDLValue {}; - -/// Handling for various Attribute value types. -template <> -struct ProcessPDLValue - : public ProcessPDLValueBasedOn { - static StringRef processAsArg(StringAttr value) { return value.getValue(); } - using ProcessPDLValueBasedOn::processAsArg; - - static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, - StringRef value) { - results.push_back(rewriter.getStringAttr(value)); - } -}; -template <> -struct ProcessPDLValue - : public ProcessPDLValueBasedOn { - template - static std::string processAsArg(T value) { - static_assert(always_false, - "`std::string` arguments require a string copy, use " - "`StringRef` for string-like arguments instead"); - return {}; - } - static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, - StringRef value) { - results.push_back(rewriter.getStringAttr(value)); - } -}; - -//===----------------------------------------------------------------------===// -// Operation - -template <> -struct ProcessPDLValue - : public ProcessBuiltinPDLValue {}; -template -struct ProcessPDLValue::value>> - : public ProcessDerivedPDLValue { - static T processAsArg(Operation *value) { return cast(value); } -}; - -//===----------------------------------------------------------------------===// -// Type - -template <> -struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; -template -struct ProcessPDLValue::value>> - : public ProcessDerivedPDLValue {}; - -//===----------------------------------------------------------------------===// -// TypeRange - -template <> -struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; -template <> -struct ProcessPDLValue> { - static void processAsResult(PatternRewriter &, PDLResultList &results, - ValueTypeRange types) { - results.push_back(types); - } -}; -template <> -struct ProcessPDLValue> { - static void processAsResult(PatternRewriter &, PDLResultList &results, - ValueTypeRange types) { - results.push_back(types); - } -}; -template -struct ProcessPDLValue> { - static void processAsResult(PatternRewriter &, PDLResultList &results, - SmallVector values) { - results.push_back(TypeRange(values)); - } -}; - -//===----------------------------------------------------------------------===// -// Value - -template <> -struct ProcessPDLValue : public ProcessBuiltinPDLValue {}; - -//===----------------------------------------------------------------------===// -// ValueRange - -template <> -struct ProcessPDLValue : public ProcessBuiltinPDLValue { -}; -template <> -struct ProcessPDLValue { - static void processAsResult(PatternRewriter &, PDLResultList &results, - OperandRange values) { - results.push_back(values); - } -}; -template <> -struct ProcessPDLValue { - static void processAsResult(PatternRewriter &, PDLResultList &results, - ResultRange values) { - results.push_back(values); - } -}; -template -struct ProcessPDLValue> { - static void processAsResult(PatternRewriter &, PDLResultList &results, - SmallVector values) { - results.push_back(ValueRange(values)); - } -}; - -//===----------------------------------------------------------------------===// -// PDL Function Builder: Argument Handling -//===----------------------------------------------------------------------===// - -/// Validate the given PDLValues match the constraints defined by the argument -/// types of the given function. In the case of failure, a match failure -/// diagnostic is emitted. -/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL -/// does not currently preserve Constraint application ordering. -template -LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef values, - std::index_sequence) { - using FnTraitsT = llvm::function_traits; - - auto errorFn = [&](const Twine &msg) { - return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg); - }; - return success( - (succeeded(ProcessPDLValue>:: - verifyAsArg(errorFn, values[I], I)) && - ...)); -} - -/// Assert that the given PDLValues match the constraints defined by the -/// arguments of the given function. In the case of failure, a fatal error -/// is emitted. -template -void assertArgs(PatternRewriter &rewriter, ArrayRef values, - std::index_sequence) { - // We only want to do verification in debug builds, same as with `assert`. -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - using FnTraitsT = llvm::function_traits; - auto errorFn = [&](const Twine &msg) -> LogicalResult { - llvm::report_fatal_error(msg); - }; - (void)errorFn; - assert((succeeded(ProcessPDLValue>:: - verifyAsArg(errorFn, values[I], I)) && - ...)); -#endif - (void)values; -} - -//===----------------------------------------------------------------------===// -// PDL Function Builder: Results Handling -//===----------------------------------------------------------------------===// - -/// Store a single result within the result list. -template -static LogicalResult processResults(PatternRewriter &rewriter, - PDLResultList &results, T &&value) { - ProcessPDLValue::processAsResult(rewriter, results, - std::forward(value)); - return success(); -} - -/// Store a std::pair<> as individual results within the result list. -template -static LogicalResult processResults(PatternRewriter &rewriter, - PDLResultList &results, - std::pair &&pair) { - if (failed(processResults(rewriter, results, std::move(pair.first))) || - failed(processResults(rewriter, results, std::move(pair.second)))) - return failure(); - return success(); -} - -/// Store a std::tuple<> as individual results within the result list. -template -static LogicalResult processResults(PatternRewriter &rewriter, - PDLResultList &results, - std::tuple &&tuple) { - auto applyFn = [&](auto &&...args) { - return (succeeded(processResults(rewriter, results, std::move(args))) && - ...); - }; - return success(std::apply(applyFn, std::move(tuple))); -} - -/// Handle LogicalResult propagation. -inline LogicalResult processResults(PatternRewriter &rewriter, - PDLResultList &results, - LogicalResult &&result) { - return result; -} -template -static LogicalResult processResults(PatternRewriter &rewriter, - PDLResultList &results, - FailureOr &&result) { - if (failed(result)) - return failure(); - return processResults(rewriter, results, std::move(*result)); -} - -//===----------------------------------------------------------------------===// -// PDL Constraint Builder -//===----------------------------------------------------------------------===// - -/// Process the arguments of a native constraint and invoke it. -template > -typename FnTraitsT::result_t -processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, - ArrayRef values, - std::index_sequence) { - return fn( - rewriter, - (ProcessPDLValue>::processAsArg( - values[I]))...); -} - -/// Build a constraint function from the given function `ConstraintFnT`. This -/// allows for enabling the user to define simpler, more direct constraint -/// functions without needing to handle the low-level PDL goop. -/// -/// If the constraint function is already in the correct form, we just forward -/// it directly. -template -std::enable_if_t< - std::is_convertible::value, - PDLConstraintFunction> -buildConstraintFn(ConstraintFnT &&constraintFn) { - return std::forward(constraintFn); -} -/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form -/// we desire. -template -std::enable_if_t< - !std::is_convertible::value, - PDLConstraintFunction> -buildConstraintFn(ConstraintFnT &&constraintFn) { - return [constraintFn = std::forward(constraintFn)]( - PatternRewriter &rewriter, - ArrayRef values) -> LogicalResult { - auto argIndices = std::make_index_sequence< - llvm::function_traits::num_args - 1>(); - if (failed(verifyAsArgs(rewriter, values, argIndices))) - return failure(); - return processArgsAndInvokeConstraint(constraintFn, rewriter, values, - argIndices); - }; -} - -//===----------------------------------------------------------------------===// -// PDL Rewrite Builder -//===----------------------------------------------------------------------===// - -/// Process the arguments of a native rewrite and invoke it. -/// This overload handles the case of no return values. -template > -std::enable_if_t::value, - LogicalResult> -processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, - PDLResultList &, ArrayRef values, - std::index_sequence) { - fn(rewriter, - (ProcessPDLValue>::processAsArg( - values[I]))...); - return success(); -} -/// This overload handles the case of return values, which need to be packaged -/// into the result list. -template > -std::enable_if_t::value, - LogicalResult> -processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, - PDLResultList &results, ArrayRef values, - std::index_sequence) { - return processResults( - rewriter, results, - fn(rewriter, (ProcessPDLValue>:: - processAsArg(values[I]))...)); - (void)values; -} - -/// Build a rewrite function from the given function `RewriteFnT`. This -/// allows for enabling the user to define simpler, more direct rewrite -/// functions without needing to handle the low-level PDL goop. -/// -/// If the rewrite function is already in the correct form, we just forward -/// it directly. -template -std::enable_if_t::value, - PDLRewriteFunction> -buildRewriteFn(RewriteFnT &&rewriteFn) { - return std::forward(rewriteFn); -} -/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form -/// we desire. -template -std::enable_if_t::value, - PDLRewriteFunction> -buildRewriteFn(RewriteFnT &&rewriteFn) { - return [rewriteFn = std::forward(rewriteFn)]( - PatternRewriter &rewriter, PDLResultList &results, - ArrayRef values) { - auto argIndices = - std::make_index_sequence::num_args - - 1>(); - assertArgs(rewriter, values, argIndices); - return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, - argIndices); - }; -} - -} // namespace pdl_function_builder -} // namespace detail - -//===----------------------------------------------------------------------===// -// PDLPatternModule - -/// This class contains all of the necessary data for a set of PDL patterns, or -/// pattern rewrites specified in the form of the PDL dialect. This PDL module -/// contained by this pattern may contain any number of `pdl.pattern` -/// operations. -class PDLPatternModule { -public: - PDLPatternModule() = default; - - /// Construct a PDL pattern with the given module and configurations. - PDLPatternModule(OwningOpRef module) - : pdlModule(std::move(module)) {} - template - PDLPatternModule(OwningOpRef module, ConfigsT &&...patternConfigs) - : PDLPatternModule(std::move(module)) { - auto configSet = std::make_unique( - std::forward(patternConfigs)...); - attachConfigToPatterns(*pdlModule, *configSet); - configs.emplace_back(std::move(configSet)); - } - - /// Merge the state in `other` into this pattern module. - void mergeIn(PDLPatternModule &&other); - - /// Return the internal PDL module of this pattern. - ModuleOp getModule() { return pdlModule.get(); } - - //===--------------------------------------------------------------------===// - // Function Registry - - /// Register a constraint function with PDL. A constraint function may be - /// specified in one of two ways: - /// - /// * `LogicalResult (PatternRewriter &, ArrayRef)` - /// - /// In this overload the arguments of the constraint function are passed via - /// the low-level PDLValue form. - /// - /// * `LogicalResult (PatternRewriter &, ValueTs... values)` - /// - /// In this form the arguments of the constraint function are passed via the - /// expected high level C++ type. In this form, the framework will - /// automatically unwrap PDLValues and convert them to the expected ValueTs. - /// For example, if the constraint function accepts a `Operation *`, the - /// framework will automatically cast the input PDLValue. In the case of a - /// `StringRef`, the framework will automatically unwrap the argument as a - /// StringAttr and pass the underlying string value. To see the full list of - /// supported types, or to see how to add handling for custom types, view - /// the definition of `ProcessPDLValue` above. - void registerConstraintFunction(StringRef name, - PDLConstraintFunction constraintFn); - template - void registerConstraintFunction(StringRef name, - ConstraintFnT &&constraintFn) { - registerConstraintFunction(name, - detail::pdl_function_builder::buildConstraintFn( - std::forward(constraintFn))); - } - - /// Register a rewrite function with PDL. A rewrite function may be specified - /// in one of two ways: - /// - /// * `void (PatternRewriter &, PDLResultList &, ArrayRef)` - /// - /// In this overload the arguments of the constraint function are passed via - /// the low-level PDLValue form, and the results are manually appended to - /// the given result list. - /// - /// * `ResultT (PatternRewriter &, ValueTs... values)` - /// - /// In this form the arguments and result of the rewrite function are passed - /// via the expected high level C++ type. In this form, the framework will - /// automatically unwrap the PDLValues arguments and convert them to the - /// expected ValueTs. It will also automatically handle the processing and - /// packaging of the result value to the result list. For example, if the - /// rewrite function takes a `Operation *`, the framework will automatically - /// cast the input PDLValue. In the case of a `StringRef`, the framework - /// will automatically unwrap the argument as a StringAttr and pass the - /// underlying string value. In the reverse case, if the rewrite returns a - /// StringRef or std::string, it will automatically package this as a - /// StringAttr and append it to the result list. To see the full list of - /// supported types, or to see how to add handling for custom types, view - /// the definition of `ProcessPDLValue` above. - void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); - template - void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) { - registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( - std::forward(rewriteFn))); - } - - /// Return the set of the registered constraint functions. - const llvm::StringMap &getConstraintFunctions() const { - return constraintFunctions; - } - llvm::StringMap takeConstraintFunctions() { - return constraintFunctions; - } - /// Return the set of the registered rewrite functions. - const llvm::StringMap &getRewriteFunctions() const { - return rewriteFunctions; - } - llvm::StringMap takeRewriteFunctions() { - return rewriteFunctions; - } - - /// Return the set of the registered pattern configs. - SmallVector> takeConfigs() { - return std::move(configs); - } - DenseMap takeConfigMap() { - return std::move(configMap); - } - - /// Clear out the patterns and functions within this module. - void clear() { - pdlModule = nullptr; - constraintFunctions.clear(); - rewriteFunctions.clear(); - } - -private: - /// Attach the given pattern config set to the patterns defined within the - /// given module. - void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet); - - /// The module containing the `pdl.pattern` operations. - OwningOpRef pdlModule; +} // namespace mlir - /// The set of configuration sets referenced by patterns within `pdlModule`. - SmallVector> configs; - DenseMap configMap; +// Optionally expose PDL pattern matching methods. +#include "PDLPatternMatch.h.inc" - /// The external functions referenced from within the PDL module. - llvm::StringMap constraintFunctions; - llvm::StringMap rewriteFunctions; -}; +namespace mlir { //===----------------------------------------------------------------------===// // RewritePatternSet @@ -1679,8 +759,7 @@ class RewritePatternSet { nativePatterns.emplace_back(std::move(pattern)); } RewritePatternSet(PDLPatternModule &&pattern) - : context(pattern.getModule()->getContext()), - pdlPatterns(std::move(pattern)) {} + : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {} MLIRContext *getContext() const { return context; } @@ -1853,6 +932,7 @@ class RewritePatternSet { pattern->addDebugLabels(debugLabels); nativePatterns.emplace_back(std::move(pattern)); } + template std::enable_if_t::value> addImpl(ArrayRef debugLabels, Args &&...args) { @@ -1863,6 +943,9 @@ class RewritePatternSet { MLIRContext *const context; NativePatternListT nativePatterns; + + // Patterns expressed with PDL. This will compile to a stub class when PDL is + // not enabled. PDLPatternModule pdlPatterns; }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 6de981d35c8c3..c5725e9c85625 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -13,6 +13,7 @@ #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_ #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ +#include "mlir/Config/mlir-config.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringMap.h" @@ -1015,6 +1016,7 @@ class ConversionTarget { MLIRContext &ctx; }; +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH //===----------------------------------------------------------------------===// // PDL Configuration //===----------------------------------------------------------------------===// @@ -1044,6 +1046,19 @@ class PDLConversionConfig final /// Register the dialect conversion PDL functions with the given pattern set. void registerConversionPDLFunctions(RewritePatternSet &patterns); +#else + +// Stubs for when PDL in rewriting is not enabled. + +inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {} + +class PDLConversionConfig final { +public: + PDLConversionConfig(const TypeConverter * /*converter*/) {} +}; + +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt index 7f7b348b17ae6..be5eb73b91229 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt @@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRBufferizationTransformOps MLIRFunctionInterfaces MLIRLinalgDialect MLIRParser - MLIRPDLDialect MLIRSideEffectInterfaces MLIRTransformDialect ) diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index 21bba11b85117..a155b7c5ecade 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -61,3 +61,10 @@ add_mlir_library(MLIRIR LINK_LIBS PUBLIC MLIRSupport ) + +if(MLIR_ENABLE_PDL_IN_PATTERNMATCH) + add_subdirectory(PDL) + target_link_libraries(MLIRIR PUBLIC + MLIRIRPDLPatternMatch) +endif() + diff --git a/mlir/lib/IR/PDL/CMakeLists.txt b/mlir/lib/IR/PDL/CMakeLists.txt new file mode 100644 index 0000000000000..08b7fe36fac09 --- /dev/null +++ b/mlir/lib/IR/PDL/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_library(MLIRIRPDLPatternMatch + PDLPatternMatch.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR +) + diff --git a/mlir/lib/IR/PDL/PDLPatternMatch.cpp b/mlir/lib/IR/PDL/PDLPatternMatch.cpp new file mode 100644 index 0000000000000..da07cc462a5a1 --- /dev/null +++ b/mlir/lib/IR/PDL/PDLPatternMatch.cpp @@ -0,0 +1,133 @@ +//===- PDLPatternMatch.cpp - Base classes for PDL pattern match +//------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Iterators.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// PDLValue +//===----------------------------------------------------------------------===// + +void PDLValue::print(raw_ostream &os) const { + if (!value) { + os << ""; + return; + } + switch (kind) { + case Kind::Attribute: + os << cast(); + break; + case Kind::Operation: + os << *cast(); + break; + case Kind::Type: + os << cast(); + break; + case Kind::TypeRange: + llvm::interleaveComma(cast(), os); + break; + case Kind::Value: + os << cast(); + break; + case Kind::ValueRange: + llvm::interleaveComma(cast(), os); + break; + } +} + +void PDLValue::print(raw_ostream &os, Kind kind) { + switch (kind) { + case Kind::Attribute: + os << "Attribute"; + break; + case Kind::Operation: + os << "Operation"; + break; + case Kind::Type: + os << "Type"; + break; + case Kind::TypeRange: + os << "TypeRange"; + break; + case Kind::Value: + os << "Value"; + break; + case Kind::ValueRange: + os << "ValueRange"; + break; + } +} + +//===----------------------------------------------------------------------===// +// PDLPatternModule +//===----------------------------------------------------------------------===// + +void PDLPatternModule::mergeIn(PDLPatternModule &&other) { + // Ignore the other module if it has no patterns. + if (!other.pdlModule) + return; + + // Steal the functions and config of the other module. + for (auto &it : other.constraintFunctions) + registerConstraintFunction(it.first(), std::move(it.second)); + for (auto &it : other.rewriteFunctions) + registerRewriteFunction(it.first(), std::move(it.second)); + for (auto &it : other.configs) + configs.emplace_back(std::move(it)); + for (auto &it : other.configMap) + configMap.insert(it); + + // Steal the other state if we have no patterns. + if (!pdlModule) { + pdlModule = std::move(other.pdlModule); + return; + } + + // Merge the pattern operations from the other module into this one. + Block *block = pdlModule->getBody(); + block->getOperations().splice(block->end(), + other.pdlModule->getBody()->getOperations()); +} + +void PDLPatternModule::attachConfigToPatterns(ModuleOp module, + PDLPatternConfigSet &configSet) { + // Attach the configuration to the symbols within the module. We only add + // to symbols to avoid hardcoding any specific operation names here (given + // that we don't depend on any PDL dialect). We can't use + // cast here because patterns may be optional symbols. + module->walk([&](Operation *op) { + if (op->hasTrait()) + configMap[op] = &configSet; + }); +} + +//===----------------------------------------------------------------------===// +// Function Registry + +void PDLPatternModule::registerConstraintFunction( + StringRef name, PDLConstraintFunction constraintFn) { + // TODO: Is it possible to diagnose when `name` is already registered to + // a function that is not equivalent to `constraintFn`? + // Allow existing mappings in the case multiple patterns depend on the same + // constraint. + constraintFunctions.try_emplace(name, std::move(constraintFn)); +} + +void PDLPatternModule::registerRewriteFunction(StringRef name, + PDLRewriteFunction rewriteFn) { + // TODO: Is it possible to diagnose when `name` is already registered to + // a function that is not equivalent to `rewriteFn`? + // Allow existing mappings in the case multiple patterns depend on the same + // rewrite. + rewriteFunctions.try_emplace(name, std::move(rewriteFn)); +} diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5e9b9b2a810a4..5e788cdb4897d 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/PatternMatch.h" +#include "mlir/Config/mlir-config.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" #include "mlir/IR/RegionKindInterface.h" @@ -97,124 +98,6 @@ LogicalResult RewritePattern::match(Operation *op) const { /// Out-of-line vtable anchor. void RewritePattern::anchor() {} -//===----------------------------------------------------------------------===// -// PDLValue -//===----------------------------------------------------------------------===// - -void PDLValue::print(raw_ostream &os) const { - if (!value) { - os << ""; - return; - } - switch (kind) { - case Kind::Attribute: - os << cast(); - break; - case Kind::Operation: - os << *cast(); - break; - case Kind::Type: - os << cast(); - break; - case Kind::TypeRange: - llvm::interleaveComma(cast(), os); - break; - case Kind::Value: - os << cast(); - break; - case Kind::ValueRange: - llvm::interleaveComma(cast(), os); - break; - } -} - -void PDLValue::print(raw_ostream &os, Kind kind) { - switch (kind) { - case Kind::Attribute: - os << "Attribute"; - break; - case Kind::Operation: - os << "Operation"; - break; - case Kind::Type: - os << "Type"; - break; - case Kind::TypeRange: - os << "TypeRange"; - break; - case Kind::Value: - os << "Value"; - break; - case Kind::ValueRange: - os << "ValueRange"; - break; - } -} - -//===----------------------------------------------------------------------===// -// PDLPatternModule -//===----------------------------------------------------------------------===// - -void PDLPatternModule::mergeIn(PDLPatternModule &&other) { - // Ignore the other module if it has no patterns. - if (!other.pdlModule) - return; - - // Steal the functions and config of the other module. - for (auto &it : other.constraintFunctions) - registerConstraintFunction(it.first(), std::move(it.second)); - for (auto &it : other.rewriteFunctions) - registerRewriteFunction(it.first(), std::move(it.second)); - for (auto &it : other.configs) - configs.emplace_back(std::move(it)); - for (auto &it : other.configMap) - configMap.insert(it); - - // Steal the other state if we have no patterns. - if (!pdlModule) { - pdlModule = std::move(other.pdlModule); - return; - } - - // Merge the pattern operations from the other module into this one. - Block *block = pdlModule->getBody(); - block->getOperations().splice(block->end(), - other.pdlModule->getBody()->getOperations()); -} - -void PDLPatternModule::attachConfigToPatterns(ModuleOp module, - PDLPatternConfigSet &configSet) { - // Attach the configuration to the symbols within the module. We only add - // to symbols to avoid hardcoding any specific operation names here (given - // that we don't depend on any PDL dialect). We can't use - // cast here because patterns may be optional symbols. - module->walk([&](Operation *op) { - if (op->hasTrait()) - configMap[op] = &configSet; - }); -} - -//===----------------------------------------------------------------------===// -// Function Registry - -void PDLPatternModule::registerConstraintFunction( - StringRef name, PDLConstraintFunction constraintFn) { - // TODO: Is it possible to diagnose when `name` is already registered to - // a function that is not equivalent to `constraintFn`? - // Allow existing mappings in the case multiple patterns depend on the same - // constraint. - constraintFunctions.try_emplace(name, std::move(constraintFn)); -} - -void PDLPatternModule::registerRewriteFunction(StringRef name, - PDLRewriteFunction rewriteFn) { - // TODO: Is it possible to diagnose when `name` is already registered to - // a function that is not equivalent to `rewriteFn`? - // Allow existing mappings in the case multiple patterns depend on the same - // rewrite. - rewriteFunctions.try_emplace(name, std::move(rewriteFn)); -} - //===----------------------------------------------------------------------===// // RewriterBase //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h index 4d43fe636bd1f..4aceac7ed3a4c 100644 --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -16,6 +16,8 @@ #include "mlir/IR/PatternMatch.h" +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH + namespace mlir { namespace pdl_interp { class RecordMatchOp; @@ -224,4 +226,38 @@ class PDLByteCode { } // namespace detail } // namespace mlir +#else + +namespace mlir::detail { + +class PDLByteCodeMutableState { +public: + void cleanupAfterMatchAndRewrite() {} + void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {} +}; + +class PDLByteCodePattern : public Pattern {}; + +class PDLByteCode { +public: + struct MatchResult { + const PDLByteCodePattern *pattern = nullptr; + PatternBenefit benefit; + }; + + void initializeMutableState(PDLByteCodeMutableState &state) const {} + void match(Operation *op, PatternRewriter &rewriter, + SmallVectorImpl &matches, + PDLByteCodeMutableState &state) const {} + LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, + PDLByteCodeMutableState &state) const { + return failure(); + } + ArrayRef getPatterns() const { return {}; } +}; + +} // namespace mlir::detail + +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + #endif // MLIR_REWRITE_BYTECODE_H_ diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt index e0395be6cd6f5..a6c39406aa4b3 100644 --- a/mlir/lib/Rewrite/CMakeLists.txt +++ b/mlir/lib/Rewrite/CMakeLists.txt @@ -1,5 +1,6 @@ +set(LLVM_OPTIONAL_SOURCES ByteCode.cpp) + add_mlir_library(MLIRRewrite - ByteCode.cpp FrozenRewritePatternSet.cpp PatternApplicator.cpp @@ -11,8 +12,31 @@ add_mlir_library(MLIRRewrite LINK_LIBS PUBLIC MLIRIR - MLIRPDLDialect - MLIRPDLInterpDialect - MLIRPDLToPDLInterp MLIRSideEffectInterfaces ) + +if(MLIR_ENABLE_PDL_IN_PATTERNMATCH) + add_mlir_library(MLIRRewritePDL + ByteCode.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite + + DEPENDS + mlir-generic-headers + + LINK_LIBS PUBLIC + MLIRIR + MLIRPDLDialect + MLIRPDLInterpDialect + MLIRPDLToPDLInterp + MLIRSideEffectInterfaces + ) + + target_link_libraries(MLIRRewrite PUBLIC + MLIRPDLDialect + MLIRPDLInterpDialect + MLIRPDLToPDLInterp + MLIRRewritePDL) +endif() + diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp index 43840d1e8cec2..17fe02df9f66c 100644 --- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp @@ -8,8 +8,6 @@ #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "ByteCode.h" -#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" -#include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -17,6 +15,11 @@ using namespace mlir; +// Include the PDL rewrite support. +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" + static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule, DenseMap &configMap) { @@ -48,6 +51,7 @@ convertPDLToPDLInterp(ModuleOp pdlModule, pdlModule.getBody()->walk(simplifyFn); return success(); } +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH //===----------------------------------------------------------------------===// // FrozenRewritePatternSet @@ -121,6 +125,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet( impl->nativeAnyOpPatterns.push_back(std::move(pat)); } +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH // Generate the bytecode for the PDL patterns if any were provided. PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); ModuleOp pdlModule = pdlPatterns.getModule(); @@ -137,6 +142,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet( pdlModule, pdlPatterns.takeConfigs(), configMap, pdlPatterns.takeConstraintFunctions(), pdlPatterns.takeRewriteFunctions()); +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH } FrozenRewritePatternSet::~FrozenRewritePatternSet() = default; diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index 08d6ee618ac69..0064eb84aba84 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -152,7 +152,6 @@ LogicalResult PatternApplicator::matchAndRewrite( // Find the next pattern with the highest benefit. const Pattern *bestPattern = nullptr; unsigned *bestPatternIt = &opIt; - const PDLByteCode::MatchResult *pdlMatch = nullptr; /// Operation specific patterns. if (opIt < opE) @@ -164,6 +163,8 @@ LogicalResult PatternApplicator::matchAndRewrite( bestPatternIt = &anyIt; bestPattern = anyOpPatterns[anyIt]; } + + const PDLByteCode::MatchResult *pdlMatch = nullptr; /// PDL patterns. if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < pdlMatches[pdlIt].benefit)) { @@ -171,6 +172,7 @@ LogicalResult PatternApplicator::matchAndRewrite( pdlMatch = &pdlMatches[pdlIt]; bestPattern = pdlMatch->pattern; } + if (!bestPattern) break; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4d2afe462b928..85433d088dcbf 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Config/mlir-config.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -3312,6 +3313,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const return std::nullopt; } +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH //===----------------------------------------------------------------------===// // PDL Configuration //===----------------------------------------------------------------------===// @@ -3382,6 +3384,7 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { return std::move(remappedTypes); }); } +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH //===----------------------------------------------------------------------===// // Op Conversion Entry Points diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 3f312164cb1f3..7ec4c8f0963a2 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -97,16 +97,13 @@ set(MLIR_TEST_DEPENDS mlir-capi-ir-test mlir-capi-llvm-test mlir-capi-pass-test - mlir-capi-pdl-test mlir-capi-quant-test mlir-capi-sparse-tensor-test mlir-capi-transform-test mlir-capi-translation-test mlir-linalg-ods-yaml-gen mlir-lsp-server - mlir-pdll-lsp-server mlir-opt - mlir-pdll mlir-query mlir-reduce mlir-tblgen @@ -115,6 +112,12 @@ set(MLIR_TEST_DEPENDS tblgen-to-irdl ) +set(MLIR_TEST_DEPENDS ${MLIR_TEST_DEPENDS} + mlir-capi-pdl-test + mlir-pdll-lsp-server + mlir-pdll + ) + # The native target may not be enabled, in this case we won't # run tests that involves executing on the host: do not build # useless binaries. @@ -159,9 +162,10 @@ if(LLVM_BUILD_EXAMPLES) toyc-ch3 toyc-ch4 toyc-ch5 + ) + list(APPEND MLIR_TEST_DEPENDS transform-opt-ch2 transform-opt-ch3 - mlir-minimal-opt ) if(MLIR_ENABLE_EXECUTION_ENGINE) list(APPEND MLIR_TEST_DEPENDS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index e032ce7200fbf..2a3a8608db544 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_OPTIONAL_SOURCES + TestDialectConversion.cpp) +set(MLIRTestTransformsPDLDep) +set(MLIRTestTransformsPDLSrc) +if(MLIR_ENABLE_PDL_IN_PATTERNMATCH) add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen TestDialectConversion.pdll TestDialectConversionPDLLPatterns.h.inc @@ -6,17 +11,22 @@ add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test ) + set(MLIRTestTransformsPDLSrc + TestDialectConversion.cpp) + set(MLIRTestTransformsPDLDep + MLIRTestDialectConversionPDLLPatternsIncGen) +endif() # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp - TestDialectConversion.cpp TestInlining.cpp TestIntRangeInference.cpp TestMakeIsolatedFromAbove.cpp TestTopologicalSort.cpp + ${MLIRTestTransformsPDLSrc} EXCLUDE_FROM_LIBMLIR @@ -24,7 +34,7 @@ add_mlir_library(MLIRTestTransforms ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms DEPENDS - MLIRTestDialectConversionPDLLPatternsIncGen + ${MLIRTestTransformsPDLDep} LINK_LIBS PUBLIC MLIRAnalysis diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt index e90ccf17af17f..9664f6b94844e 100644 --- a/mlir/tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt @@ -21,10 +21,12 @@ if(MLIR_INCLUDE_TESTS) MLIRTestIR MLIRTestPass MLIRTestReducer + ) + set(test_libs + ${test_libs} MLIRTestRewrite MLIRTestTransformDialect - MLIRTestTransforms - ) + MLIRTestTransforms) endif() set(LIBS diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index b6ada66d32188..15317a119c154 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -38,16 +38,18 @@ if(MLIR_INCLUDE_TESTS) MLIRTestIR MLIRTestOneToNTypeConversionPass MLIRTestPass - MLIRTestPDLL MLIRTestReducer - MLIRTestRewrite - MLIRTestTransformDialect MLIRTestTransforms MLIRTilingInterfaceTestPasses MLIRVectorTestPasses MLIRTestVectorToSPIRV MLIRLLVMTestPasses ) + set(test_libs ${test_libs} + MLIRTestPDLL + MLIRTestRewrite + MLIRTestTransformDialect + ) endif() set(LIBS diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index f7a5b3183b50b..bf8f3b7aa21d1 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -85,7 +85,9 @@ void registerTestDataLayoutQuery(); void registerTestDeadCodeAnalysisPass(); void registerTestDecomposeCallGraphTypes(); void registerTestDiagnosticsPass(); +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH void registerTestDialectConversionPasses(); +#endif void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestEmulateNarrowTypePass(); @@ -147,8 +149,8 @@ void registerTestNvgpuLowerings(); namespace test { void registerTestDialect(DialectRegistry &); -void registerTestTransformDialectExtension(DialectRegistry &); void registerTestDynDialect(DialectRegistry &); +void registerTestTransformDialectExtension(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS @@ -260,6 +262,9 @@ void registerTestPasses() { mlir::test::registerTestVectorReductionToSPIRVDotProd(); mlir::test::registerTestNvgpuLowerings(); mlir::test::registerTestWrittenToPass(); +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH + mlir::test::registerTestDialectConversionPasses(); +#endif } #endif diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 2a56b2d6f0373..2a72bf965e544 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -35,6 +35,7 @@ expand_template( substitutions = { "#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS": "#define MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 0", "#cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED}": "/* #undef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED */", + "#cmakedefine01 MLIR_ENABLE_PDL_IN_PATTERNMATCH": "#define MLIR_ENABLE_PDL_IN_PATTERNMATCH 1", }, template = "include/mlir/Config/mlir-config.h.cmake", ) @@ -318,11 +319,13 @@ cc_library( srcs = glob([ "lib/IR/*.cpp", "lib/IR/*.h", + "lib/IR/PDL/*.cpp", "lib/Bytecode/Reader/*.h", "lib/Bytecode/Writer/*.h", "lib/Bytecode/*.h", ]) + [ "lib/Bytecode/BytecodeOpInterface.cpp", + "include/mlir/IR/PDLPatternMatch.h.inc", ], hdrs = glob([ "include/mlir/IR/*.h", @@ -345,6 +348,7 @@ cc_library( ":BuiltinTypesIncGen", ":BytecodeOpInterfaceIncGen", ":CallOpInterfacesIncGen", + ":config", ":DataLayoutInterfacesIncGen", ":InferTypeOpInterfaceIncGen", ":OpAsmInterfaceIncGen",