diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index a57aadcdcc5b0..491e7510113f7 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -219,18 +219,16 @@ bool isLinearizableVector(VectorType type); /// Creates a TransferReadOp from `source`. /// -/// The shape of the vector to read is specified via `inputVectorSizes`. If the -/// shape of the output vector differs from the shape of the value being read, -/// masking is used to avoid out-of-bounds accesses. Set +/// If the shape of vector to read differs from the shape of the value being +/// read, masking is used to avoid out-of-bounds accesses. Set /// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute /// instead of explicit masks. /// /// Note: all read offsets are set to 0. Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, - ArrayRef inputVectorSizes, + const VectorType &vecToReadTy, std::optional padValue = std::nullopt, - bool useInBoundsInsteadOfMasking = false, - ArrayRef inputScalableVecDims = {}); + bool useInBoundsInsteadOfMasking = false); /// Returns success if `inputVectorSizes` is a valid masking configuraion for /// given `shape`, i.e., it meets: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 591bae5d7a157..b3168d5baaf47 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1887,9 +1887,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, // Create masked TransferReadOp. auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue, - useInBoundsInsteadOfMasking, - /*inputScalableVecSizes=*/{}); + rewriter, loc, packOp.getSource(), readVecType, padValue, + useInBoundsInsteadOfMasking); // Create ShapeCastOp. auto expandedVecType = VectorType::get(writeVecSizesUnpermuted, @@ -1976,9 +1975,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, } // -- Generate the read operation -- + VectorType readVecType = + VectorType::get(readVectorSizes, unpackTensorType.getElementType(), + readScalableVectorFlags); Value readResult = vector::createReadOrMaskedRead( - rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt, - useInBoundsInsteadOfMasking, readScalableVectorFlags); + rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt, + useInBoundsInsteadOfMasking); // -- Generate the transpose operation -- PackingMetadata packMetadata; @@ -2024,9 +2026,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, .reifyResultShapes(rewriter, reifiedReturnShapes); (void)status; // prevent unused variable warning on non-assert builds assert(succeeded(status) && "failed to reify result shapes"); + auto readType = VectorType::get(inputVectorSizes, padValue.getType()); auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{}); + rewriter, loc, padOp.getSource(), readType, padValue, + /*useInBoundsInsteadOfMasking=*/false); // Create Xfer write Op Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], @@ -2221,9 +2224,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, opOperand.get(), readType.getShape(), + rewriter, loc, opOperand.get(), readType, /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), - /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); + /*useInBoundsInsteadOfMasking=*/false); vecOperands.push_back(read); } @@ -3164,9 +3167,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, SmallVector readIndices( vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, source, vecType.getShape(), padValue, - /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(), - /*inputScalableVecSizes=*/{}); + rewriter, loc, source, vecType, padValue, + /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); // Create write auto writeIndices = diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 025ee9a04a1de..23178c9abb78e 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -318,50 +318,51 @@ bool vector::isLinearizableVector(VectorType type) { Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, - ArrayRef inputVectorSizes, + const VectorType &vecToReadTy, std::optional padValue, - bool useInBoundsInsteadOfMasking, - ArrayRef inputScalableVecDims) { - assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && + bool useInBoundsInsteadOfMasking) { + assert(!llvm::is_contained(vecToReadTy.getScalableDims(), + ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast(source.getType()); auto sourceShape = sourceShapedType.getShape(); - assert(sourceShape.size() == inputVectorSizes.size() && + + int64_t vecToReadRank = vecToReadTy.getRank(); + auto vecToReadShape = vecToReadTy.getShape(); + + assert(sourceShape.size() == static_cast(vecToReadRank) && "expected same ranks."); - auto vectorType = - VectorType::get(inputVectorSizes, sourceShapedType.getElementType(), - inputScalableVecDims); assert((!padValue.has_value() || padValue.value().getType() == sourceShapedType.getElementType()) && "expected same pad element type to match source element type"); - int64_t readRank = inputVectorSizes.size(); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); - SmallVector inBoundsVal(readRank, true); + SmallVector inBoundsVal(vecToReadRank, true); if (useInBoundsInsteadOfMasking) { // Update the inBounds attribute. // FIXME: This computation is too weak - it ignores the read indices. - for (unsigned i = 0; i < readRank; i++) - inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) && + for (unsigned i = 0; i < vecToReadRank; i++) + inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) && ShapedType::isStatic(sourceShape[i]); } auto transferReadOp = vector::TransferReadOp::create( builder, loc, - /*vectorType=*/vectorType, + /*vectorType=*/vecToReadTy, /*source=*/source, - /*indices=*/SmallVector(readRank, zero), + /*indices=*/SmallVector(vecToReadRank, zero), /*padding=*/padValue, /*inBounds=*/inBoundsVal); - if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking) + if (llvm::equal(vecToReadTy.getShape(), sourceShape) || + useInBoundsInsteadOfMasking) return transferReadOp; SmallVector mixedSourceDims = isa(source.getType()) ? memref::getMixedSizes(builder, loc, source) : tensor::getMixedSizes(builder, loc, source); - auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(), - inputScalableVecDims); + auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type()); Value mask = vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask)