Skip to content

Commit

Permalink
[mlir] switch transform dialect ops to use TransformTypeInterface
Browse files Browse the repository at this point in the history
Use the recently introduced TransformTypeInterface instead of hardcoding
the PDLOperationType. This will allow the operations to use more
specific transform types to express pre/post-conditions in the future.
It requires the syntax and Python op construction API to be updated.
Dialect extensions will be switched separately.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135584
  • Loading branch information
ftynse committed Oct 11, 2022
1 parent b586d56 commit 6fe0309
Show file tree
Hide file tree
Showing 41 changed files with 482 additions and 505 deletions.
85 changes: 45 additions & 40 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
Expand Down Expand Up @@ -88,22 +88,22 @@ def AlternativesOp : TransformDialectOp<"alternatives",
```
}];

let arguments = (ins Optional<PDL_Operation>:$scope);
let results = (outs Variadic<AnyType>:$results);
let arguments = (ins Optional<TransformTypeInterface>:$scope);
let results = (outs Variadic<TransformTypeInterface>:$results);
let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);

let assemblyFormat =
"($scope^)? (`->` type($results)^)? attr-dict-with-keyword regions";
"($scope^ `:` type($scope))? (`->` type($results)^)? "
"attr-dict-with-keyword regions";
let hasVerifier = 1;
}

def CastOp : TransformDialectOp<"cast",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
// TODO: temporarily fallback support for casting from PDL_Operation type.
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
let arguments = (ins TransformTypeInterface:$input);
let results = (outs TransformTypeInterface:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";

let extraClassDeclaration = [{
Expand Down Expand Up @@ -143,10 +143,11 @@ def ForeachOp : TransformDialectOp<"foreach",
merged and mapped to the same resulting handle.
}];

let arguments = (ins PDL_Operation:$target);
let results = (outs Variadic<PDL_Operation>:$results);
let arguments = (ins TransformTypeInterface:$target);
let results = (outs Variadic<TransformTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = "$target (`->` type($results)^)? $body attr-dict";
let assemblyFormat =
"$target `:` type($target) (`->` type($results)^)? $body attr-dict";
let hasVerifier = 1;

let extraClassDeclaration = [{
Expand Down Expand Up @@ -182,9 +183,10 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
on the further transformation applied to the handle produced here.
}];

let arguments = (ins PDL_Operation:$target);
let results = (outs PDL_Operation:$parent);
let assemblyFormat = "$target attr-dict";
let arguments = (ins TransformTypeInterface:$target);
let results = (outs TransformTypeInterface:$parent);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
}

def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
Expand All @@ -200,15 +202,17 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
computational operations, which can be empty.
}];

let arguments = (ins PDL_Operation:$target,
let arguments = (ins TransformTypeInterface:$target,
I64Attr:$operand_number);
let results = (outs PDL_Operation:$parent);
let assemblyFormat = "$target `[` $operand_number `]` attr-dict";
let results = (outs TransformTypeInterface:$parent);
let assemblyFormat = "$target `[` $operand_number `]` attr-dict `:` "
"functional-type(operands, results)";
}

def MergeHandlesOp : TransformDialectOp<"merge_handles",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SameOperandsAndResultType]> {
let summary = "Merges handles into one pointing to the union of payload ops";
let description = [{
Creates a new Transform IR handle value that points to the same Payload IR
Expand All @@ -221,10 +225,10 @@ def MergeHandlesOp : TransformDialectOp<"merge_handles",
same or different handles. Consumes the operands and produces a new handle.
}];

let arguments = (ins Variadic<PDL_Operation>:$handles,
let arguments = (ins Variadic<TransformTypeInterface>:$handles,
UnitAttr:$deduplicate);
let results = (outs PDL_Operation:$result);
let assemblyFormat = "($deduplicate^)? $handles attr-dict";
let results = (outs TransformTypeInterface:$result);
let assemblyFormat = "($deduplicate^)? $handles attr-dict `:` type($result)";
let hasFolder = 1;
}

Expand All @@ -246,13 +250,12 @@ def SplitHandlesOp : TransformDialectOp<"split_handles",
operations contained in the source `handle`. Otherwise it silently fails.
}];

let arguments = (ins PDL_Operation:$handle,
let arguments = (ins TransformTypeInterface:$handle,
I64Attr:$num_result_handles);
let results = (outs Variadic<PDL_Operation>:$results);
let results = (outs Variadic<TransformTypeInterface>:$results);
let assemblyFormat = [{
$handle `in` `[` $num_result_handles `]`
custom<StaticNumPDLResults>(type($results), ref($num_result_handles))
attr-dict
attr-dict `:` functional-type(operands, results)
}];
}

Expand All @@ -278,17 +281,19 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
}];

let arguments = (ins
Arg<PDL_Operation, "Payload IR scope to match within">:$root,
Arg<TransformTypeInterface, "Payload IR scope to match within">:$root,
SymbolRefAttr:$pattern_name);
let results = (outs
Res<PDL_Operation, "Handle to the matched Payload IR ops">:$matched);
Res<TransformTypeInterface, "Handle to the matched Payload IR ops">:$matched);

let assemblyFormat = "$pattern_name `in` $root attr-dict";
let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
"functional-type(operands, results)";
}

def ReplicateOp : TransformDialectOp<"replicate",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AllTypesMatch<["handles", "replicated"]>]> {
let summary = "Lists payload ops multiple times in the new handle";
let description = [{
Produces a new handle associated with a list of payload IR ops that is
Expand All @@ -314,12 +319,11 @@ def ReplicateOp : TransformDialectOp<"replicate",
MergeHandlesOp can be used to construct arbitrary lists with repetitions.
}];

let arguments = (ins PDL_Operation:$pattern,
Variadic<PDL_Operation>:$handles);
let results = (outs Variadic<PDL_Operation>:$replicated);
let assemblyFormat =
"`num` `(` $pattern `)` $handles "
"custom<PDLOpTypedResults>(type($replicated), ref($handles)) attr-dict";
let arguments = (ins TransformTypeInterface:$pattern,
Variadic<TransformTypeInterface>:$handles);
let results = (outs Variadic<TransformTypeInterface>:$replicated);
let assemblyFormat = "`num` `(` $pattern `)` $handles attr-dict `:` "
"type($pattern) `,` type($handles)";
}

def SequenceOp : TransformDialectOp<"sequence",
Expand Down Expand Up @@ -358,12 +362,13 @@ def SequenceOp : TransformDialectOp<"sequence",
}];

let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
Optional<PDL_Operation>:$root);
let results = (outs Variadic<AnyType>:$results);
Optional<TransformTypeInterface>:$root);
let results = (outs Variadic<TransformTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);

let assemblyFormat =
"($root^)? `failures` `(` $failure_propagation_mode `)` attr-dict-with-keyword regions (`:` type($results)^)?";
"($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` "
"$failure_propagation_mode `)` attr-dict-with-keyword regions";

let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
Expand Down Expand Up @@ -414,10 +419,10 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
}];

let arguments = (ins
Arg<Optional<PDL_Operation>, "Root operation of the Payload IR",
Arg<Optional<TransformTypeInterface>, "Root operation of the Payload IR",
[TransformMappingRead]>:$root);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = "($root^)? attr-dict-with-keyword regions";
let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";

let hasVerifier = 1;

Expand All @@ -436,7 +441,7 @@ def YieldOp : TransformDialectOp<"yield", [Terminator]> {
}];

let arguments = (ins
Arg<Variadic<AnyType>, "Operation handles yielded back to the parent",
Arg<Variadic<TransformTypeInterface>, "Operation handles yielded back to the parent",
[TransformMappingRead]>:$operands);
let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";

Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ void transform::detail::checkImplementsTransformTypeInterface(
}
#endif // NDEBUG

namespace {
struct PDLOperationTypeTransformTypeInterfaceImpl
: public transform::TransformTypeInterface::ExternalModel<
PDLOperationTypeTransformTypeInterfaceImpl, pdl::OperationType> {
DiagnosedSilenceableFailure
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}
};
} // namespace

void transform::TransformDialect::initialize() {
// Using the checked versions to enable the same assertions as for the ops
// from extensions.
Expand All @@ -53,6 +64,9 @@ void transform::TransformDialect::initialize() {
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
>();

pdl::OperationType::attachInterface<
PDLOperationTypeTransformTypeInterfaceImpl>(*getContext());
}

void transform::TransformDialect::mergeInPDLMatchHooks(
Expand Down
30 changes: 13 additions & 17 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -71,12 +70,11 @@ transform::TransformState::setPayloadOps(Value value,
if (value.use_empty())
return success();

if (auto iface = value.getType().dyn_cast<TransformTypeInterface>()) {
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), targets);
if (failed(result.checkAndReport()))
return failure();
}
auto iface = value.getType().cast<TransformTypeInterface>();
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), targets);
if (failed(result.checkAndReport()))
return failure();

// Setting new payload for the value without cleaning it first is a misuse of
// the API, assert here.
Expand Down Expand Up @@ -128,12 +126,11 @@ LogicalResult transform::TransformState::updatePayloadOps(
}
}

if (auto iface = value.getType().dyn_cast<TransformTypeInterface>()) {
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), updated);
if (failed(result.checkAndReport()))
return failure();
}
auto iface = value.getType().cast<TransformTypeInterface>();
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), updated);
if (failed(result.checkAndReport()))
return failure();

std::swap(association, updated);
return success();
Expand Down Expand Up @@ -369,10 +366,9 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {

Block *body = &bodyRegion->front();
if (body->getNumArguments() != 1 ||
!body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
return op->emitOpError()
<< "expects the entry block to have one argument of type "
<< pdl::OperationType::get(op->getContext());
!body->getArgumentTypes()[0].isa<TransformTypeInterface>()) {
return op->emitOpError() << "expects the entry block to have one argument "
"of type implementing TransformTypeInterface";
}

if (auto *parent =
Expand Down
38 changes: 4 additions & 34 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
Expand All @@ -24,31 +25,6 @@

using namespace mlir;

/// Custom parser for ReplicateOp.
static ParseResult parsePDLOpTypedResults(
OpAsmParser &parser, SmallVectorImpl<Type> &types,
const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
return success();
}

/// Custom printer for ReplicateOp.
static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
ValueRange) {}

/// Custom parser for SplitHandlesOp.
static ParseResult parseStaticNumPDLResults(OpAsmParser &parser,
SmallVectorImpl<Type> &types,
IntegerAttr numHandlesAttr) {
types.resize(numHandlesAttr.getInt(),
pdl::OperationType::get(parser.getContext()));
return success();
}

/// Custom printer for SplitHandlesOp.
static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange,
IntegerAttr) {}

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"

Expand Down Expand Up @@ -269,13 +245,6 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
LogicalResult transform::AlternativesOp::verify() {
for (Region &alternative : getAlternatives()) {
Block &block = alternative.front();
if (block.getNumArguments() != 1 ||
!block.getArgument(0).getType().isa<pdl::OperationType>()) {
return emitOpError()
<< "expects region blocks to have one operand of type "
<< pdl::OperationType::get(getContext());
}

Operation *terminator = block.getTerminator();
if (terminator->getOperands().getTypes() != getResults().getTypes()) {
InFlightDiagnostic diag = emitOpError()
Expand Down Expand Up @@ -403,8 +372,9 @@ LogicalResult transform::ForeachOp::verify() {
return emitOpError() << "expects the same number of results as the "
"terminator has operands";
for (Value v : yieldOp.getOperands())
if (!v.getType().isa<pdl::OperationType>())
return yieldOp->emitOpError("expects only PDL_Operation operands");
if (!v.getType().isa<TransformTypeInterface>())
return yieldOp->emitOpError(
"expects operands to have types implementing TransformTypeInterface");
return success();
}

Expand Down

0 comments on commit 6fe0309

Please sign in to comment.