Skip to content

Commit

Permalink
[mlir][transform] Utilize op interface instead of tensor::TrackingLis…
Browse files Browse the repository at this point in the history
…tener

Add a new interface `FindPayloadReplacementOpInterface` to specify ops that should be skipped when looking for payload replacement ops. Such ops are typically metadata-only ops.

With this change, we no longer need to maintain a custom TrackingListener in the tensor dialect.

Note: `CastOpInterface` by itself is not sufficient. Some metadata-only ops such as "tensor.reshape" are not casts, and it would be incorrect for them to implement the `CastOpInterface`.

Differential Revision: https://reviews.llvm.org/D151888
  • Loading branch information
matthias-springer committed Jun 2, 2023
1 parent a584f0f commit 000bc58
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 64 deletions.
16 changes: 2 additions & 14 deletions mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
Expand Up @@ -18,21 +18,9 @@ namespace mlir {
class DialectRegistry;

namespace tensor {

/// A specialized TrackingListener for transform ops that operate on tensor IR.
/// This listener skips cast-like tensor ops when looking for payload op
/// replacements.
class TrackingListener : public transform::TrackingListener {
public:
using transform::TrackingListener::TrackingListener;

protected:
Operation *findReplacementOp(Operation *op,
ValueRange newValues) const override;
};

void registerTransformDialectExtension(DialectRegistry &registry);

void registerFindPayloadReplacementOpInterfaceExternalModels(
DialectRegistry &registry);
} // namespace tensor
} // namespace mlir

Expand Down
35 changes: 35 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
Expand Up @@ -193,4 +193,39 @@ def ParamProducerTransformOpTrait : NativeOpTrait<"ParamProducerTransformOpTrait
let cppNamespace = "::mlir::transform";
}

def FindPayloadReplacementOpInterface
: OpInterface<"FindPayloadReplacementOpInterface"> {
let description = [{
This interface is queried by the `TrackingListener` and can be implemented
by payload ops to indicate that the lookup should be continue with its
operands when looking for payload op replacements.

Example: Consider the case where a tracked "test.foo" payload op is replaced
with a new "test.foo" op, but wrapped in a "tensor.reshape" op. In that
case, the mapping of the original "test.foo" op should be updated with the
new "test.foo" op. A "tensor.reshape" is a metadata-only op that should be
skipped when inspecting the replacement values of the original "test.foo"
op. More details can be found at `TrackingListener` documentation.

Note: Ops that implement `CastOpInterface` do not need to implement this
interface. Such ops are skipped by default. This interface should be
implemented by cast-like/metadata-only ops that cannot implement
`CastOpInterface`.
}];

let cppNamespace = "::mlir::transform";

let methods = [
InterfaceMethod<
/*desc=*/[{
Return the operands at which the lookup for replacement payload ops
should continue.
}],
/*returnType=*/"::llvm::SmallVector<::mlir::Value>",
/*name=*/"getNextOperands",
/*arguments=*/(ins)
>,
];
}

#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
Expand Up @@ -50,6 +50,45 @@ class TrackingListener : public RewriterBase::Listener,
/// replaced with the given values. By default, if all values are defined by
/// the same op, which also has the same type as the given op, that defining
/// op is used as a replacement.
///
/// Example: A tracked "linalg.generic" with two results is replaced with two
/// values defined by (another) "linalg.generic". It is reasonable to assume
/// that the replacement "linalg.generic" represents the same "computation".
/// Therefore, the payload op mapping is updated to the defining op of the
/// replacement values.
///
/// Counter Example: A "linalg.generic" is replaced with values defined by an
/// "scf.for". Without further investigation, the relationship between the
/// "linalg.generic" and the "scf.for" is unclear. They may not represent the
/// same computation; e.g., there may be tiled "linalg.generic" inside the
/// loop body that represents the original computation. Therefore, the
/// TrackingListener is conservative by default: it drops the mapping and
/// triggers the "payload replacement not found" notification.
///
/// If no replacement op could be found according to the rules mentioned
/// above, this function tries to skip over cast-like ops that implement
/// `CastOpInterface`.
///
/// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
/// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
/// reasonable to assume that the wrapped "linalg.generic" represents the same
/// computation as the original "linalg.generic". The mapping is updated
/// accordingly.
///
/// Certain ops (typically also metadata-only ops) are not considered casts,
/// but should be skipped nonetheless. Such ops should implement
/// `FindPayloadReplacementOpInterface` to specify with which operands the
/// lookup should continue.
///
/// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
/// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
/// not cast. (Implementing `CastOpInterface` would be incorrect and cause
/// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
/// implementation, the replacement op lookup continues with the wrapped
/// "linalg.generic" and the mapping is updated accordingly.
///
/// Derived classes may override `findReplacementOp` to specify custom
/// replacement rules.
virtual Operation *findReplacementOp(Operation *op,
ValueRange newValues) const;

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/InitAllDialects.h
Expand Up @@ -159,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
shape::registerBufferizableOpInterfaceExternalModels(registry);
sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
Expand Down
92 changes: 55 additions & 37 deletions mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
Expand Up @@ -15,50 +15,68 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace tensor;

//===----------------------------------------------------------------------===//
// TrackingListener
// FindPayloadReplacementOpInterface implementations
//===----------------------------------------------------------------------===//

Operation *
tensor::TrackingListener::findReplacementOp(Operation *op,
ValueRange newValues) const {
SmallVector<Value> values(newValues.begin(), newValues.end());
do {
if (Operation *replacement =
transform::TrackingListener::findReplacementOp(op, values))
return replacement;

Operation *defOp = getCommonDefiningOp(values);
if (!defOp)
return nullptr;

// Skip cast-like operations.
values.clear();
llvm::TypeSwitch<Operation *>(defOp)
.Case<CastOp>([&](CastOp op) { values.push_back(op.getSource()); })
.Case<CollapseShapeOp>(
[&](CollapseShapeOp op) { values.push_back(op.getSrc()); })
.Case<ExpandShapeOp>(
[&](ExpandShapeOp op) { values.push_back(op.getSrc()); })
.Case<ReshapeOp>(
[&](ReshapeOp op) { values.push_back(op.getSource()); })
.Case<InsertSliceOp>([&](InsertSliceOp op) {
if (isCastLikeInsertSliceOp(op))
values.push_back(op.getSource());
})
.Case<ExtractSliceOp>([&](ExtractSliceOp op) {
if (isCastLikeExtractSliceOp(op))
values.push_back(op.getSource());
})
.Default([](Operation *op) {});
} while (!values.empty());

return nullptr;
namespace {
struct ExtractSliceOpReplacementInterface
: public transform::FindPayloadReplacementOpInterface::ExternalModel<
ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
SmallVector<Value> getNextOperands(Operation *op) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
if (!isCastLikeExtractSliceOp(extractSliceOp))
return {};
return {extractSliceOp.getSource()};
}
};

struct InsertSliceOpReplacementInterface
: public transform::FindPayloadReplacementOpInterface::ExternalModel<
InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
SmallVector<Value> getNextOperands(Operation *op) const {
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
if (!isCastLikeInsertSliceOp(insertSliceOp))
return {};
return {insertSliceOp.getSource()};
}
};

struct ReshapeOpReplacementInterface
: public transform::FindPayloadReplacementOpInterface::ExternalModel<
ReshapeOpReplacementInterface, tensor::ReshapeOp> {
SmallVector<Value> getNextOperands(Operation *op) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
return {reshapeOp.getSource()};
}
};

template <typename ConcreteOp>
struct ReassociativeReshapeOpReplacementInterface
: public transform::FindPayloadReplacementOpInterface::ExternalModel<
ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
SmallVector<Value> getNextOperands(Operation *op) const {
auto reshapeOp = cast<ConcreteOp>(op);
return {reshapeOp.getSrc()};
}
};
} // namespace

void tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
CollapseShapeOp::attachInterface<
ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
ExpandShapeOp::attachInterface<
ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
});
}

//===----------------------------------------------------------------------===//
Expand Down
34 changes: 26 additions & 8 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Expand Up @@ -74,17 +74,35 @@ transform::TrackingListener::findReplacementOp(Operation *op,
ValueRange newValues) const {
assert(op->getNumResults() == newValues.size() &&
"invalid number of replacement values");
SmallVector<Value> values(newValues.begin(), newValues.end());

// If the replacement values belong to different ops, drop the mapping.
Operation *defOp = getCommonDefiningOp(newValues);
if (!defOp)
return nullptr;
do {
// If the replacement values belong to different ops, drop the mapping.
Operation *defOp = getCommonDefiningOp(values);
if (!defOp)
return nullptr;

// If the replacement op has a different type, drop the mapping.
if (op->getName() != defOp->getName())
return nullptr;
// If the defining op has the same type, we take it as a replacement.
if (op->getName() == defOp->getName())
return defOp;

return defOp;
values.clear();

// Skip through ops that implement FindPayloadReplacementOpInterface.
if (auto findReplacementOpInterface =
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
values.assign(findReplacementOpInterface.getNextOperands());
continue;
}

// Skip through ops that implement CastOpInterface.
if (isa<CastOpInterface>(defOp)) {
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
continue;
}
} while (!values.empty());

return nullptr;
}

LogicalResult transform::TrackingListener::notifyMatchFailure(
Expand Down
1 change: 0 additions & 1 deletion mlir/test/lib/Dialect/Tensor/CMakeLists.txt
Expand Up @@ -10,7 +10,6 @@ add_mlir_library(MLIRTensorTestPasses
MLIRPass
MLIRSCFDialect
MLIRTensorDialect
MLIRTensorTransformOps
MLIRTensorTransforms
MLIRTransformDialect
MLIRTransforms
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Expand Up @@ -14,10 +14,10 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -296,9 +296,9 @@ applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
}

namespace {
class DummyTrackingListener : public tensor::TrackingListener {
class DummyTrackingListener : public transform::TrackingListener {
public:
using tensor::TrackingListener::TrackingListener;
using transform::TrackingListener::TrackingListener;

// Expose `findReplacementOp` as a public function, so that it can be tested.
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
Expand Down
1 change: 0 additions & 1 deletion utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Expand Up @@ -857,7 +857,6 @@ cc_library(
"//mlir:Pass",
"//mlir:SCFDialect",
"//mlir:TensorDialect",
"//mlir:TensorTransformOps",
"//mlir:TensorTransforms",
"//mlir:TransformDialect",
"//mlir:Transforms",
Expand Down

0 comments on commit 000bc58

Please sign in to comment.