-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[ExpandVectorPredication] Support vp.merge in foldEVLIntoMask. #157195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Craig Topper (topperc) ChangesPartial fix for #157184. Full diff: https://github.com/llvm/llvm-project/pull/157195.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
index 753c656007703..5f79bd7b2dd13 100644
--- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp
+++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -527,6 +527,12 @@ std::pair<Value *, bool> CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
// Only VP intrinsics can have an %evl parameter.
Value *OldMaskParam = VPI.getMaskParam();
+ if (!OldMaskParam) {
+ assert(VPI.getIntrinsicID() == Intrinsic::vp_merge &&
+ "Unexpected VP intrinsic without mask operand");
+ OldMaskParam = VPI.getArgOperand(0);
+ }
+
Value *OldEVLParam = VPI.getVectorLengthParam();
assert(OldMaskParam && "no mask param to fold the vl param into");
assert(OldEVLParam && "no EVL param to fold away");
@@ -538,7 +544,10 @@ std::pair<Value *, bool> CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
ElementCount ElemCount = VPI.getStaticVectorLength();
Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
- VPI.setMaskParam(NewMaskParam);
+ if (VPI.getIntrinsicID() == Intrinsic::vp_merge)
+ VPI.setArgOperand(0, NewMaskParam);
+ else
+ VPI.setMaskParam(NewMaskParam);
// Drop the %evl parameter.
discardEVLParameter(VPI);
diff --git a/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll
index 6eaf98f893bfa..fe7d725439060 100644
--- a/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll
+++ b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll
@@ -68,6 +68,7 @@ define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i3
%rE = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
%rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
%r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+ %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n)
ret void
}
@@ -111,6 +112,7 @@ define void @test_vp_int_vscale(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1,
%rE = call <vscale x 4 x i32> @llvm.vp.ashr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
%rF = call <vscale x 4 x i32> @llvm.vp.lshr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
%r10 = call <vscale x 4 x i32> @llvm.vp.shl.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+ %r11 = call <vscale x 4 x i32> @llvm.vp.merge.nxv4i32(<vscale x 4 x i1> %m, <vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, i32 %n)
ret void
}
@@ -322,6 +324,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x
; LEGAL_LEGAL-NEXT: %rE = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
; LEGAL_LEGAL-NEXT: %rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
; LEGAL_LEGAL-NEXT: %r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+; LEGAL_LEGAL-NEXT: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n)
; LEGAL_LEGAL-NEXT: ret void
; LEGAL_LEGAL:define void @test_vp_int_vscale(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i32> %i2, <vscale x 4 x i32> %f3, <vscale x 4 x i1> %m, i32 %n) {
@@ -342,6 +345,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x
; LEGAL_LEGAL-NEXT: %rE = call <vscale x 4 x i32> @llvm.vp.ashr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
; LEGAL_LEGAL-NEXT: %rF = call <vscale x 4 x i32> @llvm.vp.lshr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
; LEGAL_LEGAL-NEXT: %r10 = call <vscale x 4 x i32> @llvm.vp.shl.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+; LEGAL_LEGAL-NEXT: %r11 = call <vscale x 4 x i32> @llvm.vp.merge.nxv4i32(<vscale x 4 x i1> %m, <vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, i32 %n)
; LEGAL_LEGAL-NEXT: ret void
; LEGAL_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) {
@@ -415,6 +419,11 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x
; DISCARD_LEGAL-NEXT: %rE = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8)
; DISCARD_LEGAL-NEXT: %rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8)
; DISCARD_LEGAL-NEXT: %r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8)
+; DISCARD_LEGAL-NEXT: [[NSPLATINS2:%.+]] = insertelement <8 x i32> poison, i32 %n, i64 0
+; DISCARD_LEGAL-NEXT: [[NSPLAT2:%.+]] = shufflevector <8 x i32> [[NSPLATINS2]], <8 x i32> poison, <8 x i32> zeroinitializer
+; DISCARD_LEGAL-NEXT: [[EVLMASK2:%.+]] = icmp ult <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, [[NSPLAT2]]
+; DISCARD_LEGAL-NEXT: [[NEWMASK2:%.+]] = and <8 x i1> [[EVLMASK2]], %m
+; DISCARD_LEGAL-NEXT: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> [[NEWMASK2]], <8 x i32> %i0, <8 x i32> %i1, i32 8)
; DISCARD_LEGAL-NEXT: ret void
; TODO compute vscale only once and use caching.
@@ -431,7 +440,8 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x
; DISCARD_LEGAL: [[NEWM:%.+]] = and <vscale x 4 x i1> [[EVLM]], %m
; DISCARD_LEGAL: %r3 = call <vscale x 4 x i32> @llvm.vp.sdiv.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> [[NEWM]], i32 %scalable_size{{.*}})
; DISCARD_LEGAL-NOT: %{{.+}} = call <vscale x 4 x i32> @llvm.vp.{{.*}}, i32 %n)
-; DISCARD_LEGAL: ret void
+; DISCARD_LEGAL: %r11 = call <vscale x 4 x i32> @llvm.vp.merge.nxv4i32(<vscale x 4 x i1> %{{.*}}, <vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, i32 %scalable_size{{.*}})
+; DISCARD_LEGAL-NEXT: ret void
; DISCARD_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) {
; DISCARD_LEGAL-NEXT: [[NSPLATINS:%.+]] = insertelement <4 x i32> poison, i32 %n, i64 0
@@ -503,6 +513,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x
; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8)
; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8)
; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8)
+; CONVERT_LEGAL: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %{{.*}}, <8 x i32> %i0, <8 x i32> %i1, i32 8)
; CONVERT_LEGAL: ret void
; Similar to %evl discard, %mask legal but make sure the first VP intrinsic has a legal expansion
@@ -513,6 +524,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x
; CONVERT_LEGAL-NEXT: %scalable_size = mul nuw i32 %vscale, 4
; CONVERT_LEGAL-NEXT: %r0 = call <vscale x 4 x i32> @llvm.vp.add.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> [[NEWM]], i32 %scalable_size)
; CONVERT_LEGAL-NOT: %{{.*}} = call <vscale x 4 x i32> @llvm.vp.{{.*}}, i32 %n)
+; CONVERT_LEGAL: %r11 = call <vscale x 4 x i32> @llvm.vp.merge.nxv4i32(<vscale x 4 x i1> %{{.*}}, <vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, i32 %scalable_size{{.*}})
; CONVERT_LEGAL: ret void
; CONVERT_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
// Only VP intrinsics can have an %evl parameter. | ||
Value *OldMaskParam = VPI.getMaskParam(); | ||
if (!OldMaskParam) { | ||
assert(VPI.getIntrinsicID() == Intrinsic::vp_merge && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vp_select doesn't have a mask param either. Should we handle it here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vp_select is speculatable so the equivalebt test from the bug report didn’t get here.
Partial fix for #157184. It still crashes later in SelectionDAG.