Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SVE] Handle consecutive Predicates in CC_AArch64_Custom_Block #90122

Merged
merged 5 commits into from
May 24, 2024

Conversation

zhaoshiz
Copy link
Contributor

For 2d masks as function arguments, even in [1 x <vscale x 4 x i1>] type, they're flagged as InConsecutiveRegs. This fix checks for mask types and allocate them to P registers.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Zhaoshi Zheng (zhaoshiz)

Changes

For 2d masks as function arguments, even in [1 x <vscale x 4 x i1>] type, they're flagged as InConsecutiveRegs. This fix checks for mask types and allocate them to P registers.


Full diff: https://github.com/llvm/llvm-project/pull/90122.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64CallingConvention.cpp (+11-3)
  • (modified) llvm/test/CodeGen/AArch64/sve-calling-convention.ll (+22)
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp b/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp
index bfcafc6442d241..9a2838992eb02d 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.cpp
@@ -38,6 +38,8 @@ static const MCPhysReg QRegList[] = {AArch64::Q0, AArch64::Q1, AArch64::Q2,
 static const MCPhysReg ZRegList[] = {AArch64::Z0, AArch64::Z1, AArch64::Z2,
                                      AArch64::Z3, AArch64::Z4, AArch64::Z5,
                                      AArch64::Z6, AArch64::Z7};
+static const MCPhysReg PRegList[] = {AArch64::P0, AArch64::P1, AArch64::P2,
+                                     AArch64::P3};
 
 static bool finishStackBlock(SmallVectorImpl<CCValAssign> &PendingMembers,
                              MVT LocVT, ISD::ArgFlagsTy &ArgFlags,
@@ -140,9 +142,15 @@ static bool CC_AArch64_Custom_Block(unsigned &ValNo, MVT &ValVT, MVT &LocVT,
     RegList = DRegList;
   else if (LocVT.SimpleTy == MVT::f128 || LocVT.is128BitVector())
     RegList = QRegList;
-  else if (LocVT.isScalableVector())
-    RegList = ZRegList;
-  else {
+  else if (LocVT.isScalableVector()) {
+    // Scalable masks should be pass by Predicate registers.
+    if (LocVT == MVT::nxv1i1 || LocVT == MVT::nxv2i1 || LocVT == MVT::nxv4i1 ||
+        LocVT == MVT::nxv8i1 || LocVT == MVT::nxv16i1 ||
+        LocVT == MVT::aarch64svcount)
+      RegList = PRegList;
+    else
+      RegList = ZRegList;
+  } else {
     // Not an array we want to split up after all.
     return false;
   }
diff --git a/llvm/test/CodeGen/AArch64/sve-calling-convention.ll b/llvm/test/CodeGen/AArch64/sve-calling-convention.ll
index 0a45244f12be54..a0eee24275f1e8 100644
--- a/llvm/test/CodeGen/AArch64/sve-calling-convention.ll
+++ b/llvm/test/CodeGen/AArch64/sve-calling-convention.ll
@@ -128,6 +128,14 @@ define <vscale x 4 x i1> @sve_signature_pred(<vscale x 4 x i1> %arg1, <vscale x
   ret <vscale x 4 x i1> %arg2
 }
 
+; CHECK-LABEL: name: sve_signature_pred_2d
+; CHECK: [[RES:%[0-9]+]]:ppr = COPY $p1
+; CHECK: $p0 = COPY [[RES]]
+; CHECK: RET_ReallyLR implicit $p0
+define [1 x <vscale x 4 x i1>] @sve_signature_pred_2d([1 x <vscale x 4 x i1>] %arg1, [1 x <vscale x 4 x i1>] %arg2) nounwind {
+  ret [1 x <vscale x 4 x i1>] %arg2
+}
+
 ; CHECK-LABEL: name: sve_signature_vec_caller
 ; CHECK-DAG: [[ARG2:%[0-9]+]]:zpr = COPY $z1
 ; CHECK-DAG: [[ARG1:%[0-9]+]]:zpr = COPY $z0
@@ -156,6 +164,20 @@ define <vscale x 4 x i1> @sve_signature_pred_caller(<vscale x 4 x i1> %arg1, <vs
   ret <vscale x 4 x i1> %res
 }
 
+; CHECK-LABEL: name: sve_signature_pred_2d_caller
+; CHECK-DAG: [[ARG2:%[0-9]+]]:ppr = COPY $p1
+; CHECK-DAG: [[ARG1:%[0-9]+]]:ppr = COPY $p0
+; CHECK-DAG: $p0 = COPY [[ARG2]]
+; CHECK-DAG: $p1 = COPY [[ARG1]]
+; CHECK-NEXT: BL @sve_signature_pred_2d, csr_aarch64_sve_aapcs
+; CHECK: [[RES:%[0-9]+]]:ppr = COPY $p0
+; CHECK: $p0 = COPY [[RES]]
+; CHECK: RET_ReallyLR implicit $p0
+define [1 x <vscale x 4 x i1>] @sve_signature_pred_2d_caller([1 x <vscale x 4 x i1>] %arg1, [1 x <vscale x 4 x i1>] %arg2) nounwind {
+  %res = call [1 x <vscale x 4 x i1>] @sve_signature_pred_2d([1 x <vscale x 4 x i1>] %arg2, [1 x <vscale x 4 x i1>] %arg1)
+  ret [1 x <vscale x 4 x i1>] %res
+}
+
 ; Test that functions returning or taking SVE arguments use the correct
 ; callee-saved set when using the default C calling convention (as opposed
 ; to aarch64_sve_vector_pcs)

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch.

I was wondering where this code kicks in for Z regs as I couldn't see any tests with arrays, but it's used for tuples of scalable vectors like in @foo1 in llvm/test/CodeGen/AArch64/sve-calling-convention-mixed.ll.

LGTM cheers, but I'll add another reviewer since I don't work on the backend much at the moment. So please wait before landing.

@zhaoshiz
Copy link
Contributor Author

Thanks for the patch.

I was wondering where this code kicks in for Z regs as I couldn't see any tests with arrays, but it's used for tuples of scalable vectors like in @foo1 in llvm/test/CodeGen/AArch64/sve-calling-convention-mixed.ll.

LGTM cheers, but I'll add another reviewer since I don't work on the backend much at the moment. So please wait before landing.

Thanks for reviewing. I'm working on scalable vectorizing linalg.reduce in MLIR and run into this case when I try to generate SVE instructions with mask types like <2x[4]xi1>. Without this fix, a mask of such type is assigned to Z registers and the backend crashes later.

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On lines 58-68, there is some code that suggests the calling convention has additional requirements that need to be taken into account:

    // we cannot allocate enough registers for the tuple we should still leave
    // any remaining registers unallocated. However, when we call the
    // CCAssignFn again we want it to behave as if all remaining registers are
    // allocated. This will force the code to pass the tuple indirectly in
    // accordance with the PCS.
    bool RegsAllocated[8];
    for (int I = 0; I < 8; I++) {
      RegsAllocated[I] = State.isAllocated(ZRegList[I]);
      State.AllocateReg(ZRegList[I]);
    }

This applies to both Z registers and P registers. I believe this corresponds to
AAPCS (parameter passing) Step C.7 and C.8.

Could you also add some tests for this?

For 2d masks as function arguments, even in [1 x <vscale x 4 x i1>] type,
they're flagged as InConsecutiveRegs. This fix checks for mask types and
allocate them to P registers.
Per AAPCS64, when P0~P4 are exhausted or not able to hold a scalable predicate
argument, the argument is allocated to the stack and a pointer is passed
to the callee.

For consecutive predicates in types like [M x <vscale x N x i1>], we
should handle them in the same way as Z registers, as shown by:
https://reviews.llvm.org/D71216

Reference:
https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
@zhaoshiz zhaoshiz reopened this May 10, 2024
…tack

Adding a test cast where smaller consecutive predicates (arg1 and arg3 below)
are passed by P0~P3 and larger one (arg2 below) is passed through the stack:

[2 x <vscale x 16 x i1>] callee (
[2 x <vscale x 16 x i1>] arg1,
[4 x <vscale x 16 x i1>] arg2,
[2 x <vscale x 16 x i1>] arg3)
@zhaoshiz
Copy link
Contributor Author

On lines 58-68, there is some code that suggests the calling convention has additional requirements that need to be taken into account:

    // we cannot allocate enough registers for the tuple we should still leave
    // any remaining registers unallocated. However, when we call the
    // CCAssignFn again we want it to behave as if all remaining registers are
    // allocated. This will force the code to pass the tuple indirectly in
    // accordance with the PCS.
    bool RegsAllocated[8];
    for (int I = 0; I < 8; I++) {
      RegsAllocated[I] = State.isAllocated(ZRegList[I]);
      State.AllocateReg(ZRegList[I]);
    }

This applies to both Z registers and P registers. I believe this corresponds to AAPCS (parameter passing) Step C.7 and C.8.

Could you also add some tests for this?

I've added similar code to handle P registers in the case a predicate argument is passed indirectly through the stack.
But I ran into some issues when changing values of M and N in predicate types [M x ]:

  1. when M=1, the argument has both flags InConsectiveRegs and InConsectutiveRegsLast set, which will trigger assertions: for callee or caller at:
    https://github.com/llvm/llvm-project/blob/3dcd604eb1d6612fda667793dbb52c5dfaa5fc4f/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp#L7217
    https://github.com/llvm/llvm-project/blob/3dcd604eb1d6612fda667793dbb52c5dfaa5fc4f/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp#L8214
    Is checking (ValueVTs.size() > 1) before return true from AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(), a good fixe or more needs to be done to handle M=1?
    https://github.com/llvm/llvm-project/blob/3dcd604eb1d6612fda667793dbb52c5dfaa5fc4f/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp#L26132

  2. when N is not 16, DAGSelection fails with cannot select errors due to nxv8i1, nxv4i1, and nxv1i1 patterns of STR_PXI/LDR_PXI are remove from AArch64SVEInstrInfo.td:

    defm Pat_Store_P16 : unpred_store_predicate<nxv16i1, STR_PXI>;
    defm Pat_Load_P16 : unpred_load_predicate<nxv16i1, LDR_PXI>;
    in https://reviews.llvm.org/D88994. @efriedma-quic, can you comment on complications of adding them back?

diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 93b02b2d692e..96ad0717d483 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2120,9 +2120,6 @@ let Predicates = [HasSVEorStreamingSVE] in {
}

defm Pat_Store_P16 : unpred_store_predicate<nxv16i1, STR_PXI>;

  • defm Pat_Store_P8 : unpred_store_predicate<nxv8i1, STR_PXI>;

  • defm Pat_Store_P4 : unpred_store_predicate<nxv4i1, STR_PXI>;

  • defm Pat_Store_P2 : unpred_store_predicate<nxv2i1, STR_PXI>;

    multiclass unpred_load_predicate<ValueType Ty, Instruction Load> {
    def _fi : Pat<(Ty (load (am_sve_fi GPR64sp:$base, simm9:$offset))),
    @@ -2133,9 +2130,6 @@ let Predicates = [HasSVEorStreamingSVE] in {
    }

    defm Pat_Load_P16 : unpred_load_predicate<nxv16i1, LDR_PXI>;

  • defm Pat_Load_P8 : unpred_load_predicate<nxv8i1, LDR_PXI>;

  • defm Pat_Load_P4 : unpred_load_predicate<nxv4i1, LDR_PXI>;

  • defm Pat_Load_P2 : unpred_load_predicate<nxv2i1, LDR_PXI>;

@efriedma-quic
Copy link
Collaborator

efriedma-quic commented May 11, 2024

STR_PXI/LDR_PXI don't have the same semantics as an IR "load"/"store" instruction for types other than nxv16i1, so the patterns don't make sense.

If you want to store a SVE predicate, you need cast to nxv16i1 (AArch64ISD::REINTERPRET_CAST) before you store.

…calable vector

Function arguemnts in types of 1-element arrays of scalable vector,
e.g. [1 x <vscale x 16 x i1>], are flagged as both InConsecutiveRegs and
InConsecutiveRegsLast. This triggers asserstions when lowering the argument
through the stack. Remove those assertions since existing code can handle
1-element arrays types as well.
@zhaoshiz
Copy link
Contributor Author

@sdesmalen-arm I've added test cases where scalable predicate args are passed by reference through the stack. I also think remove the assertions triggered by [1 x <vscale x 16 x i1>] args in LowerFormalArguments() and LowerCall().
@efriedma-quic thanks for the pointer, will address load/store of non-nvx16i1 types in the furture.

@zhaoshiz
Copy link
Contributor Author

gentle ping...

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for <vscale x 32 x i1>, to make sure that works. (I think it should, but just to verify.)

Otherwise LGTM

…2i1 types

Check that a <vscale x 32 x i1> predicate argument is assgined to two
P registers and passed through the stacks if not enough P registers are
available.

Also renamed functions that have array of scalable predicate arguments with
explicit argument types, e.g.: 2xv16i1.
@zhaoshiz
Copy link
Contributor Author

If no objection or more test cases needed I'd like to merge this in next 24 hours.

@zhaoshiz zhaoshiz merged commit f492471 into llvm:main May 24, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants