Skip to content

Commit

Permalink
[AArch64][SME] Remove immediate argument restriction for svldr and sv…
Browse files Browse the repository at this point in the history
…str (#68565)

The svldr_vnum and svstr_vnum builtins always modify the base register
and tile slice and provide immediate offsets of zero, even when the
offset provided to the builtin is an immediate. This patch optimises the
output of the builtins when the offset is an immediate, to pass it
directly to the instruction and to not need the base register and tile
slice updates.
  • Loading branch information
SamTebbs33 committed Nov 20, 2023
1 parent bfd3734 commit f7b5c25
Show file tree
Hide file tree
Showing 11 changed files with 569 additions and 169 deletions.
16 changes: 4 additions & 12 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9886,18 +9886,10 @@ Value *CodeGenFunction::EmitSMEZero(const SVETypeFlags &TypeFlags,
Value *CodeGenFunction::EmitSMELdrStr(const SVETypeFlags &TypeFlags,
SmallVectorImpl<Value *> &Ops,
unsigned IntID) {
if (Ops.size() == 3) {
Function *Cntsb = CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb);
llvm::Value *CntsbCall = Builder.CreateCall(Cntsb, {}, "svlb");

llvm::Value *VecNum = Ops[2];
llvm::Value *MulVL = Builder.CreateMul(CntsbCall, VecNum, "mulvl");

Ops[1] = Builder.CreateGEP(Int8Ty, Ops[1], MulVL);
Ops[0] = Builder.CreateAdd(
Ops[0], Builder.CreateIntCast(VecNum, Int32Ty, true), "tileslice");
Ops.erase(&Ops[2]);
}
if (Ops.size() == 2)
Ops.push_back(Builder.getInt32(0));
else
Ops[2] = Builder.CreateIntCast(Ops[2], Int32Ty, true);
Function *F = CGM.getIntrinsic(IntID, {});
return Builder.CreateCall(F, Ops);
}
Expand Down
95 changes: 31 additions & 64 deletions clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,86 +6,53 @@

#include <arm_sme_draft_spec_subject_to_change.h>

// CHECK-C-LABEL: define dso_local void @test_svldr_vnum_za(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z18test_svldr_vnum_zajPKv(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svldr_vnum_za(
// CHECK-CXX-LABEL: @_Z18test_svldr_vnum_zajPKv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
// CHECK-NEXT: ret void
//
void test_svldr_vnum_za(uint32_t slice_base, const void *ptr) {
svldr_vnum_za(slice_base, ptr, 0);
}

// CHECK-C-LABEL: define dso_local void @test_svldr_vnum_za_1(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-C-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], 15
// CHECK-C-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-C-NEXT: [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z20test_svldr_vnum_za_1jPKv(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-CXX-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], 15
// CHECK-CXX-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-CXX-NEXT: [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svldr_vnum_za_1(
// CHECK-CXX-LABEL: @_Z20test_svldr_vnum_za_1jPKv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 15)
// CHECK-NEXT: ret void
//
void test_svldr_vnum_za_1(uint32_t slice_base, const void *ptr) {
svldr_vnum_za(slice_base, ptr, 15);
}

// CHECK-C-LABEL: define dso_local void @test_svldr_za(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z13test_svldr_zajPKv(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svldr_za(
// CHECK-CXX-LABEL: @_Z13test_svldr_zajPKv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
// CHECK-NEXT: ret void
//
void test_svldr_za(uint32_t slice_base, const void *ptr) {
svldr_za(slice_base, ptr);
}

// CHECK-C-LABEL: define dso_local void @test_svldr_vnum_za_var(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-C-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
// CHECK-C-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-C-NEXT: [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
// CHECK-C-NEXT: [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z22test_svldr_vnum_za_varjPKvl(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-CXX-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
// CHECK-CXX-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-CXX-NEXT: [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
// CHECK-CXX-NEXT: [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svldr_vnum_za_var(
// CHECK-CXX-LABEL: @_Z22test_svldr_vnum_za_varjPKvl(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[VNUM:%.*]] to i32
// CHECK-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 [[TMP0:%.*]])
// CHECK-NEXT: ret void
//
void test_svldr_vnum_za_var(uint32_t slice_base, const void *ptr, int64_t vnum) {
svldr_vnum_za(slice_base, ptr, vnum);
}
//// NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
// CHECK: {{.*}}

// CHECK-C-LABEL: @test_svldr_vnum_za_2(
// CHECK-CXX-LABEL: @_Z20test_svldr_vnum_za_2jPKv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 16)
// CHECK-NEXT: ret void
//
void test_svldr_vnum_za_2(uint32_t slice_base, const void *ptr) {
svldr_vnum_za(slice_base, ptr, 16);
}
95 changes: 31 additions & 64 deletions clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,86 +6,53 @@

#include <arm_sme_draft_spec_subject_to_change.h>

// CHECK-C-LABEL: define dso_local void @test_svstr_vnum_za(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z18test_svstr_vnum_zajPv(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svstr_vnum_za(
// CHECK-CXX-LABEL: @_Z18test_svstr_vnum_zajPv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
// CHECK-NEXT: ret void
//
void test_svstr_vnum_za(uint32_t slice_base, void *ptr) {
svstr_vnum_za(slice_base, ptr, 0);
}

// CHECK-C-LABEL: define dso_local void @test_svstr_vnum_za_1(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-C-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], 15
// CHECK-C-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-C-NEXT: [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z20test_svstr_vnum_za_1jPv(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-CXX-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], 15
// CHECK-CXX-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-CXX-NEXT: [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svstr_vnum_za_1(
// CHECK-CXX-LABEL: @_Z20test_svstr_vnum_za_1jPv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 15)
// CHECK-NEXT: ret void
//
void test_svstr_vnum_za_1(uint32_t slice_base, void *ptr) {
svstr_vnum_za(slice_base, ptr, 15);
}

// CHECK-C-LABEL: define dso_local void @test_svstr_za(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z13test_svstr_zajPv(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svstr_za(
// CHECK-CXX-LABEL: @_Z13test_svstr_zajPv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
// CHECK-NEXT: ret void
//
void test_svstr_za(uint32_t slice_base, void *ptr) {
svstr_za(slice_base, ptr);
}

// CHECK-C-LABEL: define dso_local void @test_svstr_vnum_za_var(
// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
// CHECK-C-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-C-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
// CHECK-C-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-C-NEXT: [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
// CHECK-C-NEXT: [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
// CHECK-C-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-C-NEXT: ret void
//
// CHECK-CXX-LABEL: define dso_local void @_Z22test_svstr_vnum_za_varjPvl(
// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
// CHECK-CXX-NEXT: [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
// CHECK-CXX-NEXT: [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
// CHECK-CXX-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
// CHECK-CXX-NEXT: [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
// CHECK-CXX-NEXT: [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
// CHECK-CXX-NEXT: ret void
// CHECK-C-LABEL: @test_svstr_vnum_za_var(
// CHECK-CXX-LABEL: @_Z22test_svstr_vnum_za_varjPvl(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[VNUM:%.*]] to i32
// CHECK-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 [[TMP0:%.*]])
// CHECK-NEXT: ret void
//
void test_svstr_vnum_za_var(uint32_t slice_base, void *ptr, int64_t vnum) {
svstr_vnum_za(slice_base, ptr, vnum);
}
//// NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
// CHECK: {{.*}}

// CHECK-C-LABEL: @test_svstr_vnum_za_2(
// CHECK-CXX-LABEL: @_Z20test_svstr_vnum_za_2jPv(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 16)
// CHECK-NEXT: ret void
//
void test_svstr_vnum_za_2(uint32_t slice_base, void *ptr) {
svstr_vnum_za(slice_base, ptr, 16);
}
13 changes: 9 additions & 4 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -2679,10 +2679,10 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sme_st1q_vert : SME_Load_Store_Intrinsic<llvm_nxv1i1_ty>;

// Spill + fill
def int_aarch64_sme_ldr : DefaultAttrsIntrinsic<
[], [llvm_i32_ty, llvm_ptr_ty]>;
def int_aarch64_sme_str : DefaultAttrsIntrinsic<
[], [llvm_i32_ty, llvm_ptr_ty]>;
class SME_LDR_STR_ZA_Intrinsic
: DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty, llvm_i32_ty]>;
def int_aarch64_sme_ldr : SME_LDR_STR_ZA_Intrinsic;
def int_aarch64_sme_str : SME_LDR_STR_ZA_Intrinsic;

class SME_TileToVector_Intrinsic
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
Expand Down Expand Up @@ -3454,4 +3454,9 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sve_sel_x2 : SVE2_VG2_Sel_Intrinsic;
def int_aarch64_sve_sel_x4 : SVE2_VG4_Sel_Intrinsic;

class SME_LDR_STR_ZT_Intrinsic
: DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty]>;
def int_aarch64_sme_ldr_zt : SME_LDR_STR_ZT_Intrinsic;
def int_aarch64_sme_str_zt : SME_LDR_STR_ZT_Intrinsic;

}
90 changes: 90 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2406,6 +2406,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::FCMP)
MAKE_CASE(AArch64ISD::STRICT_FCMP)
MAKE_CASE(AArch64ISD::STRICT_FCMPE)
MAKE_CASE(AArch64ISD::SME_ZA_LDR)
MAKE_CASE(AArch64ISD::SME_ZA_STR)
MAKE_CASE(AArch64ISD::DUP)
MAKE_CASE(AArch64ISD::DUPLANE8)
MAKE_CASE(AArch64ISD::DUPLANE16)
Expand Down Expand Up @@ -4830,6 +4832,90 @@ SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
Mask);
}

// Lower an SME LDR/STR ZA intrinsic
// Case 1: If the vector number (vecnum) is an immediate in range, it gets
// folded into the instruction
// ldr(%tileslice, %ptr, 11) -> ldr [%tileslice, 11], [%ptr, 11]
// Case 2: If the vecnum is not an immediate, then it is used to modify the base
// and tile slice registers
// ldr(%tileslice, %ptr, %vecnum)
// ->
// %svl = rdsvl
// %ptr2 = %ptr + %svl * %vecnum
// %tileslice2 = %tileslice + %vecnum
// ldr [%tileslice2, 0], [%ptr2, 0]
// Case 3: If the vecnum is an immediate out of range, then the same is done as
// case 2, but the base and slice registers are modified by the greatest
// multiple of 15 lower than the vecnum and the remainder is folded into the
// instruction. This means that successive loads and stores that are offset from
// each other can share the same base and slice register updates.
// ldr(%tileslice, %ptr, 22)
// ldr(%tileslice, %ptr, 23)
// ->
// %svl = rdsvl
// %ptr2 = %ptr + %svl * 15
// %tileslice2 = %tileslice + 15
// ldr [%tileslice2, 7], [%ptr2, 7]
// ldr [%tileslice2, 8], [%ptr2, 8]
// Case 4: If the vecnum is an add of an immediate, then the non-immediate
// operand and the immediate can be folded into the instruction, like case 2.
// ldr(%tileslice, %ptr, %vecnum + 7)
// ldr(%tileslice, %ptr, %vecnum + 8)
// ->
// %svl = rdsvl
// %ptr2 = %ptr + %svl * %vecnum
// %tileslice2 = %tileslice + %vecnum
// ldr [%tileslice2, 7], [%ptr2, 7]
// ldr [%tileslice2, 8], [%ptr2, 8]
// Case 5: The vecnum being an add of an immediate out of range is also handled,
// in which case the same remainder logic as case 3 is used.
SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
SDLoc DL(N);

SDValue TileSlice = N->getOperand(2);
SDValue Base = N->getOperand(3);
SDValue VecNum = N->getOperand(4);
int32_t ConstAddend = 0;
SDValue VarAddend = VecNum;

// If the vnum is an add of an immediate, we can fold it into the instruction
if (VecNum.getOpcode() == ISD::ADD &&
isa<ConstantSDNode>(VecNum.getOperand(1))) {
ConstAddend = cast<ConstantSDNode>(VecNum.getOperand(1))->getSExtValue();
VarAddend = VecNum.getOperand(0);
} else if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum)) {
ConstAddend = ImmNode->getSExtValue();
VarAddend = SDValue();
}

int32_t ImmAddend = ConstAddend % 16;
if (int32_t C = (ConstAddend - ImmAddend)) {
SDValue CVal = DAG.getTargetConstant(C, DL, MVT::i32);
VarAddend = VarAddend
? DAG.getNode(ISD::ADD, DL, MVT::i32, {VarAddend, CVal})
: CVal;
}

if (VarAddend) {
// Get the vector length that will be multiplied by vnum
auto SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
DAG.getConstant(1, DL, MVT::i32));

// Multiply SVL and vnum then add it to the base
SDValue Mul = DAG.getNode(
ISD::MUL, DL, MVT::i64,
{SVL, DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VarAddend)});
Base = DAG.getNode(ISD::ADD, DL, MVT::i64, {Base, Mul});
// Just add vnum to the tileslice
TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, {TileSlice, VarAddend});
}

return DAG.getNode(IsLoad ? AArch64ISD::SME_ZA_LDR : AArch64ISD::SME_ZA_STR,
DL, MVT::Other,
{/*Chain=*/N.getOperand(0), TileSlice, Base,
DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
}

SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
SelectionDAG &DAG) const {
unsigned IntNo = Op.getConstantOperandVal(1);
Expand All @@ -4853,6 +4939,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Chain,
DAG.getTargetConstant(PrfOp, DL, MVT::i32), Addr);
}
case Intrinsic::aarch64_sme_str:
case Intrinsic::aarch64_sme_ldr: {
return LowerSMELdrStr(Op, DAG, IntNo == Intrinsic::aarch64_sme_ldr);
}
case Intrinsic::aarch64_sme_za_enable:
return DAG.getNode(
AArch64ISD::SMSTART, DL, MVT::Other,
Expand Down

0 comments on commit f7b5c25

Please sign in to comment.