-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern #92938
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Hugo Trachino (nujaa) ChangesThis MR is part of a list of MRs aiming to generalize Discussed here. Full diff: https://github.com/llvm/llvm-project/pull/92938.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..a8494eac3e5aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1695,6 +1695,66 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+/// vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+/// vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ if (!srcVT)
+ return failure();
+ auto resVT = broadcastOp.getResultVectorType();
+ VectorType newSrcVT = srcVT;
+ VectorType newResVT = resVT;
+ auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
+ // Reversing allows us to remove dims from the back without keeping track of
+ // removed dimensions.
+ for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+ if (dim.value() == 1 &&
+ !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
+ !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
+ newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
+ dim.index() - 1);
+ newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
+ dim.index() - 1);
+ }
+ }
+
+ if (newSrcVT == srcVT)
+ return failure();
+ auto loc = broadcastOp->getLoc();
+ auto newSource = rewriter.create<vector::ShapeCastOp>(
+ loc, newSrcVT, broadcastOp.getSource());
+ auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+ newOp.getResult());
+ return success();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1819,8 +1879,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
- patterns.getContext(), benefit);
+ patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
+ ShapeCastOpFolder>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..f1fc443b9d4bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -460,6 +460,61 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK-128B-NOT: memref.collapse_shape
+// -----
+
+func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
+ %bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
+ return %bc : vector<4x1x[1]x3x1xf128>
+}
+
+// CHECK-LABEL: func.func @drop_broadcast_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
+// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
+// CHECK: return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
+
+// -----
+
+func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
+ %bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
+ return %bc : vector<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @drop_broadcasted_only_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
+// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[VAL_3]] : vector<1x1xf32>
+
+// -----
+
+// Generated unit dimensions through broadcasts are not dropped as we prefer to have a
+// single broadcast rather than a broadcast and a shape_cast.
+func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
+ %bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
+ return %bc : vector<3x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @drop_broadcast_generated_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
+// CHECK: return %[[VAL_1]] : vector<3x1x4xf32>
+
+// -----
+
+// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
+func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
+ %bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
+ return %bc : vector<2x3x4xf32>
+}
+// CHECK-LABEL: func.func @drop_broadcasted_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
+// CHECK: return %[[VAL_1]] : vector<2x3x4xf32>
+
+
// -----
func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
|
CC @banach-space @MacDue . |
Hi Hugo, thanks for sending this!
I think that it would be good to take a step back and discuss whether this is actually needed. IIUC, you are trying to solve the problem outlined here: As I pointed out in my reply in that thread:
Also, as @dcaballe hinted:
Before moving ahead with this - what's your long-term goal? If we manage to get rid of
then, IIUC, this change won't be needed. Given that getting rid of |
FYI, I also have one last pattern ready which needs to go through some internal process before upstreaming. It matches a series of elementwise op to generate
It does make sense. I can likely give it a look over the week. Happy to talk implementation details on a dedicated thread or something. |
69e5e4c
to
cebfd74
Compare
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); | ||
// Reversing allows us to remove dims from the back without keeping track of | ||
// removed dimensions. | ||
for (const auto &dim : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (const auto &dim : | |
for (const auto [index, dim] : |
if (dim.value() == 1 && | ||
!srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] && | ||
!broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) { | ||
srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1); | ||
resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be easier to read with some names:
if (dim.value() == 1 && | |
!srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] && | |
!broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) { | |
srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1); | |
resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1); | |
} | |
auto sourceDimIndex = srcVecTy.getRank() - index - 1; | |
auto resultDimIndex = resVecTy.getRank() - index - 1; | |
if (dim == 1 && !srcVecTy.getScalableDims()[sourceDimIndex] && | |
!broadcastedUnitDims.contains(sourceDimIndex)) { | |
srcVecTyBuilder.dropDim(sourceDimIndex); | |
resVecTyBuilder.dropDim(resultDimIndex); | |
} |
auto srcVecTyBuilder = VectorType::Builder(srcVecTy); | ||
auto resVecTyBuilder = VectorType::Builder(resVecTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Builders may be less efficient than just appending the dims not dropped to a new vector (but this is probably not much of a concern given the number of dims is normally < 5-ish).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no need to change this -- just a note)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I tried implementing it with shapes to be appended. But, considering one has to rebuild also scalableDims for both the new source and result type, I found it was generating lots code which is hidden thanks to dropDim.
Also, for some reason, I was able to generate the base of the new resultShape creating a subvector of it with :
SmallVector<int64_t> newResShape =
llvm::to_vector(resVecTy.getShape().drop_back(srcVecTy.getRank()));
but for Scalable Dims I get some errors like this and I dont think I should be changing the behaviour of SmallVector. I suspect it comes from the way ScalableDims are defined.
llvm/include/llvm/ADT/SmallVector.h:1317:11: error: type 'decltype(__cont.begin())' (aka 'const bool *') cannot be narrowed to 'bool' in initializer list [-Wc++11-narrowing]
In order to fix it, I needed to create an ugly vector inserting explicit casts like
SmallVector<bool> newResScalableDims = {
static_cast<bool>(resVecTy.getScalableDims().begin()),
static_cast<bool>(resVecTy.getScalableDims().drop_back(srcVecTy.getRank()).end())};
If you want, I can push my solution on top and we revert it if we prefer it as it currently is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those static_casts
look very suspect 😅. It looks like that's just going to make a SmallVector of two (likely true) bools. I think keeping it as-is is fine as it's simpler, and likely not really a performance concern (vector types are normally small).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, I've been trying to make writing code like this easier for scalable dims (for a while now 😅). With my current attempt #96236, I think you'd be able to rewrite this as (untested!!!):
auto srcDims = VectorDimList::from(srcVecTy);
auto resDims = VectorDimList::from(resVecTy);
auto rankDiff = resDims.size() - srcDims.size();
SmallVector<VectorDim> newSrcDims;
SmallVector<VectorDim> newResDims(resDims.takeFront(rankDiff));
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
for (auto [idx, dim] : llvm::enumerate(srcDims)) {
if (dim != VectorDim::getFixed(1) || broadcastedUnitDims.contains(idx)) {
newSrcDims.push_back(dim);
newResDims.push_back(resDims[idx + rankDiff]);
}
}
auto newSourceType = ScalableVectorType::get(newSrcDims, srcVecTy.getElementType());
auto newResultType = ScalableVectorType::get(newResDims, srcVecTy.getElementType());
Please take a look at the PR if you think it'd be useful :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks indeed way simpler, very fortunate PR you submitted. I will probably not be able to review it at the moment as I need to sort some things out today. But I hope to give it a look. (I'll be off for a while, do not wait for my review).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM, just a few more nits:
if (VectorType(srcVecTyBuilder) == srcVecTy) | ||
return failure(); | ||
auto loc = broadcastOp->getLoc(); | ||
auto newSource = rewriter.create<vector::ShapeCastOp>( | ||
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: avoid constructing the new vector type twice:
if (VectorType(srcVecTyBuilder) == srcVecTy) | |
return failure(); | |
auto loc = broadcastOp->getLoc(); | |
auto newSource = rewriter.create<vector::ShapeCastOp>( | |
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource()); | |
auto newSrcVecTy = VectorType(srcVecTyBuilder); | |
if (newSrcVecTy == srcVecTy) | |
return failure(); | |
auto loc = broadcastOp->getLoc(); | |
auto newSource = rewriter.create<vector::ShapeCastOp>( | |
loc, newSrcVecTy, broadcastOp.getSource()); |
} | ||
|
||
// CHECK-LABEL: func.func @drop_broadcast_unit_dim( | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use better names than the generated VAL_*
:)
I'm a bit confused by this PR. Isn't a |
This MR is part of a list of MRs aiming to generalize
DropUnitDimFromElementwiseOps
for other ops.This commit implements
DropUnitDimFromBroadcastOp
to targetvector::BroadcastOp
.This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts.
Discussed here.