Skip to content

Commit

Permalink
[X86] Twist shuffle mask when fold HOP(SHUFFLE(X,Y),SHUFFLE(X,Y)) -> …
Browse files Browse the repository at this point in the history
…SHUFFLE(HOP(X,Y))

This patch fixes PR50823.

The shuffle mask should be twisted twice before gotten the correct one due to the difference between inner HOP and outer.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D104903
  • Loading branch information
phoebewang committed Jul 5, 2021
1 parent 681aa57 commit 9ab99f7
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 6 deletions.
7 changes: 4 additions & 3 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -43709,9 +43709,10 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
ShuffleVectorSDNode::commuteMask(ScaledMask1);
}
if ((Op00 == Op10) && (Op01 == Op11)) {
SmallVector<int, 4> ShuffleMask;
ShuffleMask.append(ScaledMask0.begin(), ScaledMask0.end());
ShuffleMask.append(ScaledMask1.begin(), ScaledMask1.end());
const int Map[4] = {0, 2, 1, 3};
SmallVector<int, 4> ShuffleMask(
{Map[ScaledMask0[0]], Map[ScaledMask1[0]], Map[ScaledMask0[1]],
Map[ScaledMask1[1]]});
MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f64 : MVT::v4i64;
SDValue Res = DAG.getNode(Opcode, DL, VT, DAG.getBitcast(SrcVT, Op00),
DAG.getBitcast(SrcVT, Op01));
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/X86/haddsub-undef.ll
Expand Up @@ -1169,7 +1169,7 @@ define <4 x double> @PR34724_add_v4f64_u123(<4 x double> %0, <4 x double> %1) {
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vextractf128 $1, %ymm0, %xmm0
; AVX512-FAST-NEXT: vhaddpd %ymm1, %ymm0, %ymm0
; AVX512-FAST-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,2,0,3]
; AVX512-FAST-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,0,1,3]
; AVX512-FAST-NEXT: retq
%3 = shufflevector <4 x double> %0, <4 x double> %1, <2 x i32> <i32 2, i32 4>
%4 = shufflevector <4 x double> %0, <4 x double> %1, <2 x i32> <i32 3, i32 5>
Expand Down Expand Up @@ -1270,7 +1270,7 @@ define <4 x double> @PR34724_add_v4f64_01u3(<4 x double> %0, <4 x double> %1) {
; AVX512-FAST-LABEL: PR34724_add_v4f64_01u3:
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vhaddpd %ymm1, %ymm0, %ymm0
; AVX512-FAST-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,3,1,3]
; AVX512-FAST-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,2,3,3]
; AVX512-FAST-NEXT: retq
%3 = shufflevector <4 x double> %0, <4 x double> undef, <2 x i32> <i32 0, i32 2>
%4 = shufflevector <4 x double> %0, <4 x double> undef, <2 x i32> <i32 1, i32 3>
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/X86/packss.ll
Expand Up @@ -370,7 +370,7 @@ define <32 x i8> @packsswb_icmp_zero_trunc_256(<16 x i16> %a0) {
; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX2-NEXT: vpcmpeqw %ymm1, %ymm0, %ymm0
; AVX2-NEXT: vpacksswb %ymm0, %ymm1, %ymm0
; AVX2-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,0,2,3]
; AVX2-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,1,0,3]
; AVX2-NEXT: ret{{[l|q]}}
%1 = icmp eq <16 x i16> %a0, zeroinitializer
%2 = sext <16 x i1> %1 to <16 x i16>
Expand Down
35 changes: 35 additions & 0 deletions llvm/test/CodeGen/X86/pr50823.ll
@@ -0,0 +1,35 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=x86_64-unknown -mcpu=core-avx2 | FileCheck %s

%v8_uniform_FVector3 = type { float, float, float }

declare <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float>, <8 x float>)

define void @foo(%v8_uniform_FVector3* %Out, float* %In, <8 x i32> %__mask) {
; CHECK-LABEL: foo:
; CHECK: # %bb.0: # %allocas
; CHECK-NEXT: vmovups (%rsi), %xmm0
; CHECK-NEXT: vhaddps 32(%rsi), %xmm0, %xmm0
; CHECK-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,0,1,1]
; CHECK-NEXT: vhaddps %ymm0, %ymm0, %ymm0
; CHECK-NEXT: vextractf128 $1, %ymm0, %xmm1
; CHECK-NEXT: vaddss %xmm1, %xmm0, %xmm0
; CHECK-NEXT: vmovss %xmm0, (%rdi)
; CHECK-NEXT: vzeroupper
; CHECK-NEXT: retq
allocas:
%ptr_cast_for_load = bitcast float* %In to <8 x float>*
%ptr_masked_load74 = load <8 x float>, <8 x float>* %ptr_cast_for_load, align 4
%ptr8096 = getelementptr float, float* %In, i64 8
%ptr_cast_for_load81 = bitcast float* %ptr8096 to <8 x float>*
%ptr80_masked_load82 = load <8 x float>, <8 x float>* %ptr_cast_for_load81, align 4
%ret_7.i.i = shufflevector <8 x float> %ptr_masked_load74, <8 x float> %ptr80_masked_load82, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 8, i32 9, i32 10, i32 11>
%Out_load19 = getelementptr %v8_uniform_FVector3, %v8_uniform_FVector3* %Out, i64 0, i32 0
%v1.i.i100 = tail call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> %ret_7.i.i, <8 x float> %ret_7.i.i)
%v2.i.i101 = tail call <8 x float> @llvm.x86.avx.hadd.ps.256(<8 x float> %v1.i.i100, <8 x float> %v1.i.i100)
%scalar1.i.i102 = extractelement <8 x float> %v2.i.i101, i32 0
%scalar2.i.i103 = extractelement <8 x float> %v2.i.i101, i32 4
%sum.i.i104 = fadd float %scalar1.i.i102, %scalar2.i.i103
store float %sum.i.i104, float* %Out_load19, align 4
ret void
}

0 comments on commit 9ab99f7

Please sign in to comment.