-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[InstCombine] Transform vector.reduce.add
and splat
into multiplication
#161020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…32 %0, 2` Fixes llvm#160066 Whenever we have a vector with all the same elemnts, created with `insertelement` and `shufflevector` and the result type's element number is a power of two and we sum the vector, we have a multiplication by a power of two, which can be replaced with a left shift.
This should not be limited to powers of two. You can just emit a multiply and it will get folded to a shift in the power of two case. |
Thank you very much for your review @nikic . I am really happy that you have suggested to optimize the non power of two cases. It was fun implementig those too. :)
I am also open to any further potential improvement idea for this patch. |
vector.reduce.add (splat %0, 4)
into shl i32 %0, 2
vector.reduce.add
and splat
into multiplication
There is one lldb failure on Linux. I think that is just a flaky test case, which isn't caused by this PR. I will retrigger the CI. |
@llvm/pr-subscribers-llvm-transforms Author: Gábor Spaits (spaits) ChangesFixes #160066 Whenever we have a vector with all the same elemnts, created with Full diff: https://github.com/llvm/llvm-project/pull/161020.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 6ad493772d170..74c263e86f4a4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -64,6 +64,7 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/KnownFPClass.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/TypeSize.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
@@ -3761,6 +3762,41 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(CI, Res);
}
}
+
+ // Handle the case where a value is multiplied by a power of two.
+ // For example:
+ // %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ // %3 = shufflevector <4 x i32> %2, poison, <4 x i32> zeroinitializer
+ // %4 = tail call i32 @llvm.vector.reduce.add.v4i32(%3)
+ // =>
+ // %2 = shl i32 %0, 2
+ assert(Arg->getType()->isVectorTy() &&
+ "The vector.reduce.add intrinsic's argument must be a vector!");
+
+ if (Value *Splat = getSplatValue(Arg)) {
+ // It is only a multiplication if we add the same element over and over.
+ ElementCount ReducedVectorElementCount =
+ static_cast<VectorType *>(Arg->getType())->getElementCount();
+ if (ReducedVectorElementCount.isFixed()) {
+ unsigned VectorSize = ReducedVectorElementCount.getFixedValue();
+ Type *SplatType = Splat->getType();
+ unsigned SplatTypeWidth = SplatType->getIntegerBitWidth();
+ Value *Res;
+ // Power of two is a special case. We can just use a left shif here.
+ if (isPowerOf2_32(VectorSize)) {
+ unsigned Pow2 = Log2_32(VectorSize);
+ Res = Builder.CreateShl(
+ Splat, Constant::getIntegerValue(SplatType,
+ APInt(SplatTypeWidth, Pow2)));
+ return replaceInstUsesWith(CI, Res);
+ }
+ // Otherwise just multiply.
+ Res = Builder.CreateMul(
+ Splat, Constant::getIntegerValue(
+ SplatType, APInt(SplatTypeWidth, VectorSize)));
+ return replaceInstUsesWith(CI, Res);
+ }
+ }
}
[[fallthrough]];
}
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 10f4aca72dbc7..e071415d2d6c1 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -308,3 +308,93 @@ define i32 @diff_of_sums_type_mismatch2(<8 x i32> %v0, <4 x i32> %v1) {
%r = sub i32 %r0, %r1
ret i32 %r
}
+
+define i32 @constant_multiplied_at_0(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+ ret i32 %4
+}
+
+define i64 @constant_multiplied_at_0_64bits(i64 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_64bits(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP0:%.*]], 2
+; CHECK-NEXT: ret i64 [[TMP2]]
+;
+ %2 = insertelement <4 x i64> poison, i64 %0, i64 0
+ %3 = shufflevector <4 x i64> %2, <4 x i64> poison, <4 x i32> zeroinitializer
+ %4 = tail call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %3)
+ ret i64 %4
+}
+
+define i32 @constant_multiplied_at_0_two_pow8(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_two_pow8(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 3
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <8 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %3)
+ ret i32 %4
+}
+
+
+define i32 @constant_multiplied_at_0_two_pow16(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_two_pow16(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 4
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <16 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %3)
+ ret i32 %4
+}
+
+
+define i32 @constant_multiplied_at_1(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_1(
+; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 1
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison,
+ <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+ %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+ ret i32 %4
+}
+
+define i32 @negative_constant_multiplied_at_1(i32 %0) {
+; CHECK-LABEL: @negative_constant_multiplied_at_1(
+; CHECK-NEXT: ret i32 poison
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 1
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+ ret i32 %4
+}
+
+define i32 @constant_multiplied_non_power_of_2(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_non_power_of_2(
+; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP0:%.*]], 6
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+ %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+ %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <6 x i32> zeroinitializer
+ %4 = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> %3)
+ ret i32 %4
+}
+
+define i64 @constant_multiplied_non_power_of_2_i64(i64 %0) {
+; CHECK-LABEL: @constant_multiplied_non_power_of_2_i64(
+; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP0:%.*]], 6
+; CHECK-NEXT: ret i64 [[TMP2]]
+;
+ %2 = insertelement <4 x i64> poison, i64 %0, i64 0
+ %3 = shufflevector <4 x i64> %2, <4 x i64> poison, <6 x i32> zeroinitializer
+ %4 = tail call i64 @llvm.vector.reduce.add.v6i64(<6 x i64> %3)
+ ret i64 %4
+}
|
Thank you very much for your review @XChy . I have addressed your comments. |
@zyw-bot mfuzz |
assert(Arg->getType()->isVectorTy() && | ||
"The vector.reduce.add intrinsic's argument must be a vector!"); | ||
ElementCount ReducedVectorElementCount = | ||
static_cast<VectorType *>(Arg->getType())->getElementCount(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_cast<VectorType *>(Arg->getType())->getElementCount(); | |
cast<VectorType>(Arg->getType())->getElementCount(); |
And remove the assert.
Value *Res = | ||
Builder.CreateMul(Splat, ConstantInt::get(SplatType, VectorSize)); | ||
return replaceInstUsesWith(CI, Res); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Value *Res = | |
Builder.CreateMul(Splat, ConstantInt::get(SplatType, VectorSize)); | |
return replaceInstUsesWith(CI, Res); | |
return BinaryOperator::CreateMul(Splat, ConstantInt::get(SplatType, VectorSize)); |
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x i1> [[TMP2]], <2 x i1> poison, <2 x i32> zeroinitializer | ||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <2 x i1> [[TMP3]] to i2 | ||
; CHECK-NEXT: [[TMP5:%.*]] = call range(i2 0, -1) i2 @llvm.ctpop.i2(i2 [[TMP4]]) | ||
; CHECK-NEXT: [[TMP6:%.*]] = trunc i2 [[TMP5]] to i1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for so many i1 tests that don't hit this code path anyway. I'd suggest adding additional i2 tests instead, which make it a bit clearer what is going on (e.g. v5i2 and v6i2).
ret i2 %4 | ||
} | ||
|
||
define i2 @constant_multiplied_5xi2(i2 %0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ret i2 %4 | ||
} | ||
|
||
define i2 @constant_multiplied_7xi2(i2 %0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ret i2 %4 | ||
} | ||
|
||
define i2 @constant_multiplied_6xi2(i2 %0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except a nit, cheers.
✅ With the latest revision this PR passed the C/C++ code formatter. |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/76/builds/13019 Here is the relevant piece of the build log for the reference
|
…cation (llvm#161020) Fixes llvm#160066 Whenever we have a vector with all the same elemnts, created with `insertelement` and `shufflevector` and we sum the vector, we have a multiplication.
Fixes #160066
Whenever we have a vector with all the same elemnts, created with
insertelement
andshufflevector
and we sum the vector, we have a multiplication.