Skip to content

Commit

Permalink
[X86][SSE] Vectorized i8 and i16 shift operators
Browse files Browse the repository at this point in the history
This patch ensures that SHL/SRL/SRA shifts for i8 and i16 vectors avoid scalarization. It builds on the existing i8 SHL vectorized implementation of moving the shift bits up to the sign bit position and separating the 4, 2 & 1 bit shifts with several improvements:

1 - SSE41 targets can use (v)pblendvb directly with the sign bit instead of performing a comparison to feed into a VSELECT node.
2 - pre-SSE41 targets were masking + comparing with an 0x80 constant - we avoid this by using the fact that a set sign bit means a negative integer which can be compared against zero to then feed into VSELECT, avoiding the need for a constant mask (zero generation is much cheaper).
3 - SRA i8 needs to be unpacked to the upper byte of a i16 so that the i16 psraw instruction can be correctly used for sign extension - we have to do more work than for SHL/SRL but perf tests indicate that this is still beneficial.

The i16 implementation is similar but simpler than for i8 - we have to do 8, 4, 2 & 1 bit shifts but less shift masking is involved. SSE41 use of (v)pblendvb requires that the i16 shift amount is splatted to both bytes however.

Tested on SSE2, SSE41 and AVX machines.

Differential Revision: http://reviews.llvm.org/D9474

llvm-svn: 239509
  • Loading branch information
RKSimon committed Jun 11, 2015
1 parent 2e8ffa3 commit 5965680
Show file tree
Hide file tree
Showing 8 changed files with 706 additions and 1,381 deletions.
192 changes: 164 additions & 28 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -17012,36 +17012,111 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
}
}

if (VT == MVT::v16i8 && Op->getOpcode() == ISD::SHL) {
// Turn 'a' into a mask suitable for VSELECT: a = a << 5;
Op = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(5, dl, VT));

SDValue VSelM = DAG.getConstant(0x80, dl, VT);
SDValue OpVSel = DAG.getNode(ISD::AND, dl, VT, VSelM, Op);
OpVSel = DAG.getNode(X86ISD::PCMPEQ, dl, VT, OpVSel, VSelM);

// r = VSELECT(r, shl(r, 4), a);
SDValue M = DAG.getNode(ISD::SHL, dl, VT, R, DAG.getConstant(4, dl, VT));
R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);

// a += a
Op = DAG.getNode(ISD::ADD, dl, VT, Op, Op);
OpVSel = DAG.getNode(ISD::AND, dl, VT, VSelM, Op);
OpVSel = DAG.getNode(X86ISD::PCMPEQ, dl, VT, OpVSel, VSelM);

// r = VSELECT(r, shl(r, 2), a);
M = DAG.getNode(ISD::SHL, dl, VT, R, DAG.getConstant(2, dl, VT));
R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
if (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget->hasInt256())) {
MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
unsigned ShiftOpcode = Op->getOpcode();

auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
// On SSE41 targets we make use of the fact that VSELECT lowers
// to PBLENDVB which selects bytes based just on the sign bit.
if (Subtarget->hasSSE41()) {
V0 = DAG.getBitcast(VT, V0);
V1 = DAG.getBitcast(VT, V1);
Sel = DAG.getBitcast(VT, Sel);
return DAG.getBitcast(SelVT,
DAG.getNode(ISD::VSELECT, dl, VT, Sel, V0, V1));
}
// On pre-SSE41 targets we test for the sign bit by comparing to
// zero - a negative value will set all bits of the lanes to true
// and VSELECT uses that in its OR(AND(V0,C),AND(V1,~C)) lowering.
SDValue Z = getZeroVector(SelVT, Subtarget, DAG, dl);
SDValue C = DAG.getNode(X86ISD::PCMPGT, dl, SelVT, Z, Sel);
return DAG.getNode(ISD::VSELECT, dl, SelVT, C, V0, V1);
};

// a += a
Op = DAG.getNode(ISD::ADD, dl, VT, Op, Op);
OpVSel = DAG.getNode(ISD::AND, dl, VT, VSelM, Op);
OpVSel = DAG.getNode(X86ISD::PCMPEQ, dl, VT, OpVSel, VSelM);
// Turn 'a' into a mask suitable for VSELECT: a = a << 5;
// We can safely do this using i16 shifts as we're only interested in
// the 3 lower bits of each byte.
Amt = DAG.getBitcast(ExtVT, Amt);
Amt = DAG.getNode(ISD::SHL, dl, ExtVT, Amt, DAG.getConstant(5, dl, ExtVT));
Amt = DAG.getBitcast(VT, Amt);

if (Op->getOpcode() == ISD::SHL || Op->getOpcode() == ISD::SRL) {
// r = VSELECT(r, shift(r, 4), a);
SDValue M =
DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(4, dl, VT));
R = SignBitSelect(VT, Amt, M, R);

// a += a
Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);

// r = VSELECT(r, shift(r, 2), a);
M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(2, dl, VT));
R = SignBitSelect(VT, Amt, M, R);

// a += a
Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);

// return VSELECT(r, shift(r, 1), a);
M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(1, dl, VT));
R = SignBitSelect(VT, Amt, M, R);
return R;
}

// return VSELECT(r, r+r, a);
R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel,
DAG.getNode(ISD::ADD, dl, VT, R, R), R);
return R;
if (Op->getOpcode() == ISD::SRA) {
// For SRA we need to unpack each byte to the higher byte of a i16 vector
// so we can correctly sign extend. We don't care what happens to the
// lower byte.
SDValue ALo = DAG.getNode(X86ISD::UNPCKL, dl, VT, DAG.getUNDEF(VT), Amt);
SDValue AHi = DAG.getNode(X86ISD::UNPCKH, dl, VT, DAG.getUNDEF(VT), Amt);
SDValue RLo = DAG.getNode(X86ISD::UNPCKL, dl, VT, DAG.getUNDEF(VT), R);
SDValue RHi = DAG.getNode(X86ISD::UNPCKH, dl, VT, DAG.getUNDEF(VT), R);
ALo = DAG.getBitcast(ExtVT, ALo);
AHi = DAG.getBitcast(ExtVT, AHi);
RLo = DAG.getBitcast(ExtVT, RLo);
RHi = DAG.getBitcast(ExtVT, RHi);

// r = VSELECT(r, shift(r, 4), a);
SDValue MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo,
DAG.getConstant(4, dl, ExtVT));
SDValue MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi,
DAG.getConstant(4, dl, ExtVT));
RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);

// a += a
ALo = DAG.getNode(ISD::ADD, dl, ExtVT, ALo, ALo);
AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi);

// r = VSELECT(r, shift(r, 2), a);
MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo,
DAG.getConstant(2, dl, ExtVT));
MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi,
DAG.getConstant(2, dl, ExtVT));
RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);

// a += a
ALo = DAG.getNode(ISD::ADD, dl, ExtVT, ALo, ALo);
AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi);

// r = VSELECT(r, shift(r, 1), a);
MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo,
DAG.getConstant(1, dl, ExtVT));
MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi,
DAG.getConstant(1, dl, ExtVT));
RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);

// Logical shift the result back to the lower byte, leaving a zero upper
// byte
// meaning that we can safely pack with PACKUSWB.
RLo =
DAG.getNode(ISD::SRL, dl, ExtVT, RLo, DAG.getConstant(8, dl, ExtVT));
RHi =
DAG.getNode(ISD::SRL, dl, ExtVT, RHi, DAG.getConstant(8, dl, ExtVT));
return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi);
}
}

// It's worth extending once and using the v8i32 shifts for 16-bit types, but
Expand Down Expand Up @@ -17075,6 +17150,67 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi);
}

if (VT == MVT::v8i16) {
unsigned ShiftOpcode = Op->getOpcode();

auto SignBitSelect = [&](SDValue Sel, SDValue V0, SDValue V1) {
// On SSE41 targets we make use of the fact that VSELECT lowers
// to PBLENDVB which selects bytes based just on the sign bit.
if (Subtarget->hasSSE41()) {
MVT ExtVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() * 2);
V0 = DAG.getBitcast(ExtVT, V0);
V1 = DAG.getBitcast(ExtVT, V1);
Sel = DAG.getBitcast(ExtVT, Sel);
return DAG.getBitcast(
VT, DAG.getNode(ISD::VSELECT, dl, ExtVT, Sel, V0, V1));
}
// On pre-SSE41 targets we splat the sign bit - a negative value will
// set all bits of the lanes to true and VSELECT uses that in
// its OR(AND(V0,C),AND(V1,~C)) lowering.
SDValue C =
DAG.getNode(ISD::SRA, dl, VT, Sel, DAG.getConstant(15, dl, VT));
return DAG.getNode(ISD::VSELECT, dl, VT, C, V0, V1);
};

// Turn 'a' into a mask suitable for VSELECT: a = a << 12;
if (Subtarget->hasSSE41()) {
// On SSE41 targets we need to replicate the shift mask in both
// bytes for PBLENDVB.
Amt = DAG.getNode(
ISD::OR, dl, VT,
DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(4, dl, VT)),
DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(12, dl, VT)));
} else {
Amt = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(12, dl, VT));
}

// r = VSELECT(r, shift(r, 8), a);
SDValue M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(8, dl, VT));
R = SignBitSelect(Amt, M, R);

// a += a
Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);

// r = VSELECT(r, shift(r, 4), a);
M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(4, dl, VT));
R = SignBitSelect(Amt, M, R);

// a += a
Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);

// r = VSELECT(r, shift(r, 2), a);
M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(2, dl, VT));
R = SignBitSelect(Amt, M, R);

// a += a
Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);

// return VSELECT(r, shift(r, 1), a);
M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(1, dl, VT));
R = SignBitSelect(Amt, M, R);
return R;
}

// Decompose 256-bit shifts into smaller 128-bit shifts.
if (VT.is256BitVector()) {
unsigned NumElems = VT.getVectorNumElements();
Expand Down
20 changes: 10 additions & 10 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Expand Up @@ -153,13 +153,13 @@ unsigned X86TTIImpl::getArithmeticInstrCost(
{ ISD::SHL, MVT::v4i64, 1 },
{ ISD::SRL, MVT::v4i64, 1 },

{ ISD::SHL, MVT::v32i8, 42 }, // cmpeqb sequence.
{ ISD::SHL, MVT::v32i8, 11 }, // vpblendvb sequence.
{ ISD::SHL, MVT::v16i16, 10 }, // extend/vpsrlvd/pack sequence.

{ ISD::SRL, MVT::v32i8, 32*10 }, // Scalarized.
{ ISD::SRL, MVT::v32i8, 11 }, // vpblendvb sequence.
{ ISD::SRL, MVT::v16i16, 10 }, // extend/vpsrlvd/pack sequence.

{ ISD::SRA, MVT::v32i8, 32*10 }, // Scalarized.
{ ISD::SRA, MVT::v32i8, 24 }, // vpblendvb sequence.
{ ISD::SRA, MVT::v16i16, 10 }, // extend/vpsravd/pack sequence.
{ ISD::SRA, MVT::v4i64, 4*10 }, // Scalarized.

Expand Down Expand Up @@ -253,19 +253,19 @@ unsigned X86TTIImpl::getArithmeticInstrCost(
// to ISel. The cost model must return worst case assumptions because it is
// used for vectorization and we don't want to make vectorized code worse
// than scalar code.
{ ISD::SHL, MVT::v16i8, 30 }, // cmpeqb sequence.
{ ISD::SHL, MVT::v8i16, 8*10 }, // Scalarized.
{ ISD::SHL, MVT::v4i32, 2*5 }, // We optimized this using mul.
{ ISD::SHL, MVT::v16i8, 26 }, // cmpgtb sequence.
{ ISD::SHL, MVT::v8i16, 32 }, // cmpgtb sequence.
{ ISD::SHL, MVT::v4i32, 2*5 }, // We optimized this using mul.
{ ISD::SHL, MVT::v2i64, 2*10 }, // Scalarized.
{ ISD::SHL, MVT::v4i64, 4*10 }, // Scalarized.

{ ISD::SRL, MVT::v16i8, 16*10 }, // Scalarized.
{ ISD::SRL, MVT::v8i16, 8*10 }, // Scalarized.
{ ISD::SRL, MVT::v16i8, 26 }, // cmpgtb sequence.
{ ISD::SRL, MVT::v8i16, 32 }, // cmpgtb sequence.
{ ISD::SRL, MVT::v4i32, 4*10 }, // Scalarized.
{ ISD::SRL, MVT::v2i64, 2*10 }, // Scalarized.

{ ISD::SRA, MVT::v16i8, 16*10 }, // Scalarized.
{ ISD::SRA, MVT::v8i16, 8*10 }, // Scalarized.
{ ISD::SRA, MVT::v16i8, 54 }, // unpacked cmpgtb sequence.
{ ISD::SRA, MVT::v8i16, 32 }, // cmpgtb sequence.
{ ISD::SRA, MVT::v4i32, 4*10 }, // Scalarized.
{ ISD::SRA, MVT::v2i64, 2*10 }, // Scalarized.

Expand Down
24 changes: 12 additions & 12 deletions llvm/test/Analysis/CostModel/X86/testshiftashr.ll
Expand Up @@ -29,9 +29,9 @@ entry:
define %shifttype8i16 @shift8i16(%shifttype8i16 %a, %shifttype8i16 %b) {
entry:
; SSE2: shift8i16
; SSE2: cost of 80 {{.*}} ashr
; SSE2: cost of 32 {{.*}} ashr
; SSE2-CODEGEN: shift8i16
; SSE2-CODEGEN: sarw %cl
; SSE2-CODEGEN: psraw

%0 = ashr %shifttype8i16 %a , %b
ret %shifttype8i16 %0
Expand All @@ -41,9 +41,9 @@ entry:
define %shifttype16i16 @shift16i16(%shifttype16i16 %a, %shifttype16i16 %b) {
entry:
; SSE2: shift16i16
; SSE2: cost of 160 {{.*}} ashr
; SSE2: cost of 64 {{.*}} ashr
; SSE2-CODEGEN: shift16i16
; SSE2-CODEGEN: sarw %cl
; SSE2-CODEGEN: psraw

%0 = ashr %shifttype16i16 %a , %b
ret %shifttype16i16 %0
Expand All @@ -53,9 +53,9 @@ entry:
define %shifttype32i16 @shift32i16(%shifttype32i16 %a, %shifttype32i16 %b) {
entry:
; SSE2: shift32i16
; SSE2: cost of 320 {{.*}} ashr
; SSE2: cost of 128 {{.*}} ashr
; SSE2-CODEGEN: shift32i16
; SSE2-CODEGEN: sarw %cl
; SSE2-CODEGEN: psraw

%0 = ashr %shifttype32i16 %a , %b
ret %shifttype32i16 %0
Expand Down Expand Up @@ -209,9 +209,9 @@ entry:
define %shifttype8i8 @shift8i8(%shifttype8i8 %a, %shifttype8i8 %b) {
entry:
; SSE2: shift8i8
; SSE2: cost of 80 {{.*}} ashr
; SSE2: cost of 32 {{.*}} ashr
; SSE2-CODEGEN: shift8i8
; SSE2-CODEGEN: sarw %cl
; SSE2-CODEGEN: psraw

%0 = ashr %shifttype8i8 %a , %b
ret %shifttype8i8 %0
Expand All @@ -221,9 +221,9 @@ entry:
define %shifttype16i8 @shift16i8(%shifttype16i8 %a, %shifttype16i8 %b) {
entry:
; SSE2: shift16i8
; SSE2: cost of 160 {{.*}} ashr
; SSE2: cost of 54 {{.*}} ashr
; SSE2-CODEGEN: shift16i8
; SSE2-CODEGEN: sarb %cl
; SSE2-CODEGEN: psraw

%0 = ashr %shifttype16i8 %a , %b
ret %shifttype16i8 %0
Expand All @@ -233,9 +233,9 @@ entry:
define %shifttype32i8 @shift32i8(%shifttype32i8 %a, %shifttype32i8 %b) {
entry:
; SSE2: shift32i8
; SSE2: cost of 320 {{.*}} ashr
; SSE2: cost of 108 {{.*}} ashr
; SSE2-CODEGEN: shift32i8
; SSE2-CODEGEN: sarb %cl
; SSE2-CODEGEN: psraw

%0 = ashr %shifttype32i8 %a , %b
ret %shifttype32i8 %0
Expand Down
24 changes: 12 additions & 12 deletions llvm/test/Analysis/CostModel/X86/testshiftlshr.ll
Expand Up @@ -29,9 +29,9 @@ entry:
define %shifttype8i16 @shift8i16(%shifttype8i16 %a, %shifttype8i16 %b) {
entry:
; SSE2: shift8i16
; SSE2: cost of 80 {{.*}} lshr
; SSE2: cost of 32 {{.*}} lshr
; SSE2-CODEGEN: shift8i16
; SSE2-CODEGEN: shrl %cl
; SSE2-CODEGEN: psrlw

%0 = lshr %shifttype8i16 %a , %b
ret %shifttype8i16 %0
Expand All @@ -41,9 +41,9 @@ entry:
define %shifttype16i16 @shift16i16(%shifttype16i16 %a, %shifttype16i16 %b) {
entry:
; SSE2: shift16i16
; SSE2: cost of 160 {{.*}} lshr
; SSE2: cost of 64 {{.*}} lshr
; SSE2-CODEGEN: shift16i16
; SSE2-CODEGEN: shrl %cl
; SSE2-CODEGEN: psrlw

%0 = lshr %shifttype16i16 %a , %b
ret %shifttype16i16 %0
Expand All @@ -53,9 +53,9 @@ entry:
define %shifttype32i16 @shift32i16(%shifttype32i16 %a, %shifttype32i16 %b) {
entry:
; SSE2: shift32i16
; SSE2: cost of 320 {{.*}} lshr
; SSE2: cost of 128 {{.*}} lshr
; SSE2-CODEGEN: shift32i16
; SSE2-CODEGEN: shrl %cl
; SSE2-CODEGEN: psrlw

%0 = lshr %shifttype32i16 %a , %b
ret %shifttype32i16 %0
Expand Down Expand Up @@ -209,9 +209,9 @@ entry:
define %shifttype8i8 @shift8i8(%shifttype8i8 %a, %shifttype8i8 %b) {
entry:
; SSE2: shift8i8
; SSE2: cost of 80 {{.*}} lshr
; SSE2: cost of 32 {{.*}} lshr
; SSE2-CODEGEN: shift8i8
; SSE2-CODEGEN: shrl %cl
; SSE2-CODEGEN: psrlw

%0 = lshr %shifttype8i8 %a , %b
ret %shifttype8i8 %0
Expand All @@ -221,9 +221,9 @@ entry:
define %shifttype16i8 @shift16i8(%shifttype16i8 %a, %shifttype16i8 %b) {
entry:
; SSE2: shift16i8
; SSE2: cost of 160 {{.*}} lshr
; SSE2: cost of 26 {{.*}} lshr
; SSE2-CODEGEN: shift16i8
; SSE2-CODEGEN: shrb %cl
; SSE2-CODEGEN: psrlw

%0 = lshr %shifttype16i8 %a , %b
ret %shifttype16i8 %0
Expand All @@ -233,9 +233,9 @@ entry:
define %shifttype32i8 @shift32i8(%shifttype32i8 %a, %shifttype32i8 %b) {
entry:
; SSE2: shift32i8
; SSE2: cost of 320 {{.*}} lshr
; SSE2: cost of 52 {{.*}} lshr
; SSE2-CODEGEN: shift32i8
; SSE2-CODEGEN: shrb %cl
; SSE2-CODEGEN: psrlw

%0 = lshr %shifttype32i8 %a , %b
ret %shifttype32i8 %0
Expand Down

0 comments on commit 5965680

Please sign in to comment.