Skip to content

Commit

Permalink
[X86][AVX] Fix handling of out-of-bounds shift amounts in AVX2 vector…
Browse files Browse the repository at this point in the history
… logical shift nodes #83840 (#86922)

Resolve #83840
  • Loading branch information
SahilPatidar committed Jul 12, 2024
1 parent 4a02b0b commit c0a5bf8
Show file tree
Hide file tree
Showing 3 changed files with 500 additions and 2 deletions.
69 changes: 67 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46181,6 +46181,32 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
}
}

// Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
// with out-of-bounds clamping.

// Unlike general shift instructions (SHL/SRL), AVX2's VSHLV/VSRLV handle
// shift amounts exceeding the element bitwidth. VSHLV/VSRLV clamps the amount
// to bitwidth-1 for unsigned shifts, effectively performing a maximum left
// shift of bitwidth-1 positions. and returns zero for unsigned right shifts
// exceeding bitwidth-1.
if (N->getOpcode() == ISD::VSELECT &&
(LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
supportedVectorVarShift(VT, Subtarget, LHS.getOpcode())) {
APInt SV;
// fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt)
// fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt)
if (Cond.getOpcode() == ISD::SETCC &&
Cond.getOperand(0) == LHS.getOperand(1) &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
SV == VT.getScalarSizeInBits()) {
return DAG.getNode(LHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
: X86ISD::VSHLV,
DL, VT, LHS.getOperand(0), LHS.getOperand(1));
}
}

// Early exit check
if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
return SDValue();
Expand Down Expand Up @@ -47991,12 +48017,32 @@ static SDValue combineShiftToPMULH(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ExtOpc, DL, VT, Mulh);
}

static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG) {
static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
EVT VT = N0.getValueType();

// Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
// with out-of-bounds clamping.
if (N0.getOpcode() == ISD::VSELECT &&
supportedVectorVarShift(VT, Subtarget, ISD::SHL)) {
SDValue Cond = N0.getOperand(0);
SDValue N00 = N0.getOperand(1);
SDValue N01 = N0.getOperand(2);
APInt SV;
// fold shl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psllv(x,amt)
if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
SV == VT.getScalarSizeInBits()) {
SDLoc DL(N);
return DAG.getNode(X86ISD::VSHLV, DL, VT, N00, N1);
}
}

// fold (shl (and (setcc_c), c1), c2) -> (and setcc_c, (c1 << c2))
// since the result of setcc_c is all zero's or all ones.
if (VT.isInteger() && !VT.isVector() &&
Expand Down Expand Up @@ -48115,6 +48161,25 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
return V;

// Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
// with out-of-bounds clamping.
if (N0.getOpcode() == ISD::VSELECT &&
supportedVectorVarShift(VT, Subtarget, ISD::SRL)) {
SDValue Cond = N0.getOperand(0);
SDValue N00 = N0.getOperand(1);
SDValue N01 = N0.getOperand(2);
APInt SV;
// fold srl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psrlv(x,amt)
if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
SV == VT.getScalarSizeInBits()) {
SDLoc DL(N);
return DAG.getNode(X86ISD::VSRLV, DL, VT, N00, N1);
}
}

// Only do this on the last DAG combine as it can interfere with other
// combines.
if (!DCI.isAfterLegalizeDAG())
Expand Down Expand Up @@ -57613,7 +57678,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::SBB: return combineSBB(N, DAG);
case X86ISD::ADC: return combineADC(N, DAG, DCI);
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
case ISD::SHL: return combineShiftLeft(N, DAG);
case ISD::SHL: return combineShiftLeft(N, DAG, Subtarget);
case ISD::SRA: return combineShiftRightArithmetic(N, DAG, Subtarget);
case ISD::SRL: return combineShiftRightLogical(N, DAG, DCI, Subtarget);
case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
Expand Down
184 changes: 184 additions & 0 deletions llvm/test/CodeGen/X86/combine-shl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -929,3 +929,187 @@ define <4 x i32> @combine_vec_add_shuffle_shl(<4 x i32> %a0) {
%3 = add <4 x i32> %2, <i32 3, i32 3, i32 3, i32 3>
ret <4 x i32> %3
}

define <4 x i32> @combine_vec_shl_clamped1(<4 x i32> %sh, <4 x i32> %amt) {
; SSE2-LABEL: combine_vec_shl_clamped1:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa {{.*#+}} xmm2 = [2147483648,2147483648,2147483648,2147483648]
; SSE2-NEXT: pxor %xmm1, %xmm2
; SSE2-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
; SSE2-NEXT: pslld $23, %xmm1
; SSE2-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE2-NEXT: cvttps2dq %xmm1, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm1, %xmm0
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm3, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: pandn %xmm0, %xmm2
; SSE2-NEXT: movdqa %xmm2, %xmm0
; SSE2-NEXT: retq
;
; SSE41-LABEL: combine_vec_shl_clamped1:
; SSE41: # %bb.0:
; SSE41-NEXT: pmovsxbd {{.*#+}} xmm2 = [31,31,31,31]
; SSE41-NEXT: pminud %xmm1, %xmm2
; SSE41-NEXT: pcmpeqd %xmm1, %xmm2
; SSE41-NEXT: pslld $23, %xmm1
; SSE41-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE41-NEXT: cvttps2dq %xmm1, %xmm1
; SSE41-NEXT: pmulld %xmm1, %xmm0
; SSE41-NEXT: pand %xmm2, %xmm0
; SSE41-NEXT: retq
;
; AVX-LABEL: combine_vec_shl_clamped1:
; AVX: # %bb.0:
; AVX-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
; AVX-NEXT: retq
%cmp.i = icmp ult <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
%shl = shl <4 x i32> %sh, %amt
%1 = select <4 x i1> %cmp.i, <4 x i32> %shl, <4 x i32> zeroinitializer
ret <4 x i32> %1
}

define <4 x i32> @combine_vec_shl_clamped2(<4 x i32> %sh, <4 x i32> %amt) {
; SSE2-LABEL: combine_vec_shl_clamped2:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa {{.*#+}} xmm2 = [2147483648,2147483648,2147483648,2147483648]
; SSE2-NEXT: pxor %xmm1, %xmm2
; SSE2-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
; SSE2-NEXT: pandn %xmm0, %xmm2
; SSE2-NEXT: pslld $23, %xmm1
; SSE2-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE2-NEXT: cvttps2dq %xmm1, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm2[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm1, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[0,2,2,3]
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm3, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: retq
;
; SSE41-LABEL: combine_vec_shl_clamped2:
; SSE41: # %bb.0:
; SSE41-NEXT: pmovsxbd {{.*#+}} xmm2 = [31,31,31,31]
; SSE41-NEXT: pminud %xmm1, %xmm2
; SSE41-NEXT: pcmpeqd %xmm1, %xmm2
; SSE41-NEXT: pand %xmm2, %xmm0
; SSE41-NEXT: pslld $23, %xmm1
; SSE41-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE41-NEXT: cvttps2dq %xmm1, %xmm1
; SSE41-NEXT: pmulld %xmm1, %xmm0
; SSE41-NEXT: retq
;
; AVX-LABEL: combine_vec_shl_clamped2:
; AVX: # %bb.0:
; AVX-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
; AVX-NEXT: retq
%cmp.i = icmp ult <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
%1 = select <4 x i1> %cmp.i, <4 x i32> %sh, <4 x i32> zeroinitializer
%shl = shl <4 x i32> %1, %amt
ret <4 x i32> %shl
}

define <4 x i32> @combine_vec_shl_commuted_clamped(<4 x i32> %sh, <4 x i32> %amt) {
; SSE2-LABEL: combine_vec_shl_commuted_clamped:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa {{.*#+}} xmm2 = [2147483648,2147483648,2147483648,2147483648]
; SSE2-NEXT: pxor %xmm1, %xmm2
; SSE2-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
; SSE2-NEXT: pandn %xmm0, %xmm2
; SSE2-NEXT: pslld $23, %xmm1
; SSE2-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE2-NEXT: cvttps2dq %xmm1, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm2[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm1, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[0,2,2,3]
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm3, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: retq
;
; SSE41-LABEL: combine_vec_shl_commuted_clamped:
; SSE41: # %bb.0:
; SSE41-NEXT: pmovsxbd {{.*#+}} xmm2 = [31,31,31,31]
; SSE41-NEXT: pminud %xmm1, %xmm2
; SSE41-NEXT: pcmpeqd %xmm1, %xmm2
; SSE41-NEXT: pand %xmm2, %xmm0
; SSE41-NEXT: pslld $23, %xmm1
; SSE41-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE41-NEXT: cvttps2dq %xmm1, %xmm1
; SSE41-NEXT: pmulld %xmm1, %xmm0
; SSE41-NEXT: retq
;
; AVX2-LABEL: combine_vec_shl_commuted_clamped:
; AVX2: # %bb.0:
; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm2 = [31,31,31,31]
; AVX2-NEXT: vpminud %xmm2, %xmm1, %xmm2
; AVX2-NEXT: vpcmpeqd %xmm2, %xmm1, %xmm2
; AVX2-NEXT: vpand %xmm0, %xmm2, %xmm0
; AVX2-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: combine_vec_shl_commuted_clamped:
; AVX512: # %bb.0:
; AVX512-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
%cmp.i = icmp uge <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
%1 = select <4 x i1> %cmp.i, <4 x i32> zeroinitializer, <4 x i32> %sh
%shl = shl <4 x i32> %1, %amt
ret <4 x i32> %shl
}

define <4 x i32> @combine_vec_shl_commuted_clamped1(<4 x i32> %sh, <4 x i32> %amt) {
; SSE2-LABEL: combine_vec_shl_commuted_clamped1:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa %xmm1, %xmm2
; SSE2-NEXT: pslld $23, %xmm2
; SSE2-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
; SSE2-NEXT: cvttps2dq %xmm2, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm2, %xmm0
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm2[1,1,3,3]
; SSE2-NEXT: pmuludq %xmm3, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3]
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
; SSE2-NEXT: pxor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE2-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE2-NEXT: pandn %xmm0, %xmm1
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: retq
;
; SSE41-LABEL: combine_vec_shl_commuted_clamped1:
; SSE41: # %bb.0:
; SSE41-NEXT: pmovsxbd {{.*#+}} xmm2 = [31,31,31,31]
; SSE41-NEXT: pminud %xmm1, %xmm2
; SSE41-NEXT: pcmpeqd %xmm1, %xmm2
; SSE41-NEXT: pslld $23, %xmm1
; SSE41-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; SSE41-NEXT: cvttps2dq %xmm1, %xmm1
; SSE41-NEXT: pmulld %xmm1, %xmm0
; SSE41-NEXT: pand %xmm2, %xmm0
; SSE41-NEXT: retq
;
; AVX2-LABEL: combine_vec_shl_commuted_clamped1:
; AVX2: # %bb.0:
; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm2 = [31,31,31,31]
; AVX2-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
; AVX2-NEXT: vpminud %xmm2, %xmm1, %xmm2
; AVX2-NEXT: vpcmpeqd %xmm2, %xmm1, %xmm1
; AVX2-NEXT: vpand %xmm0, %xmm1, %xmm0
; AVX2-NEXT: retq
;
; AVX512-LABEL: combine_vec_shl_commuted_clamped1:
; AVX512: # %bb.0:
; AVX512-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
%cmp.i = icmp uge <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
%shl = shl <4 x i32> %sh, %amt
%1 = select <4 x i1> %cmp.i, <4 x i32> zeroinitializer, <4 x i32> %shl
ret <4 x i32> %1
}
Loading

0 comments on commit c0a5bf8

Please sign in to comment.