Skip to content

Commit

Permalink
[mlir:PDL] Add support for DialectConversion with pattern configurations
Browse files Browse the repository at this point in the history
Up until now PDL(L) has not supported dialect conversion because we had no
way of remapping values or integrating with type conversions. This commit
rectifies that by adding a new "pattern configuration" concept to PDL. This
essentially allows for attaching external configurations to patterns, which
can hook into pattern events (for now just the scope of a rewrite, but we
could also pass configs to native rewrites as well). This allows for injecting
the type converter into the conversion pattern rewriter.

Differential Revision: https://reviews.llvm.org/D133142
  • Loading branch information
River707 committed Nov 8, 2022
1 parent f3a86a2 commit 8c66344
Show file tree
Hide file tree
Showing 19 changed files with 669 additions and 95 deletions.
10 changes: 9 additions & 1 deletion mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h
Expand Up @@ -13,19 +13,27 @@
#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
#define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H

#include <memory>
#include "mlir/Support/LLVM.h"

namespace mlir {
class ModuleOp;
class Operation;
template <typename OpT>
class OperationPass;
class PDLPatternConfigSet;

#define GEN_PASS_DECL_CONVERTPDLTOPDLINTERP
#include "mlir/Conversion/Passes.h.inc"

/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass();

/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
/// `configMap` holds a map of the configurations for each pattern being
/// compiled.
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass(
DenseMap<Operation *, PDLPatternConfigSet *> &configMap);

} // namespace mlir

#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
219 changes: 197 additions & 22 deletions mlir/include/mlir/IR/PatternMatch.h
Expand Up @@ -600,10 +600,16 @@ class IRRewriter : public RewriterBase {
class PatternRewriter : public RewriterBase {
public:
using RewriterBase::RewriterBase;

/// A hook used to indicate if the pattern rewriter can recover from failure
/// during the rewrite stage of a pattern. For example, if the pattern
/// rewriter supports rollback, it may progress smoothly even if IR was
/// changed during the rewrite.
virtual bool canRecoverFromRewriteFailure() const { return false; }
};

//===----------------------------------------------------------------------===//
// PDLPatternModule
// PDL Patterns
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -796,6 +802,108 @@ class PDLResultList {
SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
};

//===----------------------------------------------------------------------===//
// PDLPatternConfig

/// An individual configuration for a pattern, which can be accessed by native
/// functions via the PDLPatternConfigSet. This allows for injecting additional
/// configuration into PDL patterns that is specific to certain compilation
/// flows.
class PDLPatternConfig {
public:
virtual ~PDLPatternConfig() = default;

/// Hooks that are invoked at the beginning and end of a rewrite of a matched
/// pattern. These can be used to setup any specific state necessary for the
/// rewrite.
virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}

/// Return the TypeID that represents this configuration.
TypeID getTypeID() const { return id; }

protected:
PDLPatternConfig(TypeID id) : id(id) {}

private:
TypeID id;
};

/// This class provides a base class for users implementing a type of pattern
/// configuration.
template <typename T>
class PDLPatternConfigBase : public PDLPatternConfig {
public:
/// Support LLVM style casting.
static bool classof(const PDLPatternConfig *config) {
return config->getTypeID() == getConfigID();
}

/// Return the type id used for this configuration.
static TypeID getConfigID() { return TypeID::get<T>(); }

protected:
PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
};

/// This class contains a set of configurations for a specific pattern.
/// Configurations are uniqued by TypeID, meaning that only one configuration of
/// each type is allowed.
class PDLPatternConfigSet {
public:
PDLPatternConfigSet() = default;

/// Construct a set with the given configurations.
template <typename... ConfigsT>
PDLPatternConfigSet(ConfigsT &&...configs) {
(addConfig(std::forward<ConfigsT>(configs)), ...);
}

/// Get the configuration defined by the given type. Asserts that the
/// configuration of the provided type exists.
template <typename T>
const T &get() const {
const T *config = tryGet<T>();
assert(config && "configuration not found");
return *config;
}

/// Get the configuration defined by the given type, returns nullptr if the
/// configuration does not exist.
template <typename T>
const T *tryGet() const {
for (const auto &configIt : configs)
if (const T *config = dyn_cast<T>(configIt.get()))
return config;
return nullptr;
}

/// Notify the configurations within this set at the beginning or end of a
/// rewrite of a matched pattern.
void notifyRewriteBegin(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteBegin(rewriter);
}
void notifyRewriteEnd(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteEnd(rewriter);
}

protected:
/// Add a configuration to the set.
template <typename T>
void addConfig(T &&config) {
assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
configs.emplace_back(
std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
}

/// The set of configurations for this pattern. This uses a vector instead of
/// a map with the expectation that the number of configurations per set is
/// small (<= 1).
SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
};

//===----------------------------------------------------------------------===//
// PDLPatternModule

Expand All @@ -807,9 +915,11 @@ using PDLConstraintFunction =
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
/// invoked when the corresponding match was successful.
using PDLRewriteFunction =
std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
/// invoked when the corresponding match was successful. Returns failure if an
/// invariant of the rewrite was broken (certain rewriters may recover from
/// partial pattern application).
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;

namespace detail {
namespace pdl_function_builder {
Expand Down Expand Up @@ -1034,6 +1144,13 @@ struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
results.push_back(types);
}
};
template <unsigned N>
struct ProcessPDLValue<SmallVector<Type, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Type, N> values) {
results.push_back(TypeRange(values));
}
};

//===----------------------------------------------------------------------===//
// Value
Expand Down Expand Up @@ -1061,6 +1178,13 @@ struct ProcessPDLValue<ResultRange> {
results.push_back(values);
}
};
template <unsigned N>
struct ProcessPDLValue<SmallVector<Value, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Value, N> values) {
results.push_back(ValueRange(values));
}
};

//===----------------------------------------------------------------------===//
// PDL Function Builder: Argument Handling
Expand Down Expand Up @@ -1111,28 +1235,49 @@ void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,

/// Store a single result within the result list.
template <typename T>
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
T &&value) {
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results, T &&value) {
ProcessPDLValue<T>::processAsResult(rewriter, results,
std::forward<T>(value));
return success();
}

/// Store a std::pair<> as individual results within the result list.
template <typename T1, typename T2>
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
std::pair<T1, T2> &&pair) {
processResults(rewriter, results, std::move(pair.first));
processResults(rewriter, results, std::move(pair.second));
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
std::pair<T1, T2> &&pair) {
if (failed(processResults(rewriter, results, std::move(pair.first))) ||
failed(processResults(rewriter, results, std::move(pair.second))))
return failure();
return success();
}

/// Store a std::tuple<> as individual results within the result list.
template <typename... Ts>
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
std::tuple<Ts...> &&tuple) {
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
std::tuple<Ts...> &&tuple) {
auto applyFn = [&](auto &&...args) {
(processResults(rewriter, results, std::move(args)), ...);
return (succeeded(processResults(rewriter, results, std::move(args))) &&
...);
};
std::apply(applyFn, std::move(tuple));
return success(std::apply(applyFn, std::move(tuple)));
}

/// Handle LogicalResult propagation.
inline LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
LogicalResult &&result) {
return result;
}
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
FailureOr<T> &&result) {
if (failed(result))
return failure();
return processResults(rewriter, results, std::move(*result));
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1192,23 +1337,26 @@ buildConstraintFn(ConstraintFnT &&constraintFn) {
/// This overload handles the case of no return values.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
fn(rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
return success();
}
/// This overload handles the case of return values, which need to be packaged
/// into the result list.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &results, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
processResults(
return processResults(
rewriter, results,
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
processAsArg(values[I]))...));
Expand Down Expand Up @@ -1240,14 +1388,17 @@ buildRewriteFn(RewriteFnT &&rewriteFn) {
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1>();
assertArgs<RewriteFnT>(rewriter, values, argIndices);
processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
argIndices);
return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
argIndices);
};
}

} // namespace pdl_function_builder
} // namespace detail

//===----------------------------------------------------------------------===//
// PDLPatternModule

/// This class contains all of the necessary data for a set of PDL patterns, or
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
/// contained by this pattern may contain any number of `pdl.pattern`
Expand All @@ -1256,9 +1407,17 @@ class PDLPatternModule {
public:
PDLPatternModule() = default;

/// Construct a PDL pattern with the given module.
PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
: pdlModule(std::move(pdlModule)) {}
/// Construct a PDL pattern with the given module and configurations.
PDLPatternModule(OwningOpRef<ModuleOp> module)
: pdlModule(std::move(module)) {}
template <typename... ConfigsT>
PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
: PDLPatternModule(std::move(module)) {
auto configSet = std::make_unique<PDLPatternConfigSet>(
std::forward<ConfigsT>(patternConfigs)...);
attachConfigToPatterns(*pdlModule, *configSet);
configs.emplace_back(std::move(configSet));
}

/// Merge the state in `other` into this pattern module.
void mergeIn(PDLPatternModule &&other);
Expand Down Expand Up @@ -1344,6 +1503,14 @@ class PDLPatternModule {
return rewriteFunctions;
}

/// Return the set of the registered pattern configs.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
return std::move(configs);
}
DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
return std::move(configMap);
}

/// Clear out the patterns and functions within this module.
void clear() {
pdlModule = nullptr;
Expand All @@ -1352,9 +1519,17 @@ class PDLPatternModule {
}

private:
/// Attach the given pattern config set to the patterns defined within the
/// given module.
void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);

/// The module containing the `pdl.pattern` operations.
OwningOpRef<ModuleOp> pdlModule;

/// The set of configuration sets referenced by patterns within `pdlModule`.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
DenseMap<Operation *, PDLPatternConfigSet *> configMap;

/// The external functions referenced from within the PDL module.
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
Expand Down

0 comments on commit 8c66344

Please sign in to comment.