From 36fe297dd3d5537bedb9d737e356dac2d98fa726 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 12 Sep 2025 09:44:53 +0000 Subject: [PATCH] [mlir][vector] Use `source` as the source argument name This patch updates the following ops to use `source` (instead of `vector`) as the name for their source argument: * `vector.extract` * `vector.scalable.extract` * `vector.extract_strided_slice` This change ensures naming consistency with the "builders" for these Ops that already use the name `source` rather than `vector`. It also addresses part of: * https://github.com/llvm/llvm-project/issues/131602 Specifically, it ensures that we use `source` and `dest` for read and write operations, respectively (as opposed to `vector` and `dest`). --- .../mlir/Dialect/Vector/IR/VectorOps.td | 27 ++++++-- .../VectorToArmSME/VectorToArmSME.cpp | 4 +- .../Conversion/VectorToGPU/VectorToGPU.cpp | 2 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 2 +- .../Conversion/VectorToSCF/VectorToSCF.cpp | 2 +- .../VectorToSPIRV/VectorToSPIRV.cpp | 8 +-- .../ArmSME/Transforms/OuterProductFusion.cpp | 2 +- .../ArmSME/Transforms/VectorLegalization.cpp | 2 +- .../Dialect/Linalg/Transforms/Hoisting.cpp | 4 +- .../NVGPU/Transforms/CreateAsyncGroups.cpp | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 68 +++++++++---------- .../Vector/Transforms/VectorDistribute.cpp | 8 +-- .../Transforms/VectorDropLeadUnitDim.cpp | 2 +- .../Transforms/VectorEmulateNarrowType.cpp | 2 +- ...sertExtractStridedSliceRewritePatterns.cpp | 8 +-- .../Vector/Transforms/VectorLinearize.cpp | 8 +-- .../Transforms/VectorTransferOpTransforms.cpp | 2 +- .../Vector/Transforms/VectorTransforms.cpp | 8 +-- 18 files changed, 87 insertions(+), 74 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 65ba7e0ad549f..f6b151f6ed644 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -675,7 +675,7 @@ def Vector_ExtractOp : }]; let arguments = (ins - AnyVectorOfAnyRank:$vector, + AnyVectorOfAnyRank:$source, Variadic:$dynamic_position, DenseI64ArrayAttr:$static_position ); @@ -692,7 +692,7 @@ def Vector_ExtractOp : let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getSourceVectorType() { - return ::llvm::cast(getVector().getType()); + return ::llvm::cast(getSource().getType()); } /// Return a vector with all the static and dynamic position indices. @@ -709,12 +709,17 @@ def Vector_ExtractOp : bool hasDynamicPosition() { return !getDynamicPosition().empty(); } + + /// Wrapper for getSource, which replaced getVector. + [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() { + return getSource(); + } }]; let assemblyFormat = [{ - $vector `` + $source `` custom($dynamic_position, $static_position) - attr-dict `:` type($result) `from` type($vector) + attr-dict `:` type($result) `from` type($source) }]; let hasCanonicalizer = 1; @@ -1023,6 +1028,10 @@ def Vector_ScalableExtractOp : VectorType getResultVectorType() { return ::llvm::cast(getResult().getType()); } + /// Wrapper for getSource, which replaced getVector. + [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() { + return getSource(); + } }]; } @@ -1174,7 +1183,7 @@ def Vector_ExtractStridedSliceOp : Vector_Op<"extract_strided_slice", [Pure, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, - Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets, + Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets, I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>, Results<(outs AnyVectorOfNonZeroRank)> { let summary = "extract_strided_slice operation"; @@ -1209,7 +1218,7 @@ def Vector_ExtractStridedSliceOp : ]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { - return ::llvm::cast(getVector().getType()); + return ::llvm::cast(getSource().getType()); } void getOffsets(SmallVectorImpl &results); bool hasNonUnitStrides() { @@ -1217,11 +1226,15 @@ def Vector_ExtractStridedSliceOp : return ::llvm::cast(attr).getInt() != 1; }); } + /// Wrapper for getSource, which replaced getVector. + [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() { + return getSource(); + } }]; let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; - let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)"; } // TODO: Tighten semantics so that masks and inbounds can't be used diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 9efa34a9a3acc..4e1da39c29260 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering auto loc = extractOp.getLoc(); auto position = extractOp.getMixedPosition(); - Value sourceVector = extractOp.getVector(); + Value sourceVector = extractOp.getSource(); // Extract entire vector. Should be handled by folder, but just to be safe. if (position.empty()) { @@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering return rewriter.notifyMatchFailure(extractOp, "result not VectorType"); auto createMaskOp = - extractOp.getVector().getDefiningOp(); + extractOp.getSource().getDefiningOp(); if (!createMaskOp) return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp"); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 1d1904f717335..e1a22ec6d799b 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter, return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); // Find the vector.transer_read whose result vector is being sliced. - auto transferReadOp = op.getVector().getDefiningOp(); + auto transferReadOp = op.getSource().getDefiningOp(); if (!transferReadOp) return rewriter.notifyMatchFailure(op, "no transfer read"); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9fbac4925dc1d..e7266740894b1 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1131,7 +1131,7 @@ class VectorExtractOpConversion positionVec.push_back(rewriter.getZeroAttr(idxType)); } - Value extracted = adaptor.getVector(); + Value extracted = adaptor.getSource(); if (extractsAggregate) { ArrayRef position(positionVec); if (extractsScalar) { diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 508f4e25326eb..c45c45e4712f3 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion /// Return the vector from which newly generated ExtracOps will extract. Value getDataVector(TransferWriteOp xferOp) const { if (auto extractOp = getExtractOp(xferOp)) - return extractOp.getVector(); + return extractOp.getSource(); return xferOp.getVector(); } diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index c861935b4bc18..1c311d0312aaa 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -189,8 +189,8 @@ struct VectorExtractOpConvert final if (!dstType) return failure(); - if (isa(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); + if (isa(adaptor.getSource().getType())) { + rewriter.replaceOp(extractOp, adaptor.getSource()); return success(); } @@ -201,7 +201,7 @@ struct VectorExtractOpConvert final extractOp, "Static use of poison index handled elsewhere (folded to poison)"); rewriter.replaceOpWithNewOp( - extractOp, dstType, adaptor.getVector(), + extractOp, dstType, adaptor.getSource(), rewriter.getI32ArrayAttr(id.value())); } else { Value sanitizedIndex = sanitizeDynamicIndex( @@ -209,7 +209,7 @@ struct VectorExtractOpConvert final vector::ExtractOp::kPoisonIndex, extractOp.getSourceVectorType().getNumElements()); rewriter.replaceOpWithNewOp( - extractOp, dstType, adaptor.getVector(), sanitizedIndex); + extractOp, dstType, adaptor.getSource(), sanitizedIndex); } return success(); } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index 9bf026563c255..9196d2ef79592 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend return rewriter.notifyMatchFailure( extractOp, "extracted type is not a 1-D scalable vector type"); - auto *extendOp = extractOp.getVector().getDefiningOp(); + auto *extendOp = extractOp.getSource().getDefiningOp(); if (!isa_and_present( extendOp)) return rewriter.notifyMatchFailure(extractOp, diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 1c0eced43dc00..576c92b375ff3 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks PatternRewriter &rewriter) const override { auto loc = extractOp.getLoc(); auto createMaskOp = - extractOp.getVector().getDefiningOp(); + extractOp.getSource().getDefiningOp(); if (!createMaskOp) return rewriter.notifyMatchFailure( extractOp, "extract not from vector.create_mask op"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 36434cf2d2ae2..e1dc40d6d37d9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, return WalkResult::advance(); // Check that the vector to extract from is a BlockArgument. - auto blockArg = dyn_cast(extractOp.getVector()); + auto blockArg = dyn_cast(extractOp.getSource()); if (!blockArg) return WalkResult::advance(); @@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, return WalkResult::advance(); rewriter.modifyOpInPlace(broadcast, [&] { - extractOp.getVectorMutable().assign(initArg->get()); + extractOp.getSourceMutable().assign(initArg->get()); }); loop.moveOutOfLoop(extractOp); rewriter.moveOpAfter(broadcast, loop); diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp index b392ffeb13de6..050bbac2293e9 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -71,7 +71,7 @@ static FailureOr getMaskOp(Operation *loadOp) { if (auto extractOp = transferRead.getMask().getDefiningOp()) if (auto maskOp = - extractOp.getVector().getDefiningOp()) + extractOp.getSource().getDefiningOp()) return TransferMask{maskOp, SmallVector(extractOp.getStaticPosition())}; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 85e485c28c74e..8d6e263934fb4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1309,7 +1309,7 @@ LogicalResult ExtractOp::inferReturnTypes(MLIRContext *, std::optional, ExtractOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { - auto vectorType = llvm::cast(adaptor.getVector().getType()); + auto vectorType = llvm::cast(adaptor.getSource().getType()); if (static_cast(adaptor.getStaticPosition().size()) == vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); @@ -1379,7 +1379,7 @@ static SmallVector extractVector(ArrayAttr arrayAttr) { /// Fold the result of chains of ExtractOp in place by simply concatenating the /// positions. static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { - if (!extractOp.getVector().getDefiningOp()) + if (!extractOp.getSource().getDefiningOp()) return failure(); // TODO: Canonicalization for dynamic position not implemented yet. @@ -1390,7 +1390,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { ExtractOp currentOp = extractOp; ArrayRef extrPos = currentOp.getStaticPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); - while (ExtractOp nextOp = currentOp.getVector().getDefiningOp()) { + while (ExtractOp nextOp = currentOp.getSource().getDefiningOp()) { currentOp = nextOp; // TODO: Canonicalization for dynamic position not implemented yet. if (currentOp.hasDynamicPosition()) @@ -1398,7 +1398,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { ArrayRef extrPos = currentOp.getStaticPosition(); globalPosition.append(extrPos.rbegin(), extrPos.rend()); } - extractOp.setOperand(0, currentOp.getVector()); + extractOp.setOperand(0, currentOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); @@ -1584,7 +1584,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( return Value(); // If we can't fold (either internal transposition, or nothing to fold), bail. - bool nothingToFold = (source == extractOp.getVector()); + bool nothingToFold = (source == extractOp.getSource()); if (nothingToFold || !canFold()) return Value(); @@ -1592,7 +1592,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( OpBuilder b(extractOp.getContext()); extractOp.setStaticPosition( ArrayRef(extractPosition).take_front(extractedRank)); - extractOp.getVectorMutable().assign(source); + extractOp.getSourceMutable().assign(source); return extractOp.getResult(); } @@ -1602,7 +1602,7 @@ Value ExtractFromInsertTransposeChainState::fold() { if (extractOp.hasDynamicPosition()) return Value(); - Value valueToExtractFrom = extractOp.getVector(); + Value valueToExtractFrom = extractOp.getSource(); updateStateForNextIteration(valueToExtractFrom); while (nextInsertOp || nextTransposeOp) { // Case 1. If we hit a transpose, just compose the map and iterate. @@ -1693,7 +1693,7 @@ static bool isBroadcastLike(Operation *op) { /// `extract` shape. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - Operation *defOp = extractOp.getVector().getDefiningOp(); + Operation *defOp = extractOp.getSource().getDefiningOp(); if (!defOp || !isBroadcastLike(defOp)) return Value(); @@ -1762,7 +1762,7 @@ static Value foldExtractFromShuffle(ExtractOp extractOp) { if (extractOp.hasDynamicPosition()) return Value(); - auto shuffleOp = extractOp.getVector().getDefiningOp(); + auto shuffleOp = extractOp.getSource().getDefiningOp(); if (!shuffleOp) return Value(); @@ -1793,7 +1793,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) { if (extractOp.hasDynamicPosition()) return Value(); - auto shapeCastOp = extractOp.getVector().getDefiningOp(); + auto shapeCastOp = extractOp.getSource().getDefiningOp(); if (!shapeCastOp) return Value(); @@ -1859,7 +1859,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) { return Value(); auto extractStridedSliceOp = - extractOp.getVector().getDefiningOp(); + extractOp.getSource().getDefiningOp(); if (!extractStridedSliceOp) return Value(); @@ -1896,7 +1896,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) { assert(extractedPos.size() >= sliceOffsets.size()); for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) extractedPos[i] = extractedPos[i] + sliceOffsets[i]; - extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); + extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); @@ -1914,7 +1914,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { llvm::isa(extractOp.getType()) ? llvm::cast(extractOp.getType()).getRank() : 0; - auto insertOp = extractOp.getVector().getDefiningOp(); + auto insertOp = extractOp.getSource().getDefiningOp(); if (!insertOp) return Value(); @@ -1966,7 +1966,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { insertRankDiff)) return Value(); } - extractOp.getVectorMutable().assign(insertOp.getValueToStore()); + extractOp.getSourceMutable().assign(insertOp.getValueToStore()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp.setStaticPosition(offsetDiffs); @@ -1991,7 +1991,7 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { return {}; // Look for extract(from_elements). - auto fromElementsOp = extractOp.getVector().getDefiningOp(); + auto fromElementsOp = extractOp.getSource().getDefiningOp(); if (!fromElementsOp) return {}; @@ -2142,20 +2142,20 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. // Note: Do not fold "vector.extract %v[] : f32 from vector" (type // mismatch). - if (getNumIndices() == 0 && getVector().getType() == getResult().getType()) - return getVector(); - if (auto res = foldPoisonSrcExtractOp(adaptor.getVector())) + if (getNumIndices() == 0 && getSource().getType() == getResult().getType()) + return getSource(); + if (auto res = foldPoisonSrcExtractOp(adaptor.getSource())) return res; // Fold `arith.constant` indices into the `vector.extract` operation. // Do not stop here as this fold may enable subsequent folds that require // constant indices. - SmallVector operands = {getVector()}; + SmallVector operands = {getSource()}; auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands); if (auto res = foldPoisonIndexInsertExtractOp( getContext(), adaptor.getStaticPosition(), kPoisonIndex)) return res; - if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector())) + if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource())) return res; if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -2187,7 +2187,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - Operation *defOp = extractOp.getVector().getDefiningOp(); + Operation *defOp = extractOp.getSource().getDefiningOp(); VectorType outType = dyn_cast(extractOp.getType()); if (!defOp || !isBroadcastLike(defOp) || !outType) return failure(); @@ -2210,7 +2210,7 @@ class ExtractOpFromCreateMask final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { auto createMaskOp = - extractOp.getVector().getDefiningOp(); + extractOp.getSource().getDefiningOp(); if (!createMaskOp) return failure(); @@ -2271,7 +2271,7 @@ class ExtractOpFromCreateMask final : public OpRewritePattern { // does not change. LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp, PatternRewriter &rewriter) { - auto castOp = extractOp.getVector().getDefiningOp(); + auto castOp = extractOp.getSource().getDefiningOp(); if (!castOp) return failure(); @@ -2306,7 +2306,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return failure(); // Look for extracts from a from_elements op. - auto fromElementsOp = extractOp.getVector().getDefiningOp(); + auto fromElementsOp = extractOp.getSource().getDefiningOp(); if (!fromElementsOp) return failure(); VectorType inputType = fromElementsOp.getType(); @@ -2558,8 +2558,8 @@ class FromElementsToShapeCast : public OpRewritePattern { // Check condition (i) by checking that all elements have the same source // as the first element. if (insertIndex == 0) { - source = extractOp.getVector(); - } else if (extractOp.getVector() != source) { + source = extractOp.getSource(); + } else if (extractOp.getSource() != source) { return rewriter.notifyMatchFailure(fromElements, "element from different vector"); } @@ -4095,7 +4095,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { ArrayAttr extractOffsets = op.getOffsets(); ArrayAttr extractStrides = op.getStrides(); ArrayAttr extractSizes = op.getSizes(); - auto insertOp = op.getVector().getDefiningOp(); + auto insertOp = op.getSource().getDefiningOp(); while (insertOp) { if (op.getSourceVectorType().getRank() != insertOp.getSourceVectorType().getRank()) @@ -4199,17 +4199,17 @@ foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { if (getSourceVectorType() == getResult().getType()) - return getVector(); + return getSource(); if (succeeded(foldExtractStridedOpFromInsertChain(*this))) return getResult(); // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. if (auto splat = - llvm::dyn_cast_if_present(adaptor.getVector())) + llvm::dyn_cast_if_present(adaptor.getSource())) DenseElementsAttr::get(getType(), splat.getSplatValue()); // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. - return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector()); + return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource()); } void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { @@ -4241,7 +4241,7 @@ class StridedSliceCreateMaskFolder final // Return if 'extractStridedSliceOp' operand is not defined by a // CreateMaskOp. auto createMaskOp = - extractStridedSliceOp.getVector().getDefiningOp(); + extractStridedSliceOp.getSource().getDefiningOp(); if (!createMaskOp) return failure(); // Return if 'extractStridedSliceOp' has non-unit strides. @@ -4298,7 +4298,7 @@ class StridedSliceConstantMaskFolder final PatternRewriter &rewriter) const override { // Return if 'extractStridedSliceOp' operand is not defined by a // ConstantMaskOp. - auto *defOp = extractStridedSliceOp.getVector().getDefiningOp(); + auto *defOp = extractStridedSliceOp.getSource().getDefiningOp(); auto constantMaskOp = dyn_cast_or_null(defOp); if (!constantMaskOp) return failure(); @@ -4351,7 +4351,7 @@ class StridedSliceBroadcast final LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - auto broadcast = op.getVector().getDefiningOp(); + auto broadcast = op.getSource().getDefiningOp(); if (!broadcast) return failure(); auto srcVecType = @@ -4403,7 +4403,7 @@ class StridedSliceSplat final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - Value splat = getScalarSplatSource(op.getVector()); + Value splat = getScalarSplatSource(op.getSource()); if (!splat) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), splat); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 995a2595e5fbb..e95338f7d18be 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1341,7 +1341,7 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern { VectorType::get(newDistributedShape, distributedType.getElementType()); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, + rewriter, warpOp, {extractOp.getSource()}, {newDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); SmallVector distributedSizes = llvm::map_to_vector( @@ -1395,7 +1395,7 @@ struct WarpOpExtract : public WarpDistributionPattern { // the 1d case). SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {extractOp.getVector()}, + rewriter, warpOp, {extractOp.getSource()}, {extractOp.getSourceVectorType()}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); @@ -1424,7 +1424,7 @@ struct WarpOpExtract : public WarpDistributionPattern { VectorType::get(newDistributedShape, distributedType.getElementType()); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, + rewriter, warpOp, {extractOp.getSource()}, {newDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); @@ -1478,7 +1478,7 @@ struct WarpOpExtractScalar : public WarpDistributionPattern { distributedVecType = extractSrcType; } // Yield source vector and position (if present) from warp op. - SmallVector additionalResults{extractOp.getVector()}; + SmallVector additionalResults{extractOp.getSource()}; SmallVector additionalResultTypes{distributedVecType}; additionalResults.append( SmallVector(extractOp.getDynamicPosition())); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 9889d7f221fe6..cab12894487e2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -78,7 +78,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim Location loc = extractOp.getLoc(); Value newSrcVector = vector::ExtractOp::create( - rewriter, loc, extractOp.getVector(), splatZero(dropCount)); + rewriter, loc, extractOp.getSource(), splatZero(dropCount)); // The offsets/sizes/strides attribute can have a less number of elements // than the input vector's rank: it is meant for the leading dimensions. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index f78e579d6c099..852eff2d2b909 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -94,7 +94,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, !isa( maskOp)) { if (auto extractOp = dyn_cast(maskOp)) { - maskOp = extractOp.getVector().getDefiningOp(); + maskOp = extractOp.getSource().getDefiningOp(); extractOps.push_back(extractOp); } } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index cbb9d4bbf0b1f..f6d6555f4c6e2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -213,8 +213,8 @@ class Convert1DExtractStridedSliceIntoShuffle for (int64_t off = offset, e = offset + size * stride; off < e; off += stride) offsets.push_back(off); - rewriter.replaceOpWithNewOp(op, dstType, op.getVector(), - op.getVector(), offsets); + rewriter.replaceOpWithNewOp(op, dstType, op.getSource(), + op.getSource(), offsets); return success(); } }; @@ -250,7 +250,7 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final SmallVector elements; elements.reserve(size); for (int64_t i = offset, e = offset + size * stride; i < e; i += stride) - elements.push_back(ExtractOp::create(rewriter, loc, op.getVector(), i)); + elements.push_back(ExtractOp::create(rewriter, loc, op.getSource(), i)); Value result = arith::ConstantOp::create( rewriter, loc, rewriter.getZeroAttr(op.getType())); @@ -306,7 +306,7 @@ class DecomposeNDExtractStridedSlice Value res = BroadcastOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { - Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); + Value one = ExtractOp::create(rewriter, loc, op.getSource(), off); Value extracted = ExtractStridedSliceOp::create( rewriter, loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), getI64SubArray(op.getSizes(), /* dropFront=*/1), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 12acf4b3f07f5..82bac8c499028 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -252,7 +252,7 @@ struct LinearizeVectorExtractStridedSlice final SmallVector indices = getStridedSliceInsertionIndices( outputShape, inputShape, offsets.value()); - Value srcVector = adaptor.getVector(); + Value srcVector = adaptor.getSource(); rewriter.replaceOpWithNewOp( extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices); return success(); @@ -438,8 +438,8 @@ struct LinearizeVectorExtract final return rewriter.notifyMatchFailure(extractOp, "dynamic position is not supported."); - llvm::ArrayRef shape = extractOp.getVector().getType().getShape(); - int64_t size = extractOp.getVector().getType().getNumElements(); + llvm::ArrayRef shape = extractOp.getSource().getType().getShape(); + int64_t size = extractOp.getSource().getType().getNumElements(); // Compute linearized offset. int64_t linearizedOffset = 0; @@ -449,7 +449,7 @@ struct LinearizeVectorExtract final linearizedOffset += offsets[i] * size; } - Value srcVector = adaptor.getVector(); + Value srcVector = adaptor.getSource(); if (!isa(extractOp.getType())) { // Scalar case: generate a 1-D extract. Value result = rewriter.createOrFold( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 1874a3477a16c..c364a8b54167c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -1007,7 +1007,7 @@ class RewriteScalarExtractOfTransferRead LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { // Match phase. - auto xferOp = extractOp.getVector().getDefiningOp(); + auto xferOp = extractOp.getSource().getDefiningOp(); if (!xferOp) return failure(); // Check that we are extracting a scalar and not a sub-vector. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index dbb5eb398fae6..866f789ec6a39 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -576,7 +576,7 @@ struct BubbleDownVectorBitCastForExtract if (extractOp.getSourceVectorType().getRank() != 1) return failure(); - auto castOp = extractOp.getVector().getDefiningOp(); + auto castOp = extractOp.getSource().getDefiningOp(); if (!castOp) return failure(); @@ -647,7 +647,7 @@ struct BubbleDownBitCastForStridedSliceExtract LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { - auto castOp = extractOp.getVector().getDefiningOp(); + auto castOp = extractOp.getSource().getDefiningOp(); if (!castOp) return failure(); @@ -1135,7 +1135,7 @@ class ExtractOpFromElementwise final LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { - Operation *eltwise = op.getVector().getDefiningOp(); + Operation *eltwise = op.getSource().getDefiningOp(); // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be, // as it doesn't support scalars. @@ -1210,7 +1210,7 @@ class ExtractOpFromLoad final : public OpRewritePattern { LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { - auto loadOp = op.getVector().getDefiningOp(); + auto loadOp = op.getSource().getDefiningOp(); if (!loadOp) return rewriter.notifyMatchFailure(op, "expected a load op");