From 1493a742fb9c939b8b9a5247153506dd8aa2ab8a Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Fri, 5 Sep 2025 16:09:25 -0700 Subject: [PATCH] [ExpandVectorPredication] Support vp.merge in foldEVLIntoMask. Partial fix for #157184. --- llvm/lib/CodeGen/ExpandVectorPredication.cpp | 11 ++++++++++- .../PreISelIntrinsicLowering/expand-vp.ll | 14 +++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp index 753c656007703..5f79bd7b2dd13 100644 --- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -527,6 +527,12 @@ std::pair CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { // Only VP intrinsics can have an %evl parameter. Value *OldMaskParam = VPI.getMaskParam(); + if (!OldMaskParam) { + assert(VPI.getIntrinsicID() == Intrinsic::vp_merge && + "Unexpected VP intrinsic without mask operand"); + OldMaskParam = VPI.getArgOperand(0); + } + Value *OldEVLParam = VPI.getVectorLengthParam(); assert(OldMaskParam && "no mask param to fold the vl param into"); assert(OldEVLParam && "no EVL param to fold away"); @@ -538,7 +544,10 @@ std::pair CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { ElementCount ElemCount = VPI.getStaticVectorLength(); Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); - VPI.setMaskParam(NewMaskParam); + if (VPI.getIntrinsicID() == Intrinsic::vp_merge) + VPI.setArgOperand(0, NewMaskParam); + else + VPI.setMaskParam(NewMaskParam); // Drop the %evl parameter. discardEVLParameter(VPI); diff --git a/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll index 6eaf98f893bfa..fe7d725439060 100644 --- a/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll +++ b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll @@ -68,6 +68,7 @@ define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i3 %rE = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) %rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) %r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n) ret void } @@ -111,6 +112,7 @@ define void @test_vp_int_vscale( %i0, %i1, %rE = call @llvm.vp.ashr.nxv4i32( %i0, %i1, %m, i32 %n) %rF = call @llvm.vp.lshr.nxv4i32( %i0, %i1, %m, i32 %n) %r10 = call @llvm.vp.shl.nxv4i32( %i0, %i1, %m, i32 %n) + %r11 = call @llvm.vp.merge.nxv4i32( %m, %i0, %i1, i32 %n) ret void } @@ -322,6 +324,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; LEGAL_LEGAL-NEXT: %rE = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) ; LEGAL_LEGAL-NEXT: %rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) ; LEGAL_LEGAL-NEXT: %r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n) ; LEGAL_LEGAL-NEXT: ret void ; LEGAL_LEGAL:define void @test_vp_int_vscale( %i0, %i1, %i2, %f3, %m, i32 %n) { @@ -342,6 +345,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; LEGAL_LEGAL-NEXT: %rE = call @llvm.vp.ashr.nxv4i32( %i0, %i1, %m, i32 %n) ; LEGAL_LEGAL-NEXT: %rF = call @llvm.vp.lshr.nxv4i32( %i0, %i1, %m, i32 %n) ; LEGAL_LEGAL-NEXT: %r10 = call @llvm.vp.shl.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r11 = call @llvm.vp.merge.nxv4i32( %m, %i0, %i1, i32 %n) ; LEGAL_LEGAL-NEXT: ret void ; LEGAL_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) { @@ -415,6 +419,11 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; DISCARD_LEGAL-NEXT: %rE = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) ; DISCARD_LEGAL-NEXT: %rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) ; DISCARD_LEGAL-NEXT: %r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: [[NSPLATINS2:%.+]] = insertelement <8 x i32> poison, i32 %n, i64 0 +; DISCARD_LEGAL-NEXT: [[NSPLAT2:%.+]] = shufflevector <8 x i32> [[NSPLATINS2]], <8 x i32> poison, <8 x i32> zeroinitializer +; DISCARD_LEGAL-NEXT: [[EVLMASK2:%.+]] = icmp ult <8 x i32> , [[NSPLAT2]] +; DISCARD_LEGAL-NEXT: [[NEWMASK2:%.+]] = and <8 x i1> [[EVLMASK2]], %m +; DISCARD_LEGAL-NEXT: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> [[NEWMASK2]], <8 x i32> %i0, <8 x i32> %i1, i32 8) ; DISCARD_LEGAL-NEXT: ret void ; TODO compute vscale only once and use caching. @@ -431,7 +440,8 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; DISCARD_LEGAL: [[NEWM:%.+]] = and [[EVLM]], %m ; DISCARD_LEGAL: %r3 = call @llvm.vp.sdiv.nxv4i32( %i0, %i1, [[NEWM]], i32 %scalable_size{{.*}}) ; DISCARD_LEGAL-NOT: %{{.+}} = call @llvm.vp.{{.*}}, i32 %n) -; DISCARD_LEGAL: ret void +; DISCARD_LEGAL: %r11 = call @llvm.vp.merge.nxv4i32( %{{.*}}, %i0, %i1, i32 %scalable_size{{.*}}) +; DISCARD_LEGAL-NEXT: ret void ; DISCARD_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) { ; DISCARD_LEGAL-NEXT: [[NSPLATINS:%.+]] = insertelement <4 x i32> poison, i32 %n, i64 0 @@ -503,6 +513,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) ; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) ; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %{{.*}}, <8 x i32> %i0, <8 x i32> %i1, i32 8) ; CONVERT_LEGAL: ret void ; Similar to %evl discard, %mask legal but make sure the first VP intrinsic has a legal expansion @@ -513,6 +524,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; CONVERT_LEGAL-NEXT: %scalable_size = mul nuw i32 %vscale, 4 ; CONVERT_LEGAL-NEXT: %r0 = call @llvm.vp.add.nxv4i32( %i0, %i1, [[NEWM]], i32 %scalable_size) ; CONVERT_LEGAL-NOT: %{{.*}} = call @llvm.vp.{{.*}}, i32 %n) +; CONVERT_LEGAL: %r11 = call @llvm.vp.merge.nxv4i32( %{{.*}}, %i0, %i1, i32 %scalable_size{{.*}}) ; CONVERT_LEGAL: ret void ; CONVERT_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) {