Skip to content

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 16, 2025

Simplify createReadOrMaskedRead to only require one argument to
specify the vector type to read (passed as VectorType) instead of
passing vector-sizes and scalable-flags independently (i.e. two
arguments).

@banach-space banach-space changed the base branch from main to users/banach-space/linalg/vectorize_pack October 16, 2025 10:26
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][linalg] Update vectorizatio of linalg.pack
  • [mlir][vector] Simplify createReadOrMaskedRead

Full diff: https://github.com/llvm/llvm-project/pull/163736.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+4-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+14-12)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+18-18)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index a57aadcdcc5b0..2e6fab30e5120 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -219,18 +219,16 @@ bool isLinearizableVector(VectorType type);
 
 /// Creates a TransferReadOp from `source`.
 ///
-/// The shape of the vector to read is specified via `inputVectorSizes`. If the
-/// shape of the output vector differs from the shape of the value being read,
-/// masking is used to avoid out-of-bounds accesses. Set
+/// If the shape of vector to read differs from the shape of the value being
+/// read, masking is used to avoid out-of-bounds accesses. Set
 /// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
 /// instead of explicit masks.
 ///
 /// Note: all read offsets are set to 0.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
-                             ArrayRef<int64_t> inputVectorSizes,
+                             VectorType &vecToReadTy,
                              std::optional<Value> padValue = std::nullopt,
-                             bool useInBoundsInsteadOfMasking = false,
-                             ArrayRef<bool> inputScalableVecDims = {});
+                             bool useInBoundsInsteadOfMasking = false);
 
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c15e6330fd8c3..32246e3a11433 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1886,9 +1886,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
 
   // Create masked TransferReadOp.
   auto maskedRead = vector::createReadOrMaskedRead(
-      rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
-      useInBoundsInsteadOfMasking,
-      /*inputScalableVecSizes=*/{});
+      rewriter, loc, packOp.getSource(), readVecType, padValue,
+      useInBoundsInsteadOfMasking);
 
   // Create ShapeCastOp.
   auto expandedVecType =
@@ -1975,9 +1974,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   }
 
   // -- Generate the read operation --
+  VectorType readVecType =
+      VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
+                      readScalableVectorFlags);
   Value readResult = vector::createReadOrMaskedRead(
-      rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
-      useInBoundsInsteadOfMasking, readScalableVectorFlags);
+      rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
+      useInBoundsInsteadOfMasking);
 
   // -- Generate the transpose operation --
   PackingMetadata packMetadata;
@@ -2023,9 +2025,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
           .reifyResultShapes(rewriter, reifiedReturnShapes);
   (void)status; // prevent unused variable warning on non-assert builds
   assert(succeeded(status) && "failed to reify result shapes");
+  auto readType = VectorType::get(inputVectorSizes, padValue.getType());
   auto maskedRead = vector::createReadOrMaskedRead(
-      rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
-      /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
+      rewriter, loc, padOp.getSource(), readType, padValue,
+      /*useInBoundsInsteadOfMasking=*/false);
 
   // Create Xfer write Op
   Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2220,9 +2223,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
         state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
 
     Value read = mlir::vector::createReadOrMaskedRead(
-        rewriter, loc, opOperand.get(), readType.getShape(),
+        rewriter, loc, opOperand.get(), readType,
         /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
-        /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+        /*useInBoundsInsteadOfMasking=*/false);
     vecOperands.push_back(read);
   }
 
@@ -3163,9 +3166,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   SmallVector<Value> readIndices(
       vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
   Value read = mlir::vector::createReadOrMaskedRead(
-      rewriter, loc, source, vecType.getShape(), padValue,
-      /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
-      /*inputScalableVecSizes=*/{});
+      rewriter, loc, source, vecType, padValue,
+      /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
 
   // Create write
   auto writeIndices =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a04a1de..d73ccd1c41b66 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,51 +317,51 @@ bool vector::isLinearizableVector(VectorType type) {
 }
 
 Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
-                                     Value source,
-                                     ArrayRef<int64_t> inputVectorSizes,
+                                     Value source, VectorType &vecToReadTy,
                                      std::optional<Value> padValue,
-                                     bool useInBoundsInsteadOfMasking,
-                                     ArrayRef<bool> inputScalableVecDims) {
-  assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
+                                     bool useInBoundsInsteadOfMasking) {
+  assert(!llvm::is_contained(vecToReadTy.getScalableDims(),
+                             ShapedType::kDynamic) &&
          "invalid input vector sizes");
   auto sourceShapedType = cast<ShapedType>(source.getType());
   auto sourceShape = sourceShapedType.getShape();
-  assert(sourceShape.size() == inputVectorSizes.size() &&
+
+  int64_t vecToReadRank = vecToReadTy.getRank();
+  auto vecToReadShape = vecToReadTy.getShape();
+
+  assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) &&
          "expected same ranks.");
-  auto vectorType =
-      VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
-                      inputScalableVecDims);
   assert((!padValue.has_value() ||
           padValue.value().getType() == sourceShapedType.getElementType()) &&
          "expected same pad element type to match source element type");
-  int64_t readRank = inputVectorSizes.size();
+
   auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
-  SmallVector<bool> inBoundsVal(readRank, true);
+  SmallVector<bool> inBoundsVal(vecToReadRank, true);
 
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     // FIXME: This computation is too weak - it ignores the read indices.
-    for (unsigned i = 0; i < readRank; i++)
-      inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
+    for (unsigned i = 0; i < vecToReadRank; i++)
+      inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
                        ShapedType::isStatic(sourceShape[i]);
   }
   auto transferReadOp = vector::TransferReadOp::create(
       builder, loc,
-      /*vectorType=*/vectorType,
+      /*vectorType=*/vecToReadTy,
       /*source=*/source,
-      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*indices=*/SmallVector<Value>(vecToReadRank, zero),
       /*padding=*/padValue,
       /*inBounds=*/inBoundsVal);
 
-  if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
+  if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
+      useInBoundsInsteadOfMasking)
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       isa<MemRefType>(source.getType())
           ? memref::getMixedSizes(builder, loc, source)
           : tensor::getMixedSizes(builder, loc, source);
 
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
-                                  inputScalableVecDims);
+  auto maskType = vecToReadTy.clone(builder.getI1Type());
   Value mask =
       vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)

@banach-space banach-space changed the title users/banach space/vector/simplify create masked read [mlir][vector] Simplify createReadOrMaskedRead Oct 16, 2025
@banach-space banach-space force-pushed the users/banach-space/linalg/vectorize_pack branch from a26b730 to bea41f5 Compare October 16, 2025 10:28
Copy link

github-actions bot commented Oct 16, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Simplify `createReadOrMaskedRead` to only require _one_ argument to
specify the vector type to read (passed as `VectorType`) instead of
passing vector-sizes and scalable-flags independently (i.e. _two_
arguments).
@banach-space banach-space force-pushed the users/banach-space/vector/simplify_create_masked_read branch from 0ad9ece to 98a3fac Compare October 16, 2025 10:28
Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner code is always good. 😃

This looks good to me, but it's a change of API, so best to leave a bit longer to make sure users have time to react.

@rengolin rengolin requested a review from adam-smnk October 16, 2025 12:06
@rengolin
Copy link
Member

@arun-thmn FYI

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks fine 👍

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Add [NFC] before landing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants