Skip to content

Commit

Permalink
[mlir][vector][bufferize] Implement DestinationStyleOpInterface on Tr…
Browse files Browse the repository at this point in the history
…ansferWriteOp

This simplifies the BufferizableOpInterface implementation of vector.transfer_write.

Differential Revision: https://reviews.llvm.org/D136348
  • Loading branch information
matthias-springer committed Oct 27, 2022
1 parent d858447 commit bf531f2
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 29 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
Expand Down Expand Up @@ -1270,7 +1271,8 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Arguments<(ins AnyVectorOfAnyRank:$vector,
AnyShaped:$source,
Expand Down Expand Up @@ -1393,6 +1395,10 @@ def Vector_TransferWriteOp :
/// This method is added to maintain uniformity with load/store
/// ops of other dialects.
Value getValue() { return getVector(); }

std::pair<int64_t, int64_t> getOutputsPositionRange() {
return {1, 2}; // `source` operand
}
}];

let hasFolder = 1;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRArithDialect
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
MLIRMaskingInterfaces
Expand Down
34 changes: 6 additions & 28 deletions mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -63,35 +64,12 @@ struct TransferReadOpInterface

/// Bufferization of vector.transfer_write. Replace with a new
/// vector.transfer_write that operates on a memref.
///
/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
/// implementations for DestinationStyle ops.
struct TransferWriteOpInterface
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}

SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return {op->getOpResult(0)};
}

BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}

: public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
Expand Down
2 changes: 2 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3246,6 +3246,7 @@ cc_library(
":ArithDialect",
":ArithUtils",
":ControlFlowInterfaces",
":DestinationStyleOpInterface",
":DialectUtils",
":IR",
":InferTypeOpInterface",
Expand Down Expand Up @@ -8211,6 +8212,7 @@ td_library(
includes = ["include"],
deps = [
":ControlFlowInterfacesTdFiles",
":DestinationStyleOpInterfaceTdFiles",
":InferTypeOpInterfaceTdFiles",
":MaskingInterfacesTdFiles",
":OpBaseTdFiles",
Expand Down

0 comments on commit bf531f2

Please sign in to comment.