55 changes: 55 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,31 @@

using namespace mlir;

/// Custom parser for ReplicateOp.
static ParseResult parsePDLOpTypedResults(
OpAsmParser &parser, SmallVectorImpl<Type> &types,
const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
return success();
}

/// Custom printer for ReplicateOp.
static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
ValueRange) {}

/// Custom parser for SplitHandlesOp.
static ParseResult parseStaticNumPDLResults(OpAsmParser &parser,
SmallVectorImpl<Type> &types,
IntegerAttr numHandlesAttr) {
types.resize(numHandlesAttr.getInt(),
pdl::OperationType::get(parser.getContext()));
return success();
}

/// Custom printer for SplitHandlesOp.
static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange,
IntegerAttr) {}

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"

Expand Down Expand Up @@ -452,6 +467,46 @@ OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
return getHandles().front();
}

//===----------------------------------------------------------------------===//
// SplitHandlesOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::SplitHandlesOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
int64_t numResultHandles =
getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
int64_t expectedNumResultHandles = getNumResultHandles();
if (numResultHandles != expectedNumResultHandles) {
// Failing case needs to propagate gracefully for both suppress and
// propagate modes.
for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx)
results.set(getResults()[idx].cast<OpResult>(), {});
// Empty input handle corner case: always propagates empty handles in both
// suppress and propagate modes.
if (numResultHandles == 0)
return DiagnosedSilenceableFailure::success();
// If the input handle was not empty and the number of result handles does
// not match, this is a legit silenceable error.
return emitSilenceableError()
<< getHandle() << " expected to contain " << expectedNumResultHandles
<< " operation handles but it only contains " << numResultHandles
<< " handles";
}
// Normal successful case.
for (auto en : llvm::enumerate(state.getPayloadOps(getHandle())))
results.set(getResults()[en.index()].cast<OpResult>(), en.value());
return DiagnosedSilenceableFailure::success();
}

void transform::SplitHandlesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getHandle(), effects);
producesHandle(getResults(), effects);
// There are no effects on the Payload IR as this is only a handle
// manipulation.
}

//===----------------------------------------------------------------------===//
// PDLMatchOp
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,42 @@ transform.sequence failures(propagate) {

}

// -----

func.func @split_handles(%a: index, %b: index, %c: index) {
%0 = arith.muli %a, %b : index
%1 = arith.muli %a, %c : index
return
}

transform.sequence failures(propagate) {
^bb1(%fun: !pdl.operation):
%muli = transform.structured.match ops{["arith.muli"]} in %fun
%h:2 = split_handles %muli in [2]
// expected-remark @below {{1}}
transform.test_print_number_of_associated_payload_ir_ops %h#0
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
// expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}}
%h_2:3 = split_handles %muli_2 in [3]
}

// -----

func.func @split_handles(%a: index, %b: index, %c: index) {
%0 = arith.muli %a, %b : index
%1 = arith.muli %a, %c : index
return
}

transform.sequence failures(suppress) {
^bb1(%fun: !pdl.operation):
%muli = transform.structured.match ops{["arith.muli"]} in %fun
%h:2 = split_handles %muli in [2]
// expected-remark @below {{1}}
transform.test_print_number_of_associated_payload_ir_ops %h#0
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
// Silenceable failure and all handles are now empty.
%h_2:3 = split_handles %muli_2 in [3]
// expected-remark @below {{0}}
transform.test_print_number_of_associated_payload_ir_ops %h_2#0
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
DiagnosedSilenceableFailure
mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (!getHandle())
emitRemark() << 0;
emitRemark() << state.getPayloadOps(getHandle()).size();
return DiagnosedSilenceableFailure::success();
}
Expand Down