Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def Vector_ExtractOp :
}];

let arguments = (ins
AnyVectorOfAnyRank:$vector,
AnyVectorOfAnyRank:$source,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
Expand All @@ -692,7 +692,7 @@ def Vector_ExtractOp :

let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
return ::llvm::cast<VectorType>(getSource().getType());
}

/// Return a vector with all the static and dynamic position indices.
Expand All @@ -709,12 +709,17 @@ def Vector_ExtractOp :
bool hasDynamicPosition() {
return !getDynamicPosition().empty();
}

/// Wrapper for getSource, which replaced getVector.
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
return getSource();
}
}];

let assemblyFormat = [{
$vector ``
$source ``
custom<DynamicIndexList>($dynamic_position, $static_position)
attr-dict `:` type($result) `from` type($vector)
attr-dict `:` type($result) `from` type($source)
}];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -1023,6 +1028,10 @@ def Vector_ScalableExtractOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
/// Wrapper for getSource, which replaced getVector.
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
return getSource();
}
}];
}

Expand Down Expand Up @@ -1174,7 +1183,7 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets,
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
Results<(outs AnyVectorOfNonZeroRank)> {
let summary = "extract_strided_slice operation";
Expand Down Expand Up @@ -1209,19 +1218,23 @@ def Vector_ExtractStridedSliceOp :
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
return ::llvm::cast<VectorType>(getSource().getType());
}
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
return llvm::any_of(getStrides(), [](Attribute attr) {
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
/// Wrapper for getSource, which replaced getVector.
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
return getSource();
}
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)";
}

// TODO: Tighten semantics so that masks and inbounds can't be used
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering
auto loc = extractOp.getLoc();
auto position = extractOp.getMixedPosition();

Value sourceVector = extractOp.getVector();
Value sourceVector = extractOp.getSource();

// Extract entire vector. Should be handled by folder, but just to be safe.
if (position.empty()) {
Expand Down Expand Up @@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");

auto createMaskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");

// Find the vector.transer_read whose result vector is being sliced.
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return rewriter.notifyMatchFailure(op, "no transfer read");

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ class VectorExtractOpConversion
positionVec.push_back(rewriter.getZeroAttr(idxType));
}

Value extracted = adaptor.getVector();
Value extracted = adaptor.getSource();
if (extractsAggregate) {
ArrayRef<OpFoldResult> position(positionVec);
if (extractsScalar) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion
/// Return the vector from which newly generated ExtracOps will extract.
Value getDataVector(TransferWriteOp xferOp) const {
if (auto extractOp = getExtractOp(xferOp))
return extractOp.getVector();
return extractOp.getSource();
return xferOp.getVector();
}

Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ struct VectorExtractOpConvert final
if (!dstType)
return failure();

if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
rewriter.replaceOp(extractOp, adaptor.getVector());
if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
rewriter.replaceOp(extractOp, adaptor.getSource());
return success();
}

Expand All @@ -201,15 +201,15 @@ struct VectorExtractOpConvert final
extractOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, dstType, adaptor.getVector(),
extractOp, dstType, adaptor.getSource(),
rewriter.getI32ArrayAttr(id.value()));
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
vector::ExtractOp::kPoisonIndex,
extractOp.getSourceVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
extractOp, dstType, adaptor.getVector(), sanitizedIndex);
extractOp, dstType, adaptor.getSource(), sanitizedIndex);
}
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend
return rewriter.notifyMatchFailure(
extractOp, "extracted type is not a 1-D scalable vector type");

auto *extendOp = extractOp.getVector().getDefiningOp();
auto *extendOp = extractOp.getSource().getDefiningOp();
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
extendOp))
return rewriter.notifyMatchFailure(extractOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
PatternRewriter &rewriter) const override {
auto loc = extractOp.getLoc();
auto createMaskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
extractOp, "extract not from vector.create_mask op");
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();

// Check that the vector to extract from is a BlockArgument.
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
if (!blockArg)
return WalkResult::advance();

Expand Down Expand Up @@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();

rewriter.modifyOpInPlace(broadcast, [&] {
extractOp.getVectorMutable().assign(initArg->get());
extractOp.getSourceMutable().assign(initArg->get());
});
loop.moveOutOfLoop(extractOp);
rewriter.moveOpAfter(broadcast, loop);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
if (auto extractOp =
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
if (auto maskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp,
SmallVector<int64_t>(extractOp.getStaticPosition())};

Expand Down
Loading