Skip to content

Commit

Permalink
[AArch64][SME2] Add multi-vector saturating rounding shift right intr…
Browse files Browse the repository at this point in the history
…insics

Adds IR intrinsics for the following SME2 instructions:
 - sqrshr (2 and 4 vector)
 - uqrshr (2 and 4 vector)
 - sqrshru (2 and 4 vector)
 - sqrshrn (4 vector)
 - uqrshrn (4 vector)
 - sqrshrun (4 vector)

Also adds intrinsics for the following SVE2p1 instructions:
 - sqrshrn (2 vector)
 - uqrshrn (2 vector)
 - sqrshrun (2 vector)

NOTE: These intrinsics are still in development and are subject to future changes.

Reviewed By: CarolineConcatto

Differential Revision: https://reviews.llvm.org/D142466
  • Loading branch information
kmclaughlin-arm committed Jan 26, 2023
1 parent 0488ecb commit 034ad11
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 20 deletions.
31 changes: 31 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,19 @@ let TargetPrefix = "aarch64" in {
LLVMMatchType<0>, llvm_i32_ty],
[ImmArg<ArgIndex<6>>]>;

class SME2_VG2_Multi_Imm_Intrinsic
: DefaultAttrsIntrinsic<[LLVMSubdivide2VectorType<0>],
[llvm_anyvector_ty, LLVMMatchType<0>,
llvm_i32_ty],
[IntrNoMem, ImmArg<ArgIndex<2>>]>;

class SME2_VG4_Multi_Imm_Intrinsic
: DefaultAttrsIntrinsic<[LLVMSubdivide4VectorType<0>],
[llvm_anyvector_ty, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>,
llvm_i32_ty],
[IntrNoMem, ImmArg<ArgIndex<4>>]>;

class SME2_ZA_Write_VG2_Intrinsic
: DefaultAttrsIntrinsic<[],
[llvm_i32_ty,
Expand Down Expand Up @@ -2932,6 +2945,24 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sme_fmla_lane_vg1x4 : SME2_Matrix_ArrayVector_VG4_Multi_Index_Intrinsic;
def int_aarch64_sme_fmls_lane_vg1x4 : SME2_Matrix_ArrayVector_VG4_Multi_Index_Intrinsic;

// Multi-vector saturating rounding shift right intrinsics

def int_aarch64_sve_sqrshr_x2 : SME2_VG2_Multi_Imm_Intrinsic;
def int_aarch64_sve_uqrshr_x2 : SME2_VG2_Multi_Imm_Intrinsic;
def int_aarch64_sve_sqrshr_x4 : SME2_VG4_Multi_Imm_Intrinsic;
def int_aarch64_sve_uqrshr_x4 : SME2_VG4_Multi_Imm_Intrinsic;

def int_aarch64_sve_sqrshrn_x2 : SME2_VG2_Multi_Imm_Intrinsic;
def int_aarch64_sve_uqrshrn_x2 : SME2_VG2_Multi_Imm_Intrinsic;
def int_aarch64_sve_sqrshrn_x4 : SME2_VG4_Multi_Imm_Intrinsic;
def int_aarch64_sve_uqrshrn_x4 : SME2_VG4_Multi_Imm_Intrinsic;

def int_aarch64_sve_sqrshru_x2 : SME2_VG2_Multi_Imm_Intrinsic;
def int_aarch64_sve_sqrshru_x4 : SME2_VG4_Multi_Imm_Intrinsic;

def int_aarch64_sve_sqrshrun_x2 : SME2_VG2_Multi_Imm_Intrinsic;
def int_aarch64_sve_sqrshrun_x4 : SME2_VG4_Multi_Imm_Intrinsic;

//
// Multi-vector multiply-add/subtract long
//
Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -619,18 +619,18 @@ defm MOVA_VG4_MXI4Z : sme2_mova_vec_to_array_vg4_multi<"mova", int_aarch64_sme_
defm MOVA_VG2_2ZMXI : sme2_mova_array_to_vec_vg2_multi<0b000, "mova">;
defm MOVA_VG4_4ZMXI : sme2_mova_array_to_vec_vg4_multi<0b1000, "mova">;

defm SQRSHR_VG2_Z2ZI : sme2_sat_shift_vector_vg2<"sqrshr", 0b0, 0b0>;
defm SQRSHR_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshr", 0b000>;
defm SQRSHR_VG2_Z2ZI : sme2_sat_shift_vector_vg2<"sqrshr", 0b0, 0b0, int_aarch64_sve_sqrshr_x2>;
defm SQRSHR_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshr", 0b000, int_aarch64_sve_sqrshr_x4>;

defm UQRSHR_VG2_Z2ZI : sme2_sat_shift_vector_vg2<"uqrshr", 0b0, 0b1>;
defm UQRSHR_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"uqrshr", 0b001>;
defm UQRSHR_VG2_Z2ZI : sme2_sat_shift_vector_vg2<"uqrshr", 0b0, 0b1, int_aarch64_sve_uqrshr_x2>;
defm UQRSHR_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"uqrshr", 0b001, int_aarch64_sve_uqrshr_x4>;

defm SQRSHRU_VG2_Z2ZI : sme2_sat_shift_vector_vg2<"sqrshru", 0b1, 0b0>;
defm SQRSHRU_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshru", 0b010>;
defm SQRSHRU_VG2_Z2ZI : sme2_sat_shift_vector_vg2<"sqrshru", 0b1, 0b0, int_aarch64_sve_sqrshru_x2>;
defm SQRSHRU_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshru", 0b010, int_aarch64_sve_sqrshru_x4>;

defm SQRSHRN_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshrn", 0b100>;
defm UQRSHRN_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"uqrshrn", 0b101>;
defm SQRSHRUN_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshrun", 0b110>;
defm SQRSHRN_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshrn", 0b100, int_aarch64_sve_sqrshrn_x4>;
defm UQRSHRN_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"uqrshrn", 0b101, int_aarch64_sve_uqrshrn_x4>;
defm SQRSHRUN_VG4_Z4ZI : sme2_sat_shift_vector_vg4<"sqrshrun", 0b110, int_aarch64_sve_sqrshrun_x4>;

defm SEL_VG2_2ZP2Z2Z: sme2_sel_vector_vg2<"sel">;
defm SEL_VG4_4ZP4Z4Z: sme2_sel_vector_vg4<"sel">;
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3694,9 +3694,9 @@ defm PTRUE_C : sve2p1_ptrue_pn<"ptrue">;
defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>;
defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;
defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>;
defm SQRSHRN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"sqrshrn", 0b101>;
defm UQRSHRN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"uqrshrn", 0b111>;
defm SQRSHRUN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"sqrshrun", 0b001>;
defm SQRSHRN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"sqrshrn", 0b101, int_aarch64_sve_sqrshrn_x2>;
defm UQRSHRN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"uqrshrn", 0b111, int_aarch64_sve_uqrshrn_x2>;
defm SQRSHRUN_Z2ZI_StoH : sve2p1_multi_vec_shift_narrow<"sqrshrun", 0b001, int_aarch64_sve_sqrshrun_x2>;

// Load to two registers
def LD1B_2Z : sve2p1_mem_cld_ss_2z<"ld1b", 0b00, 0b0, ZZ_b_mul_r, GPR64shifted8>;
Expand Down
24 changes: 19 additions & 5 deletions llvm/lib/Target/AArch64/SMEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic
(REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3),
zpr_ty:$Zm, imm_ty:$i)>;

class SME2_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
: Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, (i32 imm_ty:$i))),
(!cast<Instruction>(name) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1), imm_ty:$i)>;

class SME2_Sat_Shift_VG4_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
: Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, in_vt:$Zn3, in_vt:$Zn4, (i32 imm_ty:$i))),
(!cast<Instruction>(name) (REG_SEQUENCE ZPR4Mul4, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1, in_vt:$Zn3, zsub2, in_vt:$Zn4, zsub3),
imm_ty:$i)>;

class SME2_Cvt_VG4_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt>
: Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, in_vt:$Zn3, in_vt:$Zn4)),
(!cast<Instruction>(name) (REG_SEQUENCE ZPR4Mul4, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1, in_vt:$Zn3, zsub2, in_vt:$Zn4, zsub3))>;
Expand Down Expand Up @@ -3999,7 +4008,7 @@ multiclass sme2_mova_array_to_vec_vg4_multi<bits<4> opc, string mnemonic> {
//===----------------------------------------------------------------------===//
// SME2 multi-vec saturating shift right narrow
class sme2_sat_shift_vector_vg2<string mnemonic, bit op, bit u>
: I<(outs ZPR16:$Zd), (ins ZZ_s_mul_r:$Zn, vecshiftR16:$imm4),
: I<(outs ZPR16:$Zd), (ins ZZ_s_mul_r:$Zn, tvecshiftR16:$imm4),
mnemonic, "\t$Zd, $Zn, $imm4",
"", []>, Sched<[]> {
bits<4> imm4;
Expand All @@ -4014,8 +4023,10 @@ class sme2_sat_shift_vector_vg2<string mnemonic, bit op, bit u>
let Inst{4-0} = Zd;
}

multiclass sme2_sat_shift_vector_vg2<string mnemonic, bit op, bit u> {
multiclass sme2_sat_shift_vector_vg2<string mnemonic, bit op, bit u, SDPatternOperator intrinsic> {
def _H : sme2_sat_shift_vector_vg2<mnemonic, op, u>;

def : SME2_Sat_Shift_VG2_Pat<NAME # _H, intrinsic, nxv8i16, nxv4i32, tvecshiftR16>;
}

class sme2_sat_shift_vector_vg4<bits<2> sz, bits<3> op, ZPRRegOp zpr_ty,
Expand All @@ -4037,18 +4048,21 @@ class sme2_sat_shift_vector_vg4<bits<2> sz, bits<3> op, ZPRRegOp zpr_ty,
let Inst{4-0} = Zd;
}

multiclass sme2_sat_shift_vector_vg4<string mnemonic, bits<3> op> {
def _B : sme2_sat_shift_vector_vg4<{0,1}, op, ZPR8, ZZZZ_s_mul_r, vecshiftR32,
multiclass sme2_sat_shift_vector_vg4<string mnemonic, bits<3> op, SDPatternOperator intrinsic> {
def _B : sme2_sat_shift_vector_vg4<{0,1}, op, ZPR8, ZZZZ_s_mul_r, tvecshiftR32,
mnemonic>{
bits<5> imm;
let Inst{20-16} = imm;
}
def _H : sme2_sat_shift_vector_vg4<{1,?}, op, ZPR16, ZZZZ_d_mul_r, vecshiftR64,
def _H : sme2_sat_shift_vector_vg4<{1,?}, op, ZPR16, ZZZZ_d_mul_r, tvecshiftR64,
mnemonic> {
bits<6> imm;
let Inst{22} = imm{5};
let Inst{20-16} = imm{4-0};
}

def : SME2_Sat_Shift_VG4_Pat<NAME # _B, intrinsic, nxv16i8, nxv4i32, tvecshiftR32>;
def : SME2_Sat_Shift_VG4_Pat<NAME # _H, intrinsic, nxv8i16, nxv2i64, tvecshiftR64>;
}

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 9 additions & 3 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,10 @@ class SVE_Shift_Add_All_Active_Pat<ValueType vtd, SDPatternOperator op, ValueTyp
: Pat<(vtd (add vt1:$Op1, (op (pt (SVEAllActive)), vt2:$Op2, vt3:$Op3))),
(inst $Op1, $Op2, $Op3)>;

class SVE2p1_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
: Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, (i32 imm_ty:$i))),
(!cast<Instruction>(name) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1), imm_ty:$i)>;

class SVE2p1_Cvt_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt>
: Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2)),
(!cast<Instruction>(name) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1))>;
Expand Down Expand Up @@ -9126,7 +9130,7 @@ multiclass sve2p1_multi_vec_extract_narrow<string mnemonic, bits<2> opc, SDPatte

// SVE2 multi-vec shift narrow
class sve2p1_multi_vec_shift_narrow<string mnemonic, bits<3> opc, bits<2> tsz>
: I<(outs ZPR16:$Zd), (ins ZZ_s_mul_r:$Zn, vecshiftR16:$imm4),
: I<(outs ZPR16:$Zd), (ins ZZ_s_mul_r:$Zn, tvecshiftR16:$imm4),
mnemonic, "\t$Zd, $Zn, $imm4",
"", []>, Sched<[]> {
bits<5> Zd;
Expand All @@ -9147,8 +9151,10 @@ class sve2p1_multi_vec_shift_narrow<string mnemonic, bits<3> opc, bits<2> tsz>
let hasSideEffects = 0;
}

multiclass sve2p1_multi_vec_shift_narrow<string mnemonic, bits<3> opc> {
def : sve2p1_multi_vec_shift_narrow<mnemonic, opc, 0b01>;
multiclass sve2p1_multi_vec_shift_narrow<string mnemonic, bits<3> opc, SDPatternOperator intrinsic> {
def NAME : sve2p1_multi_vec_shift_narrow<mnemonic, opc, 0b01>;

def : SVE2p1_Sat_Shift_VG2_Pat<NAME, intrinsic, nxv8i16, nxv4i32, tvecshiftR16>;
}


Expand Down
Loading

0 comments on commit 034ad11

Please sign in to comment.