Skip to content

Commit

Permalink
[mlir][PDL] Add support for variadic operands and results in the PDL …
Browse files Browse the repository at this point in the history
…Interpreter

This revision extends the PDL Interpreter dialect to add support for variadic operands and results, with ranges of these values represented via the recently added !pdl.range type. To support this extension, three new operations have been added that closely match the single variant:
* pdl_interp.check_types : Compare a range of types with a known range.
* pdl_interp.create_types : Create a constant range of types.
* pdl_interp.get_operands : Get a range of operands from an operation.
* pdl_interp.get_results : Get a range of results from an operation.
* pdl_interp.switch_types : Switch on a range of types.

This revision handles adding support in the interpreter dialect and the conversion from PDL to PDLInterp. Support for variadic operands and results in the bytecode will be added in a followup revision.

Differential Revision: https://reviews.llvm.org/D95722
  • Loading branch information
River707 committed Mar 16, 2021
1 parent 1eb6994 commit 3a833a0
Show file tree
Hide file tree
Showing 15 changed files with 1,105 additions and 388 deletions.
315 changes: 267 additions & 48 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,7 @@ class TypedArrayAttrBase<Attr element, string summary>: ArrayAttrBase<
CPred<"$_self.isa<::mlir::ArrayAttr>()">,
// Guarantee all elements satisfy the constraints from `element`
Concat<"::llvm::all_of($_self.cast<::mlir::ArrayAttr>(), "
"[](::mlir::Attribute attr) { return ",
"[&](::mlir::Attribute attr) { return ",
SubstLeaves<"$_self", "attr", element.predicate>,
"; })">]>,
summary> {
Expand Down
293 changes: 206 additions & 87 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Large diffs are not rendered by default.

25 changes: 11 additions & 14 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ using namespace mlir::pdl_to_pdl_interp;

Position::~Position() {}

/// Returns the depth of the first ancestor operation position.
unsigned Position::getOperationDepth() const {
if (const auto *operationPos = dyn_cast<OperationPosition>(this))
return operationPos->getDepth();
return parent->getOperationDepth();
}

//===----------------------------------------------------------------------===//
// AttributePosition

Expand All @@ -32,18 +39,8 @@ OperandPosition::OperandPosition(const KeyTy &key) : Base(key) {
}

//===----------------------------------------------------------------------===//
// OperationPosition

OperationPosition *OperationPosition::get(StorageUniquer &uniquer,
ArrayRef<unsigned> index) {
assert(!index.empty() && "expected at least two indices");

// Set the parent position if this isn't the root.
Position *parent = nullptr;
if (index.size() > 1) {
auto *node = OperationPosition::get(uniquer, index.drop_back());
parent = OperandPosition::get(uniquer, std::make_pair(node, index.back()));
}
return uniquer.get<OperationPosition>(
[parent](OperationPosition *node) { node->parent = parent; }, index);
// OperandGroupPosition

OperandGroupPosition::OperandGroupPosition(const KeyTy &key) : Base(key) {
parent = std::get<0>(key);
}
171 changes: 122 additions & 49 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,20 @@ enum Kind : unsigned {
/// Positions, ordered by decreasing priority.
OperationPos,
OperandPos,
OperandGroupPos,
AttributePos,
ResultPos,
ResultGroupPos,
TypePos,

// Questions, ordered by dependency and decreasing priority.
IsNotNullQuestion,
OperationNameQuestion,
TypeQuestion,
AttributeQuestion,
OperandCountAtLeastQuestion,
OperandCountQuestion,
ResultCountAtLeastQuestion,
ResultCountQuestion,
EqualToQuestion,
ConstraintQuestion,
Expand Down Expand Up @@ -129,21 +133,15 @@ struct OperationPosition;
/// predicates, and assists generating bytecode and memory management.
///
/// Operation positions form the base of other positions, which are formed
/// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations
/// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd
/// child of the root operation.
///
/// Positions are linked to their parent position, which describes how to obtain
/// a positional value. As a concrete example, getting OperationPosition<[0, 1]>
/// would be `root->getOperand(1)->getDefiningOp()`, so its parent is
/// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>.
/// relative to a parent operation. Operations are anchored at Operand nodes,
/// except for the root operation which is parentless.
class Position : public StorageUniquer::BaseStorage {
public:
explicit Position(Predicates::Kind kind) : kind(kind) {}
virtual ~Position();

/// Returns the base node position. This is an array of indices.
virtual ArrayRef<unsigned> getIndex() const = 0;
/// Returns the depth of the first ancestor operation position.
unsigned getOperationDepth() const;

/// Returns the parent position. The root operation position has no parent.
Position *getParent() const { return parent; }
Expand All @@ -170,9 +168,6 @@ struct AttributePosition
Predicates::AttributePos> {
explicit AttributePosition(const KeyTy &key);

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }

/// Returns the attribute name of this position.
Identifier getName() const { return key.second; }
};
Expand All @@ -187,42 +182,61 @@ struct OperandPosition
Predicates::OperandPos> {
explicit OperandPosition(const KeyTy &key);

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }

/// Returns the operand number of this position.
unsigned getOperandNumber() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// OperandGroupPosition

/// A position describing an operand group of an operation.
struct OperandGroupPosition
: public PredicateBase<
OperandGroupPosition, Position,
std::tuple<OperationPosition *, Optional<unsigned>, bool>,
Predicates::OperandGroupPos> {
explicit OperandGroupPosition(const KeyTy &key);

/// Returns a hash suitable for the given keytype.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}

/// Returns the group number of this position. If None, this group refers to
/// all operands.
Optional<unsigned> getOperandGroupNumber() const { return std::get<1>(key); }

/// Returns if the operand group has unknown size. If false, the operand group
/// has at max one element.
bool isVariadic() const { return std::get<2>(key); }
};

//===----------------------------------------------------------------------===//
// OperationPosition

/// An operation position describes an operation node in the IR. Other position
/// kinds are formed with respect to an operation position.
struct OperationPosition
: public PredicateBase<OperationPosition, Position, ArrayRef<unsigned>,
Predicates::OperationPos> {
using Base::Base;
struct OperationPosition : public PredicateBase<OperationPosition, Position,
std::pair<Position *, unsigned>,
Predicates::OperationPos> {
explicit OperationPosition(const KeyTy &key) : Base(key) {
parent = key.first;
}

/// Gets the root position, which is always [0].
/// Gets the root position.
static OperationPosition *getRoot(StorageUniquer &uniquer) {
return get(uniquer, ArrayRef<unsigned>(0));
return Base::get(uniquer, nullptr, 0);
}
/// Gets a node position for the given index.
static OperationPosition *get(StorageUniquer &uniquer,
ArrayRef<unsigned> index);

/// Constructs an instance with the given storage allocator.
static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc,
ArrayRef<unsigned> key) {
return Base::construct(alloc, alloc.copyInto(key));
/// Gets an operation position with the given parent.
static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
}

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key; }
/// Returns the depth of this position.
unsigned getDepth() const { return key.second; }

/// Returns if this operation position corresponds to the root.
bool isRoot() const { return key.size() == 1 && key[0] == 0; }
bool isRoot() const { return getDepth() == 0; }
};

//===----------------------------------------------------------------------===//
Expand All @@ -235,13 +249,37 @@ struct ResultPosition
Predicates::ResultPos> {
explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key.first->getIndex(); }

/// Returns the result number of this position.
unsigned getResultNumber() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// ResultGroupPosition

/// A position describing a result group of an operation.
struct ResultGroupPosition
: public PredicateBase<
ResultGroupPosition, Position,
std::tuple<OperationPosition *, Optional<unsigned>, bool>,
Predicates::ResultGroupPos> {
explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
parent = std::get<0>(key);
}

/// Returns a hash suitable for the given keytype.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}

/// Returns the group number of this position. If None, this group refers to
/// all results.
Optional<unsigned> getResultGroupNumber() const { return std::get<1>(key); }

/// Returns if the result group has unknown size. If false, the result group
/// has at max one element.
bool isVariadic() const { return std::get<2>(key); }
};

//===----------------------------------------------------------------------===//
// TypePosition

Expand All @@ -250,14 +288,11 @@ struct ResultPosition
struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
Predicates::TypePos> {
explicit TypePosition(const KeyTy &key) : Base(key) {
assert((isa<AttributePosition>(key) || isa<OperandPosition>(key) ||
isa<ResultPosition>(key)) &&
assert((isa<AttributePosition, OperandPosition, OperandGroupPosition,
ResultPosition, ResultGroupPosition>(key)) &&
"expected parent to be an attribute, operand, or result");
parent = key;
}

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key->getIndex(); }
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -311,8 +346,9 @@ struct TrueAnswer
using Base::Base;
};

/// An Answer representing a `Type` value.
struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Type,
/// An Answer representing a `Type` value. The value is stored as either a
/// TypeAttr, or an ArrayAttr of TypeAttr.
struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
Predicates::TypeAnswer> {
using Base::Base;
};
Expand Down Expand Up @@ -365,6 +401,9 @@ struct IsNotNullQuestion
struct OperandCountQuestion
: public PredicateBase<OperandCountQuestion, Qualifier, void,
Predicates::OperandCountQuestion> {};
struct OperandCountAtLeastQuestion
: public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
Predicates::OperandCountAtLeastQuestion> {};

/// Compare the name of an operation with a known value.
struct OperationNameQuestion
Expand All @@ -375,6 +414,9 @@ struct OperationNameQuestion
struct ResultCountQuestion
: public PredicateBase<ResultCountQuestion, Qualifier, void,
Predicates::ResultCountQuestion> {};
struct ResultCountAtLeastQuestion
: public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
Predicates::ResultCountAtLeastQuestion> {};

/// Compare the type of an attribute or value with a known type.
struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
Expand All @@ -392,8 +434,10 @@ class PredicateUniquer : public StorageUniquer {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
registerParametricStorageType<OperationPosition>();
registerParametricStorageType<ResultPosition>();
registerParametricStorageType<ResultGroupPosition>();
registerParametricStorageType<TypePosition>();

// Register the types of Questions with the uniquer.
Expand All @@ -409,8 +453,10 @@ class PredicateUniquer : public StorageUniquer {
registerSingletonStorageType<AttributeQuestion>();
registerSingletonStorageType<IsNotNullQuestion>();
registerSingletonStorageType<OperandCountQuestion>();
registerSingletonStorageType<OperandCountAtLeastQuestion>();
registerSingletonStorageType<OperationNameQuestion>();
registerSingletonStorageType<ResultCountQuestion>();
registerSingletonStorageType<ResultCountAtLeastQuestion>();
registerSingletonStorageType<TypeQuestion>();
}
};
Expand All @@ -433,10 +479,10 @@ class PredicateBuilder {
Position *getRoot() { return OperationPosition::getRoot(uniquer); }

/// Returns the parent position defining the value held by the given operand.
OperationPosition *getParent(OperandPosition *p) {
std::vector<unsigned> index = p->getIndex();
index.push_back(p->getOperandNumber());
return OperationPosition::get(uniquer, index);
OperationPosition *getOperandDefiningOp(Position *p) {
assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
"expected operand position");
return OperationPosition::get(uniquer, p);
}

/// Returns an attribute position for an attribute of the given operation.
Expand All @@ -449,11 +495,29 @@ class PredicateBuilder {
return OperandPosition::get(uniquer, p, operand);
}

/// Returns a position for a group of operands of the given operation.
Position *getOperandGroup(OperationPosition *p, Optional<unsigned> group,
bool isVariadic) {
return OperandGroupPosition::get(uniquer, p, group, isVariadic);
}
Position *getAllOperands(OperationPosition *p) {
return getOperandGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
}

/// Returns a result position for a result of the given operation.
Position *getResult(OperationPosition *p, unsigned result) {
return ResultPosition::get(uniquer, p, result);
}

/// Returns a position for a group of results of the given operation.
Position *getResultGroup(OperationPosition *p, Optional<unsigned> group,
bool isVariadic) {
return ResultGroupPosition::get(uniquer, p, group, isVariadic);
}
Position *getAllResults(OperationPosition *p) {
return getResultGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
}

/// Returns a type position for the given entity.
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }

Expand Down Expand Up @@ -496,6 +560,10 @@ class PredicateBuilder {
return {OperandCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
Predicate getOperandCountAtLeast(unsigned count) {
return {OperandCountAtLeastQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}

/// Create a predicate comparing the name of an operation to a known value.
Predicate getOperationName(StringRef name) {
Expand All @@ -509,10 +577,15 @@ class PredicateBuilder {
return {ResultCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
Predicate getResultCountAtLeast(unsigned count) {
return {ResultCountAtLeastQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}

/// Create a predicate comparing the type of an attribute or value to a known
/// type.
Predicate getTypeConstraint(Type type) {
/// type. The value is stored as either a TypeAttr, or an ArrayAttr of
/// TypeAttr.
Predicate getTypeConstraint(Attribute type) {
return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
}

Expand Down
Loading

0 comments on commit 3a833a0

Please sign in to comment.