Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86][AVX] Fix handling of out-of-bounds shift amounts in AVX2 vector shift nodes #84426

Merged
merged 1 commit into from
Mar 15, 2024

Conversation

SahilPatidar
Copy link
Contributor

Resolve #83840

@SahilPatidar
Copy link
Contributor Author

@RKSimon

ShiftAmt == VT.getScalarSizeInBits() - 1) {
SDValue ShrAmtVal = UMinNode->getOperand(0);
SDLoc DL(N);
return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N0, ShrAmtVal);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The N->getOpcode() is just the regular ISD::SRA opcode - we need to use the X86ISD::VSRAV equivalent, but only if its safe to do so - we can check that with supportedVectorShiftWithImm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the following code:

unsigned int Opcode = N->getOpcode();
if (supportedVectorShiftWithImm(VT, Subtarget, X86ISD::VSRAV)) {
     Opcode = X86ISD::VSRAV;
}

But with this code, it crashes the above test.:

LLVM ERROR: Cannot select: t13: v4i32 = X86ISD::VSRAV t2, t4
  t2: v4i32,ch = CopyFromReg t0, Register:v4i32 %0
    t1: v4i32 = Register %0
  t4: v4i32,ch = CopyFromReg t0, Register:v4i32 %1
    t3: v4i32 = Register %1
In function: combine_vec_ashr_out_of_bound
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: build/bin/llc -mtriple=x86_64-unknown-unknown -mattr=+sse4.1
1.	Running pass 'Function Pass Manager' on module '<stdin>'.
2.	Running pass 'X86 DAG->DAG Instruction Selection' on function '@combine_vec_ashr_out_of_bound'
 #0 0x0000000107ec3c14 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105c37c14)
 #1 0x0000000107ec417c PrintStackTraceSignalHandler(void*) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105c3817c)
 #2 0x0000000107ec1d58 llvm::sys::RunSignalHandlers() (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105c35d58)
 #3 0x0000000107ec5db4 SignalHandler(int) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105c39db4)
 #4 0x000000018d945a24 (/usr/lib/system/libsystem_platform.dylib+0x18046da24)
 #5 0x000000018d915cc0 (/usr/lib/system/libsystem_pthread.dylib+0x18043dcc0)
 #6 0x000000018d821a40 (/usr/lib/system/libsystem_c.dylib+0x180349a40)
 #7 0x0000000107d53964 llvm::report_fatal_error(llvm::Twine const&, bool) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105ac7964)
 #8 0x0000000107bb2154 llvm::SmallVectorTemplateCommon<(anonymous namespace)::MatchScope, void>::back() (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105926154)
 #9 0x0000000107bae860 llvm::SelectionDAGISel::SelectCodeCommon(llvm::SDNode*, unsigned char const*, unsigned int) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105922860)
#10 0x00000001048779a4 (anonymous namespace)::X86DAGToDAGISel::SelectCode(llvm::SDNode*) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1025eb9a4)
#11 0x000000010486c1e4 (anonymous namespace)::X86DAGToDAGISel::Select(llvm::SDNode*) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1025e01e4)
#12 0x0000000107b9ff2c llvm::SelectionDAGISel::DoInstructionSelection() (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105913f2c)
#13 0x0000000107b9eb5c llvm::SelectionDAGISel::CodeGenAndEmitDAG() (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105912b5c)
#14 0x0000000107b9d708 llvm::SelectionDAGISel::SelectBasicBlock(llvm::ilist_iterator_w_bits<llvm::ilist_detail::node_options<llvm::Instruction, true, false, void, true>, false, true>, llvm::ilist_iterator_w_bits<llvm::ilist_detail::node_options<llvm::Instruction, true, false, void, true>, false, true>, bool&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x105911708)
#15 0x0000000107b9be54 llvm::SelectionDAGISel::SelectAllBasicBlocks(llvm::Function const&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x10590fe54)
#16 0x0000000107b993a0 llvm::SelectionDAGISel::runOnMachineFunction(llvm::MachineFunction&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x10590d3a0)
#17 0x000000010485cad4 (anonymous namespace)::X86DAGToDAGISel::runOnMachineFunction(llvm::MachineFunction&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1025d0ad4)
#18 0x0000000106031ce0 llvm::MachineFunctionPass::runOnFunction(llvm::Function&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x103da5ce0)
#19 0x0000000106b6bfb0 llvm::FPPassManager::runOnFunction(llvm::Function&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1048dffb0)
#20 0x0000000106b731e8 llvm::FPPassManager::runOnModule(llvm::Module&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1048e71e8)
#21 0x0000000106b6c844 (anonymous namespace)::MPPassManager::runOnModule(llvm::Module&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1048e0844)
#22 0x0000000106b6c3d0 llvm::legacy::PassManagerImpl::run(llvm::Module&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1048e03d0)
#23 0x0000000106b735e8 llvm::legacy::PassManager::run(llvm::Module&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1048e75e8)
#24 0x00000001022925a4 compileModule(char**, llvm::LLVMContext&) (/Users/sahilpatidar/Desktop/llvm/llvm-project/build/bin/llc+0x1000065a4)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be supportedVectorShiftWithImm(VT, Subtarget, ISD::SRA) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I misunderstood.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to add an assertion inside supportedVectorShiftWithImm that checks that Opcode is a ISD::SHL/SRA/SRL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to be clear Does that mean if a Subtarget natively supports ISD::SRA for vector shifts, we should use ISD::SRA?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because ISD::SRA nodes treat out of bounds shift amounts as undefined - we must replace it with a X86ISD::VSRAV node as we want the explicit behaviour explained in the issue - which means we can only perform this fold if X86ISD::VSRAV is legal (vXi32 on AVX2 - vXi64 on AVX512F - vXi16 on AVX512BW) - supportedVectorShiftWithImm should handle that for us.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, here is the modified code:

if (supportedVectorShiftWithImm(VT, Subtarget, ISD::SRA) &&
      UMinNode->getOpcode() == ISD::UMIN &&
      ISD::isConstantSplatVector(UMinNode->getOperand(1).getNode(), ShiftAmt) &&
      ShiftAmt == VT.getScalarSizeInBits() - 1) {
    SDValue ShrAmtVal = UMinNode->getOperand(0);
    SDLoc DL(N);
    return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
  }

With this code, it still crashes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry but that should have been supportedVectorVarShift :(

@@ -28927,6 +28927,9 @@ SDValue X86TargetLowering::LowerWin64_INT128_TO_FP(SDValue Op,
// supported by the Subtarget
static bool supportedVectorShiftWithImm(EVT VT, const X86Subtarget &Subtarget,
unsigned Opcode) {
assert(Opcode == ISD::SHL || Opcode == ISD::SRA || Opcode == ISD::SRL &&
"Unexpected Opcode!");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase as I already added these asserts at 7b90a67

Copy link

github-actions bot commented Mar 11, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

%1 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %y, <4 x i32> <i32 31, i32 31, i32 31, i32 31>)
%2 = ashr <4 x i32> %x, %1
ret <4 x i32> %2
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you add <8 x i16> and <4 x i64> test coverage - that should check for differences in the AVX2/AVX512 cases

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments, but this is ready to move from draft to general review

llvm/lib/Target/X86/X86ISelLowering.cpp Show resolved Hide resolved
llvm/lib/Target/X86/X86ISelLowering.cpp Show resolved Hide resolved
UMinNode->getOpcode() == ISD::UMIN &&
ISD::isConstantSplatVector(UMinNode->getOperand(1).getNode(), ShiftAmt) &&
ShiftAmt == VT.getScalarSizeInBits() - 1) {
SDValue ShrAmtVal = UMinNode->getOperand(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDValue ShrAmtVal = N1.getOperand(0);

; AVX512-NEXT: vpminuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
; AVX512-NEXT: vpsravw %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
%1 = tail call <8 x i16> @llvm.umin.v4i16(<8 x i16> %y, <8 x i16> <i16 31, i16 31, i16 31, i16 31, i16 31, i16 31, i16 31, i16 31>)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 31?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry I didn't see it need to be 15

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also it need to be changed v4i16 -> v8i16?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v8i16

llvm/test/CodeGen/X86/combine-sra.ll Outdated Show resolved Hide resolved
llvm/test/CodeGen/X86/combine-sra.ll Outdated Show resolved Hide resolved
llvm/test/CodeGen/X86/combine-sra.ll Outdated Show resolved Hide resolved
llvm/test/CodeGen/X86/combine-sra.ll Outdated Show resolved Hide resolved
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 12, 2024

@llvm/pr-subscribers-backend-x86

Author: None (SahilPatidar)

Changes

Resolve #83840


Full diff: https://github.com/llvm/llvm-project/pull/84426.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+10)
  • (modified) llvm/test/CodeGen/X86/combine-sra.ll (+273)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a74901958ac056..f625844d7a796c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47334,6 +47334,16 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
     return V;
 
+  APInt ShiftAmt;
+  if (supportedVectorVarShift(VT, Subtarget, ISD::SRA) &&
+      N1.getOpcode() == ISD::UMIN &&
+      ISD::isConstantSplatVector(N1.getOperand(1).getNode(), ShiftAmt) &&
+      ShiftAmt == VT.getScalarSizeInBits() - 1) {
+    SDValue ShrAmtVal = N1.getOperand(0);
+    SDLoc DL(N);
+    return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
+  }
+
   // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
   // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
   // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
diff --git a/llvm/test/CodeGen/X86/combine-sra.ll b/llvm/test/CodeGen/X86/combine-sra.ll
index 0675ced68d7a7a..655478b1c803a9 100644
--- a/llvm/test/CodeGen/X86/combine-sra.ll
+++ b/llvm/test/CodeGen/X86/combine-sra.ll
@@ -521,3 +521,276 @@ define <4 x i32> @combine_vec_ashr_positive_splat(<4 x i32> %x, <4 x i32> %y) {
   %2 = ashr <4 x i32> %1, <i32 10, i32 10, i32 10, i32 10>
   ret <4 x i32> %2
 }
+
+define <8 x i16> @combine_vec8i16_ashr_out_of_bound(<8 x i16> %x, <8 x i16> %y) {
+; SSE2-LABEL: combine_vec8i16_ashr_out_of_bound:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psubusw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
+; SSE2-NEXT:    psubw %xmm2, %xmm1
+; SSE2-NEXT:    psllw $12, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psraw $15, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm0, %xmm3
+; SSE2-NEXT:    psraw $8, %xmm0
+; SSE2-NEXT:    pand %xmm2, %xmm0
+; SSE2-NEXT:    por %xmm3, %xmm0
+; SSE2-NEXT:    paddw %xmm1, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psraw $15, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm0, %xmm3
+; SSE2-NEXT:    psraw $4, %xmm0
+; SSE2-NEXT:    pand %xmm2, %xmm0
+; SSE2-NEXT:    por %xmm3, %xmm0
+; SSE2-NEXT:    paddw %xmm1, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psraw $15, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm0, %xmm3
+; SSE2-NEXT:    psraw $2, %xmm0
+; SSE2-NEXT:    pand %xmm2, %xmm0
+; SSE2-NEXT:    por %xmm3, %xmm0
+; SSE2-NEXT:    paddw %xmm1, %xmm1
+; SSE2-NEXT:    psraw $15, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    pandn %xmm0, %xmm2
+; SSE2-NEXT:    psraw $1, %xmm0
+; SSE2-NEXT:    pand %xmm1, %xmm0
+; SSE2-NEXT:    por %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec8i16_ashr_out_of_bound:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    movdqa %xmm0, %xmm2
+; SSE41-NEXT:    pminuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    psllw $12, %xmm0
+; SSE41-NEXT:    psllw $4, %xmm1
+; SSE41-NEXT:    por %xmm1, %xmm0
+; SSE41-NEXT:    movdqa %xmm0, %xmm1
+; SSE41-NEXT:    paddw %xmm0, %xmm1
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $8, %xmm3
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $4, %xmm3
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $2, %xmm3
+; SSE41-NEXT:    paddw %xmm1, %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $1, %xmm3
+; SSE41-NEXT:    paddw %xmm1, %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm0
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: combine_vec8i16_ashr_out_of_bound:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpminuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX2-NEXT:    vpmovsxwd %xmm0, %ymm0
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
+; AVX2-NEXT:    vpsravd %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT:    vpackssdw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: combine_vec8i16_ashr_out_of_bound:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpsravw %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    retq
+  %1 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %y, <8 x i16> <i16 15, i16 15, i16 15, i16 15, i16 15, i16 15, i16 15, i16 15>)
+  %2 = ashr <8 x i16> %x, %1
+  ret <8 x i16> %2
+}
+
+define <4 x i32> @combine_vec4i32_ashr_out_of_bound(<4 x i32> %x, <4 x i32> %y) {
+; SSE2-LABEL: combine_vec4i32_ashr_out_of_bound:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa {{.*#+}} xmm2 = [2147483648,2147483648,2147483648,2147483648]
+; SSE2-NEXT:    pxor %xmm1, %xmm2
+; SSE2-NEXT:    pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm1, %xmm3
+; SSE2-NEXT:    psrld $27, %xmm2
+; SSE2-NEXT:    por %xmm3, %xmm2
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm2[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm3
+; SSE2-NEXT:    psrad %xmm1, %xmm3
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm4 = xmm2[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm1
+; SSE2-NEXT:    psrad %xmm4, %xmm1
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm3[0]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[2,3,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm2[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm4
+; SSE2-NEXT:    psrad %xmm3, %xmm4
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm2 = xmm2[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    psrad %xmm2, %xmm0
+; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm4[1]
+; SSE2-NEXT:    shufps {{.*#+}} xmm1 = xmm1[0,3],xmm0[0,3]
+; SSE2-NEXT:    movaps %xmm1, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec4i32_ashr_out_of_bound:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    pminud {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm2 = xmm1[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm3
+; SSE41-NEXT:    psrad %xmm2, %xmm3
+; SSE41-NEXT:    pshufd {{.*#+}} xmm2 = xmm1[2,3,2,3]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm2[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm5
+; SSE41-NEXT:    psrad %xmm4, %xmm5
+; SSE41-NEXT:    pblendw {{.*#+}} xmm5 = xmm3[0,1,2,3],xmm5[4,5,6,7]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm3
+; SSE41-NEXT:    psrad %xmm1, %xmm3
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm1 = xmm2[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    psrad %xmm1, %xmm0
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm3[0,1,2,3],xmm0[4,5,6,7]
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm5[2,3],xmm0[4,5],xmm5[6,7]
+; SSE41-NEXT:    retq
+;
+; AVX-LABEL: combine_vec4i32_ashr_out_of_bound:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsravd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %1 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %y, <4 x i32> <i32 31, i32 31, i32 31, i32 31>)
+  %2 = ashr <4 x i32> %x, %1
+  ret <4 x i32> %2
+}
+
+define <4 x i64> @combine_vec4i64_ashr_out_of_bound(<4 x i64> %x, <4 x i64> %y) {
+; SSE2-LABEL: combine_vec4i64_ashr_out_of_bound:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa {{.*#+}} xmm5 = [9223372039002259456,9223372039002259456]
+; SSE2-NEXT:    movdqa %xmm3, %xmm4
+; SSE2-NEXT:    pxor %xmm5, %xmm4
+; SSE2-NEXT:    pshufd {{.*#+}} xmm6 = xmm4[0,0,2,2]
+; SSE2-NEXT:    movdqa {{.*#+}} xmm7 = [2147483711,2147483711,2147483711,2147483711]
+; SSE2-NEXT:    movdqa %xmm7, %xmm8
+; SSE2-NEXT:    pcmpgtd %xmm6, %xmm8
+; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[1,1,3,3]
+; SSE2-NEXT:    pcmpeqd %xmm5, %xmm4
+; SSE2-NEXT:    pand %xmm8, %xmm4
+; SSE2-NEXT:    movdqa {{.*#+}} xmm6 = [63,63]
+; SSE2-NEXT:    pand %xmm4, %xmm3
+; SSE2-NEXT:    pandn %xmm6, %xmm4
+; SSE2-NEXT:    por %xmm3, %xmm4
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pxor %xmm5, %xmm3
+; SSE2-NEXT:    pshufd {{.*#+}} xmm8 = xmm3[0,0,2,2]
+; SSE2-NEXT:    pcmpgtd %xmm8, %xmm7
+; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm3[1,1,3,3]
+; SSE2-NEXT:    pcmpeqd %xmm5, %xmm3
+; SSE2-NEXT:    pand %xmm7, %xmm3
+; SSE2-NEXT:    pand %xmm3, %xmm2
+; SSE2-NEXT:    pandn %xmm6, %xmm3
+; SSE2-NEXT:    por %xmm2, %xmm3
+; SSE2-NEXT:    movdqa {{.*#+}} xmm2 = [9223372036854775808,9223372036854775808]
+; SSE2-NEXT:    movdqa %xmm2, %xmm5
+; SSE2-NEXT:    psrlq %xmm3, %xmm5
+; SSE2-NEXT:    pshufd {{.*#+}} xmm6 = xmm3[2,3,2,3]
+; SSE2-NEXT:    movdqa %xmm2, %xmm7
+; SSE2-NEXT:    psrlq %xmm6, %xmm7
+; SSE2-NEXT:    movsd {{.*#+}} xmm7 = xmm5[0],xmm7[1]
+; SSE2-NEXT:    movdqa %xmm0, %xmm5
+; SSE2-NEXT:    psrlq %xmm3, %xmm5
+; SSE2-NEXT:    psrlq %xmm6, %xmm0
+; SSE2-NEXT:    movsd {{.*#+}} xmm0 = xmm5[0],xmm0[1]
+; SSE2-NEXT:    xorpd %xmm7, %xmm0
+; SSE2-NEXT:    psubq %xmm7, %xmm0
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    psrlq %xmm4, %xmm3
+; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
+; SSE2-NEXT:    psrlq %xmm5, %xmm2
+; SSE2-NEXT:    movsd {{.*#+}} xmm2 = xmm3[0],xmm2[1]
+; SSE2-NEXT:    movdqa %xmm1, %xmm3
+; SSE2-NEXT:    psrlq %xmm4, %xmm3
+; SSE2-NEXT:    psrlq %xmm5, %xmm1
+; SSE2-NEXT:    movsd {{.*#+}} xmm1 = xmm3[0],xmm1[1]
+; SSE2-NEXT:    xorpd %xmm2, %xmm1
+; SSE2-NEXT:    psubq %xmm2, %xmm1
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec4i64_ashr_out_of_bound:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    movdqa {{.*#+}} xmm7 = [9223372039002259456,9223372039002259456]
+; SSE41-NEXT:    movdqa %xmm3, %xmm0
+; SSE41-NEXT:    pxor %xmm7, %xmm0
+; SSE41-NEXT:    movdqa {{.*#+}} xmm8 = [9223372039002259519,9223372039002259519]
+; SSE41-NEXT:    movdqa %xmm8, %xmm6
+; SSE41-NEXT:    pcmpeqd %xmm0, %xmm6
+; SSE41-NEXT:    pshufd {{.*#+}} xmm9 = xmm0[0,0,2,2]
+; SSE41-NEXT:    movdqa {{.*#+}} xmm5 = [2147483711,2147483711,2147483711,2147483711]
+; SSE41-NEXT:    movdqa %xmm5, %xmm0
+; SSE41-NEXT:    pcmpgtd %xmm9, %xmm0
+; SSE41-NEXT:    pand %xmm6, %xmm0
+; SSE41-NEXT:    movapd {{.*#+}} xmm9 = [63,63]
+; SSE41-NEXT:    movapd %xmm9, %xmm6
+; SSE41-NEXT:    blendvpd %xmm0, %xmm3, %xmm6
+; SSE41-NEXT:    pxor %xmm2, %xmm7
+; SSE41-NEXT:    pcmpeqd %xmm7, %xmm8
+; SSE41-NEXT:    pshufd {{.*#+}} xmm0 = xmm7[0,0,2,2]
+; SSE41-NEXT:    pcmpgtd %xmm0, %xmm5
+; SSE41-NEXT:    pand %xmm8, %xmm5
+; SSE41-NEXT:    movdqa %xmm5, %xmm0
+; SSE41-NEXT:    blendvpd %xmm0, %xmm2, %xmm9
+; SSE41-NEXT:    movdqa {{.*#+}} xmm0 = [9223372036854775808,9223372036854775808]
+; SSE41-NEXT:    movdqa %xmm0, %xmm2
+; SSE41-NEXT:    psrlq %xmm9, %xmm2
+; SSE41-NEXT:    pshufd {{.*#+}} xmm3 = xmm9[2,3,2,3]
+; SSE41-NEXT:    movdqa %xmm0, %xmm5
+; SSE41-NEXT:    psrlq %xmm3, %xmm5
+; SSE41-NEXT:    pblendw {{.*#+}} xmm5 = xmm2[0,1,2,3],xmm5[4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm4, %xmm2
+; SSE41-NEXT:    psrlq %xmm9, %xmm2
+; SSE41-NEXT:    psrlq %xmm3, %xmm4
+; SSE41-NEXT:    pblendw {{.*#+}} xmm4 = xmm2[0,1,2,3],xmm4[4,5,6,7]
+; SSE41-NEXT:    pxor %xmm5, %xmm4
+; SSE41-NEXT:    psubq %xmm5, %xmm4
+; SSE41-NEXT:    movdqa %xmm0, %xmm2
+; SSE41-NEXT:    psrlq %xmm6, %xmm2
+; SSE41-NEXT:    pshufd {{.*#+}} xmm3 = xmm6[2,3,2,3]
+; SSE41-NEXT:    psrlq %xmm3, %xmm0
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm2[0,1,2,3],xmm0[4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm1, %xmm2
+; SSE41-NEXT:    psrlq %xmm6, %xmm2
+; SSE41-NEXT:    psrlq %xmm3, %xmm1
+; SSE41-NEXT:    pblendw {{.*#+}} xmm1 = xmm2[0,1,2,3],xmm1[4,5,6,7]
+; SSE41-NEXT:    pxor %xmm0, %xmm1
+; SSE41-NEXT:    psubq %xmm0, %xmm1
+; SSE41-NEXT:    movdqa %xmm4, %xmm0
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: combine_vec4i64_ashr_out_of_bound:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpbroadcastq {{.*#+}} ymm2 = [9223372036854775808,9223372036854775808,9223372036854775808,9223372036854775808]
+; AVX2-NEXT:    vpxor %ymm2, %ymm1, %ymm3
+; AVX2-NEXT:    vpbroadcastq {{.*#+}} ymm4 = [9223372036854775870,9223372036854775870,9223372036854775870,9223372036854775870]
+; AVX2-NEXT:    vpcmpgtq %ymm4, %ymm3, %ymm3
+; AVX2-NEXT:    vbroadcastsd {{.*#+}} ymm4 = [63,63,63,63]
+; AVX2-NEXT:    vblendvpd %ymm3, %ymm4, %ymm1, %ymm1
+; AVX2-NEXT:    vpsrlvq %ymm1, %ymm2, %ymm2
+; AVX2-NEXT:    vpsrlvq %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpsubq %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: combine_vec4i64_ashr_out_of_bound:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpsravq %ymm1, %ymm0, %ymm0
+; AVX512-NEXT:    retq
+  %1 = tail call <4 x i64> @llvm.umin.v4i64(<4 x i64> %y, <4 x i64> <i64 63, i64 63, i64 63, i64 63>)
+  %2 = ashr <4 x i64> %x, %1
+  ret <4 x i64> %2
+}

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one final comment

@@ -521,3 +521,276 @@ define <4 x i32> @combine_vec_ashr_positive_splat(<4 x i32> %x, <4 x i32> %y) {
%2 = ashr <4 x i32> %1, <i32 10, i32 10, i32 10, i32 10>
ret <4 x i32> %2
}

define <8 x i16> @combine_vec8i16_ashr_out_of_bound(<8 x i16> %x, <8 x i16> %y) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe replace "out_of_bound" with "clamped" ?

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@RKSimon RKSimon merged commit 1865655 into llvm:main Mar 15, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[X86][AVX] Recognise out of bounds AVX2 shift amounts
3 participants