253 changes: 169 additions & 84 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Large diffs are not rendered by default.

62 changes: 52 additions & 10 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ enum Kind : unsigned {

// Answers.
AttributeAnswer,
TrueAnswer,
FalseAnswer,
OperationNameAnswer,
TrueAnswer,
TypeAnswer,
UnsignedAnswer,
};
Expand Down Expand Up @@ -216,24 +217,45 @@ struct OperandGroupPosition

/// An operation position describes an operation node in the IR. Other position
/// kinds are formed with respect to an operation position.
struct OperationPosition : public PredicateBase<OperationPosition, Position,
std::pair<Position *, unsigned>,
Predicates::OperationPos> {
struct OperationPosition
: public PredicateBase<OperationPosition, Position,
std::tuple<Position *, Optional<unsigned>, unsigned>,
Predicates::OperationPos> {
static constexpr unsigned kDown = std::numeric_limits<unsigned>::max();

explicit OperationPosition(const KeyTy &key) : Base(key) {
parent = key.first;
parent = std::get<0>(key);
}

/// Returns a hash suitable for the given keytype.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}

/// Gets the root position.
static OperationPosition *getRoot(StorageUniquer &uniquer) {
return Base::get(uniquer, nullptr, 0);
return Base::get(uniquer, nullptr, kDown, 0);
}
/// Gets an operation position with the given parent.

/// Gets an downward operation position with the given parent.
static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
return Base::get(uniquer, parent, kDown, parent->getOperationDepth() + 1);
}

/// Gets an upward operation position with the given parent and operand.
static OperationPosition *get(StorageUniquer &uniquer, Position *parent,
Optional<unsigned> operand) {
return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1);
}

/// Returns the operand index for an upward operation position.
Optional<unsigned> getIndex() const { return std::get<1>(key); }

/// Returns if this operation position is upward, accepting an input.
bool isUpward() const { return getIndex().getValueOr(0) != kDown; }

/// Returns the depth of this position.
unsigned getDepth() const { return key.second; }
unsigned getDepth() const { return std::get<2>(key); }

/// Returns if this operation position corresponds to the root.
bool isRoot() const { return getDepth() == 0; }
Expand Down Expand Up @@ -346,6 +368,12 @@ struct TrueAnswer
using Base::Base;
};

/// An Answer representing a boolean 'false' value.
struct FalseAnswer
: PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
using Base::Base;
};

/// An Answer representing a `Type` value. The value is stored as either a
/// TypeAttr, or an ArrayAttr of TypeAttr.
struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
Expand Down Expand Up @@ -445,6 +473,7 @@ class PredicateUniquer : public StorageUniquer {
registerParametricStorageType<OperationNameAnswer>();
registerParametricStorageType<TypeAnswer>();
registerParametricStorageType<UnsignedAnswer>();
registerSingletonStorageType<FalseAnswer>();
registerSingletonStorageType<TrueAnswer>();

// Register the types of Answers with the uniquer.
Expand Down Expand Up @@ -485,6 +514,14 @@ class PredicateBuilder {
return OperationPosition::get(uniquer, p);
}

/// Returns the position of operation using the value at the given index.
OperationPosition *getUsersOp(Position *p, Optional<unsigned> operand) {
assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
ResultGroupPosition>(p)) &&
"expected result position");
return OperationPosition::get(uniquer, p, operand);
}

/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
Expand Down Expand Up @@ -536,11 +573,16 @@ class PredicateBuilder {
AttributeAnswer::get(uniquer, attr)};
}

/// Create a predicate comparing two values.
/// Create a predicate checking if two values are equal.
Predicate getEqualTo(Position *pos) {
return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)};
}

/// Create a predicate checking if two values are not equal.
Predicate getNotEqualTo(Position *pos) {
return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)};
}

/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
Attribute params) {
Expand Down
361 changes: 340 additions & 21 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ struct ExitNode : public MatcherNode {
/// matched. This does not terminate the matcher, as there may be multiple
/// successful matches.
struct SuccessNode : public MatcherNode {
explicit SuccessNode(pdl::PatternOp pattern,
explicit SuccessNode(pdl::PatternOp pattern, Value root,
std::unique_ptr<MatcherNode> failureNode);

/// Returns if the given matcher node is an instance of this class, used to
Expand All @@ -164,10 +164,16 @@ struct SuccessNode : public MatcherNode {
/// Return the high level pattern operation that is matched with this node.
pdl::PatternOp getPattern() const { return pattern; }

/// Return the chosen root of the pattern.
Value getRoot() const { return root; }

private:
/// The high level pattern operation that was successfully matched with this
/// node.
pdl::PatternOp pattern;

/// The chosen root of the pattern.
Value root;
};

//===----------------------------------------------------------------------===//
Expand Down
229 changes: 229 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
//===- RootOrdering.cpp - Optimal root ordering ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// An implementation of Edmonds' optimal branching algorithm. This is a
// directed analogue of the minimum spanning tree problem for a given root.
//
//===----------------------------------------------------------------------===//

#include "RootOrdering.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include <queue>
#include <utility>

using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;

/// Returns the cycle implied by the specified parent relation, starting at the
/// given node.
static SmallVector<Value> getCycle(const DenseMap<Value, Value> &parents,
Value rep) {
SmallVector<Value> cycle;
Value node = rep;
do {
cycle.push_back(node);
node = parents.lookup(node);
assert(node && "got an empty value in the cycle");
} while (node != rep);
return cycle;
}

/// Contracts the specified cycle in the given graph in-place.
/// The parentsCost map specifies, for each node in the cycle, the lowest cost
/// among the edges entering that node. Then, the nodes in the cycle C are
/// replaced with a single node v_C (the first node in the cycle). All edges
/// (u, v) entering the cycle, v \in C, are replaced with a single edge
/// (u, v_C) with an appropriately chosen cost, and the selected node v is
/// marked in the output map actualTarget[u]. All edges (u, v) leaving the
/// cycle, u \in C, are replaced with a single edge (v_C, v), and the selected
/// node u is marked in the ouptut map actualSource[v].
static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle,
const DenseMap<Value, unsigned> &parentCosts,
DenseMap<Value, Value> &actualSource,
DenseMap<Value, Value> &actualTarget) {
Value rep = cycle.front();
DenseSet<Value> cycleSet(cycle.begin(), cycle.end());

// Now, contract the cycle, marking the actual sources and targets.
DenseMap<Value, RootOrderingCost> repCosts;
for (auto outer = graph.begin(), e = graph.end(); outer != e; ++outer) {
Value target = outer->first;
if (cycleSet.contains(target)) {
// Target in the cycle => edges incoming to the cycle or within the cycle.
unsigned parentCost = parentCosts.lookup(target);
for (const auto &inner : outer->second) {
Value source = inner.first;
// Ignore edges within the cycle.
if (cycleSet.contains(source))
continue;

// Edge incoming to the cycle.
std::pair<unsigned, unsigned> cost = inner.second.cost;
assert(parentCost <= cost.first && "invalid parent cost");

// Subtract the cost of the parent within the cycle from the cost of
// the edge incoming to the cycle. This update ensures that the cost
// of the minimum-weight spanning arborescence of the entire graph is
// the cost of arborescence for the contracted graph plus the cost of
// the cycle, no matter which edge in the cycle we choose to drop.
cost.first -= parentCost;
auto it = repCosts.find(source);
if (it == repCosts.end() || it->second.cost > cost) {
actualTarget[source] = target;
// Do not bother populating the connector (the connector is only
// relevant for the final traversal, not for the optimal branching).
repCosts[source].cost = cost;
}
}
// Erase the node in the cycle.
graph.erase(outer);
} else {
// Target not in cycle => edges going away from or unrelated to the cycle.
DenseMap<Value, RootOrderingCost> &costs = outer->second;
Value bestSource;
std::pair<unsigned, unsigned> bestCost;
auto inner = costs.begin(), inner_e = costs.end();
while (inner != inner_e) {
Value source = inner->first;
if (cycleSet.contains(source)) {
// Going-away edge => get its cost and erase it.
if (!bestSource || bestCost > inner->second.cost) {
bestSource = source;
bestCost = inner->second.cost;
}
costs.erase(inner++);
} else {
++inner;
}
}

// There were going-away edges, contract them.
if (bestSource) {
costs[rep].cost = bestCost;
actualSource[target] = bestSource;
}
}
}

// Store the edges to the representative.
graph[rep] = std::move(repCosts);
}

OptimalBranching::OptimalBranching(RootOrderingGraph graph, Value root)
: graph(std::move(graph)), root(root) {}

unsigned OptimalBranching::solve() {
// Initialize the parents and total cost.
parents.clear();
parents[root] = Value();
unsigned totalCost = 0;

// A map that stores the cost of the optimal local choice for each node
// in a directed cycle. This map is cleared every time we seed the search.
DenseMap<Value, unsigned> parentCosts;
parentCosts.reserve(graph.size());

// Determine if the optimal local choice results in an acyclic graph. This is
// done by computing the optimal local choice and traversing up the computed
// parents. On success, `parents` will contain the parent of each node.
for (const auto &outer : graph) {
Value node = outer.first;
if (parents.count(node)) // already visited
continue;

// Follow the trail of best sources until we reach an already visited node.
// The code will assert if we cannot reach an already visited node, i.e.,
// the graph is not strongly connected.
parentCosts.clear();
do {
auto it = graph.find(node);
assert(it != graph.end() && "the graph is not strongly connected");

Value &bestSource = parents[node];
unsigned &bestCost = parentCosts[node];
for (const auto &inner : it->second) {
const RootOrderingCost &cost = inner.second;
if (!bestSource /* initial */ || bestCost > cost.cost.first) {
bestSource = inner.first;
bestCost = cost.cost.first;
}
}
assert(bestSource && "the graph is not strongly connected");
node = bestSource;
totalCost += bestCost;
} while (!parents.count(node));

// If we reached a non-root node, we have a cycle.
if (parentCosts.count(node)) {
// Determine the cycle starting at the representative node.
SmallVector<Value> cycle = getCycle(parents, node);

// The following maps disambiguate the source / target of the edges
// going out of / into the cycle.
DenseMap<Value, Value> actualSource, actualTarget;

// Contract the cycle and recurse.
contract(graph, cycle, parentCosts, actualSource, actualTarget);
totalCost = solve();

// Redirect the going-away edges.
for (auto &p : parents)
if (p.second == node)
// The parent is the node representating the cycle; replace it
// with the actual (best) source in the cycle.
p.second = actualSource.lookup(p.first);

// Redirect the unique incoming edge and copy the cycle.
Value parent = parents.lookup(node);
Value entry = actualTarget.lookup(parent);
cycle.push_back(node); // complete the cycle
for (size_t i = 0, e = cycle.size() - 1; i < e; ++i) {
totalCost += parentCosts.lookup(cycle[i]);
if (cycle[i] == entry)
parents[cycle[i]] = parent; // break the cycle
else
parents[cycle[i]] = cycle[i + 1];
}

// `parents` has a complete solution.
break;
}
}

return totalCost;
}

OptimalBranching::EdgeList
OptimalBranching::preOrderTraversal(ArrayRef<Value> nodes) const {
// Invert the parent mapping.
DenseMap<Value, std::vector<Value>> children;
for (Value node : nodes) {
if (node != root) {
Value parent = parents.lookup(node);
assert(parent && "invalid parent");
children[parent].push_back(node);
}
}

// The result which simultaneously acts as a queue.
EdgeList result;
result.reserve(nodes.size());
result.emplace_back(root, Value());

// Perform a BFS, pushing into the queue.
for (size_t i = 0; i < result.size(); ++i) {
Value node = result[i].first;
for (Value child : children[node])
result.emplace_back(child, node);
}

return result;
}
137 changes: 137 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//===- RootOrdering.h - Optimal root ordering ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definition for a cost graph over candidate roots and
// an implementation of an algorithm to determine the optimal ordering over
// these roots. Each edge in this graph indicates that the target root can be
// connected (via a chain of positions) to the source root, and their cost
// indicates the estimated cost of such traversal. The optimal root ordering
// is then formulated as that of finding a spanning arborescence (i.e., a
// directed spanning tree) of minimal weight.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_
#define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_

#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <functional>
#include <vector>

namespace mlir {
namespace pdl_to_pdl_interp {

/// The information associated with an edge in the cost graph. Each node in
/// the cost graph corresponds to a candidate root detected in the pdl.pattern,
/// and each edge in the cost graph corresponds to connecting the two candidate
/// roots via a chain of operations. The cost of an edge is the smallest number
/// of upward traversals required to go from the source to the target root, and
/// the connector is a `Value` in the intersection of the two subtrees rooted at
/// the source and target root that results in that smallest number of upward
/// traversals. Consider the following pattern with 3 roots op3, op4, and op5:
///
/// argA ---> op1 ---> op2 ---> op3 ---> res3
/// ^ ^
/// | |
/// argB argC
/// | |
/// v v
/// res4 <--- op4 op5 ---> res5
/// ^ ^
/// | |
/// op6 op7
///
/// The cost of the edge op3 -> op4 is 1 (the upward traversal argB -> op4),
/// with argB being the connector `Value` and similarly for op3 -> op5 (cost 1,
/// connector argC). The cost of the edge op4 -> op3 is 3 (upward traversals
/// argB -> op1 -> op2 -> op3, connector argB), while the cost of edge op5 ->
/// op3 is 2 (uwpard traversals argC -> op2 -> op3). There are no edges between
/// op4 and op5 in the cost graph, because the subtrees rooted at these two
/// roots do not intersect. It is easy to see that the optimal root for this
/// pattern is op3, resulting in the spanning arborescence op3 -> {op4, op5}.
struct RootOrderingCost {
/// The depth of the connector `Value` w.r.t. the target root.
///
/// This is a pair where the first entry is the actual cost, and the second
/// entry is a priority for breaking ties (with 0 being the highest).
/// Typically, the priority is a unique edge ID.
std::pair<unsigned, unsigned> cost;

/// The connector value in the intersection of the two subtrees rooted at
/// the source and target root that results in that smallest depth w.r.t.
/// the target root.
Value connector;
};

/// A directed graph representing the cost of ordering the roots in the
/// predicate tree. It is represented as an adjacency map, where the outer map
/// is indexed by the target node, and the inner map is indexed by the source
/// node. Each edge is associated with a cost and the underlying connector
/// value.
using RootOrderingGraph = DenseMap<Value, DenseMap<Value, RootOrderingCost>>;

/// The optimal branching algorithm solver. This solver accepts a graph and the
/// root in its constructor, and is invoked via the solve() member function.
/// This is a direct implementation of the Edmonds' algorithm, see
/// https://en.wikipedia.org/wiki/Edmonds%27_algorithm. The worst-case
/// computational complexity of this algorithm is O(N^3), for a single root.
/// The PDL-to-PDLInterp lowering calls this N times (once for each candidate
/// root), so the overall complexity root ordering is O(N^4). If needed, this
/// could be reduced to O(N^3) with a more efficient algorithm. However, note
/// that the underlying implementation is very efficient, and N in our
/// instances tends to be very small (<10).
class OptimalBranching {
public:
/// A list of edges (child, parent).
using EdgeList = std::vector<std::pair<Value, Value>>;

/// Constructs the solver for the given graph and root value.
OptimalBranching(RootOrderingGraph graph, Value root);

/// Runs the Edmonds' algorithm for the current `graph`, returning the total
/// cost of the minimum-weight spanning arborescence (sum of the edge costs).
/// This function first determines the optimal local choice of the parents
/// and stores this choice in the `parents` mapping. If this choice results
/// in an acyclic graph, the function returns immediately. Otherwise, it
/// takes an arbitrary cycle, contracts it, and recurses on the new graph
/// (which is guaranteed to have fewer nodes than we began with). After we
/// return from recursion, we redirect the edges to/from the contracted node,
/// so the `parents` map contains a valid solution for the current graph.
unsigned solve();

/// Returns the computed parent map. This is the unique predecessor for each
/// node (root) in the optimal branching.
const DenseMap<Value, Value> &getRootOrderingParents() const {
return parents;
}

/// Returns the computed edges as visited in the preorder traversal.
/// The specified array determines the order for breaking any ties.
EdgeList preOrderTraversal(ArrayRef<Value> nodes) const;

private:
/// The graph whose optimal branching we wish to determine.
RootOrderingGraph graph;

/// The root of the optimal branching.
Value root;

/// The computed parent mapping. This is the unique predecessor for each node
/// in the optimal branching. The keys of this map correspond to the keys of
/// the outer map of the input graph, and each value is one of the keys of
/// the inner map for this node. Also used as an intermediate (possibly
/// cyclical) result in the optimal branching algorithm.
DenseMap<Value, Value> parents;
};

} // end namespace pdl_to_pdl_interp
} // end namespace mlir

#endif // MLIR_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_
166 changes: 106 additions & 60 deletions mlir/lib/Dialect/PDL/IR/PDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::pdl;
Expand All @@ -34,41 +35,55 @@ void PDLDialect::initialize() {
// PDL Operations
//===----------------------------------------------------------------------===//

/// Returns true if the given operation is used by a "binding" pdl operation
/// within the main matcher body of a `pdl.pattern`.
static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) {
for (OpOperand &use : op->getUses()) {
Operation *user = use.getOwner();
if (user->getBlock() != matcherBlock)
continue;
if (isa<AttributeOp, OperandOp, OperandsOp, OperationOp>(user))
return true;
// Only the first operand of RewriteOp may be bound to, i.e. the root
// operation of the pattern.
if (isa<RewriteOp>(user) && use.getOperandNumber() == 0)
return true;
/// Returns true if the given operation is used by a "binding" pdl operation.
static bool hasBindingUse(Operation *op) {
for (Operation *user : op->getUsers())
// A result by itself is not binding, it must also be bound.
if (isa<ResultOp, ResultsOp>(user) &&
hasBindingUseInMatcher(user, matcherBlock))
if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user))
return true;
}
return false;
}

/// Returns success if the given operation is used by a "binding" pdl operation
/// within the main matcher body of a `pdl.pattern`. On failure, emits an error
/// with the given context message.
static LogicalResult
verifyHasBindingUseInMatcher(Operation *op,
StringRef bindableContextStr = "`pdl.operation`") {
// If the pattern is not a pattern, there is nothing to do.
/// Returns success if the given operation is not in the main matcher body or
/// is used by a "binding" operation. On failure, emits an error.
static LogicalResult verifyHasBindingUse(Operation *op) {
// If the parent is not a pattern, there is nothing to do.
if (!isa<PatternOp>(op->getParentOp()))
return success();
if (hasBindingUseInMatcher(op, op->getBlock()))
if (hasBindingUse(op))
return success();
return op->emitOpError()
<< "expected a bindable (i.e. " << bindableContextStr
<< ") user when defined in the matcher body of a `pdl.pattern`";
return op->emitOpError(
"expected a bindable user when defined in the matcher body of a "
"`pdl.pattern`");
}

/// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
/// connected to the given operation.
static void visit(Operation *op, DenseSet<Operation *> &visited) {
// If the parent is not a pattern, there is nothing to do.
if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op))
return;

// Ignore if already visited.
if (visited.contains(op))
return;

// Mark as visited.
visited.insert(op);

// Traverse the operands / parent.
TypeSwitch<Operation *>(op)
.Case<OperationOp>([&visited](auto operation) {
for (Value operand : operation.operands())
visit(operand.getDefiningOp(), visited);
})
.Case<ResultOp, ResultsOp>([&visited](auto result) {
visit(result.parent().getDefiningOp(), visited);
});

// Traverse the users.
for (Operation *user : op->getUsers())
visit(user, visited);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -104,24 +119,20 @@ static LogicalResult verify(AttributeOp op) {
"`pdl.rewrite`");
if (attrValue && attrType)
return op.emitOpError("expected only one of [`type`, `value`] to be set");
return verifyHasBindingUseInMatcher(op);
return verifyHasBindingUse(op);
}

//===----------------------------------------------------------------------===//
// pdl::OperandOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(OperandOp op) {
return verifyHasBindingUseInMatcher(op);
}
static LogicalResult verify(OperandOp op) { return verifyHasBindingUse(op); }

//===----------------------------------------------------------------------===//
// pdl::OperandsOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(OperandsOp op) {
return verifyHasBindingUseInMatcher(op);
}
static LogicalResult verify(OperandsOp op) { return verifyHasBindingUse(op); }

//===----------------------------------------------------------------------===//
// pdl::OperationOp
Expand Down Expand Up @@ -237,7 +248,7 @@ static LogicalResult verify(OperationOp op) {
return failure();
}

return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`");
return verifyHasBindingUse(op);
}

bool OperationOp::hasTypeInference() {
Expand All @@ -256,15 +267,16 @@ bool OperationOp::hasTypeInference() {

static LogicalResult verify(PatternOp pattern) {
Region &body = pattern.body();
auto *term = body.front().getTerminator();
if (!isa<RewriteOp>(term)) {
Operation *term = body.front().getTerminator();
auto rewrite_op = dyn_cast<RewriteOp>(term);
if (!rewrite_op) {
return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
.attachNote(term->getLoc())
.append("see terminator defined here");
}

// Check that all values defined in the top-level pattern are referenced at
// least once in the source tree.
// Check that all values defined in the top-level pattern belong to the PDL
// dialect.
WalkResult result = body.walk([&](Operation *op) -> WalkResult {
if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
pattern
Expand All @@ -275,15 +287,61 @@ static LogicalResult verify(PatternOp pattern) {
}
return WalkResult::advance();
});
return failure(result.wasInterrupted());
if (result.wasInterrupted())
return failure();

// Check that there is at least one operation.
if (body.front().getOps<OperationOp>().empty())
return pattern.emitOpError(
"the pattern must contain at least one `pdl.operation`");

// Determine if the operations within the pdl.pattern form a connected
// component. This is determined by starting the search from the first
// operand/result/operation and visiting their users / parents / operands.
// We limit our attention to operations that have a user in pdl.rewrite,
// those that do not will be detected via other means (expected bindable
// user).
bool first = true;
DenseSet<Operation *> visited;
for (Operation &op : body.front()) {
// The following are the operations forming the connected component.
if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
continue;

// Determine if the operation has a user in `pdl.rewrite`.
bool hasUserInRewrite = false;
for (Operation *user : op.getUsers()) {
Region *region = user->getParentRegion();
if (isa<RewriteOp>(user) ||
(region && isa<RewriteOp>(region->getParentOp()))) {
hasUserInRewrite = true;
break;
}
}

// If the operation does not have a user in `pdl.rewrite`, ignore it.
if (!hasUserInRewrite)
continue;

if (first) {
// For the first operation, invoke visit.
visit(&op, visited);
first = false;
} else if (!visited.count(&op)) {
// For the subsequent operations, check if already visited.
return pattern
.emitOpError("the operations must form a connected component")
.attachNote(op.getLoc())
.append("see a disconnected value / operation here");
}
}

return success();
}

void PatternOp::build(OpBuilder &builder, OperationState &state,
Optional<StringRef> rootKind, Optional<uint16_t> benefit,
Optional<StringRef> name) {
build(builder, state,
rootKind ? builder.getStringAttr(*rootKind) : StringAttr(),
builder.getI16IntegerAttr(benefit ? *benefit : 0),
Optional<uint16_t> benefit, Optional<StringRef> name) {
build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0),
name ? builder.getStringAttr(*name) : StringAttr());
state.regions[0]->emplaceBlock();
}
Expand All @@ -293,13 +351,6 @@ RewriteOp PatternOp::getRewriter() {
return cast<RewriteOp>(body().front().getTerminator());
}

/// Return the root operation kind that this pattern matches, or None if
/// there isn't a specific root.
Optional<StringRef> PatternOp::getRootKind() {
OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp());
return rootOp.name();
}

//===----------------------------------------------------------------------===//
// pdl::ReplaceOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -380,18 +431,13 @@ static LogicalResult verify(RewriteOp op) {
// pdl::TypeOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(TypeOp op) {
return verifyHasBindingUseInMatcher(
op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`");
}
static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); }

//===----------------------------------------------------------------------===//
// pdl::TypesOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(TypesOp op) {
return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`");
}
static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); }

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,29 @@ void PDLValue::print(raw_ostream &os) const {
}
}

void PDLValue::print(raw_ostream &os, Kind kind) {
switch (kind) {
case Kind::Attribute:
os << "Attribute";
break;
case Kind::Operation:
os << "Operation";
break;
case Kind::Type:
os << "Type";
break;
case Kind::TypeRange:
os << "TypeRange";
break;
case Kind::Value:
os << "Value";
break;
case Kind::ValueRange:
os << "ValueRange";
break;
}
}

//===----------------------------------------------------------------------===//
// PDLPatternModule
//===----------------------------------------------------------------------===//
Expand Down
453 changes: 391 additions & 62 deletions mlir/lib/Rewrite/ByteCode.cpp

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions mlir/lib/Rewrite/ByteCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class PDLByteCode;
/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;
using OwningOpRange = llvm::OwningArrayRef<Operation *>;

//===----------------------------------------------------------------------===//
// PDLByteCodePattern
Expand Down Expand Up @@ -79,6 +80,12 @@ class PDLByteCodeMutableState {
/// of the bytecode.
std::vector<const void *> memory;

/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of operations. These are always stored by
/// owning references, because at no point in the execution of the byte code
/// we get an indexed range (view) of operations.
std::vector<OwningOpRange> opRangeMemory;

/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of types.
std::vector<TypeRange> typeRangeMemory;
Expand All @@ -93,6 +100,11 @@ class PDLByteCodeMutableState {
/// interpreter to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;

/// The current index of ranges being iterated over for each level of nesting.
/// These are always maintained at 0 for the loops that are not active, so we
/// do not need to have a separate initialization phase for each loop.
std::vector<unsigned> loopIndex;

/// The up-to-date benefits of the patterns held by the bytecode. The order
/// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
std::vector<PatternBenefit> currentPatternBenefits;
Expand Down Expand Up @@ -188,8 +200,12 @@ class PDLByteCode {
ByteCodeField maxValueMemoryIndex = 0;

/// The maximum number of different types of ranges.
ByteCodeField maxOpRangeCount = 0;
ByteCodeField maxTypeRangeCount = 0;
ByteCodeField maxValueRangeCount = 0;

/// The maximum number of nested loops.
ByteCodeField maxLoopLevel = 0;
};

} // end namespace detail
Expand Down
167 changes: 166 additions & 1 deletion mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ module @switch_result_count_at_least {
// -----

// CHECK-LABEL: module @predicate_ordering
module @predicate_ordering {
module @predicate_ordering {
// Check that the result is checked for null first, before applying the
// constraint. The null check is prevalent in both patterns, so should be
// prioritized first.
Expand All @@ -408,3 +408,168 @@ module @predicate_ordering {
pdl.rewrite %apply with "rewriter"
}
}


// -----

// CHECK-LABEL: module @multi_root
module @multi_root {
// Check the lowering of a simple two-root pattern.
// This checks that we correctly generate the pdl_interp.choose_op operation
// and tie the break between %root1 and %root2 in favor of %root1.

// CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation)
// CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[ROOT1]]
// CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]]
// CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL1]] : !pdl.value
// CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[OPS]]
// CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 0 of %[[ROOT2]]
// CHECK-DAG: pdl_interp.are_equal %[[VAL1]], %[[OPERANDS]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]]
// CHECK-DAG: pdl_interp.continue
// CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]]
// CHECK-DAG: %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]]
// CHECK-DAG: pdl_interp.is_not_null %[[OP1]] : !pdl.operation -> ^{{.*}}, ^[[CONTINUE]]
// CHECK-DAG: pdl_interp.is_not_null %[[OP2]] : !pdl.operation
// CHECK-DAG: pdl_interp.is_not_null %[[VAL1]] : !pdl.value
// CHECK-DAG: pdl_interp.is_not_null %[[VAL2]] : !pdl.value
// CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation
// CHECK-DAG: pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[CONTINUE]]

pdl.pattern @rewrite_multi_root : benefit(1) {
%input1 = pdl.operand
%input2 = pdl.operand
%type = pdl.type
%op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type)
%val1 = pdl.result 0 of %op1
%root1 = pdl.operation(%val1 : !pdl.value)
%op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type)
%val2 = pdl.result 0 of %op2
%root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value)
pdl.rewrite %root1 with "rewriter"(%root2 : !pdl.operation)
}
}


// -----

// CHECK-LABEL: module @overlapping_roots
module @overlapping_roots {
// Check the lowering of a degenerate two-root pattern, where one root
// is in the subtree rooted at another.

// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL]]
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[OP]]
// CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[OP]]
// CHECK-DAG: pdl_interp.is_not_null %[[VAL]] : !pdl.value
// CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation
// CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value
// CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value

pdl.pattern @rewrite_overlapping_roots : benefit(1) {
%input1 = pdl.operand
%input2 = pdl.operand
%type = pdl.type
%op = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type)
%val = pdl.result 0 of %op
%root = pdl.operation(%val : !pdl.value)
pdl.rewrite with "rewriter"(%root : !pdl.operation)
}
}

// -----

// CHECK-LABEL: module @force_overlapped_root
module @force_overlapped_root {
// Check the lowering of a degenerate two-root pattern, where one root
// is in the subtree rooted at another, and we are forced to use this
// root as the root of the search tree.

// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 2
// CHECK-DAG: pdl_interp.check_result_count of %[[ROOT]] is 1
// CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value
// CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL]] : !pdl.value
// CHECK-DAG: pdl_interp.foreach %[[OP:.*]] : !pdl.operation in %[[OPS]]
// CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation
// CHECK-DAG: pdl_interp.check_operand_count of %[[OP]] is 1

pdl.pattern @rewrite_forced_overlapped_root : benefit(1) {
%input1 = pdl.operand
%input2 = pdl.operand
%type = pdl.type
%root = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type)
%val = pdl.result 0 of %root
%op = pdl.operation(%val : !pdl.value)
pdl.rewrite %root with "rewriter"(%op : !pdl.operation)
}
}

// -----

// CHECK-LABEL: module @variadic_results_all
module @variadic_results_all {
// Check the correct lowering when using all results of an operation
// and passing it them as operands to another operation.

// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 0
// CHECK-DAG: %[[VALS:.*]] = pdl_interp.get_results of %[[ROOT]] : !pdl.range<value>
// CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS]]
// CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
// CHECK-DAG: pdl_interp.foreach %[[OP:.*]] : !pdl.operation in %[[OPS]]
// CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands of %[[OP]]
// CHECK-DAG pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] -> ^{{.*}}, ^[[CONTINUE:.*]]
// CHECK-DAG: pdl_interp.is_not_null %[[OP]]
// CHECK-DAG: pdl_interp.check_result_count of %[[OP]] is 0
pdl.pattern @variadic_results_all : benefit(1) {
%types = pdl.types
%root = pdl.operation -> (%types : !pdl.range<type>)
%vals = pdl.results of %root
%op = pdl.operation(%vals : !pdl.range<value>)
pdl.rewrite %root with "rewriter"(%op : !pdl.operation)
}
}

// -----

// CHECK-LABEL: module @variadic_results_at
module @variadic_results_at {
// Check the correct lowering when using selected results of an operation
// and passing it them as an operand to another operation.

// CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation)
// CHECK-DAG: %[[VALS:.*]] = pdl_interp.get_operands 0 of %[[ROOT1]] : !pdl.range<value>
// CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VALS]] : !pdl.range<value>
// CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation
// CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT1]] is at_least 1
// CHECK-DAG: pdl_interp.check_result_count of %[[ROOT1]] is 0
// CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operands 1 of %[[ROOT1]] : !pdl.value
// CHECK-DAG: pdl_interp.is_not_null %[[VAL]]
// CHECK-DAG: pdl_interp.is_not_null %[[VALS]]
// CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS]]
// CHECK-DAG: %[[ROOTS2:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value
// CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[ROOTS2]] {
// CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 1 of %[[ROOT2]]
// CHECK-DAG: pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] : !pdl.range<value> -> ^{{.*}}, ^[[CONTINUE:.*]]
// CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]]
// CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT2]] is at_least 1
// CHECK-DAG: pdl_interp.check_result_count of %[[ROOT2]] is 0
// CHECK-DAG: pdl_interp.check_operand_count of %[[OP]] is 0
// CHECK-DAG: pdl_interp.check_result_count of %[[OP]] is at_least 1
pdl.pattern @variadic_results_at : benefit(1) {
%type = pdl.type
%types = pdl.types
%val = pdl.operand
%op = pdl.operation -> (%types, %type : !pdl.range<type>, !pdl.type)
%vals = pdl.results 0 of %op -> !pdl.range<value>
%root1 = pdl.operation(%vals, %val : !pdl.range<value>, !pdl.value)
%root2 = pdl.operation(%val, %vals : !pdl.value, !pdl.range<value>)
pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
}
}
61 changes: 51 additions & 10 deletions mlir/test/Dialect/PDL/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pdl.pattern : benefit(1) {
// -----

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
// expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.attribute

%op = pdl.operation "foo.op"
Expand All @@ -81,7 +81,7 @@ pdl.pattern : benefit(1) {
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
// expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.operand

%op = pdl.operation "foo.op"
Expand All @@ -95,7 +95,7 @@ pdl.pattern : benefit(1) {
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
// expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.operands

%op = pdl.operation "foo.op"
Expand Down Expand Up @@ -143,7 +143,7 @@ pdl.pattern : benefit(1) {
// -----

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.operation` or `pdl.rewrite`) user when defined in the matcher body of a `pdl.pattern`}}
// expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.operation "foo.op"

%op = pdl.operation "foo.op"
Expand All @@ -164,6 +164,12 @@ pdl.pattern : benefit(1) {

// -----

// expected-error@below {{the pattern must contain at least one `pdl.operation`}}
pdl.pattern : benefit(1) {
pdl.rewrite with "foo"
}

// -----
// expected-error@below {{expected only `pdl` operations within the pattern body}}
pdl.pattern : benefit(1) {
// expected-note@below {{see non-`pdl` operation defined here}}
Expand All @@ -173,6 +179,32 @@ pdl.pattern : benefit(1) {
pdl.rewrite %root with "foo"
}

// -----
// expected-error@below {{the operations must form a connected component}}
pdl.pattern : benefit(1) {
%op1 = pdl.operation "foo.op"
%op2 = pdl.operation "bar.op"
// expected-note@below {{see a disconnected value / operation here}}
%val = pdl.result 0 of %op2
pdl.rewrite %op1 with "foo"(%val : !pdl.value)
}

// -----
// expected-error@below {{the operations must form a connected component}}
pdl.pattern : benefit(1) {
%type = pdl.type
%op1 = pdl.operation "foo.op" -> (%type : !pdl.type)
%val = pdl.result 0 of %op1
%op2 = pdl.operation "bar.op"(%val : !pdl.value)
// expected-note@below {{see a disconnected value / operation here}}
%op3 = pdl.operation "baz.op"
pdl.rewrite {
pdl.erase %op1
pdl.erase %op2
pdl.erase %op3
}
}

// -----

pdl.pattern : benefit(1) {
Expand Down Expand Up @@ -212,7 +244,9 @@ pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"

// expected-error@below {{expected rewrite region to be non-empty if external name is not specified}}
"pdl.rewrite"(%op) ({}) : (!pdl.operation) -> ()
"pdl.rewrite"(%op) ({}) {
operand_segment_sizes = dense<[1,0]> : vector<2xi32>
} : (!pdl.operation) -> ()
}

// -----
Expand All @@ -223,7 +257,9 @@ pdl.pattern : benefit(1) {
// expected-error@below {{expected no external arguments when the rewrite is specified inline}}
"pdl.rewrite"(%op, %op) ({
^bb1:
}) : (!pdl.operation, !pdl.operation) -> ()
}) {
operand_segment_sizes = dense<1> : vector<2xi32>
}: (!pdl.operation, !pdl.operation) -> ()
}

// -----
Expand All @@ -234,7 +270,9 @@ pdl.pattern : benefit(1) {
// expected-error@below {{expected no external constant parameters when the rewrite is specified inline}}
"pdl.rewrite"(%op) ({
^bb1:
}) {externalConstParams = []} : (!pdl.operation) -> ()
}) {
operand_segment_sizes = dense<[1,0]> : vector<2xi32>,
externalConstParams = []} : (!pdl.operation) -> ()
}

// -----
Expand All @@ -245,7 +283,10 @@ pdl.pattern : benefit(1) {
// expected-error@below {{expected rewrite region to be empty when rewrite is external}}
"pdl.rewrite"(%op) ({
^bb1:
}) {name = "foo"} : (!pdl.operation) -> ()
}) {
name = "foo",
operand_segment_sizes = dense<[1,0]> : vector<2xi32>
} : (!pdl.operation) -> ()
}

// -----
Expand All @@ -255,7 +296,7 @@ pdl.pattern : benefit(1) {
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.attribute`, `pdl.operand`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
// expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.type

%op = pdl.operation "foo.op"
Expand All @@ -269,7 +310,7 @@ pdl.pattern : benefit(1) {
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.operands`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
// expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.types

%op = pdl.operation "foo.op"
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/PDL/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,36 @@ pdl.pattern @rewrite_with_args_and_params : benefit(1) {

// -----

pdl.pattern @rewrite_multi_root_optimal : benefit(2) {
%input1 = pdl.operand
%input2 = pdl.operand
%type = pdl.type
%op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type)
%val1 = pdl.result 0 of %op1
%root1 = pdl.operation(%val1 : !pdl.value)
%op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type)
%val2 = pdl.result 0 of %op2
%root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value)
pdl.rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation)
}

// -----

pdl.pattern @rewrite_multi_root_forced : benefit(2) {
%input1 = pdl.operand
%input2 = pdl.operand
%type = pdl.type
%op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type)
%val1 = pdl.result 0 of %op1
%root1 = pdl.operation(%val1 : !pdl.value)
%op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type)
%val2 = pdl.result 0 of %op2
%root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value)
pdl.rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation)
}

// -----

// Check that the result type of an operation within a rewrite can be inferred
// from a pdl.replace.
pdl.pattern @infer_type_from_operation_replace : benefit(1) {
Expand Down
271 changes: 271 additions & 0 deletions mlir/test/Rewrite/pdl-bytecode.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,12 @@ module @ir attributes { test.check_types_1 } {

// -----

//===----------------------------------------------------------------------===//
// pdl_interp::ContinueOp
//===----------------------------------------------------------------------===//

// Fully tested within the tests for other operations.

//===----------------------------------------------------------------------===//
// pdl_interp::CreateAttributeOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -576,12 +582,277 @@ module @ir attributes { test.create_type_1 } {

// Fully tested within the tests for other operations.

//===----------------------------------------------------------------------===//
// pdl_interp::ExtractOp
//===----------------------------------------------------------------------===//

module @patterns {
func @matcher(%root : !pdl.operation) {
%val = pdl_interp.get_result 0 of %root
%ops = pdl_interp.get_users of %val : !pdl.value
%op1 = pdl_interp.extract 1 of %ops : !pdl.operation
pdl_interp.is_not_null %op1 : !pdl.operation -> ^success, ^end
^success:
pdl_interp.record_match @rewriters::@success(%op1 : !pdl.operation) : benefit(1), loc([%root]) -> ^end
^end:
pdl_interp.finalize
}

module @rewriters {
func @success(%matched : !pdl.operation) {
%op = pdl_interp.create_operation "test.success"
pdl_interp.erase %matched
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.extract_op
// CHECK: "test.success"
// CHECK: %[[OPERAND:.*]] = "test.op"
// CHECK: "test.op"(%[[OPERAND]])
module @ir attributes { test.extract_op } {
%operand = "test.op"() : () -> i32
"test.op"(%operand) : (i32) -> (i32)
"test.op"(%operand, %operand) : (i32, i32) -> (i32)
}

// -----

module @patterns {
func @matcher(%root : !pdl.operation) {
%vals = pdl_interp.get_results of %root : !pdl.range<value>
%types = pdl_interp.get_value_type of %vals : !pdl.range<type>
%type1 = pdl_interp.extract 1 of %types : !pdl.type
pdl_interp.is_not_null %type1 : !pdl.type -> ^success, ^end
^success:
pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
^end:
pdl_interp.finalize
}

module @rewriters {
func @success(%matched : !pdl.operation) {
%op = pdl_interp.create_operation "test.success"
pdl_interp.erase %matched
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.extract_type
// CHECK: %[[OPERAND:.*]] = "test.op"
// CHECK: "test.success"
// CHECK: "test.op"(%[[OPERAND]])
module @ir attributes { test.extract_type } {
%operand = "test.op"() : () -> i32
"test.op"(%operand) : (i32) -> (i32, i32)
"test.op"(%operand) : (i32) -> (i32)
}

// -----

module @patterns {
func @matcher(%root : !pdl.operation) {
%vals = pdl_interp.get_results of %root : !pdl.range<value>
%val1 = pdl_interp.extract 1 of %vals : !pdl.value
pdl_interp.is_not_null %val1 : !pdl.value -> ^success, ^end
^success:
pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
^end:
pdl_interp.finalize
}

module @rewriters {
func @success(%matched : !pdl.operation) {
%op = pdl_interp.create_operation "test.success"
pdl_interp.erase %matched
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.extract_value
// CHECK: %[[OPERAND:.*]] = "test.op"
// CHECK: "test.success"
// CHECK: "test.op"(%[[OPERAND]])
module @ir attributes { test.extract_value } {
%operand = "test.op"() : () -> i32
"test.op"(%operand) : (i32) -> (i32, i32)
"test.op"(%operand) : (i32) -> (i32)
}

// -----

//===----------------------------------------------------------------------===//
// pdl_interp::FinalizeOp
//===----------------------------------------------------------------------===//

// Fully tested within the tests for other operations.

//===----------------------------------------------------------------------===//
// pdl_interp::ForEachOp
//===----------------------------------------------------------------------===//

module @patterns {
func @matcher(%root : !pdl.operation) {
%val1 = pdl_interp.get_result 0 of %root
%ops1 = pdl_interp.get_users of %val1 : !pdl.value
pdl_interp.foreach %op1 : !pdl.operation in %ops1 {
%val2 = pdl_interp.get_result 0 of %op1
%ops2 = pdl_interp.get_users of %val2 : !pdl.value
pdl_interp.foreach %op2 : !pdl.operation in %ops2 {
pdl_interp.record_match @rewriters::@success(%op2 : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
^cont:
pdl_interp.continue
} -> ^cont
^cont:
pdl_interp.continue
} -> ^end
^end:
pdl_interp.finalize
}

module @rewriters {
func @success(%matched : !pdl.operation) {
%op = pdl_interp.create_operation "test.success"
pdl_interp.erase %matched
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.foreach
// CHECK: "test.success"
// CHECK: "test.success"
// CHECK: "test.success"
// CHECK: "test.success"
// CHECK: %[[ROOT:.*]] = "test.op"
// CHECK: %[[VALA:.*]] = "test.op"(%[[ROOT]])
// CHECK: %[[VALB:.*]] = "test.op"(%[[ROOT]])
module @ir attributes { test.foreach } {
%root = "test.op"() : () -> i32
%valA = "test.op"(%root) : (i32) -> (i32)
"test.op"(%valA) : (i32) -> (i32)
"test.op"(%valA) : (i32) -> (i32)
%valB = "test.op"(%root) : (i32) -> (i32)
"test.op"(%valB) : (i32) -> (i32)
"test.op"(%valB) : (i32) -> (i32)
}

// -----

//===----------------------------------------------------------------------===//
// pdl_interp::GetUsersOp
//===----------------------------------------------------------------------===//

module @patterns {
func @matcher(%root : !pdl.operation) {
%val = pdl_interp.get_result 0 of %root
%ops = pdl_interp.get_users of %val : !pdl.value
pdl_interp.foreach %op : !pdl.operation in %ops {
pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
^cont:
pdl_interp.continue
} -> ^end
^end:
pdl_interp.finalize
}

module @rewriters {
func @success(%matched : !pdl.operation) {
%op = pdl_interp.create_operation "test.success"
pdl_interp.erase %matched
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.get_users_of_value
// CHECK: "test.success"
// CHECK: "test.success"
// CHECK: %[[OPERAND:.*]] = "test.op"
module @ir attributes { test.get_users_of_value } {
%operand = "test.op"() : () -> i32
"test.op"(%operand) : (i32) -> (i32)
"test.op"(%operand, %operand) : (i32, i32) -> (i32)
}

// -----

module @patterns {
func @matcher(%root : !pdl.operation) {
pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end
^next:
%vals = pdl_interp.get_results of %root : !pdl.range<value>
%ops = pdl_interp.get_users of %vals : !pdl.range<value>
pdl_interp.foreach %op : !pdl.operation in %ops {
pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
^cont:
pdl_interp.continue
} -> ^end
^end:
pdl_interp.finalize
}

module @rewriters {
func @success(%matched : !pdl.operation) {
%op = pdl_interp.create_operation "test.success"
pdl_interp.erase %matched
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.get_all_users_of_range
// CHECK: "test.success"
// CHECK: "test.success"
// CHECK: %[[OPERANDS:.*]]:2 = "test.op"
module @ir attributes { test.get_all_users_of_range } {
%operands:2 = "test.op"() : () -> (i32, i32)
"test.op"(%operands#0) : (i32) -> (i32)
"test.op"(%operands#1) : (i32) -> (i32)
}

// -----

module @patterns {
func @matcher(%root : !pdl.operation) {
pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end
^next:
%vals = pdl_interp.get_results of %root : !pdl.range<value>
%val = pdl_interp.extract 0 of %vals : !pdl.value
%ops = pdl_interp.get_users of %val : !pdl.value
pdl_interp.foreach %op : !pdl.operation in %ops {
pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
^cont:
pdl_interp.continue
} -> ^end
^end:
pdl_interp.finalize
}

module @rewriters {
func @success(%matched : !pdl.operation) {
%op = pdl_interp.create_operation "test.success"
pdl_interp.erase %matched
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.get_first_users_of_range
// CHECK: "test.success"
// CHECK: %[[OPERANDS:.*]]:2 = "test.op"
// CHECK: "test.op"
module @ir attributes { test.get_first_users_of_range } {
%operands:2 = "test.op"() : () -> (i32, i32)
"test.op"(%operands#0) : (i32) -> (i32)
"test.op"(%operands#1) : (i32) -> (i32)
}

// -----

//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
endfunction()

add_subdirectory(Analysis)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(ExecutionEngine)
add_subdirectory(Interfaces)
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(PDLToPDLInterp)
8 changes: 8 additions & 0 deletions mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_mlir_unittest(MLIRPDLToPDLInterpTests
RootOrderingTest.cpp
)
target_link_libraries(MLIRPDLToPDLInterpTests
PRIVATE
MLIRStandard
MLIRPDLToPDLInterp
)
106 changes: 106 additions & 0 deletions mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//===- RootOrderingTest.cpp - unit tests for optimal branching ------------===//
//
// Part of the LLVM Project, under the Apache License v[1].0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"

using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;

namespace {

//===----------------------------------------------------------------------===//
// Test Fixture
//===----------------------------------------------------------------------===//

/// The test fixture for constructing root ordering tests and verifying results.
/// This fixture constructs the test values v. The test populates the graph
/// with the desired costs and then calls check(), passing the expeted optimal
/// cost and the list of edges in the preorder traversal of the optimal
/// branching.
class RootOrderingTest : public ::testing::Test {
protected:
RootOrderingTest() {
context.loadDialect<StandardOpsDialect>();
createValues();
}

/// Creates the test values.
void createValues() {
OpBuilder builder(&context);
for (int i = 0; i < 4; ++i)
v[i] = builder.create<ConstantOp>(builder.getUnknownLoc(),
builder.getI32IntegerAttr(i));
}

/// Checks that optimal branching on graph has the given cost and
/// its preorder traversal results in the specified edges.
void check(unsigned cost, OptimalBranching::EdgeList edges) {
OptimalBranching opt(graph, v[0]);
EXPECT_EQ(opt.solve(), cost);
EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges);
for (std::pair<Value, Value> edge : edges)
EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second);
}

protected:
/// The context for creating the values.
MLIRContext context;

/// Values used in the graph definition. We always use leading `n` values.
Value v[4];

/// The graph being tested on.
RootOrderingGraph graph;
};

//===----------------------------------------------------------------------===//
// Simple 3-node graphs
//===----------------------------------------------------------------------===//

TEST_F(RootOrderingTest, simpleA) {
graph[v[1]][v[0]].cost = {1, 10};
graph[v[2]][v[0]].cost = {1, 11};
graph[v[1]][v[2]].cost = {2, 12};
graph[v[2]][v[1]].cost = {2, 13};
check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}});
}

TEST_F(RootOrderingTest, simpleB) {
graph[v[1]][v[0]].cost = {1, 10};
graph[v[2]][v[0]].cost = {2, 11};
graph[v[1]][v[2]].cost = {1, 12};
graph[v[2]][v[1]].cost = {1, 13};
check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
}

TEST_F(RootOrderingTest, simpleC) {
graph[v[1]][v[0]].cost = {2, 10};
graph[v[2]][v[0]].cost = {2, 11};
graph[v[1]][v[2]].cost = {1, 12};
graph[v[2]][v[1]].cost = {1, 13};
check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
}

//===----------------------------------------------------------------------===//
// Graph for testing contraction
//===----------------------------------------------------------------------===//

TEST_F(RootOrderingTest, contraction) {
graph[v[1]][v[0]].cost = {10, 0};
graph[v[2]][v[0]].cost = {5, 0};
graph[v[2]][v[1]].cost = {1, 0};
graph[v[3]][v[2]].cost = {2, 0};
graph[v[1]][v[3]].cost = {3, 0};
check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}});
}

} // end namespace