diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 1e30735b7a56a..36c9cb6c1d94f 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -707,16 +707,14 @@ let Predicates = [HasSVE_or_SME] in { defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>; defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>; - let Predicates = [HasSVE_or_SME] in { - def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)), - (UDOT_ZZZ_BtoS $Acc, $MulLHS, $MulRHS)>; - def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)), - (SDOT_ZZZ_BtoS $Acc, $MulLHS, $MulRHS)>; - def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), - (UDOT_ZZZ_HtoD $Acc, $MulLHS, $MulRHS)>; - def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), - (SDOT_ZZZ_HtoD $Acc, $MulLHS, $MulRHS)>; - } // End HasSVE_or_SME + def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)), + (UDOT_ZZZ_BtoS $Acc, $MulLHS, $MulRHS)>; + def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)), + (SDOT_ZZZ_BtoS $Acc, $MulLHS, $MulRHS)>; + def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), + (UDOT_ZZZ_HtoD $Acc, $MulLHS, $MulRHS)>; + def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), + (SDOT_ZZZ_HtoD $Acc, $MulLHS, $MulRHS)>; defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>; defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>; @@ -3646,6 +3644,9 @@ let Predicates = [HasSVE_or_SME, HasMatMulInt8] in { defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", AArch64usdot>; defm USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot", int_aarch64_sve_usdot_lane>; defm SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot", int_aarch64_sve_sudot_lane>; + + def : Pat<(nxv4i32 (partial_reduce_sumla nxv4i32:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), + (USDOT_ZZZ $Acc, $RHS, $LHS)>; } // End HasSVE_or_SME, HasMatMulInt8 let Predicates = [HasSVE, HasMatMulFP32] in { @@ -3752,6 +3753,19 @@ let Predicates = [HasSVE2_or_SME] in { defm UMLSLB_ZZZ : sve2_int_mla_long<0b10110, "umlslb", int_aarch64_sve_umlslb>; defm UMLSLT_ZZZ : sve2_int_mla_long<0b10111, "umlslt", int_aarch64_sve_umlslt>; + def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + (UMLALT_ZZZ_D (UMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + (SMLALT_ZZZ_D (SMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), + (UMLALT_ZZZ_S (UMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), + (SMLALT_ZZZ_S (SMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), + (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), + (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + // SVE2 saturating multiply-add long (indexed) defm SQDMLALB_ZZZI : sve2_int_mla_long_by_indexed_elem<0b0100, "sqdmlalb", int_aarch64_sve_sqdmlalb_lane>; defm SQDMLALT_ZZZI : sve2_int_mla_long_by_indexed_elem<0b0101, "sqdmlalt", int_aarch64_sve_sqdmlalt_lane>; @@ -3880,19 +3894,6 @@ let Predicates = [HasSVE2_or_SME] in { def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))), (SADDWT_ZZZ_H (SADDWB_ZZZ_H $Acc, $Input), $Input)>; - def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), - (UMLALT_ZZZ_D (UMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), - (SMLALT_ZZZ_D (SMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), - (UMLALT_ZZZ_S (UMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), - (SMLALT_ZZZ_S (SMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), - (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), - (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; - // SVE2 integer multiply long defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>; defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>; @@ -4200,11 +4201,6 @@ let Predicates = [HasSVEAES2, HasNonStreamingSVE_or_SSVE_AES] in { def PMULL_2ZZZ_Q : sve_crypto_pmull_multi<"pmull">; } -let Predicates = [HasSVE_or_SME, HasMatMulInt8] in { - def : Pat<(nxv4i32 (partial_reduce_sumla nxv4i32:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), - (USDOT_ZZZ $Acc, $RHS, $LHS)>; - } // End HasSVE_or_SME, HasMatMulInt8 - //===----------------------------------------------------------------------===// // SME or SVE2.1 instructions //===----------------------------------------------------------------------===// @@ -4238,12 +4234,10 @@ defm UDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"udot", 0b1, int_aarch64_sve_udot_x2 defm SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0, int_aarch64_sve_sdot_lane_x2>; defm UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1, int_aarch64_sve_udot_lane_x2>; -let Predicates = [HasSVE2p1_or_SME2] in { - def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), - (UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>; - def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), - (SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>; -} // End HasSVE2p1_or_SME2 +def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), + (UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>; +def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), + (SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>; defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>; defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;