@@ -600,10 +600,16 @@ class IRRewriter : public RewriterBase {
600600class PatternRewriter : public RewriterBase {
601601public:
602602 using RewriterBase::RewriterBase;
603+
604+ // / A hook used to indicate if the pattern rewriter can recover from failure
605+ // / during the rewrite stage of a pattern. For example, if the pattern
606+ // / rewriter supports rollback, it may progress smoothly even if IR was
607+ // / changed during the rewrite.
608+ virtual bool canRecoverFromRewriteFailure () const { return false ; }
603609};
604610
605611// ===----------------------------------------------------------------------===//
606- // PDLPatternModule
612+ // PDL Patterns
607613// ===----------------------------------------------------------------------===//
608614
609615// ===----------------------------------------------------------------------===//
@@ -796,6 +802,108 @@ class PDLResultList {
796802 SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
797803};
798804
805+ // ===----------------------------------------------------------------------===//
806+ // PDLPatternConfig
807+
808+ // / An individual configuration for a pattern, which can be accessed by native
809+ // / functions via the PDLPatternConfigSet. This allows for injecting additional
810+ // / configuration into PDL patterns that is specific to certain compilation
811+ // / flows.
812+ class PDLPatternConfig {
813+ public:
814+ virtual ~PDLPatternConfig () = default ;
815+
816+ // / Hooks that are invoked at the beginning and end of a rewrite of a matched
817+ // / pattern. These can be used to setup any specific state necessary for the
818+ // / rewrite.
819+ virtual void notifyRewriteBegin (PatternRewriter &rewriter) {}
820+ virtual void notifyRewriteEnd (PatternRewriter &rewriter) {}
821+
822+ // / Return the TypeID that represents this configuration.
823+ TypeID getTypeID () const { return id; }
824+
825+ protected:
826+ PDLPatternConfig (TypeID id) : id(id) {}
827+
828+ private:
829+ TypeID id;
830+ };
831+
832+ // / This class provides a base class for users implementing a type of pattern
833+ // / configuration.
834+ template <typename T>
835+ class PDLPatternConfigBase : public PDLPatternConfig {
836+ public:
837+ // / Support LLVM style casting.
838+ static bool classof (const PDLPatternConfig *config) {
839+ return config->getTypeID () == getConfigID ();
840+ }
841+
842+ // / Return the type id used for this configuration.
843+ static TypeID getConfigID () { return TypeID::get<T>(); }
844+
845+ protected:
846+ PDLPatternConfigBase () : PDLPatternConfig(getConfigID()) {}
847+ };
848+
849+ // / This class contains a set of configurations for a specific pattern.
850+ // / Configurations are uniqued by TypeID, meaning that only one configuration of
851+ // / each type is allowed.
852+ class PDLPatternConfigSet {
853+ public:
854+ PDLPatternConfigSet () = default ;
855+
856+ // / Construct a set with the given configurations.
857+ template <typename ... ConfigsT>
858+ PDLPatternConfigSet (ConfigsT &&...configs) {
859+ (addConfig (std::forward<ConfigsT>(configs)), ...);
860+ }
861+
862+ // / Get the configuration defined by the given type. Asserts that the
863+ // / configuration of the provided type exists.
864+ template <typename T>
865+ const T &get () const {
866+ const T *config = tryGet<T>();
867+ assert (config && " configuration not found" );
868+ return *config;
869+ }
870+
871+ // / Get the configuration defined by the given type, returns nullptr if the
872+ // / configuration does not exist.
873+ template <typename T>
874+ const T *tryGet () const {
875+ for (const auto &configIt : configs)
876+ if (const T *config = dyn_cast<T>(configIt.get ()))
877+ return config;
878+ return nullptr ;
879+ }
880+
881+ // / Notify the configurations within this set at the beginning or end of a
882+ // / rewrite of a matched pattern.
883+ void notifyRewriteBegin (PatternRewriter &rewriter) {
884+ for (const auto &config : configs)
885+ config->notifyRewriteBegin (rewriter);
886+ }
887+ void notifyRewriteEnd (PatternRewriter &rewriter) {
888+ for (const auto &config : configs)
889+ config->notifyRewriteEnd (rewriter);
890+ }
891+
892+ protected:
893+ // / Add a configuration to the set.
894+ template <typename T>
895+ void addConfig (T &&config) {
896+ assert (!tryGet<std::decay_t <T>>() && " configuration already exists" );
897+ configs.emplace_back (
898+ std::make_unique<std::decay_t <T>>(std::forward<T>(config)));
899+ }
900+
901+ // / The set of configurations for this pattern. This uses a vector instead of
902+ // / a map with the expectation that the number of configurations per set is
903+ // / small (<= 1).
904+ SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
905+ };
906+
799907// ===----------------------------------------------------------------------===//
800908// PDLPatternModule
801909
@@ -807,9 +915,11 @@ using PDLConstraintFunction =
807915// / A native PDL rewrite function. This function performs a rewrite on the
808916// / given set of values. Any results from this rewrite that should be passed
809917// / back to PDL should be added to the provided result list. This method is only
810- // / invoked when the corresponding match was successful.
811- using PDLRewriteFunction =
812- std::function<void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
918+ // / invoked when the corresponding match was successful. Returns failure if an
919+ // / invariant of the rewrite was broken (certain rewriters may recover from
920+ // / partial pattern application).
921+ using PDLRewriteFunction = std::function<LogicalResult(
922+ PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
813923
814924namespace detail {
815925namespace pdl_function_builder {
@@ -1034,6 +1144,13 @@ struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
10341144 results.push_back (types);
10351145 }
10361146};
1147+ template <unsigned N>
1148+ struct ProcessPDLValue <SmallVector<Type, N>> {
1149+ static void processAsResult (PatternRewriter &, PDLResultList &results,
1150+ SmallVector<Type, N> values) {
1151+ results.push_back (TypeRange (values));
1152+ }
1153+ };
10371154
10381155// ===----------------------------------------------------------------------===//
10391156// Value
@@ -1061,6 +1178,13 @@ struct ProcessPDLValue<ResultRange> {
10611178 results.push_back (values);
10621179 }
10631180};
1181+ template <unsigned N>
1182+ struct ProcessPDLValue <SmallVector<Value, N>> {
1183+ static void processAsResult (PatternRewriter &, PDLResultList &results,
1184+ SmallVector<Value, N> values) {
1185+ results.push_back (ValueRange (values));
1186+ }
1187+ };
10641188
10651189// ===----------------------------------------------------------------------===//
10661190// PDL Function Builder: Argument Handling
@@ -1111,28 +1235,49 @@ void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
11111235
11121236// / Store a single result within the result list.
11131237template <typename T>
1114- static void processResults (PatternRewriter &rewriter, PDLResultList &results ,
1115- T &&value) {
1238+ static LogicalResult processResults (PatternRewriter &rewriter,
1239+ PDLResultList &results, T &&value) {
11161240 ProcessPDLValue<T>::processAsResult (rewriter, results,
11171241 std::forward<T>(value));
1242+ return success ();
11181243}
11191244
11201245// / Store a std::pair<> as individual results within the result list.
11211246template <typename T1, typename T2>
1122- static void processResults (PatternRewriter &rewriter, PDLResultList &results,
1123- std::pair<T1, T2> &&pair) {
1124- processResults (rewriter, results, std::move (pair.first ));
1125- processResults (rewriter, results, std::move (pair.second ));
1247+ static LogicalResult processResults (PatternRewriter &rewriter,
1248+ PDLResultList &results,
1249+ std::pair<T1, T2> &&pair) {
1250+ if (failed (processResults (rewriter, results, std::move (pair.first ))) ||
1251+ failed (processResults (rewriter, results, std::move (pair.second ))))
1252+ return failure ();
1253+ return success ();
11261254}
11271255
11281256// / Store a std::tuple<> as individual results within the result list.
11291257template <typename ... Ts>
1130- static void processResults (PatternRewriter &rewriter, PDLResultList &results,
1131- std::tuple<Ts...> &&tuple) {
1258+ static LogicalResult processResults (PatternRewriter &rewriter,
1259+ PDLResultList &results,
1260+ std::tuple<Ts...> &&tuple) {
11321261 auto applyFn = [&](auto &&...args ) {
1133- (processResults (rewriter, results, std::move (args)), ...);
1262+ return (succeeded (processResults (rewriter, results, std::move (args))) &&
1263+ ...);
11341264 };
1135- std::apply (applyFn, std::move (tuple));
1265+ return success (std::apply (applyFn, std::move (tuple)));
1266+ }
1267+
1268+ // / Handle LogicalResult propagation.
1269+ inline LogicalResult processResults (PatternRewriter &rewriter,
1270+ PDLResultList &results,
1271+ LogicalResult &&result) {
1272+ return result;
1273+ }
1274+ template <typename T>
1275+ static LogicalResult processResults (PatternRewriter &rewriter,
1276+ PDLResultList &results,
1277+ FailureOr<T> &&result) {
1278+ if (failed (result))
1279+ return failure ();
1280+ return processResults (rewriter, results, std::move (*result));
11361281}
11371282
11381283// ===----------------------------------------------------------------------===//
@@ -1192,23 +1337,26 @@ buildConstraintFn(ConstraintFnT &&constraintFn) {
11921337// / This overload handles the case of no return values.
11931338template <typename PDLFnT, std::size_t ... I,
11941339 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1195- std::enable_if_t <std::is_same<typename FnTraitsT::result_t , void >::value>
1340+ std::enable_if_t <std::is_same<typename FnTraitsT::result_t , void >::value,
1341+ LogicalResult>
11961342processArgsAndInvokeRewrite (PDLFnT &fn, PatternRewriter &rewriter,
11971343 PDLResultList &, ArrayRef<PDLValue> values,
11981344 std::index_sequence<I...>) {
11991345 fn (rewriter,
12001346 (ProcessPDLValue<typename FnTraitsT::template arg_t <I + 1 >>::processAsArg (
12011347 values[I]))...);
1348+ return success ();
12021349}
12031350// / This overload handles the case of return values, which need to be packaged
12041351// / into the result list.
12051352template <typename PDLFnT, std::size_t ... I,
12061353 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1207- std::enable_if_t <!std::is_same<typename FnTraitsT::result_t , void >::value>
1354+ std::enable_if_t <!std::is_same<typename FnTraitsT::result_t , void >::value,
1355+ LogicalResult>
12081356processArgsAndInvokeRewrite (PDLFnT &fn, PatternRewriter &rewriter,
12091357 PDLResultList &results, ArrayRef<PDLValue> values,
12101358 std::index_sequence<I...>) {
1211- processResults (
1359+ return processResults (
12121360 rewriter, results,
12131361 fn (rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t <I + 1 >>::
12141362 processAsArg (values[I]))...));
@@ -1240,14 +1388,17 @@ buildRewriteFn(RewriteFnT &&rewriteFn) {
12401388 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
12411389 1 >();
12421390 assertArgs<RewriteFnT>(rewriter, values, argIndices);
1243- processArgsAndInvokeRewrite (rewriteFn, rewriter, results, values,
1244- argIndices);
1391+ return processArgsAndInvokeRewrite (rewriteFn, rewriter, results, values,
1392+ argIndices);
12451393 };
12461394}
12471395
12481396} // namespace pdl_function_builder
12491397} // namespace detail
12501398
1399+ // ===----------------------------------------------------------------------===//
1400+ // PDLPatternModule
1401+
12511402// / This class contains all of the necessary data for a set of PDL patterns, or
12521403// / pattern rewrites specified in the form of the PDL dialect. This PDL module
12531404// / contained by this pattern may contain any number of `pdl.pattern`
@@ -1256,9 +1407,17 @@ class PDLPatternModule {
12561407public:
12571408 PDLPatternModule () = default ;
12581409
1259- // / Construct a PDL pattern with the given module.
1260- PDLPatternModule (OwningOpRef<ModuleOp> pdlModule)
1261- : pdlModule(std::move(pdlModule)) {}
1410+ // / Construct a PDL pattern with the given module and configurations.
1411+ PDLPatternModule (OwningOpRef<ModuleOp> module )
1412+ : pdlModule(std::move(module )) {}
1413+ template <typename ... ConfigsT>
1414+ PDLPatternModule (OwningOpRef<ModuleOp> module , ConfigsT &&...patternConfigs)
1415+ : PDLPatternModule(std::move(module )) {
1416+ auto configSet = std::make_unique<PDLPatternConfigSet>(
1417+ std::forward<ConfigsT>(patternConfigs)...);
1418+ attachConfigToPatterns (*pdlModule, *configSet);
1419+ configs.emplace_back (std::move (configSet));
1420+ }
12621421
12631422 // / Merge the state in `other` into this pattern module.
12641423 void mergeIn (PDLPatternModule &&other);
@@ -1344,6 +1503,14 @@ class PDLPatternModule {
13441503 return rewriteFunctions;
13451504 }
13461505
1506+ // / Return the set of the registered pattern configs.
1507+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs () {
1508+ return std::move (configs);
1509+ }
1510+ DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap () {
1511+ return std::move (configMap);
1512+ }
1513+
13471514 // / Clear out the patterns and functions within this module.
13481515 void clear () {
13491516 pdlModule = nullptr ;
@@ -1352,9 +1519,17 @@ class PDLPatternModule {
13521519 }
13531520
13541521private:
1522+ // / Attach the given pattern config set to the patterns defined within the
1523+ // / given module.
1524+ void attachConfigToPatterns (ModuleOp module , PDLPatternConfigSet &configSet);
1525+
13551526 // / The module containing the `pdl.pattern` operations.
13561527 OwningOpRef<ModuleOp> pdlModule;
13571528
1529+ // / The set of configuration sets referenced by patterns within `pdlModule`.
1530+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
1531+ DenseMap<Operation *, PDLPatternConfigSet *> configMap;
1532+
13581533 // / The external functions referenced from within the PDL module.
13591534 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
13601535 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
0 commit comments