Skip to content

Commit

Permalink
[AArch64] Convert concat(uhadd(a,b), uhadd(c,d)) to uhadd(concat(a,c)…
Browse files Browse the repository at this point in the history
…, concat(b,d)) (#80674)

We can convert concat(v4i16 uhadd(a,b), v4i16 uhadd(c,d)) to v8i16
uhadd(concat(a,c), concat(b,d)), which can lead to further
simplifications.
  • Loading branch information
Rin18 committed Feb 6, 2024
1 parent c302909 commit 7f292b8
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 104 deletions.
49 changes: 11 additions & 38 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18299,50 +18299,23 @@ static SDValue performConcatVectorsCombine(SDNode *N,
if (DCI.isBeforeLegalizeOps())
return SDValue();

// Optimise concat_vectors of two [us]avgceils or [us]avgfloors that use
// extracted subvectors from the same original vectors. Combine these into a
// single avg that operates on the two original vectors.
// avgceil is the target independant name for rhadd, avgfloor is a hadd.
// Example:
// (concat_vectors (v8i8 (avgceils (extract_subvector (v16i8 OpA, <0>),
// extract_subvector (v16i8 OpB, <0>))),
// (v8i8 (avgceils (extract_subvector (v16i8 OpA, <8>),
// extract_subvector (v16i8 OpB, <8>)))))
// ->
// (v16i8(avgceils(v16i8 OpA, v16i8 OpB)))
if (N->getNumOperands() == 2 && N0Opc == N1Opc &&
// Optimise concat_vectors of two [us]avgceils or [us]avgfloors with a 128-bit
// destination size, combine into an avg of two contacts of the source
// vectors. eg: concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c),
// concat(b, d))
if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() &&
(N0Opc == ISD::AVGCEILU || N0Opc == ISD::AVGCEILS ||
N0Opc == ISD::AVGFLOORU || N0Opc == ISD::AVGFLOORS)) {
N0Opc == ISD::AVGFLOORU || N0Opc == ISD::AVGFLOORS) &&
N0->hasOneUse() && N1->hasOneUse()) {
SDValue N00 = N0->getOperand(0);
SDValue N01 = N0->getOperand(1);
SDValue N10 = N1->getOperand(0);
SDValue N11 = N1->getOperand(1);

EVT N00VT = N00.getValueType();
EVT N10VT = N10.getValueType();

if (N00->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
N01->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
N10->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
N11->getOpcode() == ISD::EXTRACT_SUBVECTOR && N00VT == N10VT) {
SDValue N00Source = N00->getOperand(0);
SDValue N01Source = N01->getOperand(0);
SDValue N10Source = N10->getOperand(0);
SDValue N11Source = N11->getOperand(0);

if (N00Source == N10Source && N01Source == N11Source &&
N00Source.getValueType() == VT && N01Source.getValueType() == VT) {
assert(N0.getValueType() == N1.getValueType());

uint64_t N00Index = N00.getConstantOperandVal(1);
uint64_t N01Index = N01.getConstantOperandVal(1);
uint64_t N10Index = N10.getConstantOperandVal(1);
uint64_t N11Index = N11.getConstantOperandVal(1);

if (N00Index == N01Index && N10Index == N11Index && N00Index == 0 &&
N10Index == N00VT.getVectorNumElements())
return DAG.getNode(N0Opc, dl, VT, N00Source, N01Source);
}
if (!N00.isUndef() && !N01.isUndef() && !N10.isUndef() && !N11.isUndef()) {
SDValue Concat0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N00, N10);
SDValue Concat1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N01, N11);
return DAG.getNode(N0Opc, dl, VT, Concat0, Concat1);
}
}

Expand Down
93 changes: 27 additions & 66 deletions llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll
Original file line number Diff line number Diff line change
@@ -1,75 +1,36 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=aarch64 < %s | FileCheck %s

define i32 @lower_lshr(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c, <4 x i32> %d, <4 x i32> %e, <4 x i32> %f, <4 x i32> %g, <4 x i32> %h) {
; CHECK-LABEL: lower_lshr:
; CHECK: // %bb.0:
; CHECK-NEXT: addv s0, v0.4s
; CHECK-NEXT: addv s1, v1.4s
; CHECK-NEXT: addv s4, v4.4s
; CHECK-NEXT: addv s5, v5.4s
; CHECK-NEXT: addv s2, v2.4s
; CHECK-NEXT: addv s6, v6.4s
; CHECK-NEXT: mov v0.s[1], v1.s[0]
; CHECK-NEXT: addv s1, v3.4s
; CHECK-NEXT: addv s3, v7.4s
; CHECK-NEXT: mov v4.s[1], v5.s[0]
; CHECK-NEXT: mov v0.s[2], v2.s[0]
; CHECK-NEXT: mov v4.s[2], v6.s[0]
; CHECK-NEXT: mov v0.s[3], v1.s[0]
; CHECK-NEXT: mov v4.s[3], v3.s[0]
; CHECK-NEXT: xtn v1.4h, v0.4s
; CHECK-NEXT: shrn v0.4h, v0.4s, #16
; CHECK-NEXT: xtn v2.4h, v4.4s
; CHECK-NEXT: shrn v3.4h, v4.4s, #16
; CHECK-NEXT: uhadd v0.4h, v1.4h, v0.4h
; CHECK-NEXT: uhadd v1.4h, v2.4h, v3.4h
; CHECK-NEXT: mov v0.d[1], v1.d[0]
; CHECK-NEXT: uaddlv s0, v0.8h
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
%l87 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
%l174 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %b)
%l257 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %c)
%l340 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %d)
%l427 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %e)
%l514 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %f)
%l597 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %g)
%l680 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %h)
%l681 = insertelement <8 x i32> poison, i32 %l87, i32 0
%l682 = insertelement <8 x i32> %l681, i32 %l174, i32 1
%l683 = insertelement <8 x i32> %l682, i32 %l257, i32 2
%l684 = insertelement <8 x i32> %l683, i32 %l340, i32 3
%l685 = insertelement <8 x i32> %l684, i32 %l427, i32 4
%l686 = insertelement <8 x i32> %l685, i32 %l514, i32 5
%l687 = insertelement <8 x i32> %l686, i32 %l597, i32 6
%l688 = insertelement <8 x i32> %l687, i32 %l680, i32 7
%l689 = and <8 x i32> %l688, <i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535>
%l690 = lshr <8 x i32> %l688, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
%l691 = add nuw nsw <8 x i32> %l689, %l690
%l692 = lshr <8 x i32> %l691, <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
%l693 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %l692)
ret i32 %l693
}
declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)

define <16 x i8> @lower_trunc_16xi8(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e, i16 %f, i16 %g, i16 %h, i16 %i, i16 %j, i16 %k, i16 %l, i16 %m, i16 %n, i16 %o, i16 %p) {
; CHECK-LABEL: lower_trunc_16xi8:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov s0, w0
; CHECK-NEXT: add x8, sp, #56
; CHECK-NEXT: ld1r { v1.8h }, [x8]
; CHECK-NEXT: ldr h1, [sp]
; CHECK-NEXT: add x8, sp, #8
; CHECK-NEXT: ld1 { v1.h }[1], [x8]
; CHECK-NEXT: add x8, sp, #16
; CHECK-NEXT: mov v0.h[1], w1
; CHECK-NEXT: add v3.8h, v1.8h, v1.8h
; CHECK-NEXT: ld1 { v1.h }[2], [x8]
; CHECK-NEXT: add x8, sp, #24
; CHECK-NEXT: mov v0.h[2], w2
; CHECK-NEXT: ld1 { v1.h }[3], [x8]
; CHECK-NEXT: add x8, sp, #32
; CHECK-NEXT: mov v0.h[3], w3
; CHECK-NEXT: ld1 { v1.h }[4], [x8]
; CHECK-NEXT: add x8, sp, #40
; CHECK-NEXT: ld1 { v1.h }[5], [x8]
; CHECK-NEXT: add x8, sp, #48
; CHECK-NEXT: mov v0.h[4], w4
; CHECK-NEXT: ld1 { v1.h }[6], [x8]
; CHECK-NEXT: add x8, sp, #56
; CHECK-NEXT: mov v0.h[5], w5
; CHECK-NEXT: ld1 { v1.h }[7], [x8]
; CHECK-NEXT: mov v0.h[6], w6
; CHECK-NEXT: add v2.8h, v0.8h, v0.8h
; CHECK-NEXT: add v2.8h, v1.8h, v1.8h
; CHECK-NEXT: mov v0.h[7], w7
; CHECK-NEXT: add v3.8h, v0.8h, v0.8h
; CHECK-NEXT: uzp1 v0.16b, v0.16b, v1.16b
; CHECK-NEXT: uzp1 v1.16b, v2.16b, v3.16b
; CHECK-NEXT: uzp1 v1.16b, v3.16b, v2.16b
; CHECK-NEXT: eor v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%a1 = insertelement <16 x i16> poison, i16 %a, i16 0
Expand All @@ -80,14 +41,14 @@ define <16 x i8> @lower_trunc_16xi8(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e, i16
%f1 = insertelement <16 x i16> %e1, i16 %f, i16 5
%g1 = insertelement <16 x i16> %f1, i16 %g, i16 6
%h1 = insertelement <16 x i16> %g1, i16 %h, i16 7
%i1 = insertelement <16 x i16> %f1, i16 %i, i16 8
%j1 = insertelement <16 x i16> %g1, i16 %j, i16 9
%k1 = insertelement <16 x i16> %f1, i16 %k, i16 10
%l1 = insertelement <16 x i16> %g1, i16 %l, i16 11
%m1 = insertelement <16 x i16> %f1, i16 %m, i16 12
%n1 = insertelement <16 x i16> %g1, i16 %n, i16 13
%o1 = insertelement <16 x i16> %f1, i16 %o, i16 14
%p1 = insertelement <16 x i16> %g1, i16 %p, i16 15
%i1 = insertelement <16 x i16> %h1, i16 %i, i16 8
%j1 = insertelement <16 x i16> %i1, i16 %j, i16 9
%k1 = insertelement <16 x i16> %j1, i16 %k, i16 10
%l1 = insertelement <16 x i16> %k1, i16 %l, i16 11
%m1 = insertelement <16 x i16> %l1, i16 %m, i16 12
%n1 = insertelement <16 x i16> %m1, i16 %n, i16 13
%o1 = insertelement <16 x i16> %n1, i16 %o, i16 14
%p1 = insertelement <16 x i16> %o1, i16 %p, i16 15
%t = trunc <16 x i16> %p1 to <16 x i8>
%s = add <16 x i16> %p1, %p1
%t2 = trunc <16 x i16> %s to <16 x i8>
Expand Down
152 changes: 152 additions & 0 deletions llvm/test/CodeGen/AArch64/concat-vector-add-combine.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=aarch64 < %s | FileCheck %s

define i16 @combine_add_16xi16(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e, i16 %f, i16 %g, i16 %h, i16 %i, i16 %j, i16 %k, i16 %l, i16 %m, i16 %n, i16 %o, i16 %p) {
; CHECK-LABEL: combine_add_16xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov s0, w0
; CHECK-NEXT: ldr h1, [sp]
; CHECK-NEXT: add x8, sp, #8
; CHECK-NEXT: ld1 { v1.h }[1], [x8]
; CHECK-NEXT: add x8, sp, #16
; CHECK-NEXT: mov v0.h[1], w1
; CHECK-NEXT: ld1 { v1.h }[2], [x8]
; CHECK-NEXT: add x8, sp, #24
; CHECK-NEXT: mov v0.h[2], w2
; CHECK-NEXT: ld1 { v1.h }[3], [x8]
; CHECK-NEXT: add x8, sp, #32
; CHECK-NEXT: mov v0.h[3], w3
; CHECK-NEXT: ld1 { v1.h }[4], [x8]
; CHECK-NEXT: add x8, sp, #40
; CHECK-NEXT: ld1 { v1.h }[5], [x8]
; CHECK-NEXT: add x8, sp, #48
; CHECK-NEXT: mov v0.h[4], w4
; CHECK-NEXT: ld1 { v1.h }[6], [x8]
; CHECK-NEXT: add x8, sp, #56
; CHECK-NEXT: mov v0.h[5], w5
; CHECK-NEXT: ld1 { v1.h }[7], [x8]
; CHECK-NEXT: mov v0.h[6], w6
; CHECK-NEXT: mov v0.h[7], w7
; CHECK-NEXT: uzp2 v2.16b, v0.16b, v1.16b
; CHECK-NEXT: uzp1 v0.16b, v0.16b, v1.16b
; CHECK-NEXT: uhadd v0.16b, v0.16b, v2.16b
; CHECK-NEXT: uaddlv h0, v0.16b
; CHECK-NEXT: umov w0, v0.h[0]
; CHECK-NEXT: ret
%a1 = insertelement <16 x i16> poison, i16 %a, i16 0
%b1 = insertelement <16 x i16> %a1, i16 %b, i16 1
%c1 = insertelement <16 x i16> %b1, i16 %c, i16 2
%d1 = insertelement <16 x i16> %c1, i16 %d, i16 3
%e1 = insertelement <16 x i16> %d1, i16 %e, i16 4
%f1 = insertelement <16 x i16> %e1, i16 %f, i16 5
%g1 = insertelement <16 x i16> %f1, i16 %g, i16 6
%h1 = insertelement <16 x i16> %g1, i16 %h, i16 7
%i1 = insertelement <16 x i16> %h1, i16 %i, i16 8
%j1 = insertelement <16 x i16> %i1, i16 %j, i16 9
%k1 = insertelement <16 x i16> %j1, i16 %k, i16 10
%l1 = insertelement <16 x i16> %k1, i16 %l, i16 11
%m1 = insertelement <16 x i16> %l1, i16 %m, i16 12
%n1 = insertelement <16 x i16> %m1, i16 %n, i16 13
%o1 = insertelement <16 x i16> %n1, i16 %o, i16 14
%p1 = insertelement <16 x i16> %o1, i16 %p, i16 15
%x = and <16 x i16> %p1, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
%sh1 = lshr <16 x i16> %p1, <i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8>
%s = add nuw nsw <16 x i16> %x, %sh1
%sh2 = lshr <16 x i16> %s, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
%res = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %sh2)
ret i16 %res
}

define i32 @combine_add_8xi32(i32 %a, i32 %b, i32 %c, i32 %d, i32 %e, i32 %f, i32 %g, i32 %h) local_unnamed_addr #0 {
; CHECK-LABEL: combine_add_8xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov s0, w4
; CHECK-NEXT: fmov s1, w0
; CHECK-NEXT: mov v0.s[1], w5
; CHECK-NEXT: mov v1.s[1], w1
; CHECK-NEXT: mov v0.s[2], w6
; CHECK-NEXT: mov v1.s[2], w2
; CHECK-NEXT: mov v0.s[3], w7
; CHECK-NEXT: mov v1.s[3], w3
; CHECK-NEXT: uzp2 v2.8h, v1.8h, v0.8h
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
; CHECK-NEXT: uhadd v0.8h, v0.8h, v2.8h
; CHECK-NEXT: uaddlv s0, v0.8h
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
%a1 = insertelement <8 x i32> poison, i32 %a, i32 0
%b1 = insertelement <8 x i32> %a1, i32 %b, i32 1
%c1 = insertelement <8 x i32> %b1, i32 %c, i32 2
%d1 = insertelement <8 x i32> %c1, i32 %d, i32 3
%e1 = insertelement <8 x i32> %d1, i32 %e, i32 4
%f1 = insertelement <8 x i32> %e1, i32 %f, i32 5
%g1 = insertelement <8 x i32> %f1, i32 %g, i32 6
%h1 = insertelement <8 x i32> %g1, i32 %h, i32 7
%x = and <8 x i32> %h1, <i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535>
%sh1 = lshr <8 x i32> %h1, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
%s = add nuw nsw <8 x i32> %x, %sh1
%sh2 = lshr <8 x i32> %s, <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
%res = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %sh2)
ret i32 %res
}

define i32 @combine_undef_add_8xi32(i32 %a, i32 %b, i32 %c, i32 %d) local_unnamed_addr #0 {
; CHECK-LABEL: combine_undef_add_8xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov s1, w0
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: mov v1.s[1], w1
; CHECK-NEXT: uhadd v0.4h, v0.4h, v0.4h
; CHECK-NEXT: mov v1.s[2], w2
; CHECK-NEXT: mov v1.s[3], w3
; CHECK-NEXT: xtn v2.4h, v1.4s
; CHECK-NEXT: shrn v1.4h, v1.4s, #16
; CHECK-NEXT: uhadd v1.4h, v2.4h, v1.4h
; CHECK-NEXT: mov v1.d[1], v0.d[0]
; CHECK-NEXT: uaddlv s0, v1.8h
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
%a1 = insertelement <8 x i32> poison, i32 %a, i32 0
%b1 = insertelement <8 x i32> %a1, i32 %b, i32 1
%c1 = insertelement <8 x i32> %b1, i32 %c, i32 2
%d1 = insertelement <8 x i32> %c1, i32 %d, i32 3
%e1 = insertelement <8 x i32> %d1, i32 undef, i32 4
%f1 = insertelement <8 x i32> %e1, i32 undef, i32 5
%g1 = insertelement <8 x i32> %f1, i32 undef, i32 6
%h1 = insertelement <8 x i32> %g1, i32 undef, i32 7
%x = and <8 x i32> %h1, <i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535>
%sh1 = lshr <8 x i32> %h1, <i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16, i32 16>
%s = add nuw nsw <8 x i32> %x, %sh1
%sh2 = lshr <8 x i32> %s, <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
%res = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %sh2)
ret i32 %res
}

define i64 @combine_add_4xi64(i64 %a, i64 %b, i64 %c, i64 %d) local_unnamed_addr #0 {
; CHECK-LABEL: combine_add_4xi64:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov d0, x2
; CHECK-NEXT: fmov d1, x0
; CHECK-NEXT: mov v0.d[1], x3
; CHECK-NEXT: mov v1.d[1], x1
; CHECK-NEXT: uzp2 v2.4s, v1.4s, v0.4s
; CHECK-NEXT: uzp1 v0.4s, v1.4s, v0.4s
; CHECK-NEXT: uhadd v0.4s, v0.4s, v2.4s
; CHECK-NEXT: uaddlv d0, v0.4s
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: ret
%a1 = insertelement <4 x i64> poison, i64 %a, i64 0
%b1 = insertelement <4 x i64> %a1, i64 %b, i64 1
%c1 = insertelement <4 x i64> %b1, i64 %c, i64 2
%d1 = insertelement <4 x i64> %c1, i64 %d, i64 3
%x = and <4 x i64> %d1, <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>
%sh1 = lshr <4 x i64> %d1, <i64 32, i64 32, i64 32, i64 32>
%s = add nuw nsw <4 x i64> %x, %sh1
%sh2 = lshr <4 x i64> %s, <i64 1, i64 1, i64 1, i64 1>
%res = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %sh2)
ret i64 %res
}

declare i16 @llvm.vector.reduce.add.v16i16(<16 x i16>)
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
declare i64 @llvm.vector.reduce.add.v4i64(<4 x i64>)

0 comments on commit 7f292b8

Please sign in to comment.