Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SME] Allow memory operations lowering to custom SME functions. #79263

Merged
merged 11 commits into from
Apr 9, 2024

Conversation

dtemirbulatov
Copy link
Contributor

This change allows to lower memcpy, memset, memmove to custom SME version provided by LibRT.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Dinar Temirbulatov (dtemirbulatov)

Changes

This change allows to lower memcpy, memset, memmove to custom SME version provided by LibRT.


Full diff: https://github.com/llvm/llvm-project/pull/79263.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+8-2)
  • (modified) llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp (+72)
  • (modified) llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h (+4)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+3)
  • (added) llvm/test/CodeGen/AArch64/sme2-mops.ll (+66)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 332fb37655288c..64936b9c86ac1b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7659,8 +7659,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
   if (CLI.CB)
     CalleeAttrs = SMEAttrs(*CLI.CB);
-  else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
-    CalleeAttrs = SMEAttrs(ES->getSymbol());
+  else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee)) {
+    if (StringRef(ES->getSymbol()) == StringRef("__arm_sc_memcpy")) {
+      auto Attrs = AttributeList().addFnAttribute(
+          *DAG.getContext(), "aarch64_pstate_sm_compatible");
+      CalleeAttrs = SMEAttrs(Attrs);
+    } else
+      CalleeAttrs = SMEAttrs(ES->getSymbol());
+  }
 
   auto DescribeCallsite =
       [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 9e43f206efcf78..fff4e2333194e3 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -76,12 +76,74 @@ SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
   }
 }
 
+SDValue AArch64SelectionDAGInfo::EmitSpecializedLibcall(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
+    SDValue Size, RTLIB::Libcall LC) const {
+  const AArch64Subtarget &STI =
+      DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+  const AArch64TargetLowering *TLI = STI.getTargetLowering();
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
+  Entry.Node = Dst;
+  Args.push_back(Entry);
+
+  enum { SME_MEMCPY = 0, SME_MEMMOVE, SME_MEMSET } SMELibcall;
+  switch (LC) {
+  case RTLIB::MEMCPY:
+    SMELibcall = SME_MEMCPY;
+    Entry.Node = Src;
+    Args.push_back(Entry);
+    break;
+  case RTLIB::MEMMOVE:
+    SMELibcall = SME_MEMMOVE;
+    Entry.Node = Src;
+    Args.push_back(Entry);
+    break;
+  case RTLIB::MEMSET:
+    SMELibcall = SME_MEMSET;
+    if (Src.getValueType().bitsGT(MVT::i32))
+      Src = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Src);
+    else if (Src.getValueType().bitsLT(MVT::i32))
+      Src = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, Src);
+    Entry.Node = Src;
+    Entry.Ty = Type::getInt32Ty(*DAG.getContext());
+    Entry.IsSExt = false;
+    Args.push_back(Entry);
+    break;
+  default:
+    return SDValue();
+  }
+  Entry.Node = Size;
+  Args.push_back(Entry);
+  char const *FunctionNames[3] = {"__arm_sc_memcpy", "__arm_sc_memmove",
+                                  "__arm_sc_memset"};
+
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(Chain)
+      .setLibCallee(
+          TLI->getLibcallCallingConv(RTLIB::MEMCPY),
+          Type::getVoidTy(*DAG.getContext()),
+          DAG.getExternalSymbol(FunctionNames[SMELibcall],
+                                TLI->getPointerTy(DAG.getDataLayout())),
+          std::move(Args))
+      .setDiscardResult();
+  std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
+  return CallResult.second;
+}
+
 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
     SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
     SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
   const AArch64Subtarget &STI =
       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+
+  SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+  if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+    return EmitSpecializedLibcall(DAG, DL, Chain, Dst, Src, Size,
+                                  RTLIB::MEMCPY);
   if (STI.hasMOPS())
     return EmitMOPS(AArch64ISD::MOPS_MEMCOPY, DAG, DL, Chain, Dst, Src, Size,
                     Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
@@ -95,6 +157,11 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
   const AArch64Subtarget &STI =
       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
 
+  SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+  if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+    return EmitSpecializedLibcall(DAG, dl, Chain, Dst, Src, Size,
+                                  RTLIB::MEMSET);
+
   if (STI.hasMOPS()) {
     return EmitMOPS(AArch64ISD::MOPS_MEMSET, DAG, dl, Chain, Dst, Src, Size,
                     Alignment, isVolatile, DstPtrInfo, MachinePointerInfo{});
@@ -108,6 +175,11 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
   const AArch64Subtarget &STI =
       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
+
+  SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+  if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
+    return EmitSpecializedLibcall(DAG, dl, Chain, Dst, Src, Size,
+                                  RTLIB::MEMMOVE);
   if (STI.hasMOPS()) {
     return EmitMOPS(AArch64ISD::MOPS_MEMMOVE, DAG, dl, Chain, Dst, Src, Size,
                     Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
index 73f93724d6fc73..9c55c21f3c3202 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
@@ -47,6 +47,10 @@ class AArch64SelectionDAGInfo : public SelectionDAGTargetInfo {
                                   SDValue Chain, SDValue Op1, SDValue Op2,
                                   MachinePointerInfo DstPtrInfo,
                                   bool ZeroData) const override;
+
+  SDValue EmitSpecializedLibcall(SelectionDAG &DAG, const SDLoc &DL,
+                                 SDValue Chain, SDValue Dst, SDValue Src,
+                                 SDValue Size, RTLIB::Libcall LC) const;
 };
 }
 
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 3ee54e5df0a13d..5080e4a0b4f9a2 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -51,6 +51,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_tpidr2_restore")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
                 SMEAttrs::SME_ABI_Routine);
+  if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
+      FuncName == "__arm_sc_memmove")
+    Bitmask |= SMEAttrs::SM_Compatible;
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
diff --git a/llvm/test/CodeGen/AArch64/sme2-mops.ll b/llvm/test/CodeGen/AArch64/sme2-mops.ll
new file mode 100644
index 00000000000000..0599bc61a52f73
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme2-mops.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs < %s | FileCheck %s
+
+@dst = global [512 x i8] zeroinitializer, align 1
+@src = global [512 x i8] zeroinitializer, align 1
+
+define void @sc_memcpy(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memcpy:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    mov x2, x0
+; CHECK-NEXT:    adrp x0, :got:dst
+; CHECK-NEXT:    adrp x1, :got:src
+; CHECK-NEXT:    ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT:    ldr x1, [x1, :got_lo12:src]
+; CHECK-NEXT:    bl __arm_sc_memcpy
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.memcpy.p0.p0.i64(ptr align 1 @dst, ptr nonnull align 1 @src, i64 %n, i1 false)
+  ret void
+}
+
+define void @sc_memset(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memset:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    mov x2, x0
+; CHECK-NEXT:    adrp x0, :got:dst
+; CHECK-NEXT:    mov w1, #2 // =0x2
+; CHECK-NEXT:    ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT:    // kill: def $w2 killed $w2 killed $x2
+; CHECK-NEXT:    bl __arm_sc_memset
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.memset.p0.i64(ptr align 1 @dst, i8 2, i64 %n, i1 false)
+  ret void
+}
+
+define void @sc_memmove(i64 noundef %n) "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: sc_memmove:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset w30, -16
+; CHECK-NEXT:    mov x2, x0
+; CHECK-NEXT:    adrp x0, :got:dst
+; CHECK-NEXT:    adrp x1, :got:src
+; CHECK-NEXT:    ldr x0, [x0, :got_lo12:dst]
+; CHECK-NEXT:    ldr x1, [x1, :got_lo12:src]
+; CHECK-NEXT:    bl __arm_sc_memmove
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.memmove.p0.p0.i64(ptr align 1 @dst, ptr nonnull align 1 @src, i64 %n, i1 false)
+  ret void
+}
+
+declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg)
+declare void @llvm.memcpy.p0.p0.i64(ptr nocapture writeonly, ptr nocapture readonly, i64, i1 immarg)
+declare void @llvm.memmove.p0.p0.i64(ptr nocapture writeonly, ptr nocapture readonly, i64, i1 immarg)

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, but perhaps wait before merging to give someone more knowledgeable about specialised library calls a chance to review.

@dtemirbulatov dtemirbulatov force-pushed the sme-memops-lower branch 2 times, most recently from a863f74 to 2ef16b0 Compare January 25, 2024 22:04
Entry.Node = Dst;
Args.push_back(Entry);

enum { SME_MEMCPY = 0, SME_MEMMOVE, SME_MEMSET } SMELibcall;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this indirection using the enum and it indexing into an array very confusing, why wouldn't you just write:

case RTLIB::MEMCPY:
  Symbol = DAG.getExternalSymbol("__arm_sc_memcpy", PtrVT);

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
CalleeAttrs = SMEAttrs(ES->getSymbol());
else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee)) {
if (StringRef(ES->getSymbol()) == StringRef("__arm_sc_memcpy")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing __arm_sc_memmove and __arm_sc_memset ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (StringRef(ES->getSymbol()) == StringRef("__arm_sc_memcpy")) {
auto Attrs = AttributeList().addFnAttribute(
*DAG.getContext(), "aarch64_pstate_sm_compatible");
CalleeAttrs = SMEAttrs(Attrs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change actually needed when you also make this change in SMEAttrs::SMEAttrs(StringRef FuncName)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 105 to 108
if (Src.getValueType().bitsGT(MVT::i32))
Src = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Src);
else if (Src.getValueType().bitsLT(MVT::i32))
Src = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, Src);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DAG.getZExtOrTrunc ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

case RTLIB::MEMSET:
SMELibcall = SME_MEMSET;
if (Src.getValueType().bitsGT(MVT::i32))
Src = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Src);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when this would be needed? I tried to write a test case which would require truncating Src, but found that I could only pass a source type of i8 to the memset intrinsic.
I might have missed something though, if there is a case where we would need to truncate Src please can you add a test for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have a testcase where I could see this oprtation is required:

void foo(size_t n)  __arm_streaming_compatible
{
  memset(dst, 2, n);
}

IR:

define dso_local void @foo(i64 noundef %n) local_unnamed_addr #0 {
entry:
  tail call void @llvm.memset.p0.i64(ptr nonnull align 1 @dst, i8 2, i64 %n, i1 false)
  ret void
}

(gdb) p Src->dump()
t4: i8 = Constant<2>
$1 = void
(gdb) n
107	    Entry.Node = Src;
(gdb) p Src->dump()
t6: i32 = Constant<2>
$2 = void

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing this test @dtemirbulatov.
I think in your example, Src still needs to be extended? The case I was trying to write was one where Src requires truncating, for example from i64 to i32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmclaughlin-arm , I was not able to produce truncating test, somehow. Is it ok to add zero-extend test for this?

llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp Outdated Show resolved Hide resolved
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL)
.setChain(Chain)
.setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMCPY),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be:

Suggested change
.setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMCPY),
.setLibCallee(TLI->getLibcallCallingConv(LC),

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

.setChain(Chain)
.setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMCPY),
Type::getVoidTy(*DAG.getContext()), Symbol, std::move(Args))
.setDiscardResult();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind adding setDiscardResult() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks, done.

SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();

SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for "aarch64_pstate_sm_enabled" and also a test for "aarch64_pstate_sm_body" ? I don't think you're currently supporting the first case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,66 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs < %s | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a way to disable this behaviour, in case there are any issues with this approach. Can you add an option for this, and change this test to have two RUN lines (one with it enabled, one with it disabled)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -51,6 +51,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_tpidr2_restore")
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
SMEAttrs::SME_ABI_Routine);
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__arm_sc_memchr is also streaming-compatible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -15,6 +15,12 @@ using namespace llvm;

#define DEBUG_TYPE "aarch64-selectiondag-info"

static cl::opt<bool>
EnableSMEMops("aarch64-enable-sme-mops", cl::Hidden,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EnableSMEMops("aarch64-enable-sme-mops", cl::Hidden,
LowerToSMERoutines("aarch64-lower-to-sme-routines", cl::Hidden,

Better to rename this option. 'mops' has an overloaded meaning because of FEAT_MOPS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 161 to 162
if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface() ||
Attrs.hasStreamingInterface())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (Attrs.hasStreamingBody() || Attrs.hasStreamingCompatibleInterface() ||
Attrs.hasStreamingInterface())
if (!Attrs.hasNonStreamingInterfaceAndBody())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 418 to 427
; CHECK-NEXT: .cfi_def_cfa_offset 80
; CHECK-NEXT: .cfi_offset w30, -16
; CHECK-NEXT: .cfi_offset b8, -24
; CHECK-NEXT: .cfi_offset b9, -32
; CHECK-NEXT: .cfi_offset b10, -40
; CHECK-NEXT: .cfi_offset b11, -48
; CHECK-NEXT: .cfi_offset b12, -56
; CHECK-NEXT: .cfi_offset b13, -64
; CHECK-NEXT: .cfi_offset b14, -72
; CHECK-NEXT: .cfi_offset b15, -80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you add nounwind to the tests, to reduce the number of CHECK lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@dst = global [512 x i8] zeroinitializer, align 1
@src = global [512 x i8] zeroinitializer, align 1

define void @sc_memcpy(i64 noundef %n) "aarch64_pstate_sm_compatible" {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
rather than having {streaming, streaming-compatible, locally-streaming} x {memcpy, memset, memmove} combinations, can you do {streaming} x {memcpy, memset, memmove} + {streaming-compatible, locally-streaming} x {memcpy} ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
SDValue Symbol;
Entry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be: PointerType::getUnqual(*DAG.getContext()) instead?
IntPtrTy is different from an opaque pointer type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
const AArch64TargetLowering *TLI = STI.getTargetLowering();
TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename to:

Suggested change
TargetLowering::ArgListEntry Entry;
TargetLowering::ArgListEntry DstEntry;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

default:
return SDValue();
}
Entry.Node = Size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use a new variable, e.g.

Suggested change
Entry.Node = Size;
TargetLowering::ArgListEntry SizeEntry;
SizeEntry.Node = Size;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,552 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this file, because mops has an overloaded meaning. What about streaming-compatible-memory-ops.ll ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,289 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs < %s | FileCheck %s -check-prefixes=CHECK
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs -aarch64-lower-to-sme-routines=false < %s | FileCheck %s -check-prefixes=NO_SME_MOPS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs -aarch64-lower-to-sme-routines=false < %s | FileCheck %s -check-prefixes=NO_SME_MOPS
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -verify-machineinstrs -aarch64-lower-to-sme-routines=false < %s | FileCheck %s -check-prefixes=CHECK-NO-SME-ROUTINES

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
TLI->getLibcallCallingConv(LC), Type::getVoidTy(*DAG.getContext()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed this suggestion.

TLI->getLibcallCallingConv(LC), Type::getVoidTy(*DAG.getContext()),
Symbol, std::move(Args));
std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
return CallResult.second;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed this suggestion.

EVT Ty = TLI->getPointerTy(DAG.getDataLayout());
PointerType *RetTy;

if (!LowerToSMERoutines)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this condition to where this function is called.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
TargetLowering::ArgListEntry SizeEntry;
SizeEntry.Node = Size;
SizeEntry.Ty = PointerType::getUnqual(*DAG.getContext());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Should 'size' have a pointer type? I would expect it to be an integer type of the same width as a pointer type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

DstEntry.Ty = PointerType::getUnqual(*DAG.getContext());
DstEntry.Node = Dst;
Args.push_back(DstEntry);
EVT Ty = TLI->getPointerTy(DAG.getDataLayout());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
EVT Ty = TLI->getPointerTy(DAG.getDataLayout());
EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const AArch64TargetLowering *TLI = STI.getTargetLowering();
TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry DstEntry;
SDValue Symbol;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move closer to switch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -76,15 +82,85 @@ SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
}
}

SDValue AArch64SelectionDAGInfo::EmitSpecializedLibcall(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see how you changed the handling of different return types, but we still have the code for the 'size' that is uniform for the three case-statements at the end, but which isn't generic for other libcalls (if we ever add them). It's probably better to just rename this function to what it's meant to be: EmitStreamingCompatibleMemLibCall and then fix RetTy to be PointerType::getUnqual(), rather than setting it in each case statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 145 to 146
std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
return (isa<PointerType>(RetTy) ? CallResult.second : CallResult.first);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you name this function appropriately (see my suggestion above), this can be reverted back to be:

Suggested change
std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
return (isa<PointerType>(RetTy) ? CallResult.second : CallResult.first);
return TLI->LowerCallTo(CLI).second;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtemirbulatov dtemirbulatov merged commit 528943f into llvm:main Apr 9, 2024
4 checks passed
@dtemirbulatov dtemirbulatov deleted the sme-memops-lower branch April 10, 2024 07:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants