Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern #92938

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 21, 2024

This MR is part of a list of MRs aiming to generalize DropUnitDimFromElementwiseOps for other ops.
This commit implements DropUnitDimFromBroadcastOp to target vector::BroadcastOp.
This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts.
Discussed here.

@llvmbot
Copy link
Collaborator

llvmbot commented May 21, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Hugo Trachino (nujaa)

Changes

This MR is part of a list of MRs aiming to generalize DropUnitDimFromElementwiseOps for other ops.
This commit implements DropUnitDimFromBroadcastOp to target vector::BroadcastOp.

Discussed here.


Full diff: https://github.com/llvm/llvm-project/pull/92938.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+62-2)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+55)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..a8494eac3e5aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1695,6 +1695,66 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+///  %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+///  %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///  %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+///  %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+///    vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+///    vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcVT)
+      return failure();
+    auto resVT = broadcastOp.getResultVectorType();
+    VectorType newSrcVT = srcVT;
+    VectorType newResVT = resVT;
+    auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
+    // Reversing allows us to remove dims from the back without keeping track of
+    // removed dimensions.
+    for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+      if (dim.value() == 1 &&
+          !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
+          !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
+        newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
+                                                         dim.index() - 1);
+        newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
+                                                         dim.index() - 1);
+      }
+    }
+
+    if (newSrcVT == srcVT)
+      return failure();
+    auto loc = broadcastOp->getLoc();
+    auto newSource = rewriter.create<vector::ShapeCastOp>(
+        loc, newSrcVT, broadcastOp.getSource());
+    auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+                                             newOp.getResult());
+    return success();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1819,8 +1879,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
-      patterns.getContext(), benefit);
+  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
+               ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..f1fc443b9d4bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -460,6 +460,61 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 
+// -----
+
+func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
+  %bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
+  return %bc : vector<4x1x[1]x3x1xf128>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcast_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
+// CHECK:           %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
+// CHECK:           return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
+
+// -----
+
+func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
+  %bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
+  return %bc : vector<1x1xf32>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcasted_only_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
+// CHECK:           %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] :  vector<1xf32> to vector<1x1xf32>
+// CHECK:           return %[[VAL_3]] : vector<1x1xf32>
+
+// -----
+
+// Generated unit dimensions through broadcasts are not dropped as we prefer to have a 
+// single broadcast rather than a broadcast and a shape_cast.
+func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
+  %bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
+  return %bc : vector<3x1x4xf32>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcast_generated_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
+// CHECK:           return %[[VAL_1]] : vector<3x1x4xf32>
+
+// -----
+
+// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
+func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
+  %bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
+  return %bc : vector<2x3x4xf32>
+}
+// CHECK-LABEL:   func.func @drop_broadcasted_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x3x4xf32>
+
+
 // -----
 
 func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,

@nujaa
Copy link
Contributor Author

nujaa commented May 24, 2024

CC @banach-space @MacDue .

@banach-space
Copy link
Contributor

Hi Hugo, thanks for sending this!

This MR is part of a list of MRs aiming to generalize DropUnitDimFromElementwiseOps for other ops.

I think that it would be good to take a step back and discuss whether this is actually needed. IIUC, you are trying to solve the problem outlined here:

As I pointed out in my reply in that thread:

Solutions: Should we match broadcast(+transpose)+mulf+addf to generate a contractionOp, or is it the developer’s responsibility to use cse / canonicalize with care?

IMHO, no.

Also, as @dcaballe hinted:

As mentioned multiple times in the past, the vector.multi_reduction -> vector.contract step should be removed in favor of direct linalg.matmul -> vector.contract direct lowering.

Before moving ahead with this - what's your long-term goal? If we manage to get rid of vector.multi_reduction from

  • the linalg.matmul -> vector.outer_product lowering path,

then, IIUC, this change won't be needed. Given that getting rid of vector.multi_reduction is the long-term design goal (so it's bound to happen at some point), I'm wondering whether we shouldn't re-focus on that instead?

@nujaa
Copy link
Contributor Author

nujaa commented May 28, 2024

FYI, I also have one last pattern ready which needs to go through some internal process before upstreaming. It matches a series of elementwise op to generate vector.outerproduct. Not a contraction. Would that still be lifting ?

, IIUC, this change won't be needed. Given that getting rid of vector.multi_reduction is the long-term design goal (so it's bound to happen at some point), I'm wondering whether we shouldn't re-focus on that instead?

It does make sense. I can likely give it a look over the week. Happy to talk implementation details on a dedicated thread or something.

@nujaa nujaa force-pushed the hugo.DropUnitDimFromBroadcastOp branch from 69e5e4c to cebfd74 Compare June 20, 2024 13:06
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
// Reversing allows us to remove dims from the back without keeping track of
// removed dimensions.
for (const auto &dim :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (const auto &dim :
for (const auto [index, dim] :

Comment on lines 1745 to 1750
if (dim.value() == 1 &&
!srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] &&
!broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) {
srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1);
resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be easier to read with some names:

Suggested change
if (dim.value() == 1 &&
!srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] &&
!broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) {
srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1);
resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1);
}
auto sourceDimIndex = srcVecTy.getRank() - index - 1;
auto resultDimIndex = resVecTy.getRank() - index - 1;
if (dim == 1 && !srcVecTy.getScalableDims()[sourceDimIndex] &&
!broadcastedUnitDims.contains(sourceDimIndex)) {
srcVecTyBuilder.dropDim(sourceDimIndex);
resVecTyBuilder.dropDim(resultDimIndex);
}

Comment on lines +1738 to +1739
auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
auto resVecTyBuilder = VectorType::Builder(resVecTy);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Builders may be less efficient than just appending the dims not dropped to a new vector (but this is probably not much of a concern given the number of dims is normally < 5-ish).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no need to change this -- just a note)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I tried implementing it with shapes to be appended. But, considering one has to rebuild also scalableDims for both the new source and result type, I found it was generating lots code which is hidden thanks to dropDim.
Also, for some reason, I was able to generate the base of the new resultShape creating a subvector of it with :

SmallVector<int64_t> newResShape =
        llvm::to_vector(resVecTy.getShape().drop_back(srcVecTy.getRank()));

but for Scalable Dims I get some errors like this and I dont think I should be changing the behaviour of SmallVector. I suspect it comes from the way ScalableDims are defined.

llvm/include/llvm/ADT/SmallVector.h:1317:11: error: type 'decltype(__cont.begin())' (aka 'const bool *') cannot be narrowed to 'bool' in initializer list [-Wc++11-narrowing]

In order to fix it, I needed to create an ugly vector inserting explicit casts like

    SmallVector<bool> newResScalableDims = {
        static_cast<bool>(resVecTy.getScalableDims().begin()),
        static_cast<bool>(resVecTy.getScalableDims().drop_back(srcVecTy.getRank()).end())};

If you want, I can push my solution on top and we revert it if we prefer it as it currently is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those static_casts look very suspect 😅. It looks like that's just going to make a SmallVector of two (likely true) bools. I think keeping it as-is is fine as it's simpler, and likely not really a performance concern (vector types are normally small).

Copy link
Member

@MacDue MacDue Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I've been trying to make writing code like this easier for scalable dims (for a while now 😅). With my current attempt #96236, I think you'd be able to rewrite this as (untested!!!):

auto srcDims = VectorDimList::from(srcVecTy);
auto resDims = VectorDimList::from(resVecTy);
auto rankDiff = resDims.size() - srcDims.size();

SmallVector<VectorDim> newSrcDims;
SmallVector<VectorDim> newResDims(resDims.takeFront(rankDiff));

auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
for (auto [idx, dim] : llvm::enumerate(srcDims)) {
  if (dim != VectorDim::getFixed(1) || broadcastedUnitDims.contains(idx)) {
    newSrcDims.push_back(dim);
    newResDims.push_back(resDims[idx + rankDiff]);
  }
}

auto newSourceType = ScalableVectorType::get(newSrcDims, srcVecTy.getElementType());
auto newResultType = ScalableVectorType::get(newResDims, srcVecTy.getElementType());

Please take a look at the PR if you think it'd be useful :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks indeed way simpler, very fortunate PR you submitted. I will probably not be able to review it at the moment as I need to sort some things out today. But I hope to give it a look. (I'll be off for a while, do not wait for my review).

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, just a few more nits:

Comment on lines +1754 to +1758
if (VectorType(srcVecTyBuilder) == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: avoid constructing the new vector type twice:

Suggested change
if (VectorType(srcVecTyBuilder) == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
auto newSrcVecTy = VectorType(srcVecTyBuilder);
if (newSrcVecTy == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
loc, newSrcVecTy, broadcastOp.getSource());

}

// CHECK-LABEL: func.func @drop_broadcast_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use better names than the generated VAL_* :)

@dcaballe
Copy link
Contributor

I'm a bit confused by this PR. Isn't a vector.broadcast one of the "boundary" operations that we use to restore the original shape of a vector we dropped a unit dimension from?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants