diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index ad404e8dab2ad..0710c654a95df 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -204,10 +204,18 @@ def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>; def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>; def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>; +def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3), + (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{ + return N->getFlags().hasAllowContract(); +}]>; def AArch64fadd_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3), (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{ return N->getFlags().hasNoSignedZeros(); }]>; +def AArch64fsub_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3), + (AArch64fsub_p node:$op1, node:$op2, node:$op3), [{ + return N->getFlags().hasAllowContract(); +}]>; def AArch64fsub_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3), (AArch64fsub_p node:$op1, node:$op2, node:$op3), [{ return N->getFlags().hasNoSignedZeros(); @@ -363,14 +371,12 @@ def AArch64fabd_p : PatFrags<(ops node:$pg, node:$op1, node:$op2), (AArch64fabs_mt node:$pg, (AArch64fsub_p node:$pg, node:$op1, node:$op2), undef)]>; def AArch64fmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), - [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), - (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), node:$zn, node:$zm, node:$za), node:$za)]>; + [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za)]>; def AArch64fmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), [(int_aarch64_sve_fmls_u node:$pg, node:$za, node:$zn, node:$zm), (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, node:$za), - (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za), - (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), (AArch64fneg_mt (AArch64ptrue 31), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>; + (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za)]>; def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm), @@ -423,18 +429,15 @@ def AArch64eor3 : PatFrags<(ops node:$op1, node:$op2, node:$op3), [(int_aarch64_sve_eor3 node:$op1, node:$op2, node:$op3), (xor node:$op1, (xor node:$op2, node:$op3))]>; -class fma_patfrags - : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), - [(intrinsic node:$pred, node:$op1, node:$op2, node:$op3), - (vselect node:$pred, (add (SVEAllActive), node:$op1, (AArch64fmul_p_oneuse (SVEAllActive), node:$op2, node:$op3)), node:$op1)], -[{ - if (N->getOpcode() == ISD::VSELECT) - return N->getOperand(1)->getFlags().hasAllowContract(); - return true; // it's the intrinsic -}]>; +def AArch64fmla_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(int_aarch64_sve_fmla node:$pg, node:$za, node:$zn, node:$zm), + (vselect node:$pg, (AArch64fadd_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za), + (vselect node:$pg, (AArch64fma_p (SVEAllActive), node:$zn, node:$zm, node:$za), node:$za)]>; -def AArch64fmla_m1 : fma_patfrags; -def AArch64fmls_m1 : fma_patfrags; +def AArch64fmls_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(int_aarch64_sve_fmls node:$pg, node:$za, node:$zn, node:$zm), + (vselect node:$pg, (AArch64fsub_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za), + (vselect node:$pg, (AArch64fma_p (SVEAllActive), (AArch64fneg_mt (SVEAllActive), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>; def AArch64add_m1 : VSelectUnpredOrPassthruPatFrags; def AArch64sub_m1 : VSelectUnpredOrPassthruPatFrags; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 118862b8c317c..c4c0dca114ce7 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -2317,7 +2317,10 @@ multiclass sve_fp_3op_p_zds_a opc, string asm, string Ps, SVEPseudo2Instr, SVEInstr2Rev; def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _H)>; def : SVE_4_Op_Pat(NAME # _S)>; + def : SVE_4_Op_Pat(NAME # _S)>; def : SVE_4_Op_Pat(NAME # _D)>; } diff --git a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll index 14471584bf286..e53f76f651212 100644 --- a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll @@ -1271,7 +1271,8 @@ define @fadd_sel_fmul_no_contract_s( %a define @fma_sel_h_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fma_sel_h_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %mul.add = call @llvm.fma.nxv8f16( %m1, %m2, %acc) %masked.mul.add = select %pred, %mul.add, %acc @@ -1281,7 +1282,8 @@ define @fma_sel_h_different_arg_order( %pre define @fma_sel_s_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fma_sel_s_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %mul.add = call @llvm.fma.nxv4f32( %m1, %m2, %acc) %masked.mul.add = select %pred, %mul.add, %acc @@ -1291,7 +1293,8 @@ define @fma_sel_s_different_arg_order( %pr define @fma_sel_d_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fma_sel_d_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: fmla z2.d, p0/m, z0.d, z1.d +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %mul.add = call @llvm.fma.nxv2f64( %m1, %m2, %acc) %masked.mul.add = select %pred, %mul.add, %acc @@ -1301,7 +1304,8 @@ define @fma_sel_d_different_arg_order( %p define @fnma_sel_h_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fnma_sel_h_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmsb z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: fmls z2.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %neg_m1 = fneg contract %m1 %mul.add = call @llvm.fma.nxv8f16( %neg_m1, %m2, %acc) @@ -1312,7 +1316,8 @@ define @fnma_sel_h_different_arg_order( %pr define @fnma_sel_s_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fnma_sel_s_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmsb z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: fmls z2.s, p0/m, z0.s, z1.s +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %neg_m1 = fneg contract %m1 %mul.add = call @llvm.fma.nxv4f32( %neg_m1, %m2, %acc) @@ -1323,7 +1328,8 @@ define @fnma_sel_s_different_arg_order( %p define @fnma_sel_d_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fnma_sel_d_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmsb z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: fmls z2.d, p0/m, z0.d, z1.d +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %neg_m1 = fneg contract %m1 %mul.add = call @llvm.fma.nxv2f64( %neg_m1, %m2, %acc)