diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp index 673c4f7434d6c..75e8fe4965025 100644 --- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -521,7 +521,8 @@ bool CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { // Only VP intrinsics can have an %evl parameter. Value *OldMaskParam = VPI.getMaskParam(); if (!OldMaskParam) { - assert(VPI.getIntrinsicID() == Intrinsic::vp_merge && + assert((VPI.getIntrinsicID() == Intrinsic::vp_merge || + VPI.getIntrinsicID() == Intrinsic::vp_select) && "Unexpected VP intrinsic without mask operand"); OldMaskParam = VPI.getArgOperand(0); } @@ -537,7 +538,8 @@ bool CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { ElementCount ElemCount = VPI.getStaticVectorLength(); Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); - if (VPI.getIntrinsicID() == Intrinsic::vp_merge) + if (VPI.getIntrinsicID() == Intrinsic::vp_merge || + VPI.getIntrinsicID() == Intrinsic::vp_select) VPI.setArgOperand(0, NewMaskParam); else VPI.setMaskParam(NewMaskParam); diff --git a/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll index fe7d725439060..0c3a7c681c4d0 100644 --- a/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll +++ b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-vp.ll @@ -69,6 +69,7 @@ define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i3 %rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) %r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n) + %r12 = call <8 x i32> @llvm.vp.select.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n) ret void } @@ -113,6 +114,7 @@ define void @test_vp_int_vscale( %i0, %i1, %rF = call @llvm.vp.lshr.nxv4i32( %i0, %i1, %m, i32 %n) %r10 = call @llvm.vp.shl.nxv4i32( %i0, %i1, %m, i32 %n) %r11 = call @llvm.vp.merge.nxv4i32( %m, %i0, %i1, i32 %n) + %r12 = call @llvm.vp.select.nxv4i32( %m, %i0, %i1, i32 %n) ret void } @@ -325,6 +327,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; LEGAL_LEGAL-NEXT: %rF = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) ; LEGAL_LEGAL-NEXT: %r10 = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) ; LEGAL_LEGAL-NEXT: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n) +; LEGAL_LEGAL-NEXT: %r12 = call <8 x i32> @llvm.vp.select.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 %n) ; LEGAL_LEGAL-NEXT: ret void ; LEGAL_LEGAL:define void @test_vp_int_vscale( %i0, %i1, %i2, %f3, %m, i32 %n) { @@ -346,6 +349,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; LEGAL_LEGAL-NEXT: %rF = call @llvm.vp.lshr.nxv4i32( %i0, %i1, %m, i32 %n) ; LEGAL_LEGAL-NEXT: %r10 = call @llvm.vp.shl.nxv4i32( %i0, %i1, %m, i32 %n) ; LEGAL_LEGAL-NEXT: %r11 = call @llvm.vp.merge.nxv4i32( %m, %i0, %i1, i32 %n) +; LEGAL_LEGAL-NEXT: %r12 = call @llvm.vp.select.nxv4i32( %m, %i0, %i1, i32 %n) ; LEGAL_LEGAL-NEXT: ret void ; LEGAL_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) { @@ -424,6 +428,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; DISCARD_LEGAL-NEXT: [[EVLMASK2:%.+]] = icmp ult <8 x i32> , [[NSPLAT2]] ; DISCARD_LEGAL-NEXT: [[NEWMASK2:%.+]] = and <8 x i1> [[EVLMASK2]], %m ; DISCARD_LEGAL-NEXT: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> [[NEWMASK2]], <8 x i32> %i0, <8 x i32> %i1, i32 8) +; DISCARD_LEGAL-NEXT: %r12 = call <8 x i32> @llvm.vp.select.v8i32(<8 x i1> %m, <8 x i32> %i0, <8 x i32> %i1, i32 8) ; DISCARD_LEGAL-NEXT: ret void ; TODO compute vscale only once and use caching. @@ -441,6 +446,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; DISCARD_LEGAL: %r3 = call @llvm.vp.sdiv.nxv4i32( %i0, %i1, [[NEWM]], i32 %scalable_size{{.*}}) ; DISCARD_LEGAL-NOT: %{{.+}} = call @llvm.vp.{{.*}}, i32 %n) ; DISCARD_LEGAL: %r11 = call @llvm.vp.merge.nxv4i32( %{{.*}}, %i0, %i1, i32 %scalable_size{{.*}}) +; DISCARD_LEGAL: %r12 = call @llvm.vp.select.nxv4i32( %m, %i0, %i1, i32 %scalable_size{{.*}}) ; DISCARD_LEGAL-NEXT: ret void ; DISCARD_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) { @@ -514,6 +520,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) ; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) ; CONVERT_LEGAL: %r11 = call <8 x i32> @llvm.vp.merge.v8i32(<8 x i1> %{{.*}}, <8 x i32> %i0, <8 x i32> %i1, i32 8) +; CONVERT_LEGAL: %r12 = call <8 x i32> @llvm.vp.select.v8i32(<8 x i1> %{{.*}}, <8 x i32> %i0, <8 x i32> %i1, i32 8) ; CONVERT_LEGAL: ret void ; Similar to %evl discard, %mask legal but make sure the first VP intrinsic has a legal expansion @@ -525,6 +532,7 @@ define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x ; CONVERT_LEGAL-NEXT: %r0 = call @llvm.vp.add.nxv4i32( %i0, %i1, [[NEWM]], i32 %scalable_size) ; CONVERT_LEGAL-NOT: %{{.*}} = call @llvm.vp.{{.*}}, i32 %n) ; CONVERT_LEGAL: %r11 = call @llvm.vp.merge.nxv4i32( %{{.*}}, %i0, %i1, i32 %scalable_size{{.*}}) +; CONVERT_LEGAL: %r12 = call @llvm.vp.select.nxv4i32( %{{.*}}, %i0, %i1, i32 %scalable_size{{.*}}) ; CONVERT_LEGAL: ret void ; CONVERT_LEGAL: define void @test_vp_reduce_int_v4(i32 %start, <4 x i32> %vi, <4 x i1> %m, i32 %n) {