-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Use source
as the source argument name
#158258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Use source
as the source argument name
#158258
Conversation
@llvm/pr-subscribers-mlir-nvgpu @llvm/pr-subscribers-mlir-gpu Author: Andrzej Warzyński (banach-space) ChangesThis patch updates the following ops to use
This change ensures naming consistency with the "builders" for these Ops Specifically, it ensures that we use Patch is 34.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158258.diff 18 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 65ba7e0ad549f..f52075886326b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -675,7 +675,7 @@ def Vector_ExtractOp :
}];
let arguments = (ins
- AnyVectorOfAnyRank:$vector,
+ AnyVectorOfAnyRank:$source,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
@@ -692,7 +692,7 @@ def Vector_ExtractOp :
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getVector().getType());
+ return ::llvm::cast<VectorType>(getSource().getType());
}
/// Return a vector with all the static and dynamic position indices.
@@ -709,12 +709,17 @@ def Vector_ExtractOp :
bool hasDynamicPosition() {
return !getDynamicPosition().empty();
}
+
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
let assemblyFormat = [{
- $vector ``
+ $source ``
custom<DynamicIndexList>($dynamic_position, $static_position)
- attr-dict `:` type($result) `from` type($vector)
+ attr-dict `:` type($result) `from` type($source)
}];
let hasCanonicalizer = 1;
@@ -972,6 +977,10 @@ def Vector_ScalableInsertOp :
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
}
@@ -1023,6 +1032,10 @@ def Vector_ScalableExtractOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
}
@@ -1174,7 +1187,7 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
+ Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets,
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
Results<(outs AnyVectorOfNonZeroRank)> {
let summary = "extract_strided_slice operation";
@@ -1209,7 +1222,7 @@ def Vector_ExtractStridedSliceOp :
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getVector().getType());
+ return ::llvm::cast<VectorType>(getSource().getType());
}
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
@@ -1217,11 +1230,15 @@ def Vector_ExtractStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+ let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)";
}
// TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 9efa34a9a3acc..4e1da39c29260 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering
auto loc = extractOp.getLoc();
auto position = extractOp.getMixedPosition();
- Value sourceVector = extractOp.getVector();
+ Value sourceVector = extractOp.getSource();
// Extract entire vector. Should be handled by folder, but just to be safe.
if (position.empty()) {
@@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1d1904f717335..e1a22ec6d799b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
// Find the vector.transer_read whose result vector is being sliced.
- auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
+ auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return rewriter.notifyMatchFailure(op, "no transfer read");
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9fbac4925dc1d..e7266740894b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1131,7 +1131,7 @@ class VectorExtractOpConversion
positionVec.push_back(rewriter.getZeroAttr(idxType));
}
- Value extracted = adaptor.getVector();
+ Value extracted = adaptor.getSource();
if (extractsAggregate) {
ArrayRef<OpFoldResult> position(positionVec);
if (extractsScalar) {
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 508f4e25326eb..c45c45e4712f3 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion
/// Return the vector from which newly generated ExtracOps will extract.
Value getDataVector(TransferWriteOp xferOp) const {
if (auto extractOp = getExtractOp(xferOp))
- return extractOp.getVector();
+ return extractOp.getSource();
return xferOp.getVector();
}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c861935b4bc18..1c311d0312aaa 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -189,8 +189,8 @@ struct VectorExtractOpConvert final
if (!dstType)
return failure();
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
+ if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+ rewriter.replaceOp(extractOp, adaptor.getSource());
return success();
}
@@ -201,7 +201,7 @@ struct VectorExtractOpConvert final
extractOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
+ extractOp, dstType, adaptor.getSource(),
rewriter.getI32ArrayAttr(id.value()));
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
@@ -209,7 +209,7 @@ struct VectorExtractOpConvert final
vector::ExtractOp::kPoisonIndex,
extractOp.getSourceVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+ extractOp, dstType, adaptor.getSource(), sanitizedIndex);
}
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 9bf026563c255..9196d2ef79592 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend
return rewriter.notifyMatchFailure(
extractOp, "extracted type is not a 1-D scalable vector type");
- auto *extendOp = extractOp.getVector().getDefiningOp();
+ auto *extendOp = extractOp.getSource().getDefiningOp();
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
extendOp))
return rewriter.notifyMatchFailure(extractOp,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 1c0eced43dc00..576c92b375ff3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
PatternRewriter &rewriter) const override {
auto loc = extractOp.getLoc();
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
extractOp, "extract not from vector.create_mask op");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 36434cf2d2ae2..e1dc40d6d37d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();
// Check that the vector to extract from is a BlockArgument.
- auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+ auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
if (!blockArg)
return WalkResult::advance();
@@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();
rewriter.modifyOpInPlace(broadcast, [&] {
- extractOp.getVectorMutable().assign(initArg->get());
+ extractOp.getSourceMutable().assign(initArg->get());
});
loop.moveOutOfLoop(extractOp);
rewriter.moveOpAfter(broadcast, loop);
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index b392ffeb13de6..050bbac2293e9 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -71,7 +71,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
if (auto extractOp =
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
if (auto maskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp,
SmallVector<int64_t>(extractOp.getStaticPosition())};
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..8d6e263934fb4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1309,7 +1309,7 @@ LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ExtractOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
+ auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
@@ -1379,7 +1379,7 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
/// Fold the result of chains of ExtractOp in place by simply concatenating the
/// positions.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
- if (!extractOp.getVector().getDefiningOp<ExtractOp>())
+ if (!extractOp.getSource().getDefiningOp<ExtractOp>())
return failure();
// TODO: Canonicalization for dynamic position not implemented yet.
@@ -1390,7 +1390,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
ExtractOp currentOp = extractOp;
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
- while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
+ while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
// TODO: Canonicalization for dynamic position not implemented yet.
if (currentOp.hasDynamicPosition())
@@ -1398,7 +1398,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
- extractOp.setOperand(0, currentOp.getVector());
+ extractOp.setOperand(0, currentOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
@@ -1584,7 +1584,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
return Value();
// If we can't fold (either internal transposition, or nothing to fold), bail.
- bool nothingToFold = (source == extractOp.getVector());
+ bool nothingToFold = (source == extractOp.getSource());
if (nothingToFold || !canFold())
return Value();
@@ -1592,7 +1592,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(
ArrayRef(extractPosition).take_front(extractedRank));
- extractOp.getVectorMutable().assign(source);
+ extractOp.getSourceMutable().assign(source);
return extractOp.getResult();
}
@@ -1602,7 +1602,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
if (extractOp.hasDynamicPosition())
return Value();
- Value valueToExtractFrom = extractOp.getVector();
+ Value valueToExtractFrom = extractOp.getSource();
updateStateForNextIteration(valueToExtractFrom);
while (nextInsertOp || nextTransposeOp) {
// Case 1. If we hit a transpose, just compose the map and iterate.
@@ -1693,7 +1693,7 @@ static bool isBroadcastLike(Operation *op) {
/// `extract` shape.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.getVector().getDefiningOp();
+ Operation *defOp = extractOp.getSource().getDefiningOp();
if (!defOp || !isBroadcastLike(defOp))
return Value();
@@ -1762,7 +1762,7 @@ static Value foldExtractFromShuffle(ExtractOp extractOp) {
if (extractOp.hasDynamicPosition())
return Value();
- auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+ auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
if (!shuffleOp)
return Value();
@@ -1793,7 +1793,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
if (extractOp.hasDynamicPosition())
return Value();
- auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+ auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
@@ -1859,7 +1859,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
return Value();
auto extractStridedSliceOp =
- extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
+ extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
return Value();
@@ -1896,7 +1896,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
assert(extractedPos.size() >= sliceOffsets.size());
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
- extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
+ extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
@@ -1914,7 +1914,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
llvm::isa<VectorType>(extractOp.getType())
? llvm::cast<VectorType>(extractOp.getType()).getRank()
: 0;
- auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
+ auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
if (!insertOp)
return Value();
@@ -1966,7 +1966,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
insertRankDiff))
return Value();
}
- extractOp.getVectorMutable().assign(insertOp.getValueToStore());
+ extractOp.getSourceMutable().assign(insertOp.getValueToStore());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(offsetDiffs);
@@ -1991,7 +1991,7 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return {};
// Look for extract(from_elements).
- auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+ auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
if (!fromElementsOp)
return {};
@@ -2142,20 +2142,20 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
// mismatch).
- if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
- return getVector();
- if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
+ return getSource();
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
return res;
// Fold `arith.constant` indices into the `vector.extract` operation.
// Do not stop here as this fold may enable subsequent folds that require
// constant indices.
- SmallVector<Value> operands = {getVector()};
+ SmallVector<Value> operands = {getSource()};
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
- if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+ if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
@@ -2187,7 +2187,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *defOp = extractOp.getVector().getDefiningOp();
+ Operation *defOp = extractOp.getSource().getDefiningOp();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!defOp || !isBroadcastLike(defOp) || !outType)
return failure();
@@ -2210,7 +2210,7 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return failure();
@@ -2271,7 +2271,7 @@ class ExtractOpFromCr...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThis patch updates the following ops to use
This change ensures naming consistency with the "builders" for these Ops Specifically, it ensures that we use Patch is 34.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158258.diff 18 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 65ba7e0ad549f..f52075886326b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -675,7 +675,7 @@ def Vector_ExtractOp :
}];
let arguments = (ins
- AnyVectorOfAnyRank:$vector,
+ AnyVectorOfAnyRank:$source,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
@@ -692,7 +692,7 @@ def Vector_ExtractOp :
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getVector().getType());
+ return ::llvm::cast<VectorType>(getSource().getType());
}
/// Return a vector with all the static and dynamic position indices.
@@ -709,12 +709,17 @@ def Vector_ExtractOp :
bool hasDynamicPosition() {
return !getDynamicPosition().empty();
}
+
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
let assemblyFormat = [{
- $vector ``
+ $source ``
custom<DynamicIndexList>($dynamic_position, $static_position)
- attr-dict `:` type($result) `from` type($vector)
+ attr-dict `:` type($result) `from` type($source)
}];
let hasCanonicalizer = 1;
@@ -972,6 +977,10 @@ def Vector_ScalableInsertOp :
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
}
@@ -1023,6 +1032,10 @@ def Vector_ScalableExtractOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
}
@@ -1174,7 +1187,7 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
+ Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets,
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
Results<(outs AnyVectorOfNonZeroRank)> {
let summary = "extract_strided_slice operation";
@@ -1209,7 +1222,7 @@ def Vector_ExtractStridedSliceOp :
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getVector().getType());
+ return ::llvm::cast<VectorType>(getSource().getType());
}
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
@@ -1217,11 +1230,15 @@ def Vector_ExtractStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+ let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)";
}
// TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 9efa34a9a3acc..4e1da39c29260 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering
auto loc = extractOp.getLoc();
auto position = extractOp.getMixedPosition();
- Value sourceVector = extractOp.getVector();
+ Value sourceVector = extractOp.getSource();
// Extract entire vector. Should be handled by folder, but just to be safe.
if (position.empty()) {
@@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1d1904f717335..e1a22ec6d799b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
// Find the vector.transer_read whose result vector is being sliced.
- auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
+ auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return rewriter.notifyMatchFailure(op, "no transfer read");
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9fbac4925dc1d..e7266740894b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1131,7 +1131,7 @@ class VectorExtractOpConversion
positionVec.push_back(rewriter.getZeroAttr(idxType));
}
- Value extracted = adaptor.getVector();
+ Value extracted = adaptor.getSource();
if (extractsAggregate) {
ArrayRef<OpFoldResult> position(positionVec);
if (extractsScalar) {
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 508f4e25326eb..c45c45e4712f3 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion
/// Return the vector from which newly generated ExtracOps will extract.
Value getDataVector(TransferWriteOp xferOp) const {
if (auto extractOp = getExtractOp(xferOp))
- return extractOp.getVector();
+ return extractOp.getSource();
return xferOp.getVector();
}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c861935b4bc18..1c311d0312aaa 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -189,8 +189,8 @@ struct VectorExtractOpConvert final
if (!dstType)
return failure();
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
+ if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+ rewriter.replaceOp(extractOp, adaptor.getSource());
return success();
}
@@ -201,7 +201,7 @@ struct VectorExtractOpConvert final
extractOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
+ extractOp, dstType, adaptor.getSource(),
rewriter.getI32ArrayAttr(id.value()));
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
@@ -209,7 +209,7 @@ struct VectorExtractOpConvert final
vector::ExtractOp::kPoisonIndex,
extractOp.getSourceVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+ extractOp, dstType, adaptor.getSource(), sanitizedIndex);
}
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 9bf026563c255..9196d2ef79592 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend
return rewriter.notifyMatchFailure(
extractOp, "extracted type is not a 1-D scalable vector type");
- auto *extendOp = extractOp.getVector().getDefiningOp();
+ auto *extendOp = extractOp.getSource().getDefiningOp();
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
extendOp))
return rewriter.notifyMatchFailure(extractOp,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 1c0eced43dc00..576c92b375ff3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
PatternRewriter &rewriter) const override {
auto loc = extractOp.getLoc();
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
extractOp, "extract not from vector.create_mask op");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 36434cf2d2ae2..e1dc40d6d37d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();
// Check that the vector to extract from is a BlockArgument.
- auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+ auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
if (!blockArg)
return WalkResult::advance();
@@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();
rewriter.modifyOpInPlace(broadcast, [&] {
- extractOp.getVectorMutable().assign(initArg->get());
+ extractOp.getSourceMutable().assign(initArg->get());
});
loop.moveOutOfLoop(extractOp);
rewriter.moveOpAfter(broadcast, loop);
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index b392ffeb13de6..050bbac2293e9 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -71,7 +71,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
if (auto extractOp =
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
if (auto maskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp,
SmallVector<int64_t>(extractOp.getStaticPosition())};
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..8d6e263934fb4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1309,7 +1309,7 @@ LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ExtractOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
+ auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
@@ -1379,7 +1379,7 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
/// Fold the result of chains of ExtractOp in place by simply concatenating the
/// positions.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
- if (!extractOp.getVector().getDefiningOp<ExtractOp>())
+ if (!extractOp.getSource().getDefiningOp<ExtractOp>())
return failure();
// TODO: Canonicalization for dynamic position not implemented yet.
@@ -1390,7 +1390,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
ExtractOp currentOp = extractOp;
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
- while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
+ while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
// TODO: Canonicalization for dynamic position not implemented yet.
if (currentOp.hasDynamicPosition())
@@ -1398,7 +1398,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
- extractOp.setOperand(0, currentOp.getVector());
+ extractOp.setOperand(0, currentOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
@@ -1584,7 +1584,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
return Value();
// If we can't fold (either internal transposition, or nothing to fold), bail.
- bool nothingToFold = (source == extractOp.getVector());
+ bool nothingToFold = (source == extractOp.getSource());
if (nothingToFold || !canFold())
return Value();
@@ -1592,7 +1592,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(
ArrayRef(extractPosition).take_front(extractedRank));
- extractOp.getVectorMutable().assign(source);
+ extractOp.getSourceMutable().assign(source);
return extractOp.getResult();
}
@@ -1602,7 +1602,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
if (extractOp.hasDynamicPosition())
return Value();
- Value valueToExtractFrom = extractOp.getVector();
+ Value valueToExtractFrom = extractOp.getSource();
updateStateForNextIteration(valueToExtractFrom);
while (nextInsertOp || nextTransposeOp) {
// Case 1. If we hit a transpose, just compose the map and iterate.
@@ -1693,7 +1693,7 @@ static bool isBroadcastLike(Operation *op) {
/// `extract` shape.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.getVector().getDefiningOp();
+ Operation *defOp = extractOp.getSource().getDefiningOp();
if (!defOp || !isBroadcastLike(defOp))
return Value();
@@ -1762,7 +1762,7 @@ static Value foldExtractFromShuffle(ExtractOp extractOp) {
if (extractOp.hasDynamicPosition())
return Value();
- auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+ auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
if (!shuffleOp)
return Value();
@@ -1793,7 +1793,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
if (extractOp.hasDynamicPosition())
return Value();
- auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+ auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
@@ -1859,7 +1859,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
return Value();
auto extractStridedSliceOp =
- extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
+ extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
return Value();
@@ -1896,7 +1896,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
assert(extractedPos.size() >= sliceOffsets.size());
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
- extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
+ extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
@@ -1914,7 +1914,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
llvm::isa<VectorType>(extractOp.getType())
? llvm::cast<VectorType>(extractOp.getType()).getRank()
: 0;
- auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
+ auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
if (!insertOp)
return Value();
@@ -1966,7 +1966,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
insertRankDiff))
return Value();
}
- extractOp.getVectorMutable().assign(insertOp.getValueToStore());
+ extractOp.getSourceMutable().assign(insertOp.getValueToStore());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(offsetDiffs);
@@ -1991,7 +1991,7 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return {};
// Look for extract(from_elements).
- auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+ auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
if (!fromElementsOp)
return {};
@@ -2142,20 +2142,20 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
// mismatch).
- if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
- return getVector();
- if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
+ return getSource();
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
return res;
// Fold `arith.constant` indices into the `vector.extract` operation.
// Do not stop here as this fold may enable subsequent folds that require
// constant indices.
- SmallVector<Value> operands = {getVector()};
+ SmallVector<Value> operands = {getSource()};
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
- if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+ if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
@@ -2187,7 +2187,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *defOp = extractOp.getVector().getDefiningOp();
+ Operation *defOp = extractOp.getSource().getDefiningOp();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!defOp || !isBroadcastLike(defOp) || !outType)
return failure();
@@ -2210,7 +2210,7 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return failure();
@@ -2271,7 +2271,7 @@ class ExtractOpFromCr...
[truncated]
|
This patch updates the following ops to use `source` (instead of `vector`) as the name for their source argument: * `vector.extract` * `vector.scalable.extract` * `vector.extract_strided_slice` This change ensures naming consistency with the "builders" for these Ops that already use the name `source` rather than `vector`. It also addresses part of: * llvm#131602 Specifically, it ensures that we use `source` and `dest` for read and write operations, respectively (as opposed to `vector` and `dest`).
f713b49
to
36fe297
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Should be NFC with the deprecated wrapper.
This patch updates the following ops to use
source
(instead ofvector
)as the name for their source argument:
vector.extract
vector.scalable.extract
vector.extract_strided_slice
This change ensures naming consistency with the "builders" for these Ops
that already use the name
source
rather thanvector
. It alsoaddresses part of:
Specifically, it ensures that we use
source
anddest
for read andwrite operations, respectively (as opposed to
vector
anddest
).