Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22790,6 +22790,62 @@ static SDValue combineSVEBitSel(unsigned IID, SDNode *N, SelectionDAG &DAG) {
}
}

/// Optimize patterns where we insert zeros into vector lanes before faddv.
static SDValue tryCombineFADDVWithZero(SDNode *N, SelectionDAG &DAG) {
assert(getIntrinsicID(N) == Intrinsic::aarch64_neon_faddv &&
"Expected NEON faddv intrinsic");
SDLoc DL(N);
SDValue Vec = N->getOperand(1);
EVT VT = Vec.getValueType();
EVT EltVT = VT.getVectorElementType();
unsigned NumElts = VT.getVectorNumElements();
APInt DemandedElts = APInt::getAllOnes(NumElts);
APInt KnownZeroElts = DAG.computeVectorKnownZeroElements(Vec, DemandedElts);
unsigned NumZeroElts = KnownZeroElts.popcount();
// No element is known to be +0.0, fallback to the TableGen pattern.
if (NumZeroElts == 0)
return SDValue();
// All elements are +0.0, just return zero.
if (NumZeroElts == NumElts)
return DAG.getConstantFP(0.0, DL, EltVT);

// At least one element is +0.0, so it is worth to decompose the reduction
// into fadd's. FADDV is a pairwise reduction, so we need to respect the
// order of the elements in the vector.

// Check if we can output a signed zero.
// This avoid the scenario where all the added values are -0.0 except the +0.0
// element we chose to ignore.
SDNodeFlags Flags = N->getFlags();
bool IsSignedZeroSafe = Flags.hasNoSignedZeros() ||
DAG.allUsesSignedZeroInsensitive(SDValue(N, 0));
if (!IsSignedZeroSafe)
return SDValue();

// Extract all elements.
SmallVector<SDValue, 4> Elts;
for (unsigned I = 0; I < NumElts; I++) {
Elts.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
DAG.getConstant(I, DL, MVT::i64)));
}
// Perform pairwise reduction.
while (Elts.size() > 1) {
SmallVector<SDValue, 2> NewElts;
for (unsigned I = 0; I < Elts.size(); I += 2) {
if (!KnownZeroElts[I] && !KnownZeroElts[I + 1]) {
NewElts.push_back(
DAG.getNode(ISD::FADD, DL, EltVT, Elts[I], Elts[I + 1]));
} else if (KnownZeroElts[I]) {
NewElts.push_back(Elts[I + 1]);
} else if (KnownZeroElts[I + 1]) {
NewElts.push_back(Elts[I]);
}
}
Elts = std::move(NewElts);
}
return Elts[0];
}

static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -22813,6 +22869,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
return combineAcrossLanesIntrinsic(AArch64ISD::SMAXV, N, DAG);
case Intrinsic::aarch64_neon_umaxv:
return combineAcrossLanesIntrinsic(AArch64ISD::UMAXV, N, DAG);
case Intrinsic::aarch64_neon_faddv:
return tryCombineFADDVWithZero(N, DAG);
case Intrinsic::aarch64_neon_fmax:
return DAG.getNode(ISD::FMAXIMUM, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
Expand Down
82 changes: 82 additions & 0 deletions llvm/test/CodeGen/AArch64/faddv.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64 < %s | FileCheck %s

; Test element at index 0 is zero.
define float @test_v2f32_element_0_zero(<2 x float> %vec) {
; CHECK-LABEL: test_v2f32_element_0_zero:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEXT: mov s0, v0.s[1]
; CHECK-NEXT: ret
entry:
%with_zero = insertelement <2 x float> %vec, float 0.0, i64 0
%sum = call nsz float @llvm.aarch64.neon.faddv.f32.v2f32(<2 x float> %with_zero)
ret float %sum
}

; Test element at index 3 is zero.
define float @test_v4f32_element_3_zero(<4 x float> %vec) {
; CHECK-LABEL: test_v4f32_element_3_zero:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov s1, v0.s[2]
; CHECK-NEXT: faddp s0, v0.2s
; CHECK-NEXT: fadd s0, s0, s1
; CHECK-NEXT: fabs s0, s0
; CHECK-NEXT: ret
entry:
%with_zero = insertelement <4 x float> %vec, float 0.0, i64 3
%sum = call float @llvm.aarch64.neon.faddv.f32.v4f32(<4 x float> %with_zero)
%abs = call float @llvm.fabs.f32(float %sum)
ret float %abs
}

; Test elements at index 0 and 2 are zero.
define float @test_v4f32_elements_0_2_zero(<4 x float> %vec) {
; CHECK-LABEL: test_v4f32_elements_0_2_zero:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov s0, v0.s[3]
; CHECK-NEXT: fabs s0, s0
; CHECK-NEXT: ret
entry:
%zero1 = insertelement <4 x float> %vec, float 0.0, i64 0
%zero2 = insertelement <4 x float> %zero1, float 0.0, i64 2
%sum = call float @llvm.aarch64.neon.faddv.f32.v4f32(<4 x float> %zero2)
%abs = call float @llvm.fabs.f32(float %sum)
ret float %abs
}

; Test all elements are zero.
define float @test_v4f32_all_zero(<4 x float> %vec) {
; CHECK-LABEL: test_v4f32_all_zero:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi d0, #0000000000000000
; CHECK-NEXT: ret
entry:
%zero1 = insertelement <4 x float> %vec, float 0.0, i64 0
%zero2 = insertelement <4 x float> %zero1, float 0.0, i64 1
%zero3 = insertelement <4 x float> %zero2, float 0.0, i64 2
%zero4 = insertelement <4 x float> %zero3, float 0.0, i64 3
%sum = call float @llvm.aarch64.neon.faddv.f32.v4f32(<4 x float> %zero4)
ret float %sum
}

; Test element at index 0 is zero.
define double @test_v2f64_element_0_zero(<2 x double> %vec) {
; CHECK-LABEL: test_v2f64_element_0_zero:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov d0, v0.d[1]
; CHECK-NEXT: fabs d0, d0
; CHECK-NEXT: ret
entry:
%with_zero = insertelement <2 x double> %vec, double 0.0, i64 0
%sum = call double @llvm.aarch64.neon.faddv.f64.v2f64(<2 x double> %with_zero)
%abs = call double @llvm.fabs.f64(double %sum)
ret double %abs
}

declare float @llvm.fabs.f32(float)
declare double @llvm.fabs.f64(double)

declare float @llvm.aarch64.neon.faddv.f32.v2f32(<2 x float>)
declare float @llvm.aarch64.neon.faddv.f32.v4f32(<4 x float>)
declare double @llvm.aarch64.neon.faddv.f64.v2f64(<2 x double>)
Loading