Skip to content

Commit

Permalink
[mlir:PDL] Allow non-bound pdl.attribute/pdl.type operations that cre…
Browse files Browse the repository at this point in the history
…ate constants

This allows for passing in these attributes/types to constraints/rewrites as arguments.

Differential Revision: https://reviews.llvm.org/D114817
  • Loading branch information
River707 committed Dec 10, 2021
1 parent 06c3b9c commit 233e947
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 16 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def PDLInterp_CreateTypesOp : PDLInterp_Op<"create_types", [NoSideEffect]> {
let arguments = (ins TypeArrayAttr:$value);
let results = (outs PDL_RangeOf<PDL_Type>:$result);
let assemblyFormat = "$value attr-dict";

let builders = [
OpBuilder<(ins "ArrayAttr":$type), [{
build($_builder, $_state,
pdl::RangeType::get($_builder.getType<pdl::TypeType>()), type);
}]>
];
}

//===----------------------------------------------------------------------===//
Expand Down
30 changes: 24 additions & 6 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
return val;

// Get the value for the parent position.
Value parentVal = getValueAt(currentBlock, pos->getParent());
Value parentVal;
if (Position *parent = pos->getParent())
parentVal = getValueAt(currentBlock, pos->getParent());

// TODO: Use a location from the position.
Location loc = parentVal.getLoc();
Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
builder.setInsertionPointToEnd(currentBlock);
Value value;
switch (pos->getKind()) {
Expand Down Expand Up @@ -331,6 +333,22 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
parentVal, resPos->getResultGroupNumber());
break;
}
case Predicates::AttributeLiteralPos: {
auto *attrPos = cast<AttributeLiteralPosition>(pos);
value =
builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
break;
}
case Predicates::TypeLiteralPos: {
auto *typePos = cast<TypeLiteralPosition>(pos);
Attribute rawTypeAttr = typePos->getValue();
if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
else
value = builder.create<pdl_interp::CreateTypesOp>(
loc, rawTypeAttr.cast<ArrayAttr>());
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
Expand All @@ -353,7 +371,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
args = {getValueAt(currentBlock, equalToQuestion->getValue())};
} else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
for (Position *position : std::get<1>(cstQuestion->getValue()))
for (Position *position : cstQuestion->getArgs())
args.push_back(getValueAt(currentBlock, position));
}

Expand Down Expand Up @@ -413,10 +431,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
break;
}
case Predicates::ConstraintQuestion: {
auto value = cast<ConstraintQuestion>(question)->getValue();
auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(
loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(),
success, failure);
loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
failure);
break;
}
default:
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Position::~Position() {}
unsigned Position::getOperationDepth() const {
if (const auto *operationPos = dyn_cast<OperationPosition>(this))
return operationPos->getDepth();
return parent->getOperationDepth();
return parent ? parent->getOperationDepth() : 0;
}

//===----------------------------------------------------------------------===//
Expand Down
47 changes: 47 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ enum Kind : unsigned {
ResultPos,
ResultGroupPos,
TypePos,
AttributeLiteralPos,
TypeLiteralPos,

// Questions, ordered by dependency and decreasing priority.
IsNotNullQuestion,
Expand Down Expand Up @@ -173,6 +175,16 @@ struct AttributePosition
StringAttr getName() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// AttributeLiteralPosition

/// A position describing a literal attribute.
struct AttributeLiteralPosition
: public PredicateBase<AttributeLiteralPosition, Position, Attribute,
Predicates::AttributeLiteralPos> {
using PredicateBase::PredicateBase;
};

//===----------------------------------------------------------------------===//
// OperandPosition

Expand Down Expand Up @@ -317,6 +329,17 @@ struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
}
};

//===----------------------------------------------------------------------===//
// TypeLiteralPosition

/// A position describing a literal type or type range. The value is stored as
/// either a TypeAttr, or an ArrayAttr of TypeAttr.
struct TypeLiteralPosition
: public PredicateBase<TypeLiteralPosition, Position, Attribute,
Predicates::TypeLiteralPos> {
using PredicateBase::PredicateBase;
};

//===----------------------------------------------------------------------===//
// Qualifiers
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -404,6 +427,17 @@ struct ConstraintQuestion
Predicates::ConstraintQuestion> {
using Base::Base;

/// Return the name of the constraint.
StringRef getName() const { return std::get<0>(key); }

/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }

/// Return the constant parameters of the constraint.
ArrayAttr getParams() const {
return std::get<2>(key).dyn_cast_or_null<ArrayAttr>();
}

/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
Expand Down Expand Up @@ -461,12 +495,14 @@ class PredicateUniquer : public StorageUniquer {
PredicateUniquer() {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
registerParametricStorageType<OperationPosition>();
registerParametricStorageType<ResultPosition>();
registerParametricStorageType<ResultGroupPosition>();
registerParametricStorageType<TypePosition>();
registerParametricStorageType<TypeLiteralPosition>();

// Register the types of Questions with the uniquer.
registerParametricStorageType<AttributeAnswer>();
Expand Down Expand Up @@ -527,6 +563,11 @@ class PredicateBuilder {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
}

/// Returns an attribute position for the given attribute.
Position *getAttributeLiteral(Attribute attr) {
return AttributeLiteralPosition::get(uniquer, attr);
}

/// Returns an operand position for an operand of the given operation.
Position *getOperand(OperationPosition *p, unsigned operand) {
return OperandPosition::get(uniquer, p, operand);
Expand Down Expand Up @@ -558,6 +599,12 @@ class PredicateBuilder {
/// Returns a type position for the given entity.
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }

/// Returns a type position for the given type value. The value is stored
/// as either a TypeAttr, or an ArrayAttr of TypeAttr.
Position *getTypeLiteral(Attribute attr) {
return TypeLiteralPosition::get(uniquer, attr);
}

//===--------------------------------------------------------------------===//
// Qualifiers
//===--------------------------------------------------------------------===//
Expand Down
38 changes: 36 additions & 2 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,18 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
}

/// Collect all of the predicates related to constraints within the given
/// pattern operation.
static void getAttributePredicates(pdl::AttributeOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&attrPos = inputs[op];
if (attrPos)
return;
Attribute value = op.valueAttr();
assert(value && "expected non-tree `pdl.attribute` to contain a value");
attrPos = builder.getAttributeLiteral(value);
}

static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
Expand Down Expand Up @@ -296,6 +306,19 @@ static void getResultPredicates(pdl::ResultsOp op,
predList.emplace_back(resultPos, builder.getIsNotNull());
}

static void getTypePredicates(Value typeValue,
function_ref<Attribute()> typeAttrFn,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&typePos = inputs[typeValue];
if (typePos)
return;
Attribute typeAttr = typeAttrFn();
assert(typeAttr &&
"expected non-tree `pdl.type`/`pdl.types` to contain a value");
typePos = builder.getTypeLiteral(typeAttr);
}

/// Collect all of the predicates that cannot be determined via walking the
/// tree.
static void getNonTreePredicates(pdl::PatternOp pattern,
Expand All @@ -304,11 +327,22 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
TypeSwitch<Operation *>(&op)
.Case([&](pdl::AttributeOp attrOp) {
getAttributePredicates(attrOp, predList, builder, inputs);
})
.Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
getConstraintPredicates(constraintOp, predList, builder, inputs);
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
getResultPredicates(resultOp, predList, builder, inputs);
})
.Case([&](pdl::TypeOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
})
.Case([&](pdl::TypesOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
});
}
}
Expand Down
25 changes: 18 additions & 7 deletions mlir/lib/Dialect/PDL/IR/PDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,15 @@ static LogicalResult verify(AttributeOp op) {
Value attrType = op.type();
Optional<Attribute> attrValue = op.value();

if (!attrValue && isa<RewriteOp>(op->getParentOp()))
return op.emitOpError("expected constant value when specified within a "
"`pdl.rewrite`");
if (attrValue && attrType)
if (!attrValue) {
if (isa<RewriteOp>(op->getParentOp()))
return op.emitOpError("expected constant value when specified within a "
"`pdl.rewrite`");
return verifyHasBindingUse(op);
}
if (attrType)
return op.emitOpError("expected only one of [`type`, `value`] to be set");
return verifyHasBindingUse(op);
return success();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -431,13 +434,21 @@ static LogicalResult verify(RewriteOp op) {
// pdl::TypeOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); }
static LogicalResult verify(TypeOp op) {
if (!op.typeAttr())
return verifyHasBindingUse(op);
return success();
}

//===----------------------------------------------------------------------===//
// pdl::TypesOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); }
static LogicalResult verify(TypesOp op) {
if (!op.typesAttr())
return verifyHasBindingUse(op);
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,42 @@ module @variadic_results_at {
pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
}
}

// -----

// CHECK-LABEL: module @attribute_literal
module @attribute_literal {
// CHECK: func @matcher(%{{.*}}: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.create_attribute 10 : i64
// CHECK: pdl_interp.apply_constraint "constraint"(%[[ATTR]] : !pdl.attribute)

// Check the correct lowering of an attribute that hasn't been bound.
pdl.pattern : benefit(1) {
%attr = pdl.attribute 10
pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute)

%root = pdl.operation
pdl.rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @type_literal
module @type_literal {
// CHECK: func @matcher(%{{.*}}: !pdl.operation)
// CHECK: %[[TYPE:.*]] = pdl_interp.create_type i32
// CHECK: %[[TYPES:.*]] = pdl_interp.create_types [i32, i64]
// CHECK: pdl_interp.apply_constraint "constraint"(%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>)

// Check the correct lowering of a type that hasn't been bound.
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%types = pdl.types : [i32, i64]
pdl.apply_native_constraint "constraint"(%type, %types: !pdl.type, !pdl.range<type>)

%root = pdl.operation
pdl.rewrite %root with "rewriter"
}
}

0 comments on commit 233e947

Please sign in to comment.