-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][vector] Simplify createReadOrMaskedRead #163736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/banach-space/linalg/vectorize_pack
Are you sure you want to change the base?
[mlir][vector] Simplify createReadOrMaskedRead #163736
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/163736.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index a57aadcdcc5b0..2e6fab30e5120 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -219,18 +219,16 @@ bool isLinearizableVector(VectorType type);
/// Creates a TransferReadOp from `source`.
///
-/// The shape of the vector to read is specified via `inputVectorSizes`. If the
-/// shape of the output vector differs from the shape of the value being read,
-/// masking is used to avoid out-of-bounds accesses. Set
+/// If the shape of vector to read differs from the shape of the value being
+/// read, masking is used to avoid out-of-bounds accesses. Set
/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
/// instead of explicit masks.
///
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
- ArrayRef<int64_t> inputVectorSizes,
+ VectorType &vecToReadTy,
std::optional<Value> padValue = std::nullopt,
- bool useInBoundsInsteadOfMasking = false,
- ArrayRef<bool> inputScalableVecDims = {});
+ bool useInBoundsInsteadOfMasking = false);
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c15e6330fd8c3..32246e3a11433 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1886,9 +1886,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
// Create masked TransferReadOp.
auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
- useInBoundsInsteadOfMasking,
- /*inputScalableVecSizes=*/{});
+ rewriter, loc, packOp.getSource(), readVecType, padValue,
+ useInBoundsInsteadOfMasking);
// Create ShapeCastOp.
auto expandedVecType =
@@ -1975,9 +1974,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
}
// -- Generate the read operation --
+ VectorType readVecType =
+ VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
+ readScalableVectorFlags);
Value readResult = vector::createReadOrMaskedRead(
- rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
- useInBoundsInsteadOfMasking, readScalableVectorFlags);
+ rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
+ useInBoundsInsteadOfMasking);
// -- Generate the transpose operation --
PackingMetadata packMetadata;
@@ -2023,9 +2025,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
+ auto readType = VectorType::get(inputVectorSizes, padValue.getType());
auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
+ rewriter, loc, padOp.getSource(), readType, padValue,
+ /*useInBoundsInsteadOfMasking=*/false);
// Create Xfer write Op
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2220,9 +2223,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, opOperand.get(), readType.getShape(),
+ rewriter, loc, opOperand.get(), readType,
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
- /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ /*useInBoundsInsteadOfMasking=*/false);
vecOperands.push_back(read);
}
@@ -3163,9 +3166,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
SmallVector<Value> readIndices(
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, source, vecType.getShape(), padValue,
- /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
- /*inputScalableVecSizes=*/{});
+ rewriter, loc, source, vecType, padValue,
+ /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
// Create write
auto writeIndices =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a04a1de..d73ccd1c41b66 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,51 +317,51 @@ bool vector::isLinearizableVector(VectorType type) {
}
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
- Value source,
- ArrayRef<int64_t> inputVectorSizes,
+ Value source, VectorType &vecToReadTy,
std::optional<Value> padValue,
- bool useInBoundsInsteadOfMasking,
- ArrayRef<bool> inputScalableVecDims) {
- assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
+ bool useInBoundsInsteadOfMasking) {
+ assert(!llvm::is_contained(vecToReadTy.getScalableDims(),
+ ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
- assert(sourceShape.size() == inputVectorSizes.size() &&
+
+ int64_t vecToReadRank = vecToReadTy.getRank();
+ auto vecToReadShape = vecToReadTy.getShape();
+
+ assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) &&
"expected same ranks.");
- auto vectorType =
- VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
- inputScalableVecDims);
assert((!padValue.has_value() ||
padValue.value().getType() == sourceShapedType.getElementType()) &&
"expected same pad element type to match source element type");
- int64_t readRank = inputVectorSizes.size();
+
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
- SmallVector<bool> inBoundsVal(readRank, true);
+ SmallVector<bool> inBoundsVal(vecToReadRank, true);
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the read indices.
- for (unsigned i = 0; i < readRank; i++)
- inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
+ for (unsigned i = 0; i < vecToReadRank; i++)
+ inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
ShapedType::isStatic(sourceShape[i]);
}
auto transferReadOp = vector::TransferReadOp::create(
builder, loc,
- /*vectorType=*/vectorType,
+ /*vectorType=*/vecToReadTy,
/*source=*/source,
- /*indices=*/SmallVector<Value>(readRank, zero),
+ /*indices=*/SmallVector<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
- if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
+ if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
+ useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
isa<MemRefType>(source.getType())
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
- inputScalableVecDims);
+ auto maskType = vecToReadTy.clone(builder.getI1Type());
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
|
a26b730
to
bea41f5
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
Simplify `createReadOrMaskedRead` to only require _one_ argument to specify the vector type to read (passed as `VectorType`) instead of passing vector-sizes and scalable-flags independently (i.e. _two_ arguments).
0ad9ece
to
98a3fac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaner code is always good. 😃
This looks good to me, but it's a change of API, so best to leave a bit longer to make sure users have time to react.
@arun-thmn FYI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks fine 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Add [NFC] before landing?
Simplify
createReadOrMaskedRead
to only require one argument tospecify the vector type to read (passed as
VectorType
) instead ofpassing vector-sizes and scalable-flags independently (i.e. two
arguments).