Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make createReadOrMaskedRead and isValidMaskedInputVector vector utilities #89119

Merged
merged 8 commits into from
Apr 23, 2024

Conversation

LLITCHEV
Copy link
Contributor

@LLITCHEV LLITCHEV commented Apr 17, 2024

Made the createReadOrMaskedRead and isValidMaskedInputVector utility functions - to be accessible outside of the CU. Needed by the IREE new TopK implementation.

Made the createReadOrMaskedRead a utility function - to be accessible
outside of the CU. Needed by the IREE new TopK implementation.
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Lubomir Litchev (LLITCHEV)

Changes

Made the createReadOrMaskedRead a utility function - to be accessible outside of the CU. Needed by the IREE new TopK implementation.


Full diff: https://github.com/llvm/llvm-project/pull/89119.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+29)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (-40)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..f4c56b671e9d7e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,12 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
+/// Create a TransferReadOp from `source` with static shape `readShape`. If the
+/// vector type for the read is not the same as the type of `source`, then a
+/// mask is created on the read.
+Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
+                                    Value source, ArrayRef<int64_t> readShape,
+                                    Value padValue);
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a17bc8e4cd318f..b32ebfc380fcfb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1593,3 +1593,32 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
       DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
       patterns.getContext(), benefit);
 }
+
+Value mlir::linalg::createReadOrMaskedRead(OpBuilder &builder, Location loc,
+                                    Value source, ArrayRef<int64_t> readShape,
+                                    Value padValue) {
+  assert(llvm::none_of(readShape,
+                       [](int64_t s) { return s == ShapedType::kDynamic; }));
+  auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
+  assert(sourceShape.size() == readShape.size());
+  auto maskType = VectorType::get(readShape, builder.getI1Type());
+  auto vectorType = VectorType::get(readShape, padValue.getType());
+  int64_t readRank = readShape.size();
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  auto transferReadOp = builder.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/source,
+      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(readRank, true));
+  if (llvm::equal(readShape, sourceShape)) {
+    return transferReadOp;
+  }
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(builder, loc, source);
+  Value mask =
+      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  return mlir::vector::maskOperation(builder, transferReadOp, mask)
+      ->getResult(0);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df61381432921b..e2ca5e14377286 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1410,46 +1410,6 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
   return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
 }
 
-/// Create a TransferReadOp from `source` with static shape `readShape`. If the
-/// vector type for the read is not the same as the type of `source`, then a
-/// mask is created on the read.  If `doMasking` parameter is set to false we
-/// update the `inBounds` attribute instead of masking.
-static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
-                                    Value source, ArrayRef<int64_t> readShape,
-                                    Value padValue, bool doMasking = true) {
-  assert(llvm::none_of(readShape,
-                       [](int64_t s) { return s == ShapedType::kDynamic; }));
-  auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
-  assert(sourceShape.size() == readShape.size());
-  auto maskType = VectorType::get(readShape, builder.getI1Type());
-  auto vectorType = VectorType::get(readShape, padValue.getType());
-  int64_t readRank = readShape.size();
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  SmallVector<bool> inBoundsVal(readRank, true);
-  if (!doMasking) {
-    // Update the inBounds attribute.
-    for (unsigned i = 0; i < readRank; i++)
-      inBoundsVal[i] = sourceShape[i] == readShape[i];
-  }
-  auto transferReadOp = builder.create<vector::TransferReadOp>(
-      loc,
-      /*vectorType=*/vectorType,
-      /*source=*/source,
-      /*indices=*/SmallVector<Value>(readRank, zero),
-      /*padding=*/padValue,
-      /*inBounds=*/inBoundsVal);
-
-  if (llvm::equal(readShape, sourceShape) || !doMasking) {
-    return transferReadOp;
-  }
-  SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
-  Value mask =
-      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
-  return mlir::vector::maskOperation(builder, transferReadOp, mask)
-      ->getResult(0);
-}
-
 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
 /// create an empty destination tensor and create a TransferWriteOp from the
 /// input to the empty tensor. If the destination shape is not the same as the

Copy link

github-actions bot commented Apr 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just moving code, but I think it's worth cleaning it up a bit.

ArrayRef<int64_t> readShape,
Value padValue, bool doMasking) {
assert(llvm::none_of(readShape,
[](int64_t s) { return s == ShapedType::kDynamic; }));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add && "expected static shape" for better error messages

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Value padValue, bool doMasking) {
assert(llvm::none_of(readShape,
[](int64_t s) { return s == ShapedType::kDynamic; }));
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turn into cast<...> because there is no check that the dyn_cast succeeded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

if (!doMasking) {
// Update the inBounds attribute.
for (unsigned i = 0; i < readRank; i++)
inBoundsVal[i] = sourceShape[i] == readShape[i];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this so that the respective in_bounds is set to "false" if readShape > sourceShape or sourceShape is dynamic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

@@ -1616,6 +1616,12 @@ void populateSplitReductionPattern(
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);

/// Create a TransferReadOp from `source` with static shape `readShape`. If the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment and describe doMasking. (Maybe I'd also rename it to just masking.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also include a comment on indices? These are hard-coded to be 0 - that's worth documenting.

(Maybe I'd also rename it to just masking.)

This is a bit of bikesheding and I don't mind that much ... But I do like when a bool variable includes a verb so that it's effectively a question with a binary answer (i.e. "yes"/"no") :) Personally I'd use enableMasking instead. Again, don't really mind 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to enableMasking.

assert(llvm::none_of(readShape,
[](int64_t s) { return s == ShapedType::kDynamic; }));
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
assert(sourceShape.size() == readShape.size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add && "expected same rank"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
assert(sourceShape.size() == readShape.size());
auto maskType = VectorType::get(readShape, builder.getI1Type());
auto vectorType = VectorType::get(readShape, padValue.getType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also put an assert(padValue.getType() == shapedType.getElementType() && "expected same pad element type to match source element type").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

/*padding=*/padValue,
/*inBounds=*/inBoundsVal);

if (llvm::equal(readShape, sourceShape) || !doMasking) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Trivial braces not needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't createReadOrMaskedRead be a Vector dialect Util?

@LLITCHEV LLITCHEV requested a review from jpienaar April 18, 2024 18:02
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to refactor isValidMaskedInputVector for your topk implementation?

@matthias-springer
Copy link
Member

Shouldn't createReadOrMaskedRead be a Vector dialect Util?

Oh right, this shouldn't be in Linalg.

Moved code around and other CR.
@LLITCHEV LLITCHEV changed the title Make createReadOrMaskedRead a utility Make createReadOrMaskedRead and isValidMaskedInputVector vector utilities Apr 19, 2024
@LLITCHEV
Copy link
Contributor Author

This is just moving code, but I think it's worth cleaning it up a bit.

Fixed.

@LLITCHEV
Copy link
Contributor Author

Shouldn't createReadOrMaskedRead be a Vector dialect Util?

Oh right, this shouldn't be in Linalg.

Fixed.

@LLITCHEV
Copy link
Contributor Author

I think you also need to refactor isValidMaskedInputVector for your topk implementation?

I thought to do two separate PRs, but there is no reason why not to do it in one. Thanks!

@LLITCHEV
Copy link
Contributor Author

Shouldn't createReadOrMaskedRead be a Vector dialect Util?

Oh right, this shouldn't be in Linalg.

Moved to VectorUtils.

Comment on lines 344 to 345
inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
!ShapedType::isDynamic(sourceShape[] i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not:

Suggested change
inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
!ShapedType::isDynamic(sourceShape[] i);
inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
!ShapedType::isDynamic(sourceShape[i]);

?

Copy link
Contributor Author

@LLITCHEV LLITCHEV Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
!ShapedType::isDynamic(sourceShape[i]);

Yes... Fixed... That was a typo... :) Not sure how it got there. I formatted after building maybe some accidental undo. The CI cought the issue. :)

Comment on lines 355 to 356
if (llvm::equal(readShape, sourceShape) || !doMasking)
return transferReadOp;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (llvm::equal(readShape, sourceShape) || !doMasking)
return transferReadOp;
if (llvm::equal(readShape, sourceShape) || !enableMasking)
return transferReadOp;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. :) Fixed.

Comment on lines 183 to 187
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
/// vector type for the read is not the same as the type of `source`, then a
/// mask is created on the read.
/// enableMasking if false, the inBoundsVal values are set properly, based on
/// the rank dimensions of the source and destination tensors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this comment.

  1. Fix formatting
  2. "then a mask is created on the read." - the mask is never created if enableMasking is false.
  3. Missing note on indices being hard-coded as 0.

Point 2. is rather counter-intuitive to me. @pashu123, you introduced doMasking in #88249. Looking at the implementation and you PR, it feels like you meant something like bool useInBoundsInsteadOfMasking rather than doMasking/enableMasking? As in:

BEFORE #88249:

if (llvm::equal(readShape, sourceShape))
  // use masking

AFTER #88249:

if (llvm::equal(readShape, sourceShape) && !useInBoundsInsteadOfMasking)
  // use masking

Put differently, the extra logic added to createReadOrMaskedRead in #88249 reads to me as:

  • "In some cases I want to disable masking and use in_bounds attr instead", rather than
  • "I want a dedicated toggle to enable/disable masking".

I'm asking to better understand the current logic 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"I feel the same way too. Using useInBoundsInsteadOfMasking provides a better understanding of the entire context." I can modify the variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space To me this is just a potential optimization (maybe I'm wrong). It looks like the code is setting up for a later optimization to turn read_transfer into a simple load. But if we do masking or the bounds of a dimension is different, the optimization can't be done. It alsocan't be done if masking is done. So, it looked to me that if no masking and bounds the same, set it up for later pass to do load instead of a transfer (with skipping the masking op). Or just maybe just removing an unnecessary op - masking. Maybe I misunderstood the purpose of it too. That's what I thought... :)
The rest is fixed, I hope... :)
Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useInBoundsInsteadOfMasking

I renamed them, since I'm moving the code as utility methods. Thanks!

@pashu123
Copy link
Member

@LLITCHEV If you could rename the variables here:

.

/// mask is created on the read, if use of mask is specified or the bounds on a
/// dimension are different.
///
/// `enableMasking` if false, the inBoundsVal values are set properly, based on
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the doc comments w.r.t useInBoundsInsteadOfMasking.

Copy link
Contributor Author

@LLITCHEV LLITCHEV Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops... Missed that one. Good catch. Fixed. Thanks!

@LLITCHEV
Copy link
Contributor Author

@LLITCHEV If you could rename the variables here:

.

Renamed.

@LLITCHEV LLITCHEV marked this pull request as ready for review April 22, 2024 17:28
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments 🙏🏻

@LLITCHEV LLITCHEV merged commit 30d4f6a into llvm:main Apr 23, 2024
6 checks passed
Copy link

@LLITCHEV Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

LLITCHEV added a commit to LLITCHEV/llvm-project that referenced this pull request Apr 25, 2024
The PR llvm#89119 renamed a flag,
which inverted it's previous meaning, but the logic was not updated to
reflact that.
Fixed the logic to reflect the inversion of the meaning of the name.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants