Skip to content

Commit

Permalink
[InstCombine] Fix SSE2/AVX2 vector logical shift by constant
Browse files Browse the repository at this point in the history
This patch fixes the sse2/avx2 vector shift by constant instcombine call to correctly deal with the fact that the shift amount is formed from the entire lower 64-bit and not just the lowest element as it currently assumes.

e.g.

%1 = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> %v, <4 x i32> <i32 15, i32 15, i32 15, i32 15>)

In this case, (V)PSRLD doesn't perform a lshr by 15 but in fact attempts to shift by 64424509455 ((15 << 32) | 15) - giving a zero result.

In addition, this review also recognizes shift-by-zero from a ConstantAggregateZero type (PR23821).

Differential Revision: http://reviews.llvm.org/D11760

llvm-svn: 244341
  • Loading branch information
RKSimon committed Aug 7, 2015
1 parent 855ea0f commit 3815c16
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 100 deletions.
55 changes: 39 additions & 16 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,33 +200,56 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) {
static Value *SimplifyX86immshift(const IntrinsicInst &II,
InstCombiner::BuilderTy &Builder,
bool ShiftLeft) {
// Simplify if count is constant. To 0 if >= BitWidth,
// otherwise to shl/lshr.
auto CDV = dyn_cast<ConstantDataVector>(II.getArgOperand(1));
auto CInt = dyn_cast<ConstantInt>(II.getArgOperand(1));
if (!CDV && !CInt)
// Simplify if count is constant.
auto Arg1 = II.getArgOperand(1);
auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1);
auto CDV = dyn_cast<ConstantDataVector>(Arg1);
auto CInt = dyn_cast<ConstantInt>(Arg1);
if (!CAZ && !CDV && !CInt)
return nullptr;
ConstantInt *Count;
if (CDV)
Count = cast<ConstantInt>(CDV->getElementAsConstant(0));
else
Count = CInt;

APInt Count(64, 0);
if (CDV) {
// SSE2/AVX2 uses all the first 64-bits of the 128-bit vector
// operand to compute the shift amount.
auto VT = cast<VectorType>(CDV->getType());
unsigned BitWidth = VT->getElementType()->getPrimitiveSizeInBits();
assert((64 % BitWidth) == 0 && "Unexpected packed shift size");
unsigned NumSubElts = 64 / BitWidth;

// Concatenate the sub-elements to create the 64-bit value.
for (unsigned i = 0; i != NumSubElts; ++i) {
unsigned SubEltIdx = (NumSubElts - 1) - i;
auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx));
Count = Count.shl(BitWidth);
Count |= SubElt->getValue().zextOrTrunc(64);
}
}
else if (CInt)
Count = CInt->getValue();

auto Vec = II.getArgOperand(0);
auto VT = cast<VectorType>(Vec->getType());
auto SVT = VT->getElementType();
if (Count->getZExtValue() > (SVT->getPrimitiveSizeInBits() - 1))
return ConstantAggregateZero::get(VT);

unsigned VWidth = VT->getNumElements();
unsigned BitWidth = SVT->getPrimitiveSizeInBits();

// If shift-by-zero then just return the original value.
if (Count == 0)
return Vec;

// Handle cases when Shift >= BitWidth - just return zero.
if (Count.uge(BitWidth))
return ConstantAggregateZero::get(VT);

// Get a constant vector of the same type as the first operand.
auto VTCI = ConstantInt::get(VT->getElementType(), Count->getZExtValue());
auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth));
auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt);

if (ShiftLeft)
return Builder.CreateShl(Vec, Builder.CreateVectorSplat(VWidth, VTCI));
return Builder.CreateShl(Vec, ShiftVec);

return Builder.CreateLShr(Vec, Builder.CreateVectorSplat(VWidth, VTCI));
return Builder.CreateLShr(Vec, ShiftVec);
}

static Value *SimplifyX86extend(const IntrinsicInst &II,
Expand Down
Loading

0 comments on commit 3815c16

Please sign in to comment.