Skip to content

Commit

Permalink
[AArch64][SVE] Fold fadda(ptrue, x, select(mask, y, -0.0)) into fadda…
Browse files Browse the repository at this point in the history
…(mask, x, y)

This patch adds an SVE pattern to recognize the use of a select with an
fadda in the form fadda(ptrue, x, select(mask, y, -0.0)). In this case
the select can be folded away, with the select mask used as the
predicate for fadda. This improves the codegen when vectorizing loops
with ordered fp reductions.

Differential Revision: https://reviews.llvm.org/D129623
  • Loading branch information
RosieSumpter committed Jul 19, 2022
1 parent 106d695 commit 05d424d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 4 deletions.
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Expand Up @@ -1234,6 +1234,10 @@ def fpimm0 : FPImmLeaf<fAny, [{
return Imm.isExactlyValue(+0.0);
}]>;

def fpimm_minus0 : FPImmLeaf<fAny, [{
return Imm.isExactlyValue(-0.0);
}]>;

def fpimm_half : FPImmLeaf<fAny, [{
return Imm.isExactlyValue(+0.5);
}]>;
Expand Down
16 changes: 12 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Expand Up @@ -278,10 +278,18 @@ def AArch64scvtf_mt : SDNode<"AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU", SDT_AArch
def AArch64fcvtzu_mt : SDNode<"AArch64ISD::FCVTZU_MERGE_PASSTHRU", SDT_AArch64FCVT>;
def AArch64fcvtzs_mt : SDNode<"AArch64ISD::FCVTZS_MERGE_PASSTHRU", SDT_AArch64FCVT>;

def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>;
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
def AArch64fadda_p : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;
def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3,
[SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisVec<3>, SDTCisSameNumEltsAs<1,3>]>;
def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>;
def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>;
def AArch64fadda_p_node : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>;

def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
[(AArch64fadda_p_node node:$op1, node:$op2, node:$op3),
(AArch64fadda_p_node (SVEAllActive), node:$op2,
(vselect node:$op1, node:$op3, (splat_vector (f32 fpimm_minus0)))),
(AArch64fadda_p_node (SVEAllActive), node:$op2,
(vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;

def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
Expand Down
112 changes: 112 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-fadda-select.ll
@@ -0,0 +1,112 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s

; Fold fadda(ptrue, x, select(mask, y, -0.0)) -> fadda(mask, x, y)

define float @pred_fadda_nxv2f32(float %x, <vscale x 2 x float> %y, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv2f32:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
; CHECK-NEXT: fadda s0, p0, s0, z1.s
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 2 x float> poison, float -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 2 x float> %i, <vscale x 2 x float> poison, <vscale x 2 x i32> zeroinitializer
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x float> %y, <vscale x 2 x float> %minus0
%fadda = call float @llvm.vector.reduce.fadd.nxv2f32(float %x, <vscale x 2 x float> %sel)
ret float %fadda
}

define float @pred_fadda_nxv4f32(float %x, <vscale x 4 x float> %y, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
; CHECK-NEXT: fadda s0, p0, s0, z1.s
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 4 x float> poison, float -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 4 x float> %i, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %y, <vscale x 4 x float> %minus0
%fadda = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 4 x float> %sel)
ret float %fadda
}

define double @pred_fadda_nxv2f64(double %x, <vscale x 2 x double> %y, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
; CHECK-NEXT: fadda d0, p0, d0, z1.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 2 x double> poison, double -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 2 x double> %i, <vscale x 2 x double> poison, <vscale x 2 x i32> zeroinitializer
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %y, <vscale x 2 x double> %minus0
%fadda = call double @llvm.vector.reduce.fadd.nxv2f64(double %x, <vscale x 2 x double> %sel)
ret double %fadda
}

; Currently the folding doesn't work for f16 element types, since -0.0 is not treated as a legal f16 immediate.

define half @pred_fadda_nxv2f16(half %x, <vscale x 2 x half> %y, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv2f16:
; CHECK: // %bb.0:
; CHECK-NEXT: adrp x8, .LCPI3_0
; CHECK-NEXT: add x8, x8, :lo12:.LCPI3_0
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT: ld1rh { z2.d }, p1/z, [x8]
; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d
; CHECK-NEXT: fadda h0, p1, h0, z1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 2 x half> poison, half -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 2 x half> %i, <vscale x 2 x half> poison, <vscale x 2 x i32> zeroinitializer
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x half> %y, <vscale x 2 x half> %minus0
%fadda = call half @llvm.vector.reduce.fadd.nxv2f16(half %x, <vscale x 2 x half> %sel)
ret half %fadda
}

define half @pred_fadda_nxv4f16(half %x, <vscale x 4 x half> %y, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv4f16:
; CHECK: // %bb.0:
; CHECK-NEXT: adrp x8, .LCPI4_0
; CHECK-NEXT: add x8, x8, :lo12:.LCPI4_0
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT: ld1rh { z2.s }, p1/z, [x8]
; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s
; CHECK-NEXT: fadda h0, p1, h0, z1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 4 x half> poison, half -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 4 x half> %i, <vscale x 4 x half> poison, <vscale x 4 x i32> zeroinitializer
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x half> %y, <vscale x 4 x half> %minus0
%fadda = call half @llvm.vector.reduce.fadd.nxv4f16(half %x, <vscale x 4 x half> %sel)
ret half %fadda
}

define half @pred_fadda_nxv8f16(half %x, <vscale x 8 x half> %y, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: pred_fadda_nxv8f16:
; CHECK: // %bb.0:
; CHECK-NEXT: adrp x8, .LCPI5_0
; CHECK-NEXT: add x8, x8, :lo12:.LCPI5_0
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
; CHECK-NEXT: ld1rh { z2.h }, p1/z, [x8]
; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h
; CHECK-NEXT: fadda h0, p1, h0, z1.h
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
; CHECK-NEXT: ret
%i = insertelement <vscale x 8 x half> poison, half -0.000000e+00, i32 0
%minus0 = shufflevector <vscale x 8 x half> %i, <vscale x 8 x half> poison, <vscale x 8 x i32> zeroinitializer
%sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %y, <vscale x 8 x half> %minus0
%fadda = call half @llvm.vector.reduce.fadd.nxv8f16(half %x, <vscale x 8 x half> %sel)
ret half %fadda
}

declare float @llvm.vector.reduce.fadd.nxv2f32(float, <vscale x 2 x float>)
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>)
declare double @llvm.vector.reduce.fadd.nxv2f64(double, <vscale x 2 x double>)
declare half @llvm.vector.reduce.fadd.nxv2f16(half, <vscale x 2 x half>)
declare half @llvm.vector.reduce.fadd.nxv4f16(half, <vscale x 4 x half>)
declare half @llvm.vector.reduce.fadd.nxv8f16(half, <vscale x 8 x half>)

0 comments on commit 05d424d

Please sign in to comment.