Skip to content

Commit

Permalink
[mlir:PDL] Fix bugs in PDLPatternModule merging
Browse files Browse the repository at this point in the history
* Constraints/Rewrites registered before a pattern was added were dropped
* Constraints/Rewrites may be registered multiple times (if different pattern sets depend on them)
* ModuleOp no longer has a terminator, so we shouldn't be removing the terminator from it

Differential Revision: https://reviews.llvm.org/D114816
  • Loading branch information
River707 committed Dec 10, 2021
1 parent 98f5bd3 commit 06c3b9c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
33 changes: 17 additions & 16 deletions mlir/lib/IR/PatternMatch.cpp
Expand Up @@ -157,22 +157,21 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Ignore the other module if it has no patterns.
if (!other.pdlModule)
return;

// Steal the functions of the other module.
for (auto &it : other.constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : other.rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));

// Steal the other state if we have no patterns.
if (!pdlModule) {
constraintFunctions = std::move(other.constraintFunctions);
rewriteFunctions = std::move(other.rewriteFunctions);
pdlModule = std::move(other.pdlModule);
return;
}
// Steal the functions of the other module.
for (auto &it : constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));

// Merge the pattern operations from the other module into this one.
Block *block = pdlModule->getBody();
block->getTerminator()->erase();
block->getOperations().splice(block->end(),
other.pdlModule->getBody()->getOperations());
}
Expand All @@ -182,18 +181,20 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {

void PDLPatternModule::registerConstraintFunction(
StringRef name, PDLConstraintFunction constraintFn) {
auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
(void)it;
assert(it.second &&
"constraint with the given name has already been registered");
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `constraintFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// constraint.
constraintFunctions.try_emplace(name, std::move(constraintFn));
}

void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) {
auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
(void)it;
assert(it.second && "native rewrite function with the given name has "
"already been registered");
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `rewriteFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// rewrite.
rewriteFunctions.try_emplace(name, std::move(rewriteFn));
}

//===----------------------------------------------------------------------===//
Expand Down
21 changes: 17 additions & 4 deletions mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Expand Up @@ -87,22 +87,35 @@ struct TestPDLByteCodePass
if (!patternModule || !irModule)
return;

RewritePatternSet patternList(module->getContext());

// Register ahead of time to test when functions are registered without a
// pattern.
patternList.getPDLPatterns().registerConstraintFunction(
"multi_entity_constraint", customMultiEntityConstraint);
patternList.getPDLPatterns().registerConstraintFunction(
"single_entity_constraint", customSingleEntityConstraint);

// Process the pattern module.
patternModule.getOperation()->remove();
PDLPatternModule pdlPattern(patternModule);

// Note: This constraint was already registered, but we re-register here to
// ensure that duplication registration is allowed (the duplicate mapping
// will be ignored). This tests that we support separating the registration
// of library functions from the construction of patterns, and also that we
// allow multiple patterns to depend on the same library functions (without
// asserting/crashing).
pdlPattern.registerConstraintFunction("multi_entity_constraint",
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("single_entity_constraint",
customSingleEntityConstraint);
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
customMultiEntityVariadicConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);

RewritePatternSet patternList(std::move(pdlPattern));
patternList.add(std::move(pdlPattern));

// Invoke the pattern driver with the provided patterns.
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
Expand Down

0 comments on commit 06c3b9c

Please sign in to comment.