Skip to content

Commit

Permalink
[RISCV] Support vector types in combination with fastcc
Browse files Browse the repository at this point in the history
This patch extends the RISC-V lowering of the 'fastcc' calling
convention to vector types, both fixed-length and scalable. Without this
patch, any function passing or returning vector types by value would
throw a compiler error.

Vectors are handled in 'fastcc' much as they are in the default calling
convention, the noticeable difference being the extended set of scalar
GPR registers that can be used to pass vectors indirectly.

Reviewed By: HsiangKai

Differential Revision: https://reviews.llvm.org/D102505
  • Loading branch information
frasercrmck committed Jun 1, 2021
1 parent 82f92e3 commit 4f500c4
Show file tree
Hide file tree
Showing 4 changed files with 1,284 additions and 49 deletions.
129 changes: 83 additions & 46 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -20,7 +20,6 @@
#include "RISCVTargetMachine.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
Expand Down Expand Up @@ -6572,6 +6571,27 @@ static bool CC_RISCVAssign2XLen(unsigned XLen, CCState &State, CCValAssign VA1,
return false;
}

static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
Optional<unsigned> FirstMaskArgument,
CCState &State, const RISCVTargetLowering &TLI) {
const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
if (RC == &RISCV::VRRegClass) {
// Assign the first mask argument to V0.
// This is an interim calling convention and it may be changed in the
// future.
if (FirstMaskArgument.hasValue() && ValNo == FirstMaskArgument.getValue())
return State.AllocateReg(RISCV::V0);
return State.AllocateReg(ArgVRs);
}
if (RC == &RISCV::VRM2RegClass)
return State.AllocateReg(ArgVRM2s);
if (RC == &RISCV::VRM4RegClass)
return State.AllocateReg(ArgVRM4s);
if (RC == &RISCV::VRM8RegClass)
return State.AllocateReg(ArgVRM8s);
llvm_unreachable("Unhandled register class for ValueType");
}

// Implements the RISC-V calling convention. Returns true upon failure.
static bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
Expand Down Expand Up @@ -6720,26 +6740,7 @@ static bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
else if (ValVT == MVT::f64 && !UseGPRForF64)
Reg = State.AllocateReg(ArgFPR64s);
else if (ValVT.isVector()) {
const TargetRegisterClass *RC = TLI.getRegClassFor(ValVT);
if (RC == &RISCV::VRRegClass) {
// Assign the first mask argument to V0.
// This is an interim calling convention and it may be changed in the
// future.
if (FirstMaskArgument.hasValue() &&
ValNo == FirstMaskArgument.getValue()) {
Reg = State.AllocateReg(RISCV::V0);
} else {
Reg = State.AllocateReg(ArgVRs);
}
} else if (RC == &RISCV::VRM2RegClass) {
Reg = State.AllocateReg(ArgVRM2s);
} else if (RC == &RISCV::VRM4RegClass) {
Reg = State.AllocateReg(ArgVRM4s);
} else if (RC == &RISCV::VRM8RegClass) {
Reg = State.AllocateReg(ArgVRM8s);
} else {
llvm_unreachable("Unhandled class register for ValueType");
}
Reg = allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI);
if (!Reg) {
// For return values, the vector must be passed fully via registers or
// via the stack.
Expand Down Expand Up @@ -6818,7 +6819,8 @@ static Optional<unsigned> preAssignMask(const ArgTy &Args) {

void RISCVTargetLowering::analyzeInputArgs(
MachineFunction &MF, CCState &CCInfo,
const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet) const {
const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
RISCVCCAssignFn Fn) const {
unsigned NumArgs = Ins.size();
FunctionType *FType = MF.getFunction().getFunctionType();

Expand All @@ -6837,9 +6839,9 @@ void RISCVTargetLowering::analyzeInputArgs(
ArgTy = FType->getParamType(Ins[i].getOrigArgIndex());

RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
if (CC_RISCV(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
FirstMaskArgument)) {
if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, *this,
FirstMaskArgument)) {
LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type "
<< EVT(ArgVT).getEVTString() << '\n');
llvm_unreachable(nullptr);
Expand All @@ -6850,7 +6852,7 @@ void RISCVTargetLowering::analyzeInputArgs(
void RISCVTargetLowering::analyzeOutputArgs(
MachineFunction &MF, CCState &CCInfo,
const SmallVectorImpl<ISD::OutputArg> &Outs, bool IsRet,
CallLoweringInfo *CLI) const {
CallLoweringInfo *CLI, RISCVCCAssignFn Fn) const {
unsigned NumArgs = Outs.size();

Optional<unsigned> FirstMaskArgument;
Expand All @@ -6863,9 +6865,9 @@ void RISCVTargetLowering::analyzeOutputArgs(
Type *OrigTy = CLI ? CLI->getArgs()[Outs[i].OrigArgIndex].Ty : nullptr;

RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
if (CC_RISCV(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
FirstMaskArgument)) {
if (Fn(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full,
ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, *this,
FirstMaskArgument)) {
LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type "
<< EVT(ArgVT).getEVTString() << "\n");
llvm_unreachable(nullptr);
Expand Down Expand Up @@ -7010,16 +7012,21 @@ static SDValue unpackF64OnRV32DSoftABI(SelectionDAG &DAG, SDValue Chain,

// FastCC has less than 1% performance improvement for some particular
// benchmark. But theoretically, it may has benenfit for some cases.
static bool CC_RISCV_FastCC(unsigned ValNo, MVT ValVT, MVT LocVT,
static bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State) {
ISD::ArgFlagsTy ArgFlags, CCState &State,
bool IsFixed, bool IsRet, Type *OrigTy,
const RISCVTargetLowering &TLI,
Optional<unsigned> FirstMaskArgument) {

// X5 and X6 might be used for save-restore libcall.
static const MCPhysReg GPRList[] = {
RISCV::X10, RISCV::X11, RISCV::X12, RISCV::X13, RISCV::X14,
RISCV::X15, RISCV::X16, RISCV::X17, RISCV::X7, RISCV::X28,
RISCV::X29, RISCV::X30, RISCV::X31};

if (LocVT == MVT::i32 || LocVT == MVT::i64) {
// X5 and X6 might be used for save-restore libcall.
static const MCPhysReg GPRList[] = {
RISCV::X10, RISCV::X11, RISCV::X12, RISCV::X13, RISCV::X14,
RISCV::X15, RISCV::X16, RISCV::X17, RISCV::X7, RISCV::X28,
RISCV::X29, RISCV::X30, RISCV::X31};
if (unsigned Reg = State.AllocateReg(GPRList)) {
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
return false;
Expand Down Expand Up @@ -7074,6 +7081,36 @@ static bool CC_RISCV_FastCC(unsigned ValNo, MVT ValVT, MVT LocVT,
return false;
}

if (LocVT.isVector()) {
if (unsigned Reg =
allocateRVVReg(ValVT, ValNo, FirstMaskArgument, State, TLI)) {
// Fixed-length vectors are located in the corresponding scalable-vector
// container types.
if (ValVT.isFixedLengthVector())
LocVT = TLI.getContainerForFixedLengthVector(LocVT);
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
} else {
// Try and pass the address via a "fast" GPR.
if (unsigned GPRReg = State.AllocateReg(GPRList)) {
LocInfo = CCValAssign::Indirect;
LocVT = TLI.getSubtarget().getXLenVT();
State.addLoc(CCValAssign::getReg(ValNo, ValVT, GPRReg, LocVT, LocInfo));
} else if (ValVT.isFixedLengthVector()) {
auto StackAlign =
MaybeAlign(ValVT.getScalarSizeInBits() / 8).valueOrOne();
unsigned StackOffset =
State.AllocateStack(ValVT.getStoreSize(), StackAlign);
State.addLoc(
CCValAssign::getMem(ValNo, ValVT, StackOffset, LocVT, LocInfo));
} else {
// Can't pass scalable vectors on the stack.
return true;
}
}

return false;
}

return true; // CC didn't match.
}

Expand Down Expand Up @@ -7166,12 +7203,12 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
SmallVector<CCValAssign, 16> ArgLocs;
CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());

if (CallConv == CallingConv::Fast)
CCInfo.AnalyzeFormalArguments(Ins, CC_RISCV_FastCC);
else if (CallConv == CallingConv::GHC)
if (CallConv == CallingConv::GHC)
CCInfo.AnalyzeFormalArguments(Ins, CC_RISCV_GHC);
else
analyzeInputArgs(MF, CCInfo, Ins, /*IsRet=*/false);
analyzeInputArgs(MF, CCInfo, Ins, /*IsRet=*/false,
CallConv == CallingConv::Fast ? CC_RISCV_FastCC
: CC_RISCV);

for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
CCValAssign &VA = ArgLocs[i];
Expand Down Expand Up @@ -7378,12 +7415,12 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
SmallVector<CCValAssign, 16> ArgLocs;
CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());

if (CallConv == CallingConv::Fast)
ArgCCInfo.AnalyzeCallOperands(Outs, CC_RISCV_FastCC);
else if (CallConv == CallingConv::GHC)
if (CallConv == CallingConv::GHC)
ArgCCInfo.AnalyzeCallOperands(Outs, CC_RISCV_GHC);
else
analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI);
analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI,
CallConv == CallingConv::Fast ? CC_RISCV_FastCC
: CC_RISCV);

// Check if it's really possible to do a tail call.
if (IsTailCall)
Expand Down Expand Up @@ -7628,7 +7665,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
// Assign locations to each value returned by this call.
SmallVector<CCValAssign, 16> RVLocs;
CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext());
analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true);
analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, CC_RISCV);

// Copy all of the result registers out of their specified physreg.
for (auto &VA : RVLocs) {
Expand Down Expand Up @@ -7696,7 +7733,7 @@ RISCVTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
*DAG.getContext());

analyzeOutputArgs(DAG.getMachineFunction(), CCInfo, Outs, /*IsRet=*/true,
nullptr);
nullptr, CC_RISCV);

if (CallConv == CallingConv::GHC && !RVLocs.empty())
report_fatal_error("GHC functions return void only");
Expand Down
19 changes: 16 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -15,6 +15,7 @@
#define LLVM_LIB_TARGET_RISCV_RISCVISELLOWERING_H

#include "RISCV.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLowering.h"

Expand Down Expand Up @@ -484,12 +485,24 @@ class RISCVTargetLowering : public TargetLowering {
bool shouldRemoveExtendFromGSIndex(EVT VT) const override;

private:
/// RISCVCCAssignFn - This target-specific function extends the default
/// CCValAssign with additional information used to lower RISC-V calling
/// conventions.
typedef bool RISCVCCAssignFn(const DataLayout &DL, RISCVABI::ABI,
unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State,
bool IsFixed, bool IsRet, Type *OrigTy,
const RISCVTargetLowering &TLI,
Optional<unsigned> FirstMaskArgument);

void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
const SmallVectorImpl<ISD::InputArg> &Ins,
bool IsRet) const;
const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
RISCVCCAssignFn Fn) const;
void analyzeOutputArgs(MachineFunction &MF, CCState &CCInfo,
const SmallVectorImpl<ISD::OutputArg> &Outs,
bool IsRet, CallLoweringInfo *CLI) const;
bool IsRet, CallLoweringInfo *CLI,
RISCVCCAssignFn Fn) const;

template <class NodeTy>
SDValue getAddr(NodeTy *N, SelectionDAG &DAG, bool IsLocal = true) const;
Expand Down

0 comments on commit 4f500c4

Please sign in to comment.