Skip to content

Commit

Permalink
[SME2][AArch64] Add multi-multi multiply-add long long intrinsics
Browse files Browse the repository at this point in the history
Adds intrinsics for the following SME2 instructions (2 & 4 vectors):
 - smlall
 - smlsll
 - umlall
 - umlsll
 - usmlall

NOTE: These intrinsics are still in development and are subject to future changes.

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D143277
  • Loading branch information
kmclaughlin-arm committed Feb 20, 2023
1 parent 8b3091b commit 028c722
Show file tree
Hide file tree
Showing 4 changed files with 436 additions and 22 deletions.
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Expand Up @@ -3064,6 +3064,9 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sme_ # ty # instr # _ # za # _single_vg4x1 : SME2_Matrix_ArrayVector_Single_Single_Intrinsic;
def int_aarch64_sme_ # ty # instr # _ # za # _single_vg4x2 : SME2_Matrix_ArrayVector_VG2_Multi_Single_Intrinsic;
def int_aarch64_sme_ # ty # instr # _ # za # _single_vg4x4 : SME2_Matrix_ArrayVector_VG4_Multi_Single_Intrinsic;

def int_aarch64_sme_ # ty # instr # _ # za # _vg4x2 : SME2_Matrix_ArrayVector_VG2_Multi_Multi_Intrinsic;
def int_aarch64_sme_ # ty # instr # _ # za # _vg4x4 : SME2_Matrix_ArrayVector_VG4_Multi_Multi_Intrinsic;
}
}
}
Expand All @@ -3075,6 +3078,9 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sme_usmla_za32_single_vg4x2 : SME2_Matrix_ArrayVector_VG2_Multi_Single_Intrinsic;
def int_aarch64_sme_usmla_za32_single_vg4x4 : SME2_Matrix_ArrayVector_VG4_Multi_Single_Intrinsic;

def int_aarch64_sme_usmla_za32_vg4x2 : SME2_Matrix_ArrayVector_VG2_Multi_Multi_Intrinsic;
def int_aarch64_sme_usmla_za32_vg4x4 : SME2_Matrix_ArrayVector_VG4_Multi_Multi_Intrinsic;

// Multi-vector signed saturating doubling multiply high

def int_aarch64_sve_sqdmulh_single_vgx2 : SME2_VG2_Multi_Single_Intrinsic;
Expand Down
36 changes: 18 additions & 18 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Expand Up @@ -518,35 +518,35 @@ defm SMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"smlall", 0b000>;
defm SMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"smlall", 0b0000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, int_aarch64_sme_smla_za32_single_vg4x1>;
defm SMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"smlall", 0b00000, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_smla_za32_single_vg4x2>;
defm SMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"smlall", 0b01000, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_smla_za32_single_vg4x4>;
defm SMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"smlall", 0b0000, MatrixOp32, ZZ_b_mul_r>;
defm SMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"smlall", 0b0000, MatrixOp32, ZZZZ_b_mul_r>;
defm SMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"smlall", 0b0000, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_smla_za32_vg4x2>;
defm SMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"smlall", 0b0000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_smla_za32_vg4x4>;

def USMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"usmlall", 0b001>;
defm USMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"usmlall", 0b100>;
defm USMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"usmlall", 0b100>;
defm USMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"usmlall", 0b0001, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, int_aarch64_sme_usmla_za32_single_vg4x1>;
defm USMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"usmlall", 0b00001, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_usmla_za32_single_vg4x2>;
defm USMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"usmlall", 0b01001, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_usmla_za32_single_vg4x4>;
defm USMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"usmlall", 0b0001, MatrixOp32, ZZ_b_mul_r>;
defm USMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"usmlall", 0b0001, MatrixOp32, ZZZZ_b_mul_r>;
defm USMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"usmlall", 0b0001, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_usmla_za32_vg4x2>;
defm USMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"usmlall", 0b0001, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_usmla_za32_vg4x4>;

def SMLSLL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"smlsll", 0b010>;
defm SMLSLL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"smlsll", 0b001>;
defm SMLSLL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"smlsll", 0b001>;
defm SMLSLL_MZZ_BtoS : sme2_mla_ll_array_single<"smlsll", 0b0010, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, int_aarch64_sme_smls_za32_single_vg4x1>;
defm SMLSLL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"smlsll", 0b00010, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_smls_za32_single_vg4x2>;
defm SMLSLL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"smlsll", 0b01010, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_smls_za32_single_vg4x4>;
defm SMLSLL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"smlsll", 0b0010, MatrixOp32, ZZ_b_mul_r>;
defm SMLSLL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"smlsll", 0b0010, MatrixOp32, ZZZZ_b_mul_r>;
defm SMLSLL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"smlsll", 0b0010, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_smls_za32_vg4x2>;
defm SMLSLL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"smlsll", 0b0010, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_smls_za32_vg4x4>;

def UMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"umlall", 0b100>;
defm UMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"umlall", 0b010>;
defm UMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"umlall", 0b010>;
defm UMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"umlall", 0b0100, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, int_aarch64_sme_umla_za32_single_vg4x1>;
defm UMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"umlall", 0b00100, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_umla_za32_single_vg4x2>;
defm UMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"umlall", 0b01100, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_umla_za32_single_vg4x4>;
defm UMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"umlall", 0b0100, MatrixOp32, ZZ_b_mul_r>;
defm UMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"umlall", 0b0100, MatrixOp32, ZZZZ_b_mul_r>;
defm UMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"umlall", 0b0100, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_umla_za32_vg4x2>;
defm UMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"umlall", 0b0100, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_umla_za32_vg4x4>;

def SUMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"sumlall", 0b101>;
defm SUMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"sumlall", 0b110>;
Expand All @@ -560,8 +560,8 @@ defm UMLSLL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"umlsll", 0b011>;
defm UMLSLL_MZZ_BtoS : sme2_mla_ll_array_single<"umlsll", 0b0110, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, int_aarch64_sme_umls_za32_single_vg4x1>;
defm UMLSLL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"umlsll", 0b00110, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_umls_za32_single_vg4x2>;
defm UMLSLL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"umlsll", 0b01110, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_umls_za32_single_vg4x4>;
defm UMLSLL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"umlsll", 0b0110, MatrixOp32, ZZ_b_mul_r>;
defm UMLSLL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"umlsll", 0b0110, MatrixOp32, ZZZZ_b_mul_r>;
defm UMLSLL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"umlsll", 0b0110, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_umls_za32_vg4x2>;
defm UMLSLL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"umlsll", 0b0110, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_umls_za32_vg4x4>;

defm BMOPA_MPPZZ_S : sme2_int_bmopx_tile<"bmopa", 0b100, int_aarch64_sme_bmopa_za32>;
defm BMOPS_MPPZZ_S : sme2_int_bmopx_tile<"bmops", 0b101, int_aarch64_sme_bmops_za32>;
Expand Down Expand Up @@ -745,35 +745,35 @@ defm SMLALL_VG4_M4ZZI_HtoD : sme2_mla_ll_array_vg4_index_64b<"smlall", 0b00>;
defm SMLALL_MZZ_HtoD : sme2_mla_ll_array_single<"smlall", 0b1000, MatrixOp64, ZPR16, ZPR4b16, nxv8i16, int_aarch64_sme_smla_za64_single_vg4x1>;
defm SMLALL_VG2_M2ZZ_HtoD : sme2_mla_ll_array_vg2_single<"smlall", 0b10000, MatrixOp64, ZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_smla_za64_single_vg4x2>;
defm SMLALL_VG4_M4ZZ_HtoD : sme2_mla_ll_array_vg4_single<"smlall", 0b11000, MatrixOp64, ZZZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_smla_za64_single_vg4x4>;
defm SMLALL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"smlall", 0b1000, MatrixOp64, ZZ_h_mul_r>;
defm SMLALL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"smlall", 0b1000, MatrixOp64, ZZZZ_h_mul_r>;
defm SMLALL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"smlall", 0b1000, MatrixOp64, ZZ_h_mul_r, nxv8i16, int_aarch64_sme_smla_za64_vg4x2>;
defm SMLALL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"smlall", 0b1000, MatrixOp64, ZZZZ_h_mul_r, nxv8i16, int_aarch64_sme_smla_za64_vg4x4>;

def SMLSLL_MZZI_HtoD : sme2_mla_ll_array_index_64b<"smlsll", 0b01>;
defm SMLSLL_VG2_M2ZZI_HtoD : sme2_mla_ll_array_vg2_index_64b<"smlsll", 0b01>;
defm SMLSLL_VG4_M4ZZI_HtoD : sme2_mla_ll_array_vg4_index_64b<"smlsll", 0b01>;
defm SMLSLL_MZZ_HtoD : sme2_mla_ll_array_single<"smlsll", 0b1010, MatrixOp64, ZPR16, ZPR4b16, nxv8i16, int_aarch64_sme_smls_za64_single_vg4x1>;
defm SMLSLL_VG2_M2ZZ_HtoD : sme2_mla_ll_array_vg2_single<"smlsll", 0b10010, MatrixOp64, ZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_smls_za64_single_vg4x2>;
defm SMLSLL_VG4_M4ZZ_HtoD : sme2_mla_ll_array_vg4_single<"smlsll", 0b11010, MatrixOp64, ZZZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_smls_za64_single_vg4x4>;
defm SMLSLL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"smlsll", 0b1010, MatrixOp64, ZZ_h_mul_r>;
defm SMLSLL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"smlsll", 0b1010, MatrixOp64, ZZZZ_h_mul_r>;
defm SMLSLL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"smlsll", 0b1010, MatrixOp64, ZZ_h_mul_r, nxv8i16, int_aarch64_sme_smls_za64_vg4x2>;
defm SMLSLL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"smlsll", 0b1010, MatrixOp64, ZZZZ_h_mul_r, nxv8i16, int_aarch64_sme_smls_za64_vg4x4>;

def UMLALL_MZZI_HtoD : sme2_mla_ll_array_index_64b<"umlall", 0b10>;
defm UMLALL_VG2_M2ZZI_HtoD : sme2_mla_ll_array_vg2_index_64b<"umlall", 0b10>;
defm UMLALL_VG4_M4ZZI_HtoD : sme2_mla_ll_array_vg4_index_64b<"umlall", 0b10>;
defm UMLALL_MZZ_HtoD : sme2_mla_ll_array_single<"umlall", 0b1100, MatrixOp64, ZPR16, ZPR4b16, nxv8i16, int_aarch64_sme_umla_za64_single_vg4x1>;
defm UMLALL_VG2_M2ZZ_HtoD : sme2_mla_ll_array_vg2_single<"umlall", 0b10100, MatrixOp64, ZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_umla_za64_single_vg4x2>;
defm UMLALL_VG4_M4ZZ_HtoD : sme2_mla_ll_array_vg4_single<"umlall", 0b11100, MatrixOp64, ZZZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_umla_za64_single_vg4x4>;
defm UMLALL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"umlall", 0b1100, MatrixOp64, ZZ_h_mul_r>;
defm UMLALL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"umlall", 0b1100, MatrixOp64, ZZZZ_h_mul_r>;
defm UMLALL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"umlall", 0b1100, MatrixOp64, ZZ_h_mul_r, nxv8i16, int_aarch64_sme_umla_za64_vg4x2>;
defm UMLALL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"umlall", 0b1100, MatrixOp64, ZZZZ_h_mul_r, nxv8i16, int_aarch64_sme_umla_za64_vg4x4>;

def UMLSLL_MZZI_HtoD : sme2_mla_ll_array_index_64b<"umlsll", 0b11>;
defm UMLSLL_VG2_M2ZZI_HtoD : sme2_mla_ll_array_vg2_index_64b<"umlsll", 0b11>;
defm UMLSLL_VG4_M4ZZI_HtoD : sme2_mla_ll_array_vg4_index_64b<"umlsll", 0b11>;
defm UMLSLL_MZZ_HtoD : sme2_mla_ll_array_single<"umlsll", 0b1110, MatrixOp64, ZPR16, ZPR4b16, nxv8i16, int_aarch64_sme_umls_za64_single_vg4x1>;
defm UMLSLL_VG2_M2ZZ_HtoD : sme2_mla_ll_array_vg2_single<"umlsll", 0b10110, MatrixOp64, ZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_umls_za64_single_vg4x2>;
defm UMLSLL_VG4_M4ZZ_HtoD : sme2_mla_ll_array_vg4_single<"umlsll", 0b11110, MatrixOp64, ZZZZ_h, ZPR4b16, nxv8i16, int_aarch64_sme_umls_za64_single_vg4x4>;
defm UMLSLL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"umlsll", 0b1110, MatrixOp64, ZZ_h_mul_r>;
defm UMLSLL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"umlsll", 0b1110, MatrixOp64, ZZZZ_h_mul_r>;
defm UMLSLL_VG2_M2Z2Z_HtoD : sme2_mla_ll_array_vg2_multi<"umlsll", 0b1110, MatrixOp64, ZZ_h_mul_r, nxv8i16, int_aarch64_sme_umls_za64_vg4x2>;
defm UMLSLL_VG4_M4Z4Z_HtoD : sme2_mla_ll_array_vg4_multi<"umlsll", 0b1110, MatrixOp64, ZZZZ_h_mul_r, nxv8i16, int_aarch64_sme_umls_za64_vg4x4>;
}

let Predicates = [HasSME2, HasSMEF64F64] in {
Expand Down
18 changes: 14 additions & 4 deletions llvm/lib/Target/AArch64/SMEInstrFormats.td
Expand Up @@ -2817,8 +2817,13 @@ class sme2_mla_ll_array_vg2_multi<bits<4> op, MatrixOperand matrix_ty,

multiclass sme2_mla_ll_array_vg2_multi<string mnemonic, bits<4> op,
MatrixOperand matrix_ty,
RegisterOperand vector_ty> {
def NAME : sme2_mla_ll_array_vg2_multi<op, matrix_ty, vector_ty, mnemonic>;
RegisterOperand vector_ty,
ValueType vt, SDPatternOperator intrinsic> {
def NAME : sme2_mla_ll_array_vg2_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1>;

def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm1s4range, vector_ty, SMEMatrixArray>;

def : SME2_ZA_TwoOp_VG2_Multi_Multi_Pat<NAME, intrinsic, uimm1s4range, vt, tileslicerange1s4>;

def : InstAlias<mnemonic # "\t$ZAda[$Rv, $imm], $Zn, $Zm",
(!cast<Instruction>(NAME) matrix_ty:$ZAda, MatrixIndexGPR32Op8_11:$Rv, uimm1s4range:$imm, vector_ty:$Zn, vector_ty:$Zm), 0>;
Expand Down Expand Up @@ -2856,8 +2861,13 @@ class sme2_mla_ll_array_vg4_multi<bits<4> op,MatrixOperand matrix_ty,

multiclass sme2_mla_ll_array_vg4_multi<string mnemonic, bits<4> op,
MatrixOperand matrix_ty,
RegisterOperand vector_ty> {
def NAME : sme2_mla_ll_array_vg4_multi<op, matrix_ty, vector_ty, mnemonic>;
RegisterOperand vector_ty,
ValueType vt, SDPatternOperator intrinsic> {
def NAME : sme2_mla_ll_array_vg4_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1>;

def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm1s4range, vector_ty, SMEMatrixArray>;

def : SME2_ZA_TwoOp_VG4_Multi_Multi_Pat<NAME, intrinsic, uimm1s4range, vt, tileslicerange1s4>;

def : InstAlias<mnemonic # "\t$ZAda[$Rv, $imm], $Zn, $Zm",
(!cast<Instruction>(NAME) matrix_ty:$ZAda, MatrixIndexGPR32Op8_11:$Rv, uimm1s4range:$imm, vector_ty:$Zn, vector_ty:$Zm), 0>;
Expand Down

0 comments on commit 028c722

Please sign in to comment.