Skip to content

Commit

Permalink
[DAGCombiner] add fold for vselect based on mask of signbit
Browse files Browse the repository at this point in the history
(X s< 0) ? Y : 0 --> (X s>> BW-1) & Y

We canonicalize to the icmp+select form in IR, and we already have this fold
for scalar select in SDAG, so I think it's an oversight that we don't have
the fold for vectors. It seems neutral for AArch64 and saves some instructions
on x86.

Whether we should also have the sibling folds for the inverse condition or
all-ones true value may depend on target-specific factors such as whether
there's an "and-not" instruction.

Differential Revision: https://reviews.llvm.org/D113212
  • Loading branch information
rotateright committed Nov 5, 2021
1 parent 1e7afa2 commit 4fc1fc4
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 165 deletions.
29 changes: 29 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -9544,6 +9544,31 @@ static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}

static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse() || !isNullOrNullSplat(N2))
return SDValue();

SDValue Cond0 = N0.getOperand(0);
SDValue Cond1 = N0.getOperand(1);
ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
if (VT != Cond0.getValueType())
return SDValue();

// (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
if (CC == ISD::SETLT && isNullOrNullSplat(Cond1)) {
SDLoc DL(N);
SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
return DAG.getNode(ISD::AND, DL, VT, Sra, N1);
}

return SDValue();
}

SDValue DAGCombiner::visitSELECT(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down Expand Up @@ -10234,6 +10259,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (SDValue V = foldVSelectOfConstants(N))
return V;

if (hasOperation(ISD::SRA, VT))
if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
return V;

return SDValue();
}

Expand Down
16 changes: 8 additions & 8 deletions llvm/test/CodeGen/AArch64/vselect-constants.ll
Expand Up @@ -196,8 +196,8 @@ define <4 x i32> @cmp_sel_0_or_1_vec(<4 x i32> %x, <4 x i32> %y) {
define <16 x i8> @signbit_mask_v16i8(<16 x i8> %a, <16 x i8> %b) {
; CHECK-LABEL: signbit_mask_v16i8:
; CHECK: // %bb.0:
; CHECK-NEXT: cmlt v0.16b, v0.16b, #0
; CHECK-NEXT: and v0.16b, v1.16b, v0.16b
; CHECK-NEXT: sshr v0.16b, v0.16b, #7
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%cond = icmp slt <16 x i8> %a, zeroinitializer
%r = select <16 x i1> %cond, <16 x i8> %b, <16 x i8> zeroinitializer
Expand All @@ -207,8 +207,8 @@ define <16 x i8> @signbit_mask_v16i8(<16 x i8> %a, <16 x i8> %b) {
define <8 x i16> @signbit_mask_v8i16(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: signbit_mask_v8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: cmlt v0.8h, v0.8h, #0
; CHECK-NEXT: and v0.16b, v1.16b, v0.16b
; CHECK-NEXT: sshr v0.8h, v0.8h, #15
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%cond = icmp slt <8 x i16> %a, zeroinitializer
%r = select <8 x i1> %cond, <8 x i16> %b, <8 x i16> zeroinitializer
Expand All @@ -218,8 +218,8 @@ define <8 x i16> @signbit_mask_v8i16(<8 x i16> %a, <8 x i16> %b) {
define <4 x i32> @signbit_mask_v4i32(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: signbit_mask_v4i32:
; CHECK: // %bb.0:
; CHECK-NEXT: cmlt v0.4s, v0.4s, #0
; CHECK-NEXT: and v0.16b, v1.16b, v0.16b
; CHECK-NEXT: sshr v0.4s, v0.4s, #31
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%cond = icmp slt <4 x i32> %a, zeroinitializer
%r = select <4 x i1> %cond, <4 x i32> %b, <4 x i32> zeroinitializer
Expand All @@ -229,8 +229,8 @@ define <4 x i32> @signbit_mask_v4i32(<4 x i32> %a, <4 x i32> %b) {
define <2 x i64> @signbit_mask_v2i64(<2 x i64> %a, <2 x i64> %b) {
; CHECK-LABEL: signbit_mask_v2i64:
; CHECK: // %bb.0:
; CHECK-NEXT: cmlt v0.2d, v0.2d, #0
; CHECK-NEXT: and v0.16b, v1.16b, v0.16b
; CHECK-NEXT: sshr v0.2d, v0.2d, #63
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%cond = icmp slt <2 x i64> %a, zeroinitializer
%r = select <2 x i1> %cond, <2 x i64> %b, <2 x i64> zeroinitializer
Expand Down
15 changes: 6 additions & 9 deletions llvm/test/CodeGen/Thumb2/mve-vselect-constants.ll
Expand Up @@ -137,9 +137,8 @@ define arm_aapcs_vfpcc <4 x i32> @cmp_sel_0_or_1_vec(<4 x i32> %x, <4 x i32> %y)
define arm_aapcs_vfpcc <16 x i8> @signbit_mask_v16i8(<16 x i8> %a, <16 x i8> %b) {
; CHECK-LABEL: signbit_mask_v16i8:
; CHECK: @ %bb.0:
; CHECK-NEXT: vmov.i32 q2, #0x0
; CHECK-NEXT: vcmp.s8 lt, q0, zr
; CHECK-NEXT: vpsel q0, q1, q2
; CHECK-NEXT: vshr.s8 q0, q0, #7
; CHECK-NEXT: vand q0, q0, q1
; CHECK-NEXT: bx lr
%cond = icmp slt <16 x i8> %a, zeroinitializer
%r = select <16 x i1> %cond, <16 x i8> %b, <16 x i8> zeroinitializer
Expand All @@ -149,9 +148,8 @@ define arm_aapcs_vfpcc <16 x i8> @signbit_mask_v16i8(<16 x i8> %a, <16 x i8> %b)
define arm_aapcs_vfpcc <8 x i16> @signbit_mask_v8i16(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: signbit_mask_v8i16:
; CHECK: @ %bb.0:
; CHECK-NEXT: vmov.i32 q2, #0x0
; CHECK-NEXT: vcmp.s16 lt, q0, zr
; CHECK-NEXT: vpsel q0, q1, q2
; CHECK-NEXT: vshr.s16 q0, q0, #15
; CHECK-NEXT: vand q0, q0, q1
; CHECK-NEXT: bx lr
%cond = icmp slt <8 x i16> %a, zeroinitializer
%r = select <8 x i1> %cond, <8 x i16> %b, <8 x i16> zeroinitializer
Expand All @@ -161,9 +159,8 @@ define arm_aapcs_vfpcc <8 x i16> @signbit_mask_v8i16(<8 x i16> %a, <8 x i16> %b)
define arm_aapcs_vfpcc <4 x i32> @signbit_mask_v4i32(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: signbit_mask_v4i32:
; CHECK: @ %bb.0:
; CHECK-NEXT: vmov.i32 q2, #0x0
; CHECK-NEXT: vcmp.s32 lt, q0, zr
; CHECK-NEXT: vpsel q0, q1, q2
; CHECK-NEXT: vshr.s32 q0, q0, #31
; CHECK-NEXT: vand q0, q0, q1
; CHECK-NEXT: bx lr
%cond = icmp slt <4 x i32> %a, zeroinitializer
%r = select <4 x i1> %cond, <4 x i32> %b, <4 x i32> zeroinitializer
Expand Down
40 changes: 12 additions & 28 deletions llvm/test/CodeGen/X86/avx512-logic.ll
Expand Up @@ -907,20 +907,12 @@ define <8 x i64> @ternlog_xor_and_mask(<8 x i64> %x, <8 x i64> %y) {
}

define <16 x i32> @ternlog_maskz_or_and_mask(<16 x i32> %x, <16 x i32> %y, <16 x i32> %mask) {
; KNL-LABEL: ternlog_maskz_or_and_mask:
; KNL: ## %bb.0:
; KNL-NEXT: vpxor %xmm3, %xmm3, %xmm3
; KNL-NEXT: vpcmpgtd %zmm2, %zmm3, %k1
; KNL-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm0
; KNL-NEXT: vpord %zmm1, %zmm0, %zmm0 {%k1} {z}
; KNL-NEXT: retq
;
; SKX-LABEL: ternlog_maskz_or_and_mask:
; SKX: ## %bb.0:
; SKX-NEXT: vpmovd2m %zmm2, %k1
; SKX-NEXT: vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm0
; SKX-NEXT: vorps %zmm1, %zmm0, %zmm0 {%k1} {z}
; SKX-NEXT: retq
; ALL-LABEL: ternlog_maskz_or_and_mask:
; ALL: ## %bb.0:
; ALL-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm3
; ALL-NEXT: vpsrad $31, %zmm2, %zmm0
; ALL-NEXT: vpternlogd $224, %zmm1, %zmm3, %zmm0
; ALL-NEXT: retq
%m = icmp slt <16 x i32> %mask, zeroinitializer
%a = and <16 x i32> %x, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
%b = or <16 x i32> %a, %y
Expand All @@ -929,20 +921,12 @@ define <16 x i32> @ternlog_maskz_or_and_mask(<16 x i32> %x, <16 x i32> %y, <16 x
}

define <8 x i64> @ternlog_maskz_xor_and_mask(<8 x i64> %x, <8 x i64> %y, <8 x i64> %mask) {
; KNL-LABEL: ternlog_maskz_xor_and_mask:
; KNL: ## %bb.0:
; KNL-NEXT: vpxor %xmm3, %xmm3, %xmm3
; KNL-NEXT: vpcmpgtq %zmm2, %zmm3, %k1
; KNL-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm0
; KNL-NEXT: vpxorq %zmm1, %zmm0, %zmm0 {%k1} {z}
; KNL-NEXT: retq
;
; SKX-LABEL: ternlog_maskz_xor_and_mask:
; SKX: ## %bb.0:
; SKX-NEXT: vpmovq2m %zmm2, %k1
; SKX-NEXT: vandpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm0
; SKX-NEXT: vxorpd %zmm1, %zmm0, %zmm0 {%k1} {z}
; SKX-NEXT: retq
; ALL-LABEL: ternlog_maskz_xor_and_mask:
; ALL: ## %bb.0:
; ALL-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm3
; ALL-NEXT: vpsraq $63, %zmm2, %zmm0
; ALL-NEXT: vpternlogq $96, %zmm1, %zmm3, %zmm0
; ALL-NEXT: retq
%m = icmp slt <8 x i64> %mask, zeroinitializer
%a = and <8 x i64> %x, <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>
%b = xor <8 x i64> %a, %y
Expand Down
80 changes: 24 additions & 56 deletions llvm/test/CodeGen/X86/avx512vl-logic.ll
Expand Up @@ -1077,20 +1077,12 @@ define <4 x i64> @ternlog_xor_and_mask_ymm(<4 x i64> %x, <4 x i64> %y) {
}

define <4 x i32> @ternlog_maskz_or_and_mask(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z, <4 x i32> %mask) {
; KNL-LABEL: ternlog_maskz_or_and_mask:
; KNL: ## %bb.0:
; KNL-NEXT: vpxor %xmm2, %xmm2, %xmm2
; KNL-NEXT: vpcmpgtd %xmm3, %xmm2, %k1
; KNL-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; KNL-NEXT: vpord %xmm1, %xmm0, %xmm0 {%k1} {z}
; KNL-NEXT: retq
;
; SKX-LABEL: ternlog_maskz_or_and_mask:
; SKX: ## %bb.0:
; SKX-NEXT: vpmovd2m %xmm3, %k1
; SKX-NEXT: vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; SKX-NEXT: vorps %xmm1, %xmm0, %xmm0 {%k1} {z}
; SKX-NEXT: retq
; CHECK-LABEL: ternlog_maskz_or_and_mask:
; CHECK: ## %bb.0:
; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2
; CHECK-NEXT: vpsrad $31, %xmm3, %xmm0
; CHECK-NEXT: vpternlogd $224, %xmm1, %xmm2, %xmm0
; CHECK-NEXT: retq
%m = icmp slt <4 x i32> %mask, zeroinitializer
%a = and <4 x i32> %x, <i32 255, i32 255, i32 255, i32 255>
%b = or <4 x i32> %a, %y
Expand All @@ -1099,20 +1091,12 @@ define <4 x i32> @ternlog_maskz_or_and_mask(<4 x i32> %x, <4 x i32> %y, <4 x i32
}

define <8 x i32> @ternlog_maskz_or_and_mask_ymm(<8 x i32> %x, <8 x i32> %y, <8 x i32> %mask) {
; KNL-LABEL: ternlog_maskz_or_and_mask_ymm:
; KNL: ## %bb.0:
; KNL-NEXT: vpxor %xmm3, %xmm3, %xmm3
; KNL-NEXT: vpcmpgtd %ymm2, %ymm3, %k1
; KNL-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
; KNL-NEXT: vpord %ymm1, %ymm0, %ymm0 {%k1} {z}
; KNL-NEXT: retq
;
; SKX-LABEL: ternlog_maskz_or_and_mask_ymm:
; SKX: ## %bb.0:
; SKX-NEXT: vpmovd2m %ymm2, %k1
; SKX-NEXT: vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
; SKX-NEXT: vorps %ymm1, %ymm0, %ymm0 {%k1} {z}
; SKX-NEXT: retq
; CHECK-LABEL: ternlog_maskz_or_and_mask_ymm:
; CHECK: ## %bb.0:
; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm3
; CHECK-NEXT: vpsrad $31, %ymm2, %ymm0
; CHECK-NEXT: vpternlogd $224, %ymm1, %ymm3, %ymm0
; CHECK-NEXT: retq
%m = icmp slt <8 x i32> %mask, zeroinitializer
%a = and <8 x i32> %x, <i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216>
%b = or <8 x i32> %a, %y
Expand All @@ -1121,20 +1105,12 @@ define <8 x i32> @ternlog_maskz_or_and_mask_ymm(<8 x i32> %x, <8 x i32> %y, <8 x
}

define <2 x i64> @ternlog_maskz_xor_and_mask(<2 x i64> %x, <2 x i64> %y, <2 x i64> %mask) {
; KNL-LABEL: ternlog_maskz_xor_and_mask:
; KNL: ## %bb.0:
; KNL-NEXT: vpxor %xmm3, %xmm3, %xmm3
; KNL-NEXT: vpcmpgtq %xmm2, %xmm3, %k1
; KNL-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; KNL-NEXT: vpxorq %xmm1, %xmm0, %xmm0 {%k1} {z}
; KNL-NEXT: retq
;
; SKX-LABEL: ternlog_maskz_xor_and_mask:
; SKX: ## %bb.0:
; SKX-NEXT: vpmovq2m %xmm2, %k1
; SKX-NEXT: vandpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; SKX-NEXT: vxorpd %xmm1, %xmm0, %xmm0 {%k1} {z}
; SKX-NEXT: retq
; CHECK-LABEL: ternlog_maskz_xor_and_mask:
; CHECK: ## %bb.0:
; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm3
; CHECK-NEXT: vpsraq $63, %xmm2, %xmm0
; CHECK-NEXT: vpternlogq $96, %xmm1, %xmm3, %xmm0
; CHECK-NEXT: retq
%m = icmp slt <2 x i64> %mask, zeroinitializer
%a = and <2 x i64> %x, <i64 1099511627775, i64 1099511627775>
%b = xor <2 x i64> %a, %y
Expand All @@ -1143,20 +1119,12 @@ define <2 x i64> @ternlog_maskz_xor_and_mask(<2 x i64> %x, <2 x i64> %y, <2 x i6
}

define <4 x i64> @ternlog_maskz_xor_and_mask_ymm(<4 x i64> %x, <4 x i64> %y, <4 x i64> %mask) {
; KNL-LABEL: ternlog_maskz_xor_and_mask_ymm:
; KNL: ## %bb.0:
; KNL-NEXT: vpxor %xmm3, %xmm3, %xmm3
; KNL-NEXT: vpcmpgtq %ymm2, %ymm3, %k1
; KNL-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
; KNL-NEXT: vpxorq %ymm1, %ymm0, %ymm0 {%k1} {z}
; KNL-NEXT: retq
;
; SKX-LABEL: ternlog_maskz_xor_and_mask_ymm:
; SKX: ## %bb.0:
; SKX-NEXT: vpmovq2m %ymm2, %k1
; SKX-NEXT: vandpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
; SKX-NEXT: vxorpd %ymm1, %ymm0, %ymm0 {%k1} {z}
; SKX-NEXT: retq
; CHECK-LABEL: ternlog_maskz_xor_and_mask_ymm:
; CHECK: ## %bb.0:
; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm3
; CHECK-NEXT: vpsraq $63, %ymm2, %ymm0
; CHECK-NEXT: vpternlogq $96, %ymm1, %ymm3, %ymm0
; CHECK-NEXT: retq
%m = icmp slt <4 x i64> %mask, zeroinitializer
%a = and <4 x i64> %x, <i64 72057594037927935, i64 72057594037927935, i64 72057594037927935, i64 72057594037927935>
%b = xor <4 x i64> %a, %y
Expand Down

0 comments on commit 4fc1fc4

Please sign in to comment.