From 9aa8dd52093712c9f01f28fa29487c00a21dac83 Mon Sep 17 00:00:00 2001 From: Damian Heaton Date: Thu, 13 Nov 2025 15:42:11 +0000 Subject: [PATCH 1/2] Combine vector FNEG+FMA into `FNML[A|S]` This allows for FNEG + FMA sequences to be combined into a single operation, with `FNML[A|S]`, `FNMAD`, or `FNMSB` selected depending on the operand order. --- .../Target/AArch64/AArch64ISelLowering.cpp | 50 ++++ .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 8 +- llvm/test/CodeGen/AArch64/sve-fmsub.ll | 276 ++++++++++++++++++ 3 files changed, 332 insertions(+), 2 deletions(-) create mode 100644 llvm/test/CodeGen/AArch64/sve-fmsub.ll diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7b51f453b4974..79625dd766085 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1176,6 +1176,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE); setTargetDAGCombine(ISD::CTPOP); + setTargetDAGCombine(ISD::FMA); + // In case of strict alignment, avoid an excessive number of byte wide stores. MaxStoresPerMemsetOptSize = 8; MaxStoresPerMemset = @@ -20444,6 +20446,52 @@ static SDValue performFADDCombine(SDNode *N, return SDValue(); } +static SDValue performFMACombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + SelectionDAG &DAG = DCI.DAG; + SDValue Op1 = N->getOperand(0); + SDValue Op2 = N->getOperand(1); + SDValue Op3 = N->getOperand(2); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // fma(a, b, neg(c)) -> fnmls(a, b, c) + // fma(neg(a), b, neg(c)) -> fnmla(a, b, c) + // fma(a, neg(b), neg(c)) -> fnmla(a, b, c) + if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) && + (Subtarget->hasSVE() || Subtarget->hasSME())) { + if (Op3.getOpcode() == ISD::FNEG) { + unsigned int Opcode; + if (Op1.getOpcode() == ISD::FNEG) { + Op1 = Op1.getOperand(0); + Opcode = AArch64ISD::FNMLA_PRED; + } else if (Op2.getOpcode() == ISD::FNEG) { + Op2 = Op2.getOperand(0); + Opcode = AArch64ISD::FNMLA_PRED; + } else { + Opcode = AArch64ISD::FNMLS_PRED; + } + Op3 = Op3.getOperand(0); + auto Pg = getPredicateForVector(DAG, DL, VT); + if (VT.isFixedLengthVector()) { + assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) && + "Expected only legal fixed-width types"); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + Op1 = convertToScalableVector(DAG, ContainerVT, Op1); + Op2 = convertToScalableVector(DAG, ContainerVT, Op2); + Op3 = convertToScalableVector(DAG, ContainerVT, Op3); + auto ScalableRes = + DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3); + return convertFromScalableVector(DAG, VT, ScalableRes); + } + return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3); + } + } + + return SDValue(); +} + static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) { switch (Opcode) { case ISD::STRICT_FADD: @@ -27977,6 +28025,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performANDCombine(N, DCI); case ISD::FADD: return performFADDCombine(N, DCI); + case ISD::FMA: + return performFMACombine(N, DCI, Subtarget); case ISD::INTRINSIC_WO_CHAIN: return performIntrinsicCombine(N, DCI, Subtarget); case ISD::ANY_EXTEND: diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index c8c21c4822ffe..4640719cda43c 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -240,6 +240,8 @@ def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>; def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>; def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>; def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>; +def AArch64fnmla_p_node : SDNode<"AArch64ISD::FNMLA_PRED", SDT_AArch64FMA>; +def AArch64fnmls_p_node : SDNode<"AArch64ISD::FNMLS_PRED", SDT_AArch64FMA>; def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3), (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{ @@ -460,12 +462,14 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx), def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), - [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm), + [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za), + (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm), (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))), (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>; def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), - [(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm), + [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za), + (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm), (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>; def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2), diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll new file mode 100644 index 0000000000000..721066038769c --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll @@ -0,0 +1,276 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK + +define @fmsub_nxv2f64( %a, %b, %c) { +; CHECK-LABEL: fmsub_nxv2f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: ret +entry: + %neg = fneg %c + %0 = tail call @llvm.fmuladd( %a, %b, %neg) + ret %0 +} + +define @fmsub_nxv4f32( %a, %b, %c) { +; CHECK-LABEL: fmsub_nxv4f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +entry: + %neg = fneg %c + %0 = tail call @llvm.fmuladd( %a, %b, %neg) + ret %0 +} + +define @fmsub_nxv8f16( %a, %b, %c) { +; CHECK-LABEL: fmsub_nxv8f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fnmsb z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %neg = fneg %c + %0 = tail call @llvm.fmuladd( %a, %b, %neg) + ret %0 +} + +define <2 x double> @fmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: fmsub_v2f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <2 x double> %c + %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg) + ret <2 x double> %0 +} + +define <4 x float> @fmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) { +; CHECK-LABEL: fmsub_v4f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <4 x float> %c + %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg) + ret <4 x float> %0 +} + +define <8 x half> @fmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) { +; CHECK-LABEL: fmsub_v8f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h, vl8 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmsb z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <8 x half> %c + %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg) + ret <8 x half> %0 +} + + +define <2 x double> @fmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) { +; CHECK-LABEL: fmsub_flipped_v2f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <2 x double> %c + %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg) + ret <2 x double> %0 +} + +define <4 x float> @fmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: fmsub_flipped_v4f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmls z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <4 x float> %c + %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg) + ret <4 x float> %0 +} + +define <8 x half> @fmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: fmsub_flipped_v8f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h, vl8 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmls z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <8 x half> %c + %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg) + ret <8 x half> %0 +} + +define @fnmsub_nxv2f64( %a, %b, %c) { +; CHECK-LABEL: fnmsub_nxv2f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: ret +entry: + %neg = fneg %a + %neg1 = fneg %c + %0 = tail call @llvm.fmuladd( %neg, %b, %neg1) + ret %0 +} + +define @fnmsub_nxv4f32( %a, %b, %c) { +; CHECK-LABEL: fnmsub_nxv4f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +entry: + %neg = fneg %a + %neg1 = fneg %c + %0 = tail call @llvm.fmuladd( %neg, %b, %neg1) + ret %0 +} + +define @fnmsub_nxv8f16( %a, %b, %c) { +; CHECK-LABEL: fnmsub_nxv8f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %neg = fneg %a + %neg1 = fneg %c + %0 = tail call @llvm.fmuladd( %neg, %b, %neg1) + ret %0 +} + +define <2 x double> @fnmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: fnmsub_v2f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <2 x double> %a + %neg1 = fneg <2 x double> %c + %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1) + ret <2 x double> %0 +} + +define <4 x float> @fnmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) { +; CHECK-LABEL: fnmsub_v4f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <4 x float> %a + %neg1 = fneg <4 x float> %c + %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1) + ret <4 x float> %0 +} + +define <8 x half> @fnmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) { +; CHECK-LABEL: fnmsub_v8f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h, vl8 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <8 x half> %a + %neg1 = fneg <8 x half> %c + %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1) + ret <8 x half> %0 +} + +define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) { +; CHECK-LABEL: fnmsub_flipped_v2f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <2 x double> %a + %neg1 = fneg <2 x double> %c + %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1) + ret <2 x double> %0 +} + +define <4 x float> @fnmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: fnmsub_flipped_v4f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmla z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <4 x float> %a + %neg1 = fneg <4 x float> %c + %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1) + ret <4 x float> %0 +} + +define <8 x half> @fnmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: fnmsub_flipped_v8f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h, vl8 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmla z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <8 x half> %a + %neg1 = fneg <8 x half> %c + %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1) + ret <8 x half> %0 +} From d4e4360f27eb9536c3f74b1e8802cd600216bba2 Mon Sep 17 00:00:00 2001 From: Damian Heaton Date: Tue, 18 Nov 2025 12:54:41 +0000 Subject: [PATCH 2/2] Address review comments --- .../Target/AArch64/AArch64ISelLowering.cpp | 61 +++++----- .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 4 +- llvm/test/CodeGen/AArch64/sve-fmsub.ll | 115 +++++++++++++++++- 3 files changed, 143 insertions(+), 37 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 79625dd766085..08aec2c2cb79b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -20450,46 +20450,41 @@ static SDValue performFMACombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { SelectionDAG &DAG = DCI.DAG; - SDValue Op1 = N->getOperand(0); - SDValue Op2 = N->getOperand(1); - SDValue Op3 = N->getOperand(2); + SDValue OpA = N->getOperand(0); + SDValue OpB = N->getOperand(1); + SDValue OpC = N->getOperand(2); EVT VT = N->getValueType(0); SDLoc DL(N); // fma(a, b, neg(c)) -> fnmls(a, b, c) // fma(neg(a), b, neg(c)) -> fnmla(a, b, c) // fma(a, neg(b), neg(c)) -> fnmla(a, b, c) - if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) && - (Subtarget->hasSVE() || Subtarget->hasSME())) { - if (Op3.getOpcode() == ISD::FNEG) { - unsigned int Opcode; - if (Op1.getOpcode() == ISD::FNEG) { - Op1 = Op1.getOperand(0); - Opcode = AArch64ISD::FNMLA_PRED; - } else if (Op2.getOpcode() == ISD::FNEG) { - Op2 = Op2.getOperand(0); - Opcode = AArch64ISD::FNMLA_PRED; - } else { - Opcode = AArch64ISD::FNMLS_PRED; - } - Op3 = Op3.getOperand(0); - auto Pg = getPredicateForVector(DAG, DL, VT); - if (VT.isFixedLengthVector()) { - assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) && - "Expected only legal fixed-width types"); - EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); - Op1 = convertToScalableVector(DAG, ContainerVT, Op1); - Op2 = convertToScalableVector(DAG, ContainerVT, Op2); - Op3 = convertToScalableVector(DAG, ContainerVT, Op3); - auto ScalableRes = - DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3); - return convertFromScalableVector(DAG, VT, ScalableRes); - } - return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3); - } + if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) || + !Subtarget->isSVEorStreamingSVEAvailable() || + OpC.getOpcode() != ISD::FNEG) { + return SDValue(); + } + unsigned int Opcode; + if (OpA.getOpcode() == ISD::FNEG) { + OpA = OpA.getOperand(0); + Opcode = AArch64ISD::FNMLA_PRED; + } else if (OpB.getOpcode() == ISD::FNEG) { + OpB = OpB.getOperand(0); + Opcode = AArch64ISD::FNMLA_PRED; + } else { + Opcode = AArch64ISD::FNMLS_PRED; } - - return SDValue(); + OpC = OpC.getOperand(0); + auto Pg = getPredicateForVector(DAG, DL, VT); + if (VT.isFixedLengthVector()) { + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + OpA = convertToScalableVector(DAG, ContainerVT, OpA); + OpB = convertToScalableVector(DAG, ContainerVT, OpB); + OpC = convertToScalableVector(DAG, ContainerVT, OpC); + auto ScalableRes = DAG.getNode(Opcode, DL, ContainerVT, Pg, OpA, OpB, OpC); + return convertFromScalableVector(DAG, VT, ScalableRes); + } + return DAG.getNode(Opcode, DL, VT, Pg, OpA, OpB, OpC); } static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) { diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 4640719cda43c..2d90123d37e01 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -464,13 +464,11 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx), def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za), (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm), - (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))), (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>; def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za), - (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm), - (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>; + (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm)]>; def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2), (AArch64fsub_p node:$pg, node:$op2, node:$op1)>; diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll index 721066038769c..29dbb87f1b875 100644 --- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll +++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll @@ -1,5 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 -; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK +; RUN: llc -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE +; RUN: llc -mattr=+sme -force-streaming %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SME + +target triple = "aarch64" define @fmsub_nxv2f64( %a, %b, %c) { ; CHECK-LABEL: fmsub_nxv2f64: @@ -274,3 +277,113 @@ entry: %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1) ret <8 x half> %0 } + +; Illegal types + +define @fmsub_illegal_nxv3f32( %a, %b, %c) { +; CHECK-LABEL: fmsub_illegal_nxv3f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +entry: + %neg = fneg %c + %0 = tail call @llvm.fmuladd( %a, %b, %neg) + ret %0 +} + +define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) { +; CHECK-SVE-LABEL: fmsub_illegal_v1f64: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: ptrue p0.d, vl1 +; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-SVE-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-SVE-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d +; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-SVE-NEXT: ret +; +; CHECK-SME-LABEL: fmsub_illegal_v1f64: +; CHECK-SME: // %bb.0: // %entry +; CHECK-SME-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill +; CHECK-SME-NEXT: addvl sp, sp, #-1 +; CHECK-SME-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG +; CHECK-SME-NEXT: .cfi_offset w29, -16 +; CHECK-SME-NEXT: ptrue p0.d, vl1 +; CHECK-SME-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-SME-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-SME-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-SME-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d +; CHECK-SME-NEXT: str z0, [sp] +; CHECK-SME-NEXT: ldr d0, [sp] +; CHECK-SME-NEXT: addvl sp, sp, #1 +; CHECK-SME-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload +; CHECK-SME-NEXT: ret +entry: + %neg = fneg <1 x double> %c + %0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg) + ret <1 x double> %0 +} + +define <3 x float> @fmsub_flipped_illegal_v3f32(<3 x float> %c, <3 x float> %a, <3 x float> %b) { +; CHECK-LABEL: fmsub_flipped_illegal_v3f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmls z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <3 x float> %c + %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %a, <3 x float> %b, <3 x float> %neg) + ret <3 x float> %0 +} + +define @fnmsub_illegal_nxv7f16( %a, %b, %c) { +; CHECK-LABEL: fnmsub_illegal_nxv7f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %neg = fneg %a + %neg1 = fneg %c + %0 = tail call @llvm.fmuladd( %neg, %b, %neg1) + ret %0 +} + +define <3 x float> @fnmsub_illegal_v3f32(<3 x float> %a, <3 x float> %b, <3 x float> %c) { +; CHECK-LABEL: fnmsub_illegal_v3f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <3 x float> %a + %neg1 = fneg <3 x float> %c + %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %neg, <3 x float> %b, <3 x float> %neg1) + ret <3 x float> %0 +} + +define <7 x half> @fnmsub_flipped_illegal_v7f16(<7 x half> %c, <7 x half> %a, <7 x half> %b) { +; CHECK-LABEL: fnmsub_flipped_illegal_v7f16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h, vl8 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fnmla z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %neg = fneg <7 x half> %a + %neg1 = fneg <7 x half> %c + %0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1) + ret <7 x half> %0 +}