375 changes: 300 additions & 75 deletions mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,18 @@ def PDL_AnyType : Type<
CPred<"$_self.isa<::mlir::pdl::PDLType>()">, "pdl type",
"::mlir::pdl::PDLType">;

// A range of positional values of one of the provided types.
class PDL_RangeOf<Type positionalType> :
ContainerType<AnyTypeOf<[positionalType]>, PDL_Range.predicate,
"$_self.cast<::mlir::pdl::RangeType>().getElementType()",
"range", "::mlir::pdl::RangeType">,
BuildableType<"::mlir::pdl::RangeType::get(" # positionalType.builderCall #
")">;

// Either a positional value or a range of positional values for a given type.
class PDL_InstOrRangeOf<Type positionalType> :
AnyTypeOf<[positionalType, PDL_RangeOf<positionalType>],
"single element or range of " # positionalType.summary,
"::mlir::pdl::PDLType">;

#endif // MLIR_DIALECT_PDL_IR_PDLTYPES
372 changes: 281 additions & 91 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,7 @@ class TypedArrayAttrBase<Attr element, string summary>: ArrayAttrBase<
CPred<"$_self.isa<::mlir::ArrayAttr>()">,
// Guarantee all elements satisfy the constraints from `element`
Concat<"::llvm::all_of($_self.cast<::mlir::ArrayAttr>(), "
"[](::mlir::Attribute attr) { return ",
"[&](::mlir::Attribute attr) { return ",
SubstLeaves<"$_self", "attr", element.predicate>,
"; })">]>,
summary> {
Expand Down
215 changes: 154 additions & 61 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,70 +238,178 @@ struct OpRewritePattern : public RewritePattern {
/// Storage type of byte-code interpreter values. These are passed to constraint
/// functions as arguments.
class PDLValue {
/// The internal implementation type when the value is an Attribute,
/// Operation*, or Type. See `impl` below for more details.
using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>;

public:
PDLValue(const PDLValue &other) : impl(other.impl) {}
PDLValue(std::nullptr_t = nullptr) : impl() {}
PDLValue(Attribute value) : impl(value) {}
PDLValue(Operation *value) : impl(value) {}
PDLValue(Type value) : impl(value) {}
PDLValue(Value value) : impl(value) {}
/// The underlying kind of a PDL value.
enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };

/// Construct a new PDL value.
PDLValue(const PDLValue &other) = default;
PDLValue(std::nullptr_t = nullptr) : value(nullptr), kind(Kind::Attribute) {}
PDLValue(Attribute value)
: value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
PDLValue(Value value)
: value(value.getAsOpaquePointer()), kind(Kind::Value) {}
PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}

/// Returns true if the type of the held value is `T`.
template <typename T>
std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const {
return impl.is<Value>();
}
template <typename T>
std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const {
auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
return attrOpTypeImpl && attrOpTypeImpl.is<T>();
template <typename T> bool isa() const {
assert(value && "isa<> used on a null value");
return kind == getKindOf<T>();
}

/// Attempt to dynamically cast this value to type `T`, returns null if this
/// value is not an instance of `T`.
template <typename T>
std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const {
return impl.dyn_cast<T>();
}
template <typename T>
std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const {
auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>();
template <typename T,
typename ResultT = std::conditional_t<
std::is_convertible<T, bool>::value, T, Optional<T>>>
ResultT dyn_cast() const {
return isa<T>() ? castImpl<T>() : ResultT();
}

/// Cast this value to type `T`, asserts if this value is not an instance of
/// `T`.
template <typename T>
std::enable_if_t<std::is_same<T, Value>::value, T> cast() const {
return impl.get<T>();
}
template <typename T>
std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const {
return impl.get<AttrOpTypeImplT>().get<T>();
template <typename T> T cast() const {
assert(isa<T>() && "expected value to be of type `T`");
return castImpl<T>();
}

/// Get an opaque pointer to the value.
void *getAsOpaquePointer() { return impl.getOpaqueValue(); }
const void *getAsOpaquePointer() const { return value; }

/// Return if this value is null or not.
explicit operator bool() const { return value; }

/// Return the kind of this value.
Kind getKind() const { return kind; }

/// Print this value to the provided output stream.
void print(raw_ostream &os);
void print(raw_ostream &os) const;

private:
/// The internal opaque representation of a PDLValue. We use a nested
/// PointerUnion structure here because `Value` only has 1 low bit
/// available, where as the remaining types all have 3.
llvm::PointerUnion<AttrOpTypeImplT, Value> impl;
/// Find the index of a given type in a range of other types.
template <typename...> struct index_of_t;
template <typename T, typename... R>
struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
template <typename T, typename F, typename... R>
struct index_of_t<T, F, R...>
: std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};

/// Return the kind used for the given T.
template <typename T> static Kind getKindOf() {
return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
TypeRange, Value, ValueRange>::value);
}

/// The internal implementation of `cast`, that returns the underlying value
/// as the given type `T`.
template <typename T>
std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
castImpl() const {
return T::getFromOpaquePointer(value);
}
template <typename T>
std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
castImpl() const {
return *reinterpret_cast<T *>(const_cast<void *>(value));
}
template <typename T>
std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
return reinterpret_cast<T>(const_cast<void *>(value));
}

/// The internal opaque representation of a PDLValue.
const void *value;
/// The kind of the opaque value.
Kind kind;
};

inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
value.print(os);
return os;
}

//===----------------------------------------------------------------------===//
// PDLResultList

/// The class represents a list of PDL results, returned by a native rewrite
/// method. It provides the mechanism with which to pass PDLValues back to the
/// PDL bytecode.
class PDLResultList {
public:
/// Push a new Attribute value onto the result list.
void push_back(Attribute value) { results.push_back(value); }

/// Push a new Operation onto the result list.
void push_back(Operation *value) { results.push_back(value); }

/// Push a new Type onto the result list.
void push_back(Type value) { results.push_back(value); }

/// Push a new TypeRange onto the result list.
void push_back(TypeRange value) {
// The lifetime of a TypeRange can't be guaranteed, so we'll need to
// allocate a storage for it.
llvm::OwningArrayRef<Type> storage(value.size());
llvm::copy(value, storage.begin());
allocatedTypeRanges.emplace_back(std::move(storage));
typeRanges.push_back(allocatedTypeRanges.back());
results.push_back(&typeRanges.back());
}
void push_back(ValueTypeRange<OperandRange> value) {
typeRanges.push_back(value);
results.push_back(&typeRanges.back());
}
void push_back(ValueTypeRange<ResultRange> value) {
typeRanges.push_back(value);
results.push_back(&typeRanges.back());
}

/// Push a new Value onto the result list.
void push_back(Value value) { results.push_back(value); }

/// Push a new ValueRange onto the result list.
void push_back(ValueRange value) {
// The lifetime of a ValueRange can't be guaranteed, so we'll need to
// allocate a storage for it.
llvm::OwningArrayRef<Value> storage(value.size());
llvm::copy(value, storage.begin());
allocatedValueRanges.emplace_back(std::move(storage));
valueRanges.push_back(allocatedValueRanges.back());
results.push_back(&valueRanges.back());
}
void push_back(OperandRange value) {
valueRanges.push_back(value);
results.push_back(&valueRanges.back());
}
void push_back(ResultRange value) {
valueRanges.push_back(value);
results.push_back(&valueRanges.back());
}

protected:
/// Create a new result list with the expected number of results.
PDLResultList(unsigned maxNumResults) {
// For now just reserve enough space for all of the results. We could do
// separate counts per range type, but it isn't really worth it unless there
// are a "large" number of results.
typeRanges.reserve(maxNumResults);
valueRanges.reserve(maxNumResults);
}

/// The PDL results held by this list.
SmallVector<PDLValue> results;
/// Memory used to store ranges held by the list.
SmallVector<TypeRange> typeRanges;
SmallVector<ValueRange> valueRanges;
/// Memory allocated to store ranges in the result list whose lifetime was
/// generated in the native function.
SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
};

//===----------------------------------------------------------------------===//
// PDLPatternModule

Expand All @@ -311,16 +419,13 @@ inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
/// success if the constraint successfully held, failure otherwise.
using PDLConstraintFunction = std::function<LogicalResult(
ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
/// A native PDL creation function. This function creates a new PDLValue given
/// a set of existing PDL values, a set of constant parameters specified in
/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue.
using PDLCreateFunction =
std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
/// A native PDL rewrite function. This function rewrites the given root
/// operation using the provided PatternRewriter. This method is only invoked
/// when the corresponding match was successful.
using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>,
ArrayAttr, PatternRewriter &)>;
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values and constant parameters. Any results from this rewrite
/// that should be passed back to PDL should be added to the provided result
/// list. This method is only invoked when the corresponding match was
/// successful.
using PDLRewriteFunction = std::function<void(
ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &, PDLResultList &)>;
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given opaque PDLValue entity. The second parameter is a set
/// of constant value parameters specified in Attribute form. Returns success if
Expand Down Expand Up @@ -367,9 +472,6 @@ class PDLPatternModule {
});
}

/// Register a creation function.
void registerCreateFunction(StringRef name, PDLCreateFunction createFn);

/// Register a rewrite function.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);

Expand All @@ -380,13 +482,6 @@ class PDLPatternModule {
llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
return constraintFunctions;
}
/// Return the set of the registered create functions.
const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const {
return createFunctions;
}
llvm::StringMap<PDLCreateFunction> takeCreateFunctions() {
return createFunctions;
}
/// Return the set of the registered rewrite functions.
const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
return rewriteFunctions;
Expand All @@ -399,7 +494,6 @@ class PDLPatternModule {
void clear() {
pdlModule = nullptr;
constraintFunctions.clear();
createFunctions.clear();
rewriteFunctions.clear();
}

Expand All @@ -409,7 +503,6 @@ class PDLPatternModule {

/// The external functions referenced from within the PDL module.
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
llvm::StringMap<PDLCreateFunction> createFunctions;
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
};

Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/IR/TypeRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ inline ::llvm::hash_code hash_value(TypeRange arg) {
return ::llvm::hash_combine_range(arg.begin(), arg.end());
}

/// Emit a type range to the given output stream.
inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
llvm::interleaveComma(types, os);
return os;
}

//===----------------------------------------------------------------------===//
// ValueTypeRange

Expand Down
367 changes: 248 additions & 119 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Large diffs are not rendered by default.

25 changes: 11 additions & 14 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ using namespace mlir::pdl_to_pdl_interp;

Position::~Position() {}

/// Returns the depth of the first ancestor operation position.
unsigned Position::getOperationDepth() const {
if (const auto *operationPos = dyn_cast<OperationPosition>(this))
return operationPos->getDepth();
return parent->getOperationDepth();
}

//===----------------------------------------------------------------------===//
// AttributePosition

Expand All @@ -32,18 +39,8 @@ OperandPosition::OperandPosition(const KeyTy &key) : Base(key) {
}

//===----------------------------------------------------------------------===//
// OperationPosition

OperationPosition *OperationPosition::get(StorageUniquer &uniquer,
ArrayRef<unsigned> index) {
assert(!index.empty() && "expected at least two indices");

// Set the parent position if this isn't the root.
Position *parent = nullptr;
if (index.size() > 1) {
auto *node = OperationPosition::get(uniquer, index.drop_back());
parent = OperandPosition::get(uniquer, std::make_pair(node, index.back()));
}
return uniquer.get<OperationPosition>(
[parent](OperationPosition *node) { node->parent = parent; }, index);
// OperandGroupPosition

OperandGroupPosition::OperandGroupPosition(const KeyTy &key) : Base(key) {
parent = std::get<0>(key);
}
171 changes: 122 additions & 49 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,20 @@ enum Kind : unsigned {
/// Positions, ordered by decreasing priority.
OperationPos,
OperandPos,
OperandGroupPos,
AttributePos,
ResultPos,
ResultGroupPos,
TypePos,

// Questions, ordered by dependency and decreasing priority.
IsNotNullQuestion,
OperationNameQuestion,
TypeQuestion,
AttributeQuestion,
OperandCountAtLeastQuestion,
OperandCountQuestion,
ResultCountAtLeastQuestion,
ResultCountQuestion,
EqualToQuestion,
ConstraintQuestion,
Expand Down Expand Up @@ -129,21 +133,15 @@ struct OperationPosition;
/// predicates, and assists generating bytecode and memory management.
///
/// Operation positions form the base of other positions, which are formed
/// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations
/// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd
/// child of the root operation.
///
/// Positions are linked to their parent position, which describes how to obtain
/// a positional value. As a concrete example, getting OperationPosition<[0, 1]>
/// would be `root->getOperand(1)->getDefiningOp()`, so its parent is
/// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>.
/// relative to a parent operation. Operations are anchored at Operand nodes,
/// except for the root operation which is parentless.
class Position : public StorageUniquer::BaseStorage {
public:
explicit Position(Predicates::Kind kind) : kind(kind) {}
virtual ~Position();

/// Returns the base node position. This is an array of indices.
virtual ArrayRef<unsigned> getIndex() const = 0;
/// Returns the depth of the first ancestor operation position.
unsigned getOperationDepth() const;

/// Returns the parent position. The root operation position has no parent.
Position *getParent() const { return parent; }
Expand All @@ -170,9 +168,6 @@ struct AttributePosition
Predicates::AttributePos> {
explicit AttributePosition(const KeyTy &key);

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }

/// Returns the attribute name of this position.
Identifier getName() const { return key.second; }
};
Expand All @@ -187,42 +182,61 @@ struct OperandPosition
Predicates::OperandPos> {
explicit OperandPosition(const KeyTy &key);

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }

/// Returns the operand number of this position.
unsigned getOperandNumber() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// OperandGroupPosition

/// A position describing an operand group of an operation.
struct OperandGroupPosition
: public PredicateBase<
OperandGroupPosition, Position,
std::tuple<OperationPosition *, Optional<unsigned>, bool>,
Predicates::OperandGroupPos> {
explicit OperandGroupPosition(const KeyTy &key);

/// Returns a hash suitable for the given keytype.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}

/// Returns the group number of this position. If None, this group refers to
/// all operands.
Optional<unsigned> getOperandGroupNumber() const { return std::get<1>(key); }

/// Returns if the operand group has unknown size. If false, the operand group
/// has at max one element.
bool isVariadic() const { return std::get<2>(key); }
};

//===----------------------------------------------------------------------===//
// OperationPosition

/// An operation position describes an operation node in the IR. Other position
/// kinds are formed with respect to an operation position.
struct OperationPosition
: public PredicateBase<OperationPosition, Position, ArrayRef<unsigned>,
Predicates::OperationPos> {
using Base::Base;
struct OperationPosition : public PredicateBase<OperationPosition, Position,
std::pair<Position *, unsigned>,
Predicates::OperationPos> {
explicit OperationPosition(const KeyTy &key) : Base(key) {
parent = key.first;
}

/// Gets the root position, which is always [0].
/// Gets the root position.
static OperationPosition *getRoot(StorageUniquer &uniquer) {
return get(uniquer, ArrayRef<unsigned>(0));
return Base::get(uniquer, nullptr, 0);
}
/// Gets a node position for the given index.
static OperationPosition *get(StorageUniquer &uniquer,
ArrayRef<unsigned> index);

/// Constructs an instance with the given storage allocator.
static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc,
ArrayRef<unsigned> key) {
return Base::construct(alloc, alloc.copyInto(key));
/// Gets an operation position with the given parent.
static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
}

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key; }
/// Returns the depth of this position.
unsigned getDepth() const { return key.second; }

/// Returns if this operation position corresponds to the root.
bool isRoot() const { return key.size() == 1 && key[0] == 0; }
bool isRoot() const { return getDepth() == 0; }
};

//===----------------------------------------------------------------------===//
Expand All @@ -235,13 +249,37 @@ struct ResultPosition
Predicates::ResultPos> {
explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key.first->getIndex(); }

/// Returns the result number of this position.
unsigned getResultNumber() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// ResultGroupPosition

/// A position describing a result group of an operation.
struct ResultGroupPosition
: public PredicateBase<
ResultGroupPosition, Position,
std::tuple<OperationPosition *, Optional<unsigned>, bool>,
Predicates::ResultGroupPos> {
explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
parent = std::get<0>(key);
}

/// Returns a hash suitable for the given keytype.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}

/// Returns the group number of this position. If None, this group refers to
/// all results.
Optional<unsigned> getResultGroupNumber() const { return std::get<1>(key); }

/// Returns if the result group has unknown size. If false, the result group
/// has at max one element.
bool isVariadic() const { return std::get<2>(key); }
};

//===----------------------------------------------------------------------===//
// TypePosition

Expand All @@ -250,14 +288,11 @@ struct ResultPosition
struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
Predicates::TypePos> {
explicit TypePosition(const KeyTy &key) : Base(key) {
assert((isa<AttributePosition>(key) || isa<OperandPosition>(key) ||
isa<ResultPosition>(key)) &&
assert((isa<AttributePosition, OperandPosition, OperandGroupPosition,
ResultPosition, ResultGroupPosition>(key)) &&
"expected parent to be an attribute, operand, or result");
parent = key;
}

/// Returns the index of this position.
ArrayRef<unsigned> getIndex() const final { return key->getIndex(); }
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -311,8 +346,9 @@ struct TrueAnswer
using Base::Base;
};

/// An Answer representing a `Type` value.
struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Type,
/// An Answer representing a `Type` value. The value is stored as either a
/// TypeAttr, or an ArrayAttr of TypeAttr.
struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
Predicates::TypeAnswer> {
using Base::Base;
};
Expand Down Expand Up @@ -365,6 +401,9 @@ struct IsNotNullQuestion
struct OperandCountQuestion
: public PredicateBase<OperandCountQuestion, Qualifier, void,
Predicates::OperandCountQuestion> {};
struct OperandCountAtLeastQuestion
: public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
Predicates::OperandCountAtLeastQuestion> {};

/// Compare the name of an operation with a known value.
struct OperationNameQuestion
Expand All @@ -375,6 +414,9 @@ struct OperationNameQuestion
struct ResultCountQuestion
: public PredicateBase<ResultCountQuestion, Qualifier, void,
Predicates::ResultCountQuestion> {};
struct ResultCountAtLeastQuestion
: public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
Predicates::ResultCountAtLeastQuestion> {};

/// Compare the type of an attribute or value with a known type.
struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
Expand All @@ -392,8 +434,10 @@ class PredicateUniquer : public StorageUniquer {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
registerParametricStorageType<OperationPosition>();
registerParametricStorageType<ResultPosition>();
registerParametricStorageType<ResultGroupPosition>();
registerParametricStorageType<TypePosition>();

// Register the types of Questions with the uniquer.
Expand All @@ -409,8 +453,10 @@ class PredicateUniquer : public StorageUniquer {
registerSingletonStorageType<AttributeQuestion>();
registerSingletonStorageType<IsNotNullQuestion>();
registerSingletonStorageType<OperandCountQuestion>();
registerSingletonStorageType<OperandCountAtLeastQuestion>();
registerSingletonStorageType<OperationNameQuestion>();
registerSingletonStorageType<ResultCountQuestion>();
registerSingletonStorageType<ResultCountAtLeastQuestion>();
registerSingletonStorageType<TypeQuestion>();
}
};
Expand All @@ -433,10 +479,10 @@ class PredicateBuilder {
Position *getRoot() { return OperationPosition::getRoot(uniquer); }

/// Returns the parent position defining the value held by the given operand.
Position *getParent(OperandPosition *p) {
std::vector<unsigned> index = p->getIndex();
index.push_back(p->getOperandNumber());
return OperationPosition::get(uniquer, index);
OperationPosition *getOperandDefiningOp(Position *p) {
assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
"expected operand position");
return OperationPosition::get(uniquer, p);
}

/// Returns an attribute position for an attribute of the given operation.
Expand All @@ -449,11 +495,29 @@ class PredicateBuilder {
return OperandPosition::get(uniquer, p, operand);
}

/// Returns a position for a group of operands of the given operation.
Position *getOperandGroup(OperationPosition *p, Optional<unsigned> group,
bool isVariadic) {
return OperandGroupPosition::get(uniquer, p, group, isVariadic);
}
Position *getAllOperands(OperationPosition *p) {
return getOperandGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
}

/// Returns a result position for a result of the given operation.
Position *getResult(OperationPosition *p, unsigned result) {
return ResultPosition::get(uniquer, p, result);
}

/// Returns a position for a group of results of the given operation.
Position *getResultGroup(OperationPosition *p, Optional<unsigned> group,
bool isVariadic) {
return ResultGroupPosition::get(uniquer, p, group, isVariadic);
}
Position *getAllResults(OperationPosition *p) {
return getResultGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
}

/// Returns a type position for the given entity.
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }

Expand Down Expand Up @@ -496,6 +560,10 @@ class PredicateBuilder {
return {OperandCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
Predicate getOperandCountAtLeast(unsigned count) {
return {OperandCountAtLeastQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}

/// Create a predicate comparing the name of an operation to a known value.
Predicate getOperationName(StringRef name) {
Expand All @@ -509,10 +577,15 @@ class PredicateBuilder {
return {ResultCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
Predicate getResultCountAtLeast(unsigned count) {
return {ResultCountAtLeastQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}

/// Create a predicate comparing the type of an attribute or value to a known
/// type.
Predicate getTypeConstraint(Type type) {
/// type. The value is stored as either a TypeAttr, or an ArrayAttr of
/// TypeAttr.
Predicate getTypeConstraint(Attribute type) {
return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
}

Expand Down
375 changes: 253 additions & 122 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ struct SwitchNode : public MatcherNode {
using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>;
ChildMapT &getChildren() { return children; }

/// Returns the child at the given index.
std::pair<Qualifier *, std::unique_ptr<MatcherNode>> &getChild(unsigned i) {
assert(i < children.size() && "invalid child index");
return *std::next(children.begin(), i);
}

private:
/// Switch predicate "answers" select the child. Answers that are not found
/// default to the failure node.
Expand Down
251 changes: 112 additions & 139 deletions mlir/lib/Dialect/PDL/IR/PDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,67 @@ void PDLDialect::initialize() {
registerTypes();
}

//===----------------------------------------------------------------------===//
// PDL Operations
//===----------------------------------------------------------------------===//

/// Returns true if the given operation is used by a "binding" pdl operation
/// within the main matcher body of a `pdl.pattern`.
static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) {
for (OpOperand &use : op->getUses()) {
Operation *user = use.getOwner();
if (user->getBlock() != matcherBlock)
continue;
if (isa<AttributeOp, OperandOp, OperandsOp, OperationOp>(user))
return true;
// Only the first operand of RewriteOp may be bound to, i.e. the root
// operation of the pattern.
if (isa<RewriteOp>(user) && use.getOperandNumber() == 0)
return true;
// A result by itself is not binding, it must also be bound.
if (isa<ResultOp, ResultsOp>(user) &&
hasBindingUseInMatcher(user, matcherBlock))
return true;
}
return false;
}

/// Returns success if the given operation is used by a "binding" pdl operation
/// within the main matcher body of a `pdl.pattern`. On failure, emits an error
/// with the given context message.
static LogicalResult
verifyHasBindingUseInMatcher(Operation *op,
StringRef bindableContextStr = "`pdl.operation`") {
// If the pattern is not a pattern, there is nothing to do.
if (!isa<PatternOp>(op->getParentOp()))
return success();
Block *matcherBlock = op->getBlock();
for (Operation *user : op->getUsers()) {
if (user->getBlock() != matcherBlock)
continue;
if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user))
return success();
}
if (hasBindingUseInMatcher(op, op->getBlock()))
return success();
return op->emitOpError()
<< "expected a bindable (i.e. " << bindableContextStr
<< ") user when defined in the matcher body of a `pdl.pattern`";
}

//===----------------------------------------------------------------------===//
// pdl::ApplyConstraintOp
// pdl::ApplyNativeConstraintOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(ApplyConstraintOp op) {
static LogicalResult verify(ApplyNativeConstraintOp op) {
if (op.getNumOperands() == 0)
return op.emitOpError("expected at least one argument");
return success();
}

//===----------------------------------------------------------------------===//
// pdl::ApplyNativeRewriteOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(ApplyNativeRewriteOp op) {
if (op.getNumOperands() == 0 && op.getNumResults() == 0)
return op.emitOpError("expected at least one argument or result");
return success();
}

//===----------------------------------------------------------------------===//
// pdl::AttributeOp
//===----------------------------------------------------------------------===//
Expand All @@ -83,109 +114,53 @@ static LogicalResult verify(OperandOp op) {
}

//===----------------------------------------------------------------------===//
// pdl::OperationOp
// pdl::OperandsOp
//===----------------------------------------------------------------------===//

static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
Builder &builder = p.getBuilder();

// Parse the optional operation name.
bool startsWithOperands = succeeded(p.parseOptionalLParen());
bool startsWithAttributes =
!startsWithOperands && succeeded(p.parseOptionalLBrace());
bool startsWithOpName = false;
if (!startsWithAttributes && !startsWithOperands) {
StringAttr opName;
OptionalParseResult opNameResult =
p.parseOptionalAttribute(opName, "name", state.attributes);
startsWithOpName = opNameResult.hasValue();
if (startsWithOpName && failed(*opNameResult))
return failure();
}
static LogicalResult verify(OperandsOp op) {
return verifyHasBindingUseInMatcher(op);
}

// Parse the operands.
SmallVector<OpAsmParser::OperandType, 4> operands;
if (startsWithOperands ||
(!startsWithAttributes && succeeded(p.parseOptionalLParen()))) {
if (p.parseOperandList(operands) || p.parseRParen() ||
p.resolveOperands(operands, builder.getType<ValueType>(),
state.operands))
return failure();
}
//===----------------------------------------------------------------------===//
// pdl::OperationOp
//===----------------------------------------------------------------------===//

// Parse the attributes.
static ParseResult parseOperationOpAttributes(
OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();
SmallVector<Attribute, 4> attrNames;
if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) {
SmallVector<OpAsmParser::OperandType, 4> attrOps;
if (succeeded(p.parseOptionalLBrace())) {
do {
StringAttr nameAttr;
OpAsmParser::OperandType operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
attrOps.push_back(operand);
attrOperands.push_back(operand);
} while (succeeded(p.parseOptionalComma()));

if (p.parseRBrace() ||
p.resolveOperands(attrOps, builder.getType<AttributeType>(),
state.operands))
if (p.parseRBrace())
return failure();
}
state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
state.addTypes(builder.getType<OperationType>());

// Parse the result types.
SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
if (succeeded(p.parseOptionalArrow())) {
if (p.parseOperandList(opResultTypes) ||
p.resolveOperands(opResultTypes, builder.getType<TypeType>(),
state.operands))
return failure();
state.types.append(opResultTypes.size(), builder.getType<ValueType>());
}

if (p.parseOptionalAttrDict(state.attributes))
return failure();

int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
static_cast<int32_t>(attrNames.size()),
static_cast<int32_t>(opResultTypes.size())};
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr(operandSegmentSizes));
attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}

static void print(OpAsmPrinter &p, OperationOp op) {
p << "pdl.operation ";
if (Optional<StringRef> name = op.name())
p << '"' << *name << '"';

auto operandValues = op.operands();
if (!operandValues.empty())
p << '(' << operandValues << ')';

// Emit the optional attributes.
ArrayAttr attrNames = op.attributeNames();
if (!attrNames.empty()) {
Operation::operand_range attrArgs = op.attributes();
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}

// Print the result type constraints of the operation.
if (!op.results().empty())
p << " -> " << op.types();
p.printOptionalAttrDict(op->getAttrs(),
{"attributeNames", "name", "operand_segment_sizes"});
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
OperandRange attrArgs,
ArrayAttr attrNames) {
if (attrNames.empty())
return;
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}

/// Verifies that the result types of this operation, defined within a
/// `pdl.rewrite`, can be inferred.
static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
ResultRange opResults,
OperandRange resultTypes) {
// Functor that returns if the given use can be used to infer a type.
Block *rewriterBlock = op->getBlock();
Expand All @@ -207,36 +182,33 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
return success();

// Otherwise, make sure each of the types can be inferred.
for (int i : llvm::seq<int>(0, opResults.size())) {
Operation *resultTypeOp = resultTypes[i].getDefiningOp();
for (auto it : llvm::enumerate(resultTypes)) {
Operation *resultTypeOp = it.value().getDefiningOp();
assert(resultTypeOp && "expected valid result type operation");

// If the op was defined by a `create_native`, it is guaranteed to be
// If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
// usable.
if (isa<CreateNativeOp>(resultTypeOp))
continue;

// If the type is already constrained, there is nothing to do.
TypeOp typeOp = cast<TypeOp>(resultTypeOp);
if (typeOp.type())
if (isa<ApplyNativeRewriteOp>(resultTypeOp))
continue;

// If the type operation was defined in the matcher and constrains the
// result of an input operation, it can be used.
auto constrainsInputOp = [rewriterBlock](Operation *user) {
return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
};
if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
continue;
if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
continue;
} else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
continue;
}

// Otherwise, check to see if any uses of the result can infer the type.
if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse))
continue;
return op
.emitOpError("must have inferable or constrained result types when "
"nested within `pdl.rewrite`")
.attachNote()
.append("result type #", i, " was not constrained");
.append("result type #", it.index(), " was not constrained");
}
return success();
}
Expand All @@ -256,19 +228,10 @@ static LogicalResult verify(OperationOp op) {
<< " values";
}

OperandRange resultTypes = op.types();
auto opResults = op.results();
if (resultTypes.size() != opResults.size()) {
return op.emitOpError() << "expected the same number of result values and "
"result type constraints, got "
<< opResults.size() << " results and "
<< resultTypes.size() << " constraints";
}

// If the operation is within a rewrite body and doesn't have type inference,
// ensure that the result types can be resolved.
if (isWithinRewrite && !op.hasTypeInference()) {
if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes)))
if (failed(verifyResultTypesAreInferrable(op, op.types())))
return failure();
}

Expand Down Expand Up @@ -341,37 +304,39 @@ Optional<StringRef> PatternOp::getRootKind() {
//===----------------------------------------------------------------------===//

static LogicalResult verify(ReplaceOp op) {
auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp());
auto sourceOpResults = sourceOp.results();
auto replValues = op.replValues();

if (Value replOpVal = op.replOperation()) {
auto replOp = cast<OperationOp>(replOpVal.getDefiningOp());
auto replOpResults = replOp.results();
if (sourceOpResults.size() != replOpResults.size()) {
return op.emitOpError()
<< "expected source operation to have the same number of results "
"as the replacement operation, replacement operation provided "
<< replOpResults.size() << " but expected "
<< sourceOpResults.size();
}
if (op.replOperation() && !op.replValues().empty())
return op.emitOpError() << "expected no replacement values to be provided"
" when the replacement operation is present";
return success();
}

if (!replValues.empty()) {
return op.emitOpError() << "expected no replacement values to be provided"
" when the replacement operation is present";
}
//===----------------------------------------------------------------------===//
// pdl::ResultsOp
//===----------------------------------------------------------------------===//

static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
Type &resultType) {
if (!index) {
resultType = RangeType::get(p.getBuilder().getType<ValueType>());
return success();
}
if (p.parseArrow() || p.parseType(resultType))
return failure();
return success();
}

if (sourceOpResults.size() != replValues.size()) {
return op.emitOpError()
<< "expected source operation to have the same number of results "
"as the provided replacement values, found "
<< replValues.size() << " replacement values but expected "
<< sourceOpResults.size();
}
static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
IntegerAttr index, Type resultType) {
if (index)
p << " -> " << resultType;
}

static LogicalResult verify(ResultsOp op) {
if (!op.index() && op.getType().isa<pdl::ValueType>()) {
return op.emitOpError() << "expected `pdl.range<value>` result type when "
"no index is specified, but got: "
<< op.getType();
}
return success();
}

Expand Down Expand Up @@ -419,6 +384,14 @@ static LogicalResult verify(TypeOp op) {
op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`");
}

//===----------------------------------------------------------------------===//
// pdl::TypesOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(TypesOp op) {
return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`");
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
89 changes: 24 additions & 65 deletions mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,89 +29,48 @@ void PDLInterpDialect::initialize() {
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//

static ParseResult parseCreateOperationOp(OpAsmParser &p,
OperationState &state) {
if (p.parseOptionalAttrDict(state.attributes))
return failure();
static ParseResult parseCreateOperationOpAttributes(
OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();

// Parse the operation name.
StringAttr opName;
if (p.parseAttribute(opName, "name", state.attributes))
return failure();

// Parse the operands.
SmallVector<OpAsmParser::OperandType, 4> operands;
if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() ||
p.resolveOperands(operands, builder.getType<pdl::ValueType>(),
state.operands))
return failure();

// Parse the attributes.
SmallVector<Attribute, 4> attrNames;
if (succeeded(p.parseOptionalLBrace())) {
SmallVector<OpAsmParser::OperandType, 4> attrOps;
do {
StringAttr nameAttr;
OpAsmParser::OperandType operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
attrOps.push_back(operand);
attrOperands.push_back(operand);
} while (succeeded(p.parseOptionalComma()));

if (p.parseRBrace() ||
p.resolveOperands(attrOps, builder.getType<pdl::AttributeType>(),
state.operands))
return failure();
}
state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
state.addTypes(builder.getType<pdl::OperationType>());

// Parse the result types.
SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
if (p.parseArrow())
return failure();
if (succeeded(p.parseOptionalLParen())) {
if (p.parseRParen())
if (p.parseRBrace())
return failure();
} else if (p.parseOperandList(opResultTypes) ||
p.resolveOperands(opResultTypes, builder.getType<pdl::TypeType>(),
state.operands)) {
return failure();
}

int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
static_cast<int32_t>(attrNames.size()),
static_cast<int32_t>(opResultTypes.size())};
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr(operandSegmentSizes));
attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}

static void print(OpAsmPrinter &p, CreateOperationOp op) {
p << "pdl_interp.create_operation ";
p.printOptionalAttrDict(op->getAttrs(),
{"attributeNames", "name", "operand_segment_sizes"});
p << '"' << op.name() << "\"(" << op.operands() << ')';
static void printCreateOperationOpAttributes(OpAsmPrinter &p,
CreateOperationOp op,
OperandRange attrArgs,
ArrayAttr attrNames) {
if (attrNames.empty())
return;
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}

// Emit the optional attributes.
ArrayAttr attrNames = op.attributeNames();
if (!attrNames.empty()) {
Operation::operand_range attrArgs = op.attributes();
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//

// Print the result type constraints of the operation.
auto types = op.types();
if (types.empty())
p << " -> ()";
else
p << " -> " << op.types();
/// Given the result type of a `GetValueTypeOp`, return the expected input type.
static Type getGetValueTypeOpValueType(Type type) {
Type valueTy = pdl::ValueType::get(type.getContext());
return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
}

//===----------------------------------------------------------------------===//
Expand Down
46 changes: 23 additions & 23 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,31 @@ void RewritePattern::anchor() {}
// PDLValue
//===----------------------------------------------------------------------===//

void PDLValue::print(raw_ostream &os) {
if (!impl) {
os << "<Null-PDLValue>";
void PDLValue::print(raw_ostream &os) const {
if (!value) {
os << "<NULL-PDLValue>";
return;
}
if (Value val = impl.dyn_cast<Value>()) {
os << val;
return;
switch (kind) {
case Kind::Attribute:
os << cast<Attribute>();
break;
case Kind::Operation:
os << *cast<Operation *>();
break;
case Kind::Type:
os << cast<Type>();
break;
case Kind::TypeRange:
llvm::interleaveComma(cast<TypeRange>(), os);
break;
case Kind::Value:
os << cast<Value>();
break;
case Kind::ValueRange:
llvm::interleaveComma(cast<ValueRange>(), os);
break;
}
AttrOpTypeImplT aotImpl = impl.get<AttrOpTypeImplT>();
if (Attribute attr = aotImpl.dyn_cast<Attribute>())
os << attr;
else if (Operation *op = aotImpl.dyn_cast<Operation *>())
os << *op;
else
os << aotImpl.get<Type>();
}

//===----------------------------------------------------------------------===//
Expand All @@ -102,16 +111,13 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Steal the other state if we have no patterns.
if (!pdlModule) {
constraintFunctions = std::move(other.constraintFunctions);
createFunctions = std::move(other.createFunctions);
rewriteFunctions = std::move(other.rewriteFunctions);
pdlModule = std::move(other.pdlModule);
return;
}
// Steal the functions of the other module.
for (auto &it : constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : createFunctions)
registerCreateFunction(it.first(), std::move(it.second));
for (auto &it : rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));

Expand All @@ -132,13 +138,7 @@ void PDLPatternModule::registerConstraintFunction(
assert(it.second &&
"constraint with the given name has already been registered");
}
void PDLPatternModule::registerCreateFunction(StringRef name,
PDLCreateFunction createFn) {
auto it = createFunctions.try_emplace(name, std::move(createFn));
(void)it;
assert(it.second && "native create function with the given name has "
"already been registered");
}

void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) {
auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
Expand Down
822 changes: 663 additions & 159 deletions mlir/lib/Rewrite/ByteCode.cpp

Large diffs are not rendered by default.

41 changes: 33 additions & 8 deletions mlir/lib/Rewrite/ByteCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ namespace detail {
class PDLByteCode;

/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of
/// indices into the bytecode. Correctness is checked with static asserts.
/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;

Expand Down Expand Up @@ -62,14 +61,16 @@ class PDLByteCodePattern : public Pattern {
/// threads/drivers.
class PDLByteCodeMutableState {
public:
/// Initialize the state from a bytecode instance.
void initialize(PDLByteCode &bytecode);

/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
/// to the position of the pattern within the range returned by
/// `PDLByteCode::getPatterns`.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);

/// Cleanup any allocated state after a match/rewrite has been completed. This
/// method should be called irregardless of whether the match+rewrite was a
/// success or not.
void cleanupAfterMatchAndRewrite();

private:
/// Allow access to data fields.
friend class PDLByteCode;
Expand All @@ -78,6 +79,20 @@ class PDLByteCodeMutableState {
/// of the bytecode.
std::vector<const void *> memory;

/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of types.
std::vector<TypeRange> typeRangeMemory;
/// A set of type ranges that have been allocated by the byte code interpreter
/// to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;

/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of values.
std::vector<ValueRange> valueRangeMemory;
/// A set of value ranges that have been allocated by the byte code
/// interpreter to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;

/// The up-to-date benefits of the patterns held by the bytecode. The order
/// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
std::vector<PatternBenefit> currentPatternBenefits;
Expand All @@ -98,11 +113,19 @@ class PDLByteCode {
MatchResult(Location loc, const PDLByteCodePattern &pattern,
PatternBenefit benefit)
: location(loc), pattern(&pattern), benefit(benefit) {}
MatchResult(const MatchResult &) = delete;
MatchResult &operator=(const MatchResult &) = delete;
MatchResult(MatchResult &&other) = default;
MatchResult &operator=(MatchResult &&) = default;

/// The location of operations to be replaced.
Location location;
/// Memory values defined in the matcher that are passed to the rewriter.
SmallVector<const void *, 4> values;
SmallVector<const void *> values;
/// Memory used for the range input values.
SmallVector<TypeRange, 0> typeRangeValues;
SmallVector<ValueRange, 0> valueRangeValues;

/// The originating pattern that was matched. This is always non-null, but
/// represented with a pointer to allow for assignment.
const PDLByteCodePattern *pattern;
Expand All @@ -114,7 +137,6 @@ class PDLByteCode {
/// the PDL interpreter dialect.
PDLByteCode(ModuleOp module,
llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLCreateFunction> createFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns);

/// Return the patterns held by the bytecode.
Expand Down Expand Up @@ -160,11 +182,14 @@ class PDLByteCode {

/// A set of user defined functions invoked via PDL.
std::vector<PDLConstraintFunction> constraintFunctions;
std::vector<PDLCreateFunction> createFunctions;
std::vector<PDLRewriteFunction> rewriteFunctions;

/// The maximum memory index used by a value.
ByteCodeField maxValueMemoryIndex = 0;

/// The maximum number of different types of ranges.
ByteCodeField maxTypeRangeCount = 0;
ByteCodeField maxValueRangeCount = 0;
};

} // end namespace detail
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Rewrite/FrozenRewritePatternList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ FrozenRewritePatternList::FrozenRewritePatternList(
// Generate the pdl bytecode.
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
pdlModule, pdlPatterns.takeConstraintFunctions(),
pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
pdlPatterns.takeRewriteFunctions());
}

FrozenRewritePatternList::~FrozenRewritePatternList() {}
57 changes: 37 additions & 20 deletions mlir/lib/Rewrite/PatternApplicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,29 +129,40 @@ LogicalResult PatternApplicator::matchAndRewrite(

// Process the patterns for that match the specific operation type, and any
// operation type in an interleaved fashion.
auto opIt = opPatterns.begin(), opE = opPatterns.end();
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end();
while (true) {
unsigned opIt = 0, opE = opPatterns.size();
unsigned anyIt = 0, anyE = anyOpPatterns.size();
unsigned pdlIt = 0, pdlE = pdlMatches.size();
LogicalResult result = failure();
do {
// Find the next pattern with the highest benefit.
const Pattern *bestPattern = nullptr;
unsigned *bestPatternIt = &opIt;
const PDLByteCode::MatchResult *pdlMatch = nullptr;

/// Operation specific patterns.
if (opIt != opE)
bestPattern = *(opIt++);
if (opIt < opE)
bestPattern = opPatterns[opIt];
/// Operation agnostic patterns.
if (anyIt != anyE &&
(!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit()))
bestPattern = *(anyIt++);
if (anyIt < anyE &&
(!bestPattern ||
bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
bestPatternIt = &anyIt;
bestPattern = anyOpPatterns[anyIt];
}
/// PDL patterns.
if (pdlIt != pdlE &&
(!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) {
pdlMatch = pdlIt;
bestPattern = (pdlIt++)->pattern;
if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
pdlMatches[pdlIt].benefit)) {
bestPatternIt = &pdlIt;
pdlMatch = &pdlMatches[pdlIt];
bestPattern = pdlMatch->pattern;
}
if (!bestPattern)
break;

// Update the pattern iterator on failure so that this pattern isn't
// attempted again.
++(*bestPatternIt);

// Check that the pattern can be applied.
if (canApply && !canApply(*bestPattern))
continue;
Expand All @@ -160,19 +171,25 @@ LogicalResult PatternApplicator::matchAndRewrite(
// benefit, so if we match we can immediately rewrite. For PDL patterns, the
// match has already been performed, we just need to rewrite.
rewriter.setInsertionPoint(op);
LogicalResult result = success();
if (pdlMatch) {
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));

} else {
result = static_cast<const RewritePattern *>(bestPattern)
->matchAndRewrite(op, rewriter);
const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
result = pattern->matchAndRewrite(op, rewriter);
if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
result = failure();
}
if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern))))
return success();
if (succeeded(result))
break;

// Perform any necessary cleanups.
if (onFailure)
onFailure(*bestPattern);
}
return failure();
} while (true);

if (mutableByteCodeState)
mutableByteCodeState->cleanupAfterMatchAndRewrite();
return result;
}
39 changes: 23 additions & 16 deletions mlir/lib/TableGen/Predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ namespace {
using Subst = std::pair<StringRef, StringRef>;
} // end anonymous namespace

/// Perform the given substitutions on 'str' in-place.
static void performSubstitutions(std::string &str,
ArrayRef<Subst> substitutions) {
// Apply all parent substitutions from innermost to outermost.
for (const auto &subst : llvm::reverse(substitutions)) {
auto pos = str.find(std::string(subst.first));
while (pos != std::string::npos) {
str.replace(pos, subst.first.size(), std::string(subst.second));
// Skip the newly inserted substring, which itself may consider the
// pattern to match.
pos += subst.second.size();
// Find the next possible match position.
pos = str.find(std::string(subst.first), pos);
}
}
}

// Build the predicate tree starting from the top-level predicate, which may
// have children, and perform leaf substitutions inplace. Note that after
// substitution, nodes are still pointing to the original TableGen record.
Expand All @@ -147,19 +164,7 @@ buildPredicateTree(const Pred &root,
rootNode->predicate = &root;
if (!root.isCombined()) {
rootNode->expr = root.getCondition();
// Apply all parent substitutions from innermost to outermost.
for (const auto &subst : llvm::reverse(substitutions)) {
auto pos = rootNode->expr.find(std::string(subst.first));
while (pos != std::string::npos) {
rootNode->expr.replace(pos, subst.first.size(),
std::string(subst.second));
// Skip the newly inserted substring, which itself may consider the
// pattern to match.
pos += subst.second.size();
// Find the next possible match position.
pos = rootNode->expr.find(std::string(subst.first), pos);
}
}
performSubstitutions(rootNode->expr, substitutions);
return rootNode;
}

Expand All @@ -170,12 +175,14 @@ buildPredicateTree(const Pred &root,
const auto &substPred = static_cast<const SubstLeavesPred &>(root);
allSubstitutions.push_back(
{substPred.getPattern(), substPred.getReplacement()});
}
// If the current predicate is a ConcatPred, record the prefix and suffix.
else if (rootNode->kind == PredCombinerKind::Concat) {

// If the current predicate is a ConcatPred, record the prefix and suffix.
} else if (rootNode->kind == PredCombinerKind::Concat) {
const auto &concatPred = static_cast<const ConcatPred &>(root);
rootNode->prefix = std::string(concatPred.getPrefix());
performSubstitutions(rootNode->prefix, substitutions);
rootNode->suffix = std::string(concatPred.getSuffix());
performSubstitutions(rootNode->suffix, substitutions);
}

// Build child subtrees.
Expand Down
275 changes: 256 additions & 19 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ module @simple {

// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[REWRITE_ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.apply_rewrite "rewriter" on %[[REWRITE_ROOT]]
// CHECK: pdl_interp.apply_rewrite "rewriter"(%[[REWRITE_ROOT]]
// CHECK: pdl_interp.finalize
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"()
%root = pdl.operation "foo.op"
pdl.rewrite %root with "rewriter"
}
}
Expand Down Expand Up @@ -63,15 +63,16 @@ module @constraints {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]] : !pdl.value, !pdl.value)
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]]

pdl.pattern : benefit(1) {
%input0 = pdl.operand
%input1 = pdl.operand
%root = pdl.operation(%input0, %input1 : !pdl.value, !pdl.value)
%result0 = pdl.result 0 of %root

pdl.apply_constraint "multi_constraint"[true](%input0, %input1 : !pdl.value, !pdl.value)

%root = pdl.operation(%input0, %input1)
pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
pdl.rewrite %root with "rewriter"
}
}
Expand All @@ -95,7 +96,60 @@ module @inputs {
pdl.pattern : benefit(1) {
%type = pdl.type : i64
%input = pdl.operand : %type
%root = pdl.operation(%input, %input)
%root = pdl.operation(%input, %input : !pdl.value, !pdl.value)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @variadic_inputs
module @variadic_inputs {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is at_least 2

// The first operand has a known index.
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[INPUT]] : !pdl.value

// The second operand is a group of unknown size, with a type constraint.
// CHECK-DAG: %[[VAR_INPUTS:.*]] = pdl_interp.get_operands 1 of %[[ROOT]] : !pdl.range<value>
// CHECK-DAG: pdl_interp.is_not_null %[[VAR_INPUTS]] : !pdl.range<value>

// CHECK-DAG: %[[INPUT_TYPE:.*]] = pdl_interp.get_value_type of %[[VAR_INPUTS]] : !pdl.range<type>
// CHECK-DAG: pdl_interp.check_types %[[INPUT_TYPE]] are [i64]

// The third operand is at an unknown offset due to operand 2, but is expected
// to be of size 1.
// CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operands 2 of %[[ROOT]] : !pdl.value
// CHECK-DAG: pdl_interp.are_equal %[[INPUT]], %[[INPUT2]] : !pdl.value
pdl.pattern : benefit(1) {
%types = pdl.types : [i64]
%inputs = pdl.operands : %types
%input = pdl.operand
%root = pdl.operation(%input, %inputs, %input : !pdl.value, !pdl.range<value>, !pdl.value)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @single_operand_range
module @single_operand_range {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

// Check that the operand range is treated as all of the operands of the
// operation.
// CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_operands of %[[ROOT]]
// CHECK-DAG: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] : !pdl.range<type>
// CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPES]] are [i64]

// The operand count is unknown, so there is no need to check for it.
// CHECK-NOT: pdl_interp.check_operand_count
pdl.pattern : benefit(1) {
%types = pdl.types : [i64]
%operands = pdl.operands : %types
%root = pdl.operation(%operands : !pdl.range<value>)
pdl.rewrite %root with "rewriter"
}
}
Expand All @@ -107,43 +161,226 @@ module @results {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.check_result_count of %[[ROOT]] is 2

// Get the input and check the type.
// Get the result and check the type.
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value
// CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
// CHECK-DAG: pdl_interp.check_type %[[RESULT_TYPE]] is i32

// Get the second operand and check that it is equal to the first.
// CHECK-DAG: %[[RESULT1:.*]] = pdl_interp.get_result 1 of %[[ROOT]]
// CHECK-NOT: pdl_interp.get_value_type of %[[RESULT1]]
// The second result doesn't have any constraints, so we don't generate an
// access for it.
// CHECK-NOT: pdl_interp.get_result 1 of %[[ROOT]]
pdl.pattern : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
%root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @switch_result_types
module @switch_result_types {
// CHECK-LABEL: module @variadic_results
module @variadic_results {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: pdl_interp.check_result_count of %[[ROOT]] is at_least 2

// The first result has a known index.
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value

// The second result is a group of unknown size, with a type constraint.
// CHECK-DAG: %[[VAR_RESULTS:.*]] = pdl_interp.get_results 1 of %[[ROOT]] : !pdl.range<value>
// CHECK-DAG: pdl_interp.is_not_null %[[VAR_RESULTS]] : !pdl.range<value>

// CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[VAR_RESULTS]] : !pdl.range<type>
// CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPE]] are [i64]

// The third result is at an unknown offset due to result 1, but is expected
// to be of size 1.
// CHECK-DAG: %[[RESULT2:.*]] = pdl_interp.get_results 2 of %[[ROOT]] : !pdl.value
// CHECK-DAG: pdl_interp.is_not_null %[[RESULT2]] : !pdl.value
pdl.pattern : benefit(1) {
%types = pdl.types : [i64]
%type = pdl.type
%root = pdl.operation -> (%type, %types, %type : !pdl.type, !pdl.range<type>, !pdl.type)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @single_result_range
module @single_result_range {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

// Check that the result range is treated as all of the results of the
// operation.
// CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]]
// CHECK-DAG: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] : !pdl.range<type>
// CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPES]] are [i64]

// The result count is unknown, so there is no need to check for it.
// CHECK-NOT: pdl_interp.check_result_count
pdl.pattern : benefit(1) {
%types = pdl.types : [i64]
%root = pdl.operation -> (%types : !pdl.range<type>)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @results_as_operands
module @results_as_operands {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

// Get the first result and check it matches the first operand.
// CHECK-DAG: %[[OPERAND_0:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: %[[DEF_OP_0:.*]] = pdl_interp.get_defining_op of %[[OPERAND_0]]
// CHECK-DAG: %[[RESULT_0:.*]] = pdl_interp.get_result 0 of %[[DEF_OP_0]]
// CHECK-DAG: pdl_interp.are_equal %[[RESULT_0]], %[[OPERAND_0]]

// Get the second result and check it matches the second operand.
// CHECK-DAG: %[[OPERAND_1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK-DAG: %[[DEF_OP_1:.*]] = pdl_interp.get_defining_op of %[[OPERAND_1]]
// CHECK-DAG: %[[RESULT_1:.*]] = pdl_interp.get_result 1 of %[[DEF_OP_1]]
// CHECK-DAG: pdl_interp.are_equal %[[RESULT_1]], %[[OPERAND_1]]

// Check that the parent operation of both results is the same.
// CHECK-DAG: pdl_interp.are_equal %[[DEF_OP_0]], %[[DEF_OP_1]]

pdl.pattern : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%inputOp = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
%result1 = pdl.result 0 of %inputOp
%result2 = pdl.result 1 of %inputOp

%root = pdl.operation(%result1, %result2 : !pdl.value, !pdl.value)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @single_result_range_as_operands
module @single_result_range_as_operands {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands of %[[ROOT]] : !pdl.range<value>
// CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[OPERANDS]] : !pdl.range<value>
// CHECK-DAG: pdl_interp.is_not_null %[[OP]]
// CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_results of %[[OP]] : !pdl.range<value>
// CHECK-DAG: pdl_interp.are_equal %[[RESULTS]], %[[OPERANDS]] : !pdl.range<value>

pdl.pattern : benefit(1) {
%types = pdl.types
%inputOp = pdl.operation -> (%types : !pdl.range<type>)
%results = pdl.results of %inputOp

%root = pdl.operation(%results : !pdl.range<value>)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @switch_single_result_type
module @switch_single_result_type {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
// CHECK: pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64]
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation -> %type
%root = pdl.operation -> (%type : !pdl.type)
pdl.rewrite %root with "rewriter"
}
pdl.pattern : benefit(1) {
%type = pdl.type : i64
%root, %result = pdl.operation -> %type
%root = pdl.operation -> (%type : !pdl.type)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @switch_result_types
module @switch_result_types {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]]
// CHECK: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]]
// CHECK: pdl_interp.switch_types %[[RESULT_TYPES]] to {{\[\[}}i32], [i64, i32]]
pdl.pattern : benefit(1) {
%types = pdl.types : [i32]
%root = pdl.operation -> (%types : !pdl.range<type>)
pdl.rewrite %root with "rewriter"
}
pdl.pattern : benefit(1) {
%types = pdl.types : [i64, i32]
%root = pdl.operation -> (%types : !pdl.range<type>)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @switch_operand_count_at_least
module @switch_operand_count_at_least {
// Check that when there are multiple "at_least" checks, the failure branch
// goes to the next one in increasing order.

// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.check_operand_count of %[[ROOT]] is at_least 1 -> ^[[PATTERN_1_NEXT_BLOCK:.*]],
// CHECK: ^bb2:
// CHECK-NEXT: pdl_interp.check_operand_count of %[[ROOT]] is at_least 2
// CHECK: ^[[PATTERN_1_NEXT_BLOCK]]:
// CHECK-NEXT: {{.*}} -> ^{{.*}}, ^bb2
pdl.pattern : benefit(1) {
%operand = pdl.operand
%operands = pdl.operands
%root = pdl.operation(%operand, %operands : !pdl.value, !pdl.range<value>)
pdl.rewrite %root with "rewriter"
}
pdl.pattern : benefit(1) {
%operand = pdl.operand
%operand2 = pdl.operand
%operands = pdl.operands
%root = pdl.operation(%operand, %operand2, %operands : !pdl.value, !pdl.value, !pdl.range<value>)
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @switch_result_count_at_least
module @switch_result_count_at_least {
// Check that when there are multiple "at_least" checks, the failure branch
// goes to the next one in increasing order.

// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.check_result_count of %[[ROOT]] is at_least 1 -> ^[[PATTERN_1_NEXT_BLOCK:.*]],
// CHECK: ^[[PATTERN_2_BLOCK:[a-zA-Z_0-9]*]]:
// CHECK: pdl_interp.check_result_count of %[[ROOT]] is at_least 2
// CHECK: ^[[PATTERN_1_NEXT_BLOCK]]:
// CHECK-NEXT: pdl_interp.get_result
// CHECK-NEXT: pdl_interp.is_not_null {{.*}} -> ^{{.*}}, ^[[PATTERN_2_BLOCK]]
pdl.pattern : benefit(1) {
%type = pdl.type
%types = pdl.types
%root = pdl.operation -> (%type, %types : !pdl.type, !pdl.range<type>)
pdl.rewrite %root with "rewriter"
}
pdl.pattern : benefit(1) {
%type = pdl.type
%type2 = pdl.type
%types = pdl.types
%root = pdl.operation -> (%type, %type2, %types : !pdl.type, !pdl.type, !pdl.range<type>)
pdl.rewrite %root with "rewriter"
}
}


// -----

// CHECK-LABEL: module @predicate_ordering
Expand All @@ -160,14 +397,14 @@ module @predicate_ordering {

pdl.pattern : benefit(1) {
%resultType = pdl.type
pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type)
%root, %result = pdl.operation -> %resultType
pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type)
%root = pdl.operation -> (%resultType : !pdl.type)
pdl.rewrite %root with "rewriter"
}

pdl.pattern : benefit(1) {
%resultType = pdl.type
%apply, %applyRes = pdl.operation -> %resultType
%apply = pdl.operation -> (%resultType : !pdl.type)
pdl.rewrite %apply with "rewriter"
}
}
116 changes: 70 additions & 46 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
module @external {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value)
// CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[INPUT]] : !pdl.value) on %[[ROOT]]
// CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
pdl.pattern : benefit(1) {
%input = pdl.operand
%root = pdl.operation "foo.op"(%input)
%root = pdl.operation "foo.op"(%input : !pdl.value)
pdl.rewrite %root with "rewriter"[true](%input : !pdl.value)
}
}
Expand Down Expand Up @@ -37,7 +37,7 @@ module @operation_attributes {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ATTR:.*]]: !pdl.attribute, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR1:.*]] = pdl_interp.create_attribute true
// CHECK: pdl_interp.create_operation "foo.op"() {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]}
// CHECK: pdl_interp.create_operation "foo.op" {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]}
pdl.pattern : benefit(1) {
%attr = pdl.attribute
%root = pdl.operation "foo.op" {"attr" = %attr}
Expand All @@ -55,16 +55,17 @@ module @operation_attributes {
module @operation_operands {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]])
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]] : !pdl.value)
// CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
// CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]] : !pdl.value)
pdl.pattern : benefit(1) {
%operand = pdl.operand
%root = pdl.operation "foo.op"(%operand)
%root = pdl.operation "foo.op"(%operand : !pdl.value)
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %result = pdl.operation "foo.op"(%operand) -> %type
%newOp1 = pdl.operation "foo.op2"(%result)
%newOp = pdl.operation "foo.op"(%operand : !pdl.value) -> (%type : !pdl.type)
%result = pdl.result 0 of %newOp
%newOp1 = pdl.operation "foo.op2"(%result : !pdl.value)
pdl.erase %root
}
}
Expand All @@ -76,71 +77,90 @@ module @operation_operands {
module @operation_operands {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]])
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]] : !pdl.value)
// CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
// CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]] : !pdl.value)
pdl.pattern : benefit(1) {
%operand = pdl.operand
%root = pdl.operation "foo.op"(%operand)
%root = pdl.operation "foo.op"(%operand : !pdl.value)
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %result = pdl.operation "foo.op"(%operand) -> %type
%newOp1 = pdl.operation "foo.op2"(%result)
%newOp = pdl.operation "foo.op"(%operand : !pdl.value) -> (%type : !pdl.type)
%result = pdl.result 0 of %newOp
%newOp1 = pdl.operation "foo.op2"(%result : !pdl.value)
pdl.erase %root
}
}
}

// -----

// CHECK-LABEL: module @operation_result_types
module @operation_result_types {
// CHECK-LABEL: module @operation_infer_types_from_replaceop
module @operation_infer_types_from_replaceop {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPE1:.*]]: !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]], %[[TYPE1]]
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation
// CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]]
// CHECK: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]]
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[RESULT_TYPES]] : !pdl.range<type>)
pdl.pattern : benefit(1) {
%rootType = pdl.type
%rootType1 = pdl.type
%root, %results:2 = pdl.operation "foo.op" -> %rootType, %rootType1
%root = pdl.operation "foo.op" -> (%rootType, %rootType1 : !pdl.type, !pdl.type)
pdl.rewrite %root {
%newType1 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %rootType, %newType1
%newOp = pdl.operation "foo.op" -> (%rootType, %newType1 : !pdl.type, !pdl.type)
pdl.replace %root with %newOp
}
}
}

// -----

// CHECK-LABEL: module @operation_result_types_infer_from_value_replacement
module @operation_result_types_infer_from_value_replacement {
// CHECK-LABEL: module @operation_infer_types_from_otherop_individual_results
module @operation_infer_types_from_otherop_individual_results {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
// CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPES:.*]]: !pdl.range<type>
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>)
pdl.pattern : benefit(1) {
%rootType = pdl.type
%root, %result = pdl.operation "foo.op" -> %rootType
%rootTypes = pdl.types
%root = pdl.operation "foo.op" -> (%rootType, %rootTypes : !pdl.type, !pdl.range<type>)
pdl.rewrite %root {
%newType = pdl.type
%newOp, %newResult = pdl.operation "foo.op" -> %newType
pdl.replace %root with (%newResult)
%newOp = pdl.operation "foo.op" -> (%rootType, %rootTypes : !pdl.type, !pdl.range<type>)
}
}
}

// -----

// CHECK-LABEL: module @operation_infer_types_from_otherop_results
module @operation_infer_types_from_otherop_results {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[TYPES:.*]]: !pdl.range<type>
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPES]] : !pdl.range<type>)
pdl.pattern : benefit(1) {
%rootTypes = pdl.types
%root = pdl.operation "foo.op" -> (%rootTypes : !pdl.range<type>)
pdl.rewrite %root {
%newOp = pdl.operation "foo.op" -> (%rootTypes : !pdl.range<type>)
}
}
}

// -----

// CHECK-LABEL: module @replace_with_op
module @replace_with_op {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation
// CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
// CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[NEWOP]]
// CHECK: pdl_interp.replace %[[ROOT]] with (%[[RESULTS]] : !pdl.range<value>)
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation "foo.op" -> %type
%root = pdl.operation "foo.op" -> (%type : !pdl.type)
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
%newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
pdl.replace %root with %newOp
}
}
Expand All @@ -151,16 +171,21 @@ module @replace_with_op {
// CHECK-LABEL: module @replace_with_values
module @replace_with_values {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: func @pdl_generated_rewriter({{.*}}, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation
// CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
// CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
// CHECK: %[[RESULTS:.*]] = pdl_interp.get_results 1 of %[[NEWOP]] : !pdl.range<value>
// CHECK: %[[RESULTS_2:.*]] = pdl_interp.get_results 2 of %[[NEWOP]] : !pdl.value
// CHECK: pdl_interp.replace %[[ROOT]] with (%[[RESULT]], %[[RESULTS]], %[[RESULTS_2]] : !pdl.value, !pdl.range<value>, !pdl.value)
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %result = pdl.operation "foo.op" -> %type
%types = pdl.types
%root = pdl.operation "foo.op" -> (%types : !pdl.range<type>)
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
pdl.replace %root with (%newResult)
%newOp = pdl.operation "foo.op" -> (%types : !pdl.range<type>)
%newResult = pdl.result 0 of %newOp
%newResults = pdl.results 1 of %newOp -> !pdl.range<value>
%newResults2 = pdl.results 2 of %newOp -> !pdl.value
pdl.replace %root with (%newResult, %newResults, %newResults2 : !pdl.value, !pdl.range<value>, !pdl.value)
}
}
}
Expand All @@ -184,19 +209,18 @@ module @replace_with_no_results {

// -----

// CHECK-LABEL: module @create_native
module @create_native {
// CHECK-LABEL: module @apply_native_rewrite
module @apply_native_rewrite {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[TYPE:.*]] = pdl_interp.create_native "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
// CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type)
pdl.pattern : benefit(1) {
%type = pdl.type
%root, %result = pdl.operation "foo.op" -> %type
%root = pdl.operation "foo.op" -> (%type : !pdl.type)
pdl.rewrite %root {
%newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type
%newOp, %newResult = pdl.operation "foo.op" -> %newType
pdl.replace %root with %newOp
%newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type
%newOp = pdl.operation "foo.op" -> (%newType : !pdl.type)
}
}
}
88 changes: 57 additions & 31 deletions mlir/test/Dialect/PDL/invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

//===----------------------------------------------------------------------===//
// pdl::ApplyConstraintOp
// pdl::ApplyNativeConstraintOp
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"

// expected-error@below {{expected at least one argument}}
"pdl.apply_constraint"() {name = "foo", params = []} : () -> ()
"pdl.apply_native_constraint"() {name = "foo", params = []} : () -> ()
pdl.rewrite %op with "rewriter"
}

// -----

//===----------------------------------------------------------------------===//
// pdl::ApplyNativeRewriteOp
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
pdl.rewrite %op {
// expected-error@below {{expected at least one argument}}
"pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> ()
}
}

// -----

//===----------------------------------------------------------------------===//
// pdl::AttributeOp
//===----------------------------------------------------------------------===//
Expand All @@ -24,7 +38,7 @@ pdl.pattern : benefit(1) {
// expected-error@below {{expected only one of [`type`, `value`] to be set}}
%attr = pdl.attribute : %type 10

%op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type
%op = pdl.operation "foo.op" {"attr" = %attr} -> (%type : !pdl.type)
pdl.rewrite %op with "rewriter"
}

Expand Down Expand Up @@ -76,6 +90,20 @@ pdl.pattern : benefit(1) {

// -----

//===----------------------------------------------------------------------===//
// pdl::OperandsOp
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.operands

%op = pdl.operation "foo.op"
pdl.rewrite %op with "rewriter"
}

// -----

//===----------------------------------------------------------------------===//
// pdl::OperationOp
//===----------------------------------------------------------------------===//
Expand All @@ -102,13 +130,13 @@ pdl.pattern : benefit(1) {
// -----

pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"()
%op = pdl.operation "foo.op"
pdl.rewrite %op {
%type = pdl.type

// expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}}
// expected-note@below {{result type #0 was not constrained}}
%newOp, %result = pdl.operation "foo.op" -> %type
%newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
}
}

Expand Down Expand Up @@ -147,28 +175,12 @@ pdl.pattern : benefit(1) {

// -----

//===----------------------------------------------------------------------===//
// pdl::ReplaceOp
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type

// expected-error@below {{to have the same number of results as the replacement operation}}
pdl.replace %root with %newOp
}
}

// -----

pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %oldResult = pdl.operation "foo.op" -> %type
%root = pdl.operation "foo.op" -> (%type : !pdl.type)
pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
%newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
%newResult = pdl.result 0 of %newOp

// expected-error@below {{expected no replacement values to be provided when the replacement operation is present}}
"pdl.replace"(%root, %newOp, %newResult) {
Expand All @@ -179,15 +191,15 @@ pdl.pattern : benefit(1) {

// -----

//===----------------------------------------------------------------------===//
// pdl::ResultsOp
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type

// expected-error@below {{to have the same number of results as the provided replacement values}}
pdl.replace %root with (%newResult)
}
// expected-error@below {{expected `pdl.range<value>` result type when no index is specified, but got: '!pdl.value'}}
%results = "pdl.results"(%root) : (!pdl.operation) -> !pdl.value
pdl.rewrite %root with "rewriter"
}

// -----
Expand Down Expand Up @@ -252,3 +264,17 @@ pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
pdl.rewrite %op with "rewriter"
}

// -----

//===----------------------------------------------------------------------===//
// pdl::TypesOp
//===----------------------------------------------------------------------===//

pdl.pattern : benefit(1) {
// expected-error@below {{expected a bindable (i.e. `pdl.operands`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
%unused = pdl.types

%op = pdl.operation "foo.op"
pdl.rewrite %op with "rewriter"
}
33 changes: 16 additions & 17 deletions mlir/test/Dialect/PDL/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ pdl.pattern @operations : benefit(1) {
// Operation with attributes and results.
%attribute = pdl.attribute
%type = pdl.type
%op0, %op0_result = pdl.operation {"attr" = %attribute} -> %type
%op0 = pdl.operation {"attr" = %attribute} -> (%type : !pdl.type)
%op0_result = pdl.result 0 of %op0

// Operation with input.
%input = pdl.operand
%root = pdl.operation(%op0_result, %input)
%root = pdl.operation(%op0_result, %input : !pdl.value, !pdl.value)
pdl.rewrite %root with "rewriter"
}

// -----

pdl.pattern @rewrite_with_args : benefit(1) {
%input = pdl.operand
%root = pdl.operation(%input)
%root = pdl.operation(%input : !pdl.value)
pdl.rewrite %root with "rewriter"(%input : !pdl.value)
}

Expand All @@ -35,7 +36,7 @@ pdl.pattern @rewrite_with_params : benefit(1) {

pdl.pattern @rewrite_with_args_and_params : benefit(1) {
%input = pdl.operand
%root = pdl.operation(%input)
%root = pdl.operation(%input : !pdl.value)
pdl.rewrite %root with "rewriter"["I am param"](%input : !pdl.value)
}

Expand All @@ -46,38 +47,36 @@ pdl.pattern @rewrite_with_args_and_params : benefit(1) {
pdl.pattern @infer_type_from_operation_replace : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
%root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
pdl.rewrite %root {
%type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
%newOp = pdl.operation "foo.op" -> (%type1, %type3 : !pdl.type, !pdl.type)
pdl.replace %root with %newOp
}
}

// -----

// Check that the result type of an operation within a rewrite can be inferred
// from a pdl.replace.
pdl.pattern @infer_type_from_result_replace : benefit(1) {
// from types used within the match block.
pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
%root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
pdl.rewrite %root {
%type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
pdl.replace %root with (%newResults#0, %newResults#1)
%newOp = pdl.operation "foo.op" -> (%type1, %type2 : !pdl.type, !pdl.type)
}
}

// -----

// Check that the result type of an operation within a rewrite can be inferred
// from a pdl.replace.
// from types used within the match block.
pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
%types = pdl.types
%root = pdl.operation -> (%types : !pdl.range<type>)
pdl.rewrite %root {
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2
%otherTypes = pdl.types : [i32, i64]
%newOp = pdl.operation "foo.op" -> (%types, %otherTypes : !pdl.range<type>, !pdl.range<type>)
}
}
8 changes: 4 additions & 4 deletions mlir/test/Dialect/PDLInterp/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ func @operations(%attribute: !pdl.attribute,
%input: !pdl.value,
%type: !pdl.type) {
// attributes, operands, and results
%op0 = pdl_interp.create_operation "foo.op"(%input) {"attr" = %attribute} -> %type
%op0 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) {"attr" = %attribute} -> (%type : !pdl.type)

// attributes, and results
%op1 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute} -> %type
%op1 = pdl_interp.create_operation "foo.op" {"attr" = %attribute} -> (%type : !pdl.type)

// attributes
%op2 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute, "attr1" = %attribute} -> ()
%op2 = pdl_interp.create_operation "foo.op" {"attr" = %attribute, "attr1" = %attribute}

// operands, and results
%op3 = pdl_interp.create_operation "foo.op"(%input) -> %type
%op3 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) -> (%type : !pdl.type)

pdl_interp.finalize
}
Loading