Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recommit [RISCV] RISCV vector calling convention (2/2) (#79096) #87736

Merged
merged 1 commit into from
Apr 16, 2024

Conversation

4vtomat
Copy link
Member

@4vtomat 4vtomat commented Apr 5, 2024

Bug fix: Handle RVV return type in calling convention correctly.
Return values are handled in a same way as function arguments.
One thing to mention is that if a type can be broken down into homogeneous
vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}},
it is considered as a vector tuple type and need to be handled by tuple
type rule.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Brandon Wu (4vtomat)

Changes

Return values are handled in a same way as function arguments.
One thing to mention is that if a type can be broken down into homogeneous
vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}},
it is considered as a vector tuple type and need to be handled by tuple
type rule.


Full diff: https://github.com/llvm/llvm-project/pull/87736.diff

5 Files Affected:

  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+10-2)
  • (modified) llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp (+6-4)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+89-9)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+6-3)
  • (modified) llvm/test/CodeGen/RISCV/rvv/calling-conv.ll (+116)
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f64ded4f2cf965..6e7b67ded23c84 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -1809,8 +1809,16 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
     else if (attr.hasRetAttr(Attribute::ZExt))
       Flags.setZExt();
 
-    for (unsigned i = 0; i < NumParts; ++i)
-      Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, /*isfixed=*/true, 0, 0));
+    for (unsigned i = 0; i < NumParts; ++i) {
+      ISD::ArgFlagsTy OutFlags = Flags;
+      if (NumParts > 1 && i == 0)
+        OutFlags.setSplit();
+      else if (i == NumParts - 1 && i != 0)
+        OutFlags.setSplitEnd();
+
+      Outs.push_back(
+          ISD::OutputArg(OutFlags, PartVT, VT, /*isfixed=*/true, 0, 0));
+    }
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
index 8af4bc658409d4..c18892ac62f247 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
@@ -409,7 +409,7 @@ bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
   splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);
 
   RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
-                              F.getReturnType()};
+                              ArrayRef(F.getReturnType())};
   RISCVOutgoingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, Dispatcher);
@@ -538,7 +538,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
     ++Index;
   }
 
-  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+  RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                              ArrayRef(TypeList)};
   RISCVIncomingValueAssigner Assigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, Dispatcher);
@@ -603,7 +604,8 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
   Call.addRegMask(TRI->getCallPreservedMask(MF, Info.CallConv));
 
-  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(), TypeList};
+  RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(),
+                                 ArrayRef(TypeList)};
   RISCVOutgoingValueAssigner ArgAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/false, ArgDispatcher);
@@ -635,7 +637,7 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   splitToValueTypes(Info.OrigRet, SplitRetInfos, DL, CC);
 
   RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
-                                 F.getReturnType()};
+                                 ArrayRef(F.getReturnType())};
   RISCVIncomingValueAssigner RetAssigner(
       CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
       /*IsRet=*/true, RetDispatcher);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 279d8a435a04ca..fb5027406b0808 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18327,13 +18327,15 @@ void RISCVTargetLowering::analyzeInputArgs(
   unsigned NumArgs = Ins.size();
   FunctionType *FType = MF.getFunction().getFunctionType();
 
-  SmallVector<Type *, 4> TypeList;
-  if (IsRet)
-    TypeList.push_back(MF.getFunction().getReturnType());
-  else
+  RVVArgDispatcher Dispatcher;
+  if (IsRet) {
+    Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(Ins)};
+  } else {
+    SmallVector<Type *, 4> TypeList;
     for (const Argument &Arg : MF.getFunction().args())
       TypeList.push_back(Arg.getType());
-  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+    Dispatcher = RVVArgDispatcher{&MF, this, ArrayRef(TypeList)};
+  }
 
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT ArgVT = Ins[i].VT;
@@ -18368,7 +18370,7 @@ void RISCVTargetLowering::analyzeOutputArgs(
   else if (CLI)
     for (const TargetLowering::ArgListEntry &Arg : CLI->getArgs())
       TypeList.push_back(Arg.Ty);
-  RVVArgDispatcher Dispatcher{&MF, this, TypeList};
+  RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(TypeList)};
 
   for (unsigned i = 0; i != NumArgs; i++) {
     MVT ArgVT = Outs[i].VT;
@@ -19272,7 +19274,7 @@ bool RISCVTargetLowering::CanLowerReturn(
   SmallVector<CCValAssign, 16> RVLocs;
   CCState CCInfo(CallConv, IsVarArg, MF, RVLocs, Context);
 
-  RVVArgDispatcher Dispatcher{&MF, this, MF.getFunction().getReturnType()};
+  RVVArgDispatcher Dispatcher{&MF, this, ArrayRef(Outs)};
 
   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
     MVT VT = Outs[i].VT;
@@ -21077,7 +21079,82 @@ unsigned RISCVTargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget.getMinimumJumpTableEntries();
 }
 
-void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
+// Handle single arg such as return value.
+template <typename Arg>
+void RVVArgDispatcher::constructArgInfos(ArrayRef<Arg> ArgList) {
+  // This lambda determines whether an array of types are constructed by
+  // homogeneous vector types.
+  auto isHomogeneousScalableVectorType = [&](ArrayRef<Arg> ArgList) {
+    // First, extract the first element in the argument type.
+    MVT FirstArgRegType;
+    unsigned FirstArgElements = 0;
+    typename SmallVectorImpl<Arg>::const_iterator It;
+    for (It = ArgList.begin(); It != ArgList.end(); ++It) {
+      FirstArgRegType = It->VT;
+      ++FirstArgElements;
+      if (!It->Flags.isSplit() || It->Flags.isSplitEnd())
+        break;
+    }
+    ++It;
+
+    // Return if this argument type contains only 1 element, or it's not a
+    // vector type.
+    if (It == ArgList.end() || !FirstArgRegType.isScalableVector())
+      return false;
+
+    // Second, check if the following elements in this argument type are all the
+    // same.
+    MVT ArgRegType;
+    unsigned ArgElements = 0;
+    bool IsPart = false;
+    for (; It != ArgList.end(); ++It) {
+      ArgRegType = It->VT;
+      ++ArgElements;
+      if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd()) {
+        if (ArgRegType != FirstArgRegType || ArgElements != FirstArgElements)
+          return false;
+
+        IsPart = false;
+        ArgElements = 0;
+        continue;
+      }
+
+      IsPart = true;
+    }
+
+    return true;
+  };
+
+  if (isHomogeneousScalableVectorType(ArgList)) {
+    // Handle as tuple type
+    RVVArgInfos.push_back({(unsigned)ArgList.size(), ArgList[0].VT, false});
+  } else {
+    // Handle as normal vector type
+    bool FirstVMaskAssigned = false;
+    for (const auto &OutArg : ArgList) {
+      MVT RegisterVT = OutArg.VT;
+
+      // Skip non-RVV register type
+      if (!RegisterVT.isVector())
+        continue;
+
+      if (RegisterVT.isFixedLengthVector())
+        RegisterVT = TLI->getContainerForFixedLengthVector(RegisterVT);
+
+      if (!FirstVMaskAssigned && RegisterVT.getVectorElementType() == MVT::i1) {
+        RVVArgInfos.push_back({1, RegisterVT, true});
+        FirstVMaskAssigned = true;
+        continue;
+      }
+
+      RVVArgInfos.push_back({1, RegisterVT, false});
+    }
+  }
+}
+
+// Handle multiple args.
+template <>
+void RVVArgDispatcher::constructArgInfos<Type *>(ArrayRef<Type *> TypeList) {
   const DataLayout &DL = MF->getDataLayout();
   const Function &F = MF->getFunction();
   LLVMContext &Context = F.getContext();
@@ -21090,8 +21167,11 @@ void RVVArgDispatcher::constructArgInfos(ArrayRef<Type *> TypeList) {
       EVT VT = TLI->getValueType(DL, ElemTy);
       MVT RegisterVT =
           TLI->getRegisterTypeForCallingConv(Context, F.getCallingConv(), VT);
+      unsigned NumRegs =
+          TLI->getNumRegistersForCallingConv(Context, F.getCallingConv(), VT);
 
-      RVVArgInfos.push_back({STy->getNumElements(), RegisterVT, false});
+      RVVArgInfos.push_back(
+          {NumRegs * STy->getNumElements(), RegisterVT, false});
     } else {
       SmallVector<EVT, 4> ValueVTs;
       ComputeValueVTs(*TLI, DL, Ty, ValueVTs);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index c28552354bf422..a2456f2fab66b1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -1041,13 +1041,16 @@ class RVVArgDispatcher {
     bool FirstVMask = false;
   };
 
+  template <typename Arg>
   RVVArgDispatcher(const MachineFunction *MF, const RISCVTargetLowering *TLI,
-                   ArrayRef<Type *> TypeList)
+                   ArrayRef<Arg> ArgList)
       : MF(MF), TLI(TLI) {
-    constructArgInfos(TypeList);
+    constructArgInfos(ArgList);
     compute();
   }
 
+  RVVArgDispatcher() = default;
+
   MCPhysReg getNextPhysReg();
 
 private:
@@ -1059,7 +1062,7 @@ class RVVArgDispatcher {
 
   unsigned CurIdx = 0;
 
-  void constructArgInfos(ArrayRef<Type *> TypeList);
+  template <typename Arg> void constructArgInfos(ArrayRef<Arg> Ret);
   void compute();
   void allocatePhysReg(unsigned NF = 1, unsigned LMul = 1,
                        unsigned StartReg = 0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
index 90edb994ce8222..647d3158b6167f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/calling-conv.ll
@@ -249,3 +249,119 @@ define <vscale x 1 x i64> @case4_2(<vscale x 1 x i64> %0, {<vscale x 8 x i64>, <
   %add = add <vscale x 1 x i64> %0, %2
   ret <vscale x 1 x i64> %add
 }
+
+declare <vscale x 1 x i64> @callee1()
+declare void @callee2(<vscale x 1 x i64>)
+declare void @callee3(<vscale x 4 x i32>)
+define void @caller() {
+; RV32-LABEL: caller:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call callee1
+; RV32-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; RV32-NEXT:    vadd.vv v8, v8, v8
+; RV32-NEXT:    call callee2
+; RV32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: caller:
+; RV64:       # %bb.0:
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call callee1
+; RV64-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; RV64-NEXT:    vadd.vv v8, v8, v8
+; RV64-NEXT:    call callee2
+; RV64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 16
+; RV64-NEXT:    ret
+  %a = call <vscale x 1 x i64> @callee1()
+  %add = add <vscale x 1 x i64> %a, %a
+  call void @callee2(<vscale x 1 x i64> %add)
+  ret void
+}
+
+declare {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+define void @caller_tuple() {
+; RV32-LABEL: caller_tuple:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call callee_tuple
+; RV32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT:    vadd.vv v8, v8, v10
+; RV32-NEXT:    call callee3
+; RV32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: caller_tuple:
+; RV64:       # %bb.0:
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call callee_tuple
+; RV64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT:    vadd.vv v8, v8, v10
+; RV64-NEXT:    call callee3
+; RV64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 16
+; RV64-NEXT:    ret
+  %a = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @callee_tuple()
+  %b = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 0
+  %c = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %a, 1
+  %add = add <vscale x 4 x i32> %b, %c
+  call void @callee3(<vscale x 4 x i32> %add)
+  ret void
+}
+
+declare {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+define void @caller_nested() {
+; RV32-LABEL: caller_nested:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call callee_nested
+; RV32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV32-NEXT:    vadd.vv v8, v8, v10
+; RV32-NEXT:    vadd.vv v8, v8, v12
+; RV32-NEXT:    call callee3
+; RV32-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: caller_nested:
+; RV64:       # %bb.0:
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call callee_nested
+; RV64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; RV64-NEXT:    vadd.vv v8, v8, v10
+; RV64-NEXT:    vadd.vv v8, v8, v12
+; RV64-NEXT:    call callee3
+; RV64-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 16
+; RV64-NEXT:    ret
+  %a = call {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} @callee_nested()
+  %b = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 0
+  %c = extractvalue {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}} %a, 1
+  %c0 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 0
+  %c1 = extractvalue {<vscale x 4 x i32>, <vscale x 4 x i32>} %c, 1
+  %add0 = add <vscale x 4 x i32> %b, %c0
+  %add1 = add <vscale x 4 x i32> %add0, %c1
+  call void @callee3(<vscale x 4 x i32> %add1)
+  ret void
+}

@4vtomat
Copy link
Member Author

4vtomat commented Apr 5, 2024

This pr resolves #87323

@4vtomat
Copy link
Member Author

4vtomat commented Apr 12, 2024

Integrated the "Recommit [RISCV] RISCV vector calling convention (2/2)" and bug fix patch.

@4vtomat 4vtomat changed the title [RISCV] Handle RVV return type in calling convention correctly Recommit [RISCV] RISCV vector calling convention (2/2) (#79096) Apr 12, 2024
for (; It != ArgList.end(); ++It) {
ArgRegType = It->VT;
++ArgElements;
if ((!It->Flags.isSplit() && !IsPart) || It->Flags.isSplitEnd()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests for the split case? That should only happen for types larger than LMUL=8 I think? Do we even want to consider those as tuple types?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all types in a split guaranteed to be the same type or do we need to check all of the types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests for the split case? That should only happen for types larger than LMUL=8 I think? Do we even want to consider those as tuple types?

You are right and I think the split case doesn't need to be considered as tuple type since we only have 16 vregs, and the minimum split type is also LMUL=16, so it's impossible that argument that has more than 1 split type can be fully passed by reg.
I guess it can be handled as normal vector type.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Return values are handled in a same way as function arguments.
One thing to mention is that if a type can be broken down into homogeneous
vector types, e.g. {<vscale x 4 x i32>, {<vscale x 4 x i32>, <vscale x 4 x i32>}},
it is considered as a vector tuple type and need to be handled by tuple
type rule.
@4vtomat 4vtomat merged commit 91dd844 into llvm:main Apr 16, 2024
3 of 4 checks passed
@4vtomat 4vtomat deleted the issue_87323 branch April 16, 2024 11:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants