Skip to content

Commit

Permalink
[mlir][transform] Simplify TrackingListener test case
Browse files Browse the repository at this point in the history
Use the default TrackingListener. No need to set up a derived listener just for the test case. This revision is in preparation of a future change that adds a TrackingRewriter infrastructure.

Differential Revision: https://reviews.llvm.org/D152446
  • Loading branch information
matthias-springer committed Jun 9, 2023
1 parent e967638 commit 1b390f5
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 37 deletions.
15 changes: 8 additions & 7 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1601,19 +1601,20 @@ module attributes { transform.with_named_sequence } {
// -----

// CHECK-LABEL: func @test_tracked_rewrite() {
// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"}
// CHECK-NEXT: "test.drop_mapping"() {original_op = "test.replace_me"}
// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"}
// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> i1
// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> i1
// CHECK-NEXT: return
// CHECK-NEXT: }
func.func @test_tracked_rewrite() {
%0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
%1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1)
%2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1)
%0 = transform.test_dummy_payload_op {replace_me} : () -> (i1)
%1 = transform.test_dummy_payload_op {erase_me} : () -> (i1)
%2 = transform.test_dummy_payload_op {replace_me} : () -> (i1)
func.return
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match ops{["transform.test_dummy_payload_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-remark @below {{2 iterations}}
transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
// One replacement op (test.drop_mapping) is dropped from the mapping.
Expand Down
47 changes: 17 additions & 30 deletions mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,53 +687,40 @@ void mlir::test::TestTrackedRewriteOp::getEffects(
transform::modifiesPayload(effects);
}

namespace {
/// A TrackingListener for test cases. When the replacement op is
/// "test.update_mapping", it is considered as a replacement op in the transform
/// state mapping. Otherwise, it is not and the original op is simply removed
/// from the mapping.
class TestTrackingListener : public transform::TrackingListener {
using transform::TrackingListener::TrackingListener;

protected:
FailureOr<Operation *>
findReplacementOp(Operation *op, ValueRange newValues) const override {
if (newValues.size() != 1)
return failure();
Operation *replacement = newValues[0].getDefiningOp();
if (!replacement)
return failure();
if (replacement->getName().getStringRef() != "test.update_mapping")
return failure();
return replacement;
}
};
} // namespace
void mlir::test::TestDummyPayloadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (OpResult result : getResults())
transform::producesHandle(result, effects);
}

DiagnosedSilenceableFailure
mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
TestTrackingListener listener(state, *this);
transform::ErrorCheckingTrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
int64_t numIterations = 0;

// `getPayloadOps` returns an iterator that skips ops that are erased in the
// loop body. Replacement ops are not enumerated.
for (Operation *op : state.getPayloadOps(getIn())) {
++numIterations;
rewriter.setInsertionPointToEnd(op->getBlock());
(void)op;

// Erase all payload ops. The outer loop should have only one iteration.
for (Operation *op : state.getPayloadOps(getIn())) {
if (op->getName().getStringRef() != "test.replace_me")
rewriter.setInsertionPoint(op);
if (op->hasAttr("erase_me")) {
rewriter.eraseOp(op);
continue;
auto replacementName = op->getAttrOfType<StringAttr>("replacement");
if (!replacementName)
}
if (!op->hasAttr("replace_me")) {
continue;
}

SmallVector<NamedAttribute> attributes;
attributes.emplace_back(rewriter.getStringAttr("original_op"),
op->getName().getIdentifier());
OperationState opState(op->getLoc(), replacementName,
attributes.emplace_back(rewriter.getStringAttr("new_op"),
rewriter.getUnitAttr());
OperationState opState(op->getLoc(), op->getName().getIdentifier(),
/*operands=*/ValueRange(),
/*types=*/op->getResultTypes(), attributes);
Operation *newOp = rewriter.create(opState);
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,29 @@ def TestRequiredMemoryEffectsOp
let cppNamespace = "::mlir::test";
}

// This op is used as a payload op. It must be a registered op, so that it can
// be created with "RewriterBase::replaceOpWithNewOp" (needed for a test case).
// Since only TransformOpInterface can be injected into the transform dialect,
// this op implements the interface, even though it is not used as a transform
// op.
def TestDummyPayloadOp
: Op<Transform_Dialect, "test_dummy_payload_op",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface]> {
let arguments = (ins Variadic<AnyType>:$args);
let results = (outs Variadic<AnyType>:$outs);
let assemblyFormat = "$args attr-dict `:` functional-type(operands, results)";
let cppNamespace = "::mlir::test";

let extraClassDeclaration = [{
DiagnosedSilenceableFailure apply(transform::TransformResults &results,
transform::TransformState &state) {
llvm_unreachable("op should not be used as a transform");
return DiagnosedSilenceableFailure::definiteFailure();
}
}];
}

def TestTrackedRewriteOp
: Op<Transform_Dialect, "test_tracked_rewrite",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down

0 comments on commit 1b390f5

Please sign in to comment.