Skip to content

Commit

Permalink
[AArch64] Optimize fp64 <-> fp16 SIMD conversions
Browse files Browse the repository at this point in the history
Legalization would result in needless scalarization. Add some
DAGCombines to fix this up.
  • Loading branch information
majnemer committed Mar 8, 2024
1 parent e963d07 commit 5f935e9
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 170 deletions.
95 changes: 93 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4507,13 +4507,16 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
};

if (Op.getValueType() == MVT::bf16) {
unsigned MaxWidth = IsSigned
? DAG.ComputeMaxSignificantBits(SrcVal)
: DAG.computeKnownBits(SrcVal).countMaxActiveBits();
// bf16 conversions are promoted to f32 when converting from i16.
if (DAG.ComputeMaxSignificantBits(SrcVal) <= 24) {
if (MaxWidth <= 24) {
return IntToFpViaPromotion(MVT::f32);
}

// bf16 conversions are promoted to f64 when converting from i32.
if (DAG.ComputeMaxSignificantBits(SrcVal) <= 53) {
if (MaxWidth <= 53) {
return IntToFpViaPromotion(MVT::f64);
}

Expand Down Expand Up @@ -19376,6 +19379,94 @@ static SDValue performBuildVectorCombine(SDNode *N,
SDLoc DL(N);
EVT VT = N->getValueType(0);

if (VT == MVT::v4f16 || VT == MVT::v4bf16) {
SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1),
Elt2 = N->getOperand(2), Elt3 = N->getOperand(3);
if (Elt0->getOpcode() == ISD::FP_ROUND &&
Elt1->getOpcode() == ISD::FP_ROUND &&
isa<ConstantSDNode>(Elt0->getOperand(1)) &&
isa<ConstantSDNode>(Elt1->getOperand(1)) &&
Elt0->getConstantOperandVal(1) == Elt1->getConstantOperandVal(1) &&
Elt0->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Elt1->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
// Constant index.
isa<ConstantSDNode>(Elt0->getOperand(0)->getOperand(1)) &&
isa<ConstantSDNode>(Elt1->getOperand(0)->getOperand(1)) &&
Elt0->getOperand(0)->getOperand(0) ==
Elt1->getOperand(0)->getOperand(0) &&
Elt0->getOperand(0)->getConstantOperandVal(1) == 0 &&
Elt1->getOperand(0)->getConstantOperandVal(1) == 1) {
SDValue LowLanesSrcVec = Elt0->getOperand(0)->getOperand(0);
if (LowLanesSrcVec.getValueType() == MVT::v2f64) {
SDValue HighLanes;
if (Elt2->getOpcode() == ISD::UNDEF &&
Elt3->getOpcode() == ISD::UNDEF) {
HighLanes = DAG.getUNDEF(MVT::v2f32);
} else if (Elt2->getOpcode() == ISD::FP_ROUND &&
Elt3->getOpcode() == ISD::FP_ROUND &&
isa<ConstantSDNode>(Elt2->getOperand(1)) &&
isa<ConstantSDNode>(Elt3->getOperand(1)) &&
Elt2->getConstantOperandVal(1) ==
Elt3->getConstantOperandVal(1) &&
Elt2->getOperand(0)->getOpcode() ==
ISD::EXTRACT_VECTOR_ELT &&
Elt3->getOperand(0)->getOpcode() ==
ISD::EXTRACT_VECTOR_ELT &&
// Constant index.
isa<ConstantSDNode>(Elt2->getOperand(0)->getOperand(1)) &&
isa<ConstantSDNode>(Elt3->getOperand(0)->getOperand(1)) &&
Elt2->getOperand(0)->getOperand(0) ==
Elt3->getOperand(0)->getOperand(0) &&
Elt2->getOperand(0)->getConstantOperandVal(1) == 0 &&
Elt3->getOperand(0)->getConstantOperandVal(1) == 1) {
SDValue HighLanesSrcVec = Elt2->getOperand(0)->getOperand(0);
HighLanes =
DAG.getNode(AArch64ISD::FCVTXN, DL, MVT::v2f32, HighLanesSrcVec);
}
if (HighLanes) {
SDValue DoubleToSingleSticky =
DAG.getNode(AArch64ISD::FCVTXN, DL, MVT::v2f32, LowLanesSrcVec);
SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32,
DoubleToSingleSticky, HighLanes);
return DAG.getNode(ISD::FP_ROUND, DL, VT, Concat,
Elt0->getOperand(1));
}
}
}
}

if (VT == MVT::v2f64) {
SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1);
if (Elt0->getOpcode() == ISD::FP_EXTEND &&
Elt1->getOpcode() == ISD::FP_EXTEND &&
Elt0->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Elt1->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
Elt0->getOperand(0)->getOperand(0) ==
Elt1->getOperand(0)->getOperand(0) &&
// Constant index.
isa<ConstantSDNode>(Elt0->getOperand(0)->getOperand(1)) &&
isa<ConstantSDNode>(Elt1->getOperand(0)->getOperand(1)) &&
Elt0->getOperand(0)->getConstantOperandVal(1) + 1 ==
Elt1->getOperand(0)->getConstantOperandVal(1) &&
// EXTRACT_SUBVECTOR requires that Idx be a constant multiple of
// ResultType's known minimum vector length.
Elt0->getOperand(0)->getConstantOperandVal(1) %
VT.getVectorMinNumElements() ==
0) {
SDValue SrcVec = Elt0->getOperand(0)->getOperand(0);
if (SrcVec.getValueType() == MVT::v4f16 ||
SrcVec.getValueType() == MVT::v4bf16) {
SDValue HalfToSingle =
DAG.getNode(ISD::FP_EXTEND, DL, MVT::v4f32, SrcVec);
SDValue SubvectorIdx = Elt0->getOperand(0)->getOperand(1);
SDValue Extract = DAG.getNode(
ISD::EXTRACT_SUBVECTOR, DL, VT.changeVectorElementType(MVT::f32),
HalfToSingle, SubvectorIdx);
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Extract);
}
}
}

// A build vector of two extracted elements is equivalent to an
// extract subvector where the inner vector is any-extended to the
// extract_vector_elt VT.
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -6832,7 +6832,7 @@ multiclass SIMDFPNarrowTwoVector<bit U, bit S, bits<5> opc, string asm> {
}

multiclass SIMDFPInexactCvtTwoVector<bit U, bit S, bits<5> opc, string asm,
Intrinsic OpNode> {
SDPatternOperator OpNode> {
def v2f32 : BaseSIMDFPCvtTwoVector<0, U, {S,1}, opc, V64, V128,
asm, ".2s", ".2d",
[(set (v2f32 V64:$Rd), (OpNode (v2f64 V128:$Rn)))]>;
Expand Down Expand Up @@ -7547,7 +7547,7 @@ class BaseSIMDCmpTwoScalar<bit U, bits<2> size, bits<2> size2, bits<5> opcode,
let mayRaiseFPException = 1, Uses = [FPCR] in
class SIMDInexactCvtTwoScalar<bits<5> opcode, string asm>
: I<(outs FPR32:$Rd), (ins FPR64:$Rn), asm, "\t$Rd, $Rn", "",
[(set (f32 FPR32:$Rd), (AArch64fcvtxn (f64 FPR64:$Rn)))]>,
[(set (f32 FPR32:$Rd), (AArch64fcvtxnsdr (f64 FPR64:$Rn)))]>,
Sched<[WriteVd]> {
bits<5> Rd;
bits<5> Rn;
Expand Down
11 changes: 7 additions & 4 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,12 @@ def AArch64fcmlez: SDNode<"AArch64ISD::FCMLEz", SDT_AArch64fcmpz>;
def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;

def AArch64fcvtxn_n: SDNode<"AArch64ISD::FCVTXN", SDTFPRoundOp>;
def AArch64fcvtxn: PatFrags<(ops node:$Rn),
[(f32 (int_aarch64_sisd_fcvtxn (f64 node:$Rn))),
(f32 (AArch64fcvtxn_n (f64 node:$Rn)))]>;
def AArch64fcvtxnsdr: PatFrags<(ops node:$Rn),
[(f32 (int_aarch64_sisd_fcvtxn (f64 node:$Rn))),
(f32 (AArch64fcvtxn_n (f64 node:$Rn)))]>;
def AArch64fcvtxnv: PatFrags<(ops node:$Rn),
[(int_aarch64_neon_fcvtxn node:$Rn),
(AArch64fcvtxn_n node:$Rn)]>;

def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;
Expand Down Expand Up @@ -5042,7 +5045,7 @@ def : Pat<(concat_vectors V64:$Rd, (v4f16 (any_fpround (v4f32 V128:$Rn)))),
defm FCVTPS : SIMDTwoVectorFPToInt<0,1,0b11010, "fcvtps",int_aarch64_neon_fcvtps>;
defm FCVTPU : SIMDTwoVectorFPToInt<1,1,0b11010, "fcvtpu",int_aarch64_neon_fcvtpu>;
defm FCVTXN : SIMDFPInexactCvtTwoVector<1, 0, 0b10110, "fcvtxn",
int_aarch64_neon_fcvtxn>;
AArch64fcvtxnv>;
defm FCVTZS : SIMDTwoVectorFPToInt<0, 1, 0b11011, "fcvtzs", any_fp_to_sint>;
defm FCVTZU : SIMDTwoVectorFPToInt<1, 1, 0b11011, "fcvtzu", any_fp_to_uint>;

Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/AArch64/arm64-vcvt_f.ll
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,60 @@ define <2 x float> @test_vcvt_f32_f64(<2 x double> %v) nounwind readnone ssp {
ret <2 x float> %vcvt1.i
}

; FALLBACK-NOT: remark{{.*}}G_FPEXT{{.*}}(in function: test_vcvt_bf16_f64)
; FALLBACK-NOT: remark{{.*}}fpext{{.*}}(in function: test_vcvt_bf16_f64)
define <2 x bfloat> @test_vcvt_bf16_f64(<2 x double> %v) nounwind readnone ssp {
; GENERIC-LABEL: test_vcvt_bf16_f64:
; GENERIC: // %bb.0:
; GENERIC-NEXT: fcvtxn v0.2s, v0.2d
; GENERIC-NEXT: movi.4s v1, #127, msl #8
; GENERIC-NEXT: movi.4s v2, #1
; GENERIC-NEXT: ushr.4s v3, v0, #16
; GENERIC-NEXT: add.4s v1, v0, v1
; GENERIC-NEXT: and.16b v2, v3, v2
; GENERIC-NEXT: add.4s v1, v2, v1
; GENERIC-NEXT: fcmeq.4s v2, v0, v0
; GENERIC-NEXT: orr.4s v0, #64, lsl #16
; GENERIC-NEXT: bit.16b v0, v1, v2
; GENERIC-NEXT: shrn.4h v0, v0, #16
; GENERIC-NEXT: ret
;
; FAST-LABEL: test_vcvt_bf16_f64:
; FAST: // %bb.0:
; FAST-NEXT: fcvtxn v1.2s, v0.2d
; FAST-NEXT: // implicit-def: $q0
; FAST-NEXT: fmov d0, d1
; FAST-NEXT: ushr.4s v1, v0, #16
; FAST-NEXT: movi.4s v2, #1
; FAST-NEXT: and.16b v1, v1, v2
; FAST-NEXT: add.4s v1, v1, v0
; FAST-NEXT: movi.4s v2, #127, msl #8
; FAST-NEXT: add.4s v1, v1, v2
; FAST-NEXT: mov.16b v2, v0
; FAST-NEXT: orr.4s v2, #64, lsl #16
; FAST-NEXT: fcmeq.4s v0, v0, v0
; FAST-NEXT: bsl.16b v0, v1, v2
; FAST-NEXT: shrn.4h v0, v0, #16
; FAST-NEXT: ret
;
; GISEL-LABEL: test_vcvt_bf16_f64:
; GISEL: // %bb.0:
; GISEL-NEXT: fcvtxn v0.2s, v0.2d
; GISEL-NEXT: movi.4s v1, #127, msl #8
; GISEL-NEXT: movi.4s v2, #1
; GISEL-NEXT: ushr.4s v3, v0, #16
; GISEL-NEXT: add.4s v1, v0, v1
; GISEL-NEXT: and.16b v2, v3, v2
; GISEL-NEXT: add.4s v1, v2, v1
; GISEL-NEXT: fcmeq.4s v2, v0, v0
; GISEL-NEXT: orr.4s v0, #64, lsl #16
; GISEL-NEXT: bit.16b v0, v1, v2
; GISEL-NEXT: shrn.4h v0, v0, #16
; GISEL-NEXT: ret
%vcvt1.i = fptrunc <2 x double> %v to <2 x bfloat>
ret <2 x bfloat> %vcvt1.i
}

define half @test_vcvt_f16_f32(<1 x float> %x) {
; GENERIC-LABEL: test_vcvt_f16_f32:
; GENERIC: // %bb.0:
Expand Down Expand Up @@ -350,3 +404,5 @@ define float @from_half(i16 %in) {

declare float @llvm.convert.from.fp16.f32(i16) #1
declare i16 @llvm.convert.to.fp16.f32(float) #1
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; FALLBACK: {{.*}}
50 changes: 12 additions & 38 deletions llvm/test/CodeGen/AArch64/fp16-v8-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -312,25 +312,12 @@ define <8 x half> @s_to_h(<8 x float> %a) {
define <8 x half> @d_to_h(<8 x double> %a) {
; CHECK-LABEL: d_to_h:
; CHECK: // %bb.0:
; CHECK-NEXT: mov d5, v0.d[1]
; CHECK-NEXT: fcvt h0, d0
; CHECK-NEXT: fcvt h4, d1
; CHECK-NEXT: mov d1, v1.d[1]
; CHECK-NEXT: fcvt h5, d5
; CHECK-NEXT: fcvt h1, d1
; CHECK-NEXT: mov v0.h[1], v5.h[0]
; CHECK-NEXT: mov v0.h[2], v4.h[0]
; CHECK-NEXT: mov v0.h[3], v1.h[0]
; CHECK-NEXT: fcvt h1, d2
; CHECK-NEXT: mov d2, v2.d[1]
; CHECK-NEXT: mov v0.h[4], v1.h[0]
; CHECK-NEXT: fcvt h1, d2
; CHECK-NEXT: mov d2, v3.d[1]
; CHECK-NEXT: mov v0.h[5], v1.h[0]
; CHECK-NEXT: fcvt h1, d3
; CHECK-NEXT: mov v0.h[6], v1.h[0]
; CHECK-NEXT: fcvt h1, d2
; CHECK-NEXT: mov v0.h[7], v1.h[0]
; CHECK-NEXT: fcvtxn v0.2s, v0.2d
; CHECK-NEXT: fcvtxn v2.2s, v2.2d
; CHECK-NEXT: fcvtxn2 v0.4s, v1.2d
; CHECK-NEXT: fcvtxn2 v2.4s, v3.2d
; CHECK-NEXT: fcvtn v0.4h, v0.4s
; CHECK-NEXT: fcvtn2 v0.8h, v2.4s
; CHECK-NEXT: ret
%1 = fptrunc <8 x double> %a to <8 x half>
ret <8 x half> %1
Expand All @@ -349,25 +336,12 @@ define <8 x float> @h_to_s(<8 x half> %a) {
define <8 x double> @h_to_d(<8 x half> %a) {
; CHECK-LABEL: h_to_d:
; CHECK: // %bb.0:
; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: mov h1, v0.h[1]
; CHECK-NEXT: mov h3, v0.h[3]
; CHECK-NEXT: mov h4, v0.h[2]
; CHECK-NEXT: fcvt d0, h0
; CHECK-NEXT: mov h5, v2.h[1]
; CHECK-NEXT: mov h6, v2.h[3]
; CHECK-NEXT: mov h7, v2.h[2]
; CHECK-NEXT: fcvt d16, h1
; CHECK-NEXT: fcvt d17, h3
; CHECK-NEXT: fcvt d1, h4
; CHECK-NEXT: fcvt d2, h2
; CHECK-NEXT: fcvt d4, h5
; CHECK-NEXT: fcvt d5, h6
; CHECK-NEXT: fcvt d3, h7
; CHECK-NEXT: mov v0.d[1], v16.d[0]
; CHECK-NEXT: mov v1.d[1], v17.d[0]
; CHECK-NEXT: mov v2.d[1], v4.d[0]
; CHECK-NEXT: mov v3.d[1], v5.d[0]
; CHECK-NEXT: fcvtl v1.4s, v0.4h
; CHECK-NEXT: fcvtl2 v2.4s, v0.8h
; CHECK-NEXT: fcvtl v0.2d, v1.2s
; CHECK-NEXT: fcvtl2 v3.2d, v2.4s
; CHECK-NEXT: fcvtl2 v1.2d, v1.4s
; CHECK-NEXT: fcvtl v2.2d, v2.2s
; CHECK-NEXT: ret
%1 = fpext <8 x half> %a to <8 x double>
ret <8 x double> %1
Expand Down
64 changes: 37 additions & 27 deletions llvm/test/CodeGen/AArch64/fpext.ll
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,46 @@ entry:
}

define <2 x double> @fpext_v2f16_v2f64(<2 x half> %a) {
; CHECK-LABEL: fpext_v2f16_v2f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEXT: mov h1, v0.h[1]
; CHECK-NEXT: fcvt d0, h0
; CHECK-NEXT: fcvt d1, h1
; CHECK-NEXT: mov v0.d[1], v1.d[0]
; CHECK-NEXT: ret
; CHECK-SD-LABEL: fpext_v2f16_v2f64:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: fcvtl v0.4s, v0.4h
; CHECK-SD-NEXT: fcvtl v0.2d, v0.2s
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: fpext_v2f16_v2f64:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-GI-NEXT: mov h1, v0.h[1]
; CHECK-GI-NEXT: fcvt d0, h0
; CHECK-GI-NEXT: fcvt d1, h1
; CHECK-GI-NEXT: mov v0.d[1], v1.d[0]
; CHECK-GI-NEXT: ret
entry:
%c = fpext <2 x half> %a to <2 x double>
ret <2 x double> %c
}

define <3 x double> @fpext_v3f16_v3f64(<3 x half> %a) {
; CHECK-LABEL: fpext_v3f16_v3f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEXT: mov h1, v0.h[1]
; CHECK-NEXT: mov h2, v0.h[2]
; CHECK-NEXT: fcvt d0, h0
; CHECK-NEXT: fcvt d1, h1
; CHECK-NEXT: fcvt d2, h2
; CHECK-NEXT: ret
; CHECK-SD-LABEL: fpext_v3f16_v3f64:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: fcvtl v1.4s, v0.4h
; CHECK-SD-NEXT: fcvtl v0.2d, v1.2s
; CHECK-SD-NEXT: fcvtl2 v2.2d, v1.4s
; CHECK-SD-NEXT: // kill: def $d2 killed $d2 killed $q2
; CHECK-SD-NEXT: ext v1.16b, v0.16b, v0.16b, #8
; CHECK-SD-NEXT: // kill: def $d0 killed $d0 killed $q0
; CHECK-SD-NEXT: // kill: def $d1 killed $d1 killed $q1
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: fpext_v3f16_v3f64:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-GI-NEXT: mov h1, v0.h[1]
; CHECK-GI-NEXT: mov h2, v0.h[2]
; CHECK-GI-NEXT: fcvt d0, h0
; CHECK-GI-NEXT: fcvt d1, h1
; CHECK-GI-NEXT: fcvt d2, h2
; CHECK-GI-NEXT: ret
entry:
%c = fpext <3 x half> %a to <3 x double>
ret <3 x double> %c
Expand All @@ -116,16 +133,9 @@ entry:
define <4 x double> @fpext_v4f16_v4f64(<4 x half> %a) {
; CHECK-SD-LABEL: fpext_v4f16_v4f64:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-SD-NEXT: mov h1, v0.h[1]
; CHECK-SD-NEXT: mov h2, v0.h[3]
; CHECK-SD-NEXT: mov h3, v0.h[2]
; CHECK-SD-NEXT: fcvt d0, h0
; CHECK-SD-NEXT: fcvt d4, h1
; CHECK-SD-NEXT: fcvt d2, h2
; CHECK-SD-NEXT: fcvt d1, h3
; CHECK-SD-NEXT: mov v0.d[1], v4.d[0]
; CHECK-SD-NEXT: mov v1.d[1], v2.d[0]
; CHECK-SD-NEXT: fcvtl v0.4s, v0.4h
; CHECK-SD-NEXT: fcvtl2 v1.2d, v0.4s
; CHECK-SD-NEXT: fcvtl v0.2d, v0.2s
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: fpext_v4f16_v4f64:
Expand Down

0 comments on commit 5f935e9

Please sign in to comment.