Skip to content

Commit

Permalink
[mlir][interfaces] Remove getDestinationOperands from TilingInterface
Browse files Browse the repository at this point in the history
`getDestinationOperands` was almost a duplicate of `DestinationStyleOpInterface::getOutputOperands`. Now that the interface has been moved to mlir/Interfaces, it is no longer needed.

Differential Revision: https://reviews.llvm.org/D136240
  • Loading branch information
matthias-springer committed Oct 24, 2022
1 parent a1317be commit b169643
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 68 deletions.
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
Value tensor, Value dest);

/// This is a helper function for DestinationStyleOpInterface. If there is a
/// destination operand for the given OpResult, return that operand. Otherwise,
/// return an empty tensor (`tensor.empty`) with the shape of the OpResult.
/// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface.
FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
OpResult opResult);

/// This is a helper function for DestinationStyleOpInterface. Get or create
/// destinations for every tensor OpResult of the given op.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
SmallVector<Value> &result);

/// Function to control the folding of constant and extract slice
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;

Expand Down
17 changes: 14 additions & 3 deletions mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,18 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
assert(opOperand->getOwner() == $_op.getOperation());
return !opOperand->get().getType().template isa<ShapedType>();
}]
>,
InterfaceMethod<
/*desc=*/"Return the result tied to `opOperand`.",
/*desc=*/"Return the OpResult that is tied to the given OpOperand.",
/*retTy=*/"OpResult",
/*methodName=*/"getTiedOpResult",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
assert(opOperand->getOwner() == $_op.getOperation());

auto [start, end] = $_op.getOutputsPositionRange();
int64_t resultIndex = opOperand->getOperandNumber() - start;
Expand All @@ -219,6 +219,17 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
return $_op->getResult(resultIndex);
}]
>,
InterfaceMethod<
/*desc=*/"Return the OpOperand that is tied to the given OpResult.",
/*retTy=*/"OpOperand *",
/*methodName=*/"getTiedOpOperand",
/*args=*/(ins "OpResult":$opResult),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opResult.getDefiningOp() == $_op.getOperation());
return $_op.getOutputOperand(opResult.getResultNumber());
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
Expand Down
15 changes: 0 additions & 15 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,6 @@ def TilingInterface : OpInterface<"TilingInterface"> {
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Returns a list of operands into which the result of the
tiled implementation is written into. With `tensor`
operands, this will be used as the initial tensor into which
the tiled results are inserted into. With `memref` operands,
this will be the operand into which the result of the tiled
operation is written into.
}],
/*retType=*/"SmallVector<Value>",
/*methodName=*/"getDestinationOperands",
/*args=*/(ins "OpBuilder &":$b),
/*methodBody=*/"",
/*defaultImplementation=*/"return ValueRange{};"
>,
InterfaceMethod<
/*desc=*/[{
Returns a list of iterator types that describe the number of loops.
Expand Down
14 changes: 11 additions & 3 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,18 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");

auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
// Gather destination tensors.
SmallVector<Value> destinationTensors;
if (failed(tensor::getOrCreateDestinations(
rewriter, tileableProducer->getLoc(), tileableProducer,
destinationTensors))) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to get destination tensors for: " << *tileableProducer;
return nullptr;
}

BlockAndValueMapping bvm;
bvm.map(destinationOperands[resultNumber], bbArg);
bvm.map(destinationTensors[resultNumber], bbArg);
auto tileableProducerClone =
cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
auto scopeGuard =
Expand All @@ -403,7 +411,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the use in containingOp.
rewriter.updateRootInPlace(containingOp, [&]() {
containingOp->setOperand(pUse->getOperandNumber(),
destinationOperands.front());
destinationTensors.front());
});

return fusedOp;
Expand Down
12 changes: 9 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,18 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
return {op, TilingInterface()};
}

// Compute destination tensors.
SmallVector<Value> destinationTensors;
LogicalResult destStatus = tensor::getOrCreateDestinations(
rewriter, op.getLoc(), op, destinationTensors);
(void)destStatus;
assert(succeeded(destStatus) && "failed to get destination tensors");

// Create the first part.
SmallVector<Value> firstResults;
TilingInterface firstPart = createSplitPart(
rewriter, op.getLoc(), op, offsets, sizes,
op.getDestinationOperands(rewriter), dimension, minSplitPoint,
iterationSpace[dimension].offset, firstResults);
rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension,
minSplitPoint, iterationSpace[dimension].offset, firstResults);

// Need to pretend that the original op now takes as operands firstResults,
// otherwise tiling interface implementation will take the wrong value to
Expand Down
16 changes: 11 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
if (llvm::any_of(loopRanges, hasStrideOne))
return op->emitOpError("only stride-1 supported atm");
auto dest = op.getDestinationOperands(b);

// Gather destination tensors.
SmallVector<Value> dest;
if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
return op->emitOpError("failed to get destination tensors");

SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
Expand Down Expand Up @@ -622,11 +626,13 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
getValueOrCreateConstantIndexOp(builder, loc, tileSizes[i]));
}
}
// Generate loop nest: One loop per dimension.
SmallVector<Value> destOperand =
tilingInterface.getDestinationOperands(builder);
SmallVector<Value> destinationTensors;
if (failed(tensor::getOrCreateDestinations(builder, loc, tilingInterface,
destinationTensors)))
return failure();

loopNest = mlir::scf::buildLoopNest(
builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand),
builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destinationTensors),
[&](OpBuilder &b, Location loc, ValueRange localIvs,
ValueRange iterArgs) -> scf::ValueVector {
// Compute offsets and sizes of ExtractSliceOp.
Expand Down
5 changes: 0 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ template <typename LinalgOpTy>
struct LinalgOpTilingInterface
: public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
LinalgOpTy> {
/// Return the destination operands.
SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
return cast<LinalgOp>(op).getOutputOperands();
}

/// Return the loop iterator type.
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
MLIRArithDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
MLIRMemRefDialect
Expand Down
46 changes: 27 additions & 19 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -274,6 +275,12 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
op, "missing tile size computation function");
}

// Get destination tensors.
SmallVector<Value> destinationTensors;
if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
destinationTensors)))
return rewriter.notifyMatchFailure(op, "failed to get destinations");

// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
size_t numLoops = iterationDomain.size();
Expand Down Expand Up @@ -378,17 +385,21 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}
}

FailureOr<SmallVector<Value>> replacementOr =
yieldTiledValues(rewriter, op.getDestinationOperands(rewriter),
tilingResult.tiledOp->getResults(), resultOffsetsList,
resultSizesList, tilingResult.loops);
FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
rewriter, destinationTensors, tilingResult.tiledOp->getResults(),
resultOffsetsList, resultSizesList, tilingResult.loops);
if (failed(replacementOr))
return rewriter.notifyMatchFailure(op, "failed to yield replacement");
if (auto tiledInterfaceOp = dyn_cast<TilingInterface>(tilingResult.tiledOp)) {

if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) {
auto innerMostLoop = tilingResult.loops.back();
updateDestinationOperandsForTiledOp(
rewriter, tiledInterfaceOp.getDestinationOperands(rewriter),
innerMostLoop.getRegionIterArgs());
SmallVector<Value> destinationTensors = dstOp.getOutputOperands();
assert(destinationTensors.size() ==
innerMostLoop.getRegionIterArgs().size() &&
"unexpected number of outputs");
updateDestinationOperandsForTiledOp(rewriter, destinationTensors,
innerMostLoop.getRegionIterArgs());
}

tilingResult.replacements = replacementOr.value();
Expand Down Expand Up @@ -567,20 +578,17 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
}
if (iterArgNumber) {
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto producerOp =
dyn_cast<TilingInterface>(fusableProducer.getOwner())) {
SmallVector<Value> destination =
producerOp.getDestinationOperands(rewriter);
outerMostLoop.setIterArg(iterArgNumber.value(),
destination[resultNumber]);
if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(
fusableProducer.getOwner())) {
outerMostLoop.setIterArg(
iterArgNumber.value(),
dstOp.getTiedOpOperand(fusableProducer)->get());
}
if (auto tiledAndFusedInterfaceOp =
fusedProducerValue.value().getDefiningOp<TilingInterface>()) {
if (auto dstOp = fusedProducerValue.value()
.getDefiningOp<DestinationStyleOpInterface>()) {
scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
SmallVector<Value> destination =
tiledAndFusedInterfaceOp.getDestinationOperands(rewriter);
updateDestinationOperandsForTiledOp(
rewriter, destination[resultNumber],
rewriter, dstOp.getOutputOperand(resultNumber)->get(),
innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
}
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRTensorDialect
MLIRArithUtils
MLIRCastInterfaces
MLIRComplexDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
MLIRInferTypeOpInterface
Expand Down
54 changes: 54 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
Expand Down Expand Up @@ -54,6 +55,59 @@ SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
return result;
}

FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
OpResult opResult) {
auto tensorType = opResult.getType().dyn_cast<TensorType>();
assert(tensorType && "expected tensor type");

// If the op has a destination, it implements DestinationStyleOpInterface and
// we can query the destination operand from that interface.
auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
if (destOp)
return destOp.getTiedOpOperand(opResult)->get();

// Otherwise, create a new destination tensor with the same shape.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(opResult.getDefiningOp());

// Compute sizes.
SmallVector<OpFoldResult> mixedSizes;
if (!tensorType.hasStaticShape()) {
// Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
ReifiedRankedShapedTypeDims reifiedShapes;
ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
dyn_cast<ReifyRankedShapedTypeOpInterface>(opResult.getDefiningOp());
if (!reifyShapedTypeInterface)
return failure();
if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
return failure();
mixedSizes = getAsOpFoldResult(reifiedShapes[opResult.getResultNumber()]);
} else {
// Static shape: Take static sizes directly.
for (int64_t sz : tensorType.getShape())
mixedSizes.push_back(b.getIndexAttr(sz));
}

// Create empty tensor.
Value emptyTensor =
b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
return emptyTensor;
}

LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
Operation *op,
SmallVector<Value> &result) {
for (OpResult opResult : op->getResults()) {
if (opResult.getType().isa<TensorType>()) {
FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
if (failed(destination))
return failure();
result.push_back(*destination);
}
}
return success();
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 0 additions & 15 deletions mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,6 @@ namespace {

struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {

SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
ReifiedRankedShapedTypeDims reifiedShapes;
ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
(void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);

auto padOp = cast<PadOp>(op);
SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
Value emptyTensor = b.create<EmptyOp>(
op->getLoc(), mixedSizes, padOp.getResultType().getElementType());
return {emptyTensor};
}

SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto padOp = cast<PadOp>(op);
SmallVector<utils::IteratorType> iteratorTypes(
Expand Down
2 changes: 2 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1902,6 +1902,7 @@ cc_library(
":ArithUtils",
":BufferizationDialect",
":BufferizationTransforms",
":DestinationStyleOpInterface",
":DialectUtils",
":FuncDialect",
":IR",
Expand Down Expand Up @@ -5189,6 +5190,7 @@ cc_library(
":CastOpInterfaces",
":ComplexDialect",
":ControlFlowInterfaces",
":DestinationStyleOpInterface",
":DialectUtils",
":IR",
":InferTypeOpInterface",
Expand Down

0 comments on commit b169643

Please sign in to comment.