Skip to content

Commit 8c66344

Browse files
committed
[mlir:PDL] Add support for DialectConversion with pattern configurations
Up until now PDL(L) has not supported dialect conversion because we had no way of remapping values or integrating with type conversions. This commit rectifies that by adding a new "pattern configuration" concept to PDL. This essentially allows for attaching external configurations to patterns, which can hook into pattern events (for now just the scope of a rewrite, but we could also pass configs to native rewrites as well). This allows for injecting the type converter into the conversion pattern rewriter. Differential Revision: https://reviews.llvm.org/D133142
1 parent f3a86a2 commit 8c66344

File tree

19 files changed

+669
-95
lines changed

19 files changed

+669
-95
lines changed

mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,27 @@
1313
#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
1414
#define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
1515

16-
#include <memory>
16+
#include "mlir/Support/LLVM.h"
1717

1818
namespace mlir {
1919
class ModuleOp;
20+
class Operation;
2021
template <typename OpT>
2122
class OperationPass;
23+
class PDLPatternConfigSet;
2224

2325
#define GEN_PASS_DECL_CONVERTPDLTOPDLINTERP
2426
#include "mlir/Conversion/Passes.h.inc"
2527

2628
/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
2729
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass();
2830

31+
/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
32+
/// `configMap` holds a map of the configurations for each pattern being
33+
/// compiled.
34+
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass(
35+
DenseMap<Operation *, PDLPatternConfigSet *> &configMap);
36+
2937
} // namespace mlir
3038

3139
#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 197 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -600,10 +600,16 @@ class IRRewriter : public RewriterBase {
600600
class PatternRewriter : public RewriterBase {
601601
public:
602602
using RewriterBase::RewriterBase;
603+
604+
/// A hook used to indicate if the pattern rewriter can recover from failure
605+
/// during the rewrite stage of a pattern. For example, if the pattern
606+
/// rewriter supports rollback, it may progress smoothly even if IR was
607+
/// changed during the rewrite.
608+
virtual bool canRecoverFromRewriteFailure() const { return false; }
603609
};
604610

605611
//===----------------------------------------------------------------------===//
606-
// PDLPatternModule
612+
// PDL Patterns
607613
//===----------------------------------------------------------------------===//
608614

609615
//===----------------------------------------------------------------------===//
@@ -796,6 +802,108 @@ class PDLResultList {
796802
SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
797803
};
798804

805+
//===----------------------------------------------------------------------===//
806+
// PDLPatternConfig
807+
808+
/// An individual configuration for a pattern, which can be accessed by native
809+
/// functions via the PDLPatternConfigSet. This allows for injecting additional
810+
/// configuration into PDL patterns that is specific to certain compilation
811+
/// flows.
812+
class PDLPatternConfig {
813+
public:
814+
virtual ~PDLPatternConfig() = default;
815+
816+
/// Hooks that are invoked at the beginning and end of a rewrite of a matched
817+
/// pattern. These can be used to setup any specific state necessary for the
818+
/// rewrite.
819+
virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
820+
virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
821+
822+
/// Return the TypeID that represents this configuration.
823+
TypeID getTypeID() const { return id; }
824+
825+
protected:
826+
PDLPatternConfig(TypeID id) : id(id) {}
827+
828+
private:
829+
TypeID id;
830+
};
831+
832+
/// This class provides a base class for users implementing a type of pattern
833+
/// configuration.
834+
template <typename T>
835+
class PDLPatternConfigBase : public PDLPatternConfig {
836+
public:
837+
/// Support LLVM style casting.
838+
static bool classof(const PDLPatternConfig *config) {
839+
return config->getTypeID() == getConfigID();
840+
}
841+
842+
/// Return the type id used for this configuration.
843+
static TypeID getConfigID() { return TypeID::get<T>(); }
844+
845+
protected:
846+
PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
847+
};
848+
849+
/// This class contains a set of configurations for a specific pattern.
850+
/// Configurations are uniqued by TypeID, meaning that only one configuration of
851+
/// each type is allowed.
852+
class PDLPatternConfigSet {
853+
public:
854+
PDLPatternConfigSet() = default;
855+
856+
/// Construct a set with the given configurations.
857+
template <typename... ConfigsT>
858+
PDLPatternConfigSet(ConfigsT &&...configs) {
859+
(addConfig(std::forward<ConfigsT>(configs)), ...);
860+
}
861+
862+
/// Get the configuration defined by the given type. Asserts that the
863+
/// configuration of the provided type exists.
864+
template <typename T>
865+
const T &get() const {
866+
const T *config = tryGet<T>();
867+
assert(config && "configuration not found");
868+
return *config;
869+
}
870+
871+
/// Get the configuration defined by the given type, returns nullptr if the
872+
/// configuration does not exist.
873+
template <typename T>
874+
const T *tryGet() const {
875+
for (const auto &configIt : configs)
876+
if (const T *config = dyn_cast<T>(configIt.get()))
877+
return config;
878+
return nullptr;
879+
}
880+
881+
/// Notify the configurations within this set at the beginning or end of a
882+
/// rewrite of a matched pattern.
883+
void notifyRewriteBegin(PatternRewriter &rewriter) {
884+
for (const auto &config : configs)
885+
config->notifyRewriteBegin(rewriter);
886+
}
887+
void notifyRewriteEnd(PatternRewriter &rewriter) {
888+
for (const auto &config : configs)
889+
config->notifyRewriteEnd(rewriter);
890+
}
891+
892+
protected:
893+
/// Add a configuration to the set.
894+
template <typename T>
895+
void addConfig(T &&config) {
896+
assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
897+
configs.emplace_back(
898+
std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
899+
}
900+
901+
/// The set of configurations for this pattern. This uses a vector instead of
902+
/// a map with the expectation that the number of configurations per set is
903+
/// small (<= 1).
904+
SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
905+
};
906+
799907
//===----------------------------------------------------------------------===//
800908
// PDLPatternModule
801909

@@ -807,9 +915,11 @@ using PDLConstraintFunction =
807915
/// A native PDL rewrite function. This function performs a rewrite on the
808916
/// given set of values. Any results from this rewrite that should be passed
809917
/// back to PDL should be added to the provided result list. This method is only
810-
/// invoked when the corresponding match was successful.
811-
using PDLRewriteFunction =
812-
std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
918+
/// invoked when the corresponding match was successful. Returns failure if an
919+
/// invariant of the rewrite was broken (certain rewriters may recover from
920+
/// partial pattern application).
921+
using PDLRewriteFunction = std::function<LogicalResult(
922+
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
813923

814924
namespace detail {
815925
namespace pdl_function_builder {
@@ -1034,6 +1144,13 @@ struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
10341144
results.push_back(types);
10351145
}
10361146
};
1147+
template <unsigned N>
1148+
struct ProcessPDLValue<SmallVector<Type, N>> {
1149+
static void processAsResult(PatternRewriter &, PDLResultList &results,
1150+
SmallVector<Type, N> values) {
1151+
results.push_back(TypeRange(values));
1152+
}
1153+
};
10371154

10381155
//===----------------------------------------------------------------------===//
10391156
// Value
@@ -1061,6 +1178,13 @@ struct ProcessPDLValue<ResultRange> {
10611178
results.push_back(values);
10621179
}
10631180
};
1181+
template <unsigned N>
1182+
struct ProcessPDLValue<SmallVector<Value, N>> {
1183+
static void processAsResult(PatternRewriter &, PDLResultList &results,
1184+
SmallVector<Value, N> values) {
1185+
results.push_back(ValueRange(values));
1186+
}
1187+
};
10641188

10651189
//===----------------------------------------------------------------------===//
10661190
// PDL Function Builder: Argument Handling
@@ -1111,28 +1235,49 @@ void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
11111235

11121236
/// Store a single result within the result list.
11131237
template <typename T>
1114-
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
1115-
T &&value) {
1238+
static LogicalResult processResults(PatternRewriter &rewriter,
1239+
PDLResultList &results, T &&value) {
11161240
ProcessPDLValue<T>::processAsResult(rewriter, results,
11171241
std::forward<T>(value));
1242+
return success();
11181243
}
11191244

11201245
/// Store a std::pair<> as individual results within the result list.
11211246
template <typename T1, typename T2>
1122-
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
1123-
std::pair<T1, T2> &&pair) {
1124-
processResults(rewriter, results, std::move(pair.first));
1125-
processResults(rewriter, results, std::move(pair.second));
1247+
static LogicalResult processResults(PatternRewriter &rewriter,
1248+
PDLResultList &results,
1249+
std::pair<T1, T2> &&pair) {
1250+
if (failed(processResults(rewriter, results, std::move(pair.first))) ||
1251+
failed(processResults(rewriter, results, std::move(pair.second))))
1252+
return failure();
1253+
return success();
11261254
}
11271255

11281256
/// Store a std::tuple<> as individual results within the result list.
11291257
template <typename... Ts>
1130-
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
1131-
std::tuple<Ts...> &&tuple) {
1258+
static LogicalResult processResults(PatternRewriter &rewriter,
1259+
PDLResultList &results,
1260+
std::tuple<Ts...> &&tuple) {
11321261
auto applyFn = [&](auto &&...args) {
1133-
(processResults(rewriter, results, std::move(args)), ...);
1262+
return (succeeded(processResults(rewriter, results, std::move(args))) &&
1263+
...);
11341264
};
1135-
std::apply(applyFn, std::move(tuple));
1265+
return success(std::apply(applyFn, std::move(tuple)));
1266+
}
1267+
1268+
/// Handle LogicalResult propagation.
1269+
inline LogicalResult processResults(PatternRewriter &rewriter,
1270+
PDLResultList &results,
1271+
LogicalResult &&result) {
1272+
return result;
1273+
}
1274+
template <typename T>
1275+
static LogicalResult processResults(PatternRewriter &rewriter,
1276+
PDLResultList &results,
1277+
FailureOr<T> &&result) {
1278+
if (failed(result))
1279+
return failure();
1280+
return processResults(rewriter, results, std::move(*result));
11361281
}
11371282

11381283
//===----------------------------------------------------------------------===//
@@ -1192,23 +1337,26 @@ buildConstraintFn(ConstraintFnT &&constraintFn) {
11921337
/// This overload handles the case of no return values.
11931338
template <typename PDLFnT, std::size_t... I,
11941339
typename FnTraitsT = llvm::function_traits<PDLFnT>>
1195-
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
1340+
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
1341+
LogicalResult>
11961342
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
11971343
PDLResultList &, ArrayRef<PDLValue> values,
11981344
std::index_sequence<I...>) {
11991345
fn(rewriter,
12001346
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
12011347
values[I]))...);
1348+
return success();
12021349
}
12031350
/// This overload handles the case of return values, which need to be packaged
12041351
/// into the result list.
12051352
template <typename PDLFnT, std::size_t... I,
12061353
typename FnTraitsT = llvm::function_traits<PDLFnT>>
1207-
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
1354+
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
1355+
LogicalResult>
12081356
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
12091357
PDLResultList &results, ArrayRef<PDLValue> values,
12101358
std::index_sequence<I...>) {
1211-
processResults(
1359+
return processResults(
12121360
rewriter, results,
12131361
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
12141362
processAsArg(values[I]))...));
@@ -1240,14 +1388,17 @@ buildRewriteFn(RewriteFnT &&rewriteFn) {
12401388
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
12411389
1>();
12421390
assertArgs<RewriteFnT>(rewriter, values, argIndices);
1243-
processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
1244-
argIndices);
1391+
return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
1392+
argIndices);
12451393
};
12461394
}
12471395

12481396
} // namespace pdl_function_builder
12491397
} // namespace detail
12501398

1399+
//===----------------------------------------------------------------------===//
1400+
// PDLPatternModule
1401+
12511402
/// This class contains all of the necessary data for a set of PDL patterns, or
12521403
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
12531404
/// contained by this pattern may contain any number of `pdl.pattern`
@@ -1256,9 +1407,17 @@ class PDLPatternModule {
12561407
public:
12571408
PDLPatternModule() = default;
12581409

1259-
/// Construct a PDL pattern with the given module.
1260-
PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
1261-
: pdlModule(std::move(pdlModule)) {}
1410+
/// Construct a PDL pattern with the given module and configurations.
1411+
PDLPatternModule(OwningOpRef<ModuleOp> module)
1412+
: pdlModule(std::move(module)) {}
1413+
template <typename... ConfigsT>
1414+
PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
1415+
: PDLPatternModule(std::move(module)) {
1416+
auto configSet = std::make_unique<PDLPatternConfigSet>(
1417+
std::forward<ConfigsT>(patternConfigs)...);
1418+
attachConfigToPatterns(*pdlModule, *configSet);
1419+
configs.emplace_back(std::move(configSet));
1420+
}
12621421

12631422
/// Merge the state in `other` into this pattern module.
12641423
void mergeIn(PDLPatternModule &&other);
@@ -1344,6 +1503,14 @@ class PDLPatternModule {
13441503
return rewriteFunctions;
13451504
}
13461505

1506+
/// Return the set of the registered pattern configs.
1507+
SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
1508+
return std::move(configs);
1509+
}
1510+
DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
1511+
return std::move(configMap);
1512+
}
1513+
13471514
/// Clear out the patterns and functions within this module.
13481515
void clear() {
13491516
pdlModule = nullptr;
@@ -1352,9 +1519,17 @@ class PDLPatternModule {
13521519
}
13531520

13541521
private:
1522+
/// Attach the given pattern config set to the patterns defined within the
1523+
/// given module.
1524+
void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
1525+
13551526
/// The module containing the `pdl.pattern` operations.
13561527
OwningOpRef<ModuleOp> pdlModule;
13571528

1529+
/// The set of configuration sets referenced by patterns within `pdlModule`.
1530+
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
1531+
DenseMap<Operation *, PDLPatternConfigSet *> configMap;
1532+
13581533
/// The external functions referenced from within the PDL module.
13591534
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
13601535
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;

0 commit comments

Comments
 (0)