Skip to content

Commit

Permalink
[RISCV][GlobalISel] Add lowerReturn for calling conv
Browse files Browse the repository at this point in the history
Add minimal support to lower return, and introduce an OutgoingValueHandler and an OutgoingValueAssigner for returns.

Supports return values with integer, pointer and aggregate types.

(Update of D69808 - avoiding commandeering that revision)

Co-authored By: lewis-revill

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D117318
  • Loading branch information
nitinjohnraj committed May 23, 2023
1 parent d55982a commit 991ecfb
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 17 deletions.
99 changes: 96 additions & 3 deletions llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
Expand Up @@ -14,22 +14,115 @@

#include "RISCVCallLowering.h"
#include "RISCVISelLowering.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"

using namespace llvm;

namespace {

struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
private:
// The function used internally to assign args - we ignore the AssignFn stored
// by OutgoingValueAssigner since RISC-V implements its CC using a custom
// function with a different signature.
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn;

// Whether this is assigning args for a return.
bool IsRet;

public:
RISCVOutgoingValueAssigner(
RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
: CallLowering::OutgoingValueAssigner(nullptr),
RISCVAssignFn(RISCVAssignFn_), IsRet(IsRet) {}

bool assignArg(unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
const CallLowering::ArgInfo &Info, ISD::ArgFlagsTy Flags,
CCState &State) override {
MachineFunction &MF = State.getMachineFunction();
const DataLayout &DL = MF.getDataLayout();
const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();

return RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
LocInfo, Flags, State, /*IsFixed=*/true, IsRet,
Info.Ty, *Subtarget.getTargetLowering(),
/*FirstMaskArgument=*/std::nullopt);
}
};

struct RISCVOutgoingValueHandler : public CallLowering::OutgoingValueHandler {
RISCVOutgoingValueHandler(MachineIRBuilder &B, MachineRegisterInfo &MRI,
MachineInstrBuilder MIB)
: OutgoingValueHandler(B, MRI), MIB(MIB) {}

MachineInstrBuilder MIB;

Register getStackAddress(uint64_t MemSize, int64_t Offset,
MachinePointerInfo &MPO,
ISD::ArgFlagsTy Flags) override {
llvm_unreachable("not implemented");
}

void assignValueToAddress(Register ValVReg, Register Addr, LLT MemTy,
MachinePointerInfo &MPO, CCValAssign &VA) override {
llvm_unreachable("not implemented");
}

void assignValueToReg(Register ValVReg, Register PhysReg,
CCValAssign VA) override {
Register ExtReg = extendRegister(ValVReg, VA);
MIRBuilder.buildCopy(PhysReg, ExtReg);
MIB.addUse(PhysReg, RegState::Implicit);
}
};

} // namespace

RISCVCallLowering::RISCVCallLowering(const RISCVTargetLowering &TLI)
: CallLowering(&TLI) {}

bool RISCVCallLowering::lowerReturnVal(MachineIRBuilder &MIRBuilder,
const Value *Val,
ArrayRef<Register> VRegs,
MachineInstrBuilder &Ret) const {
if (!Val)
return true;

// TODO: Only integer, pointer and aggregate types are supported now.
if (!Val->getType()->isIntOrPtrTy() && !Val->getType()->isAggregateType())
return false;

MachineFunction &MF = MIRBuilder.getMF();
const DataLayout &DL = MF.getDataLayout();
const Function &F = MF.getFunction();
CallingConv::ID CC = F.getCallingConv();

ArgInfo OrigRetInfo(VRegs, Val->getType(), 0);
setArgFlags(OrigRetInfo, AttributeList::ReturnIndex, DL, F);

SmallVector<ArgInfo, 4> SplitRetInfos;
splitToValueTypes(OrigRetInfo, SplitRetInfos, DL, CC);

RISCVOutgoingValueAssigner Assigner(
CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
/*IsRet=*/true);
RISCVOutgoingValueHandler Handler(MIRBuilder, MF.getRegInfo(), Ret);
return determineAndHandleAssignments(Handler, Assigner, SplitRetInfos,
MIRBuilder, CC, F.isVarArg());
}

bool RISCVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
const Value *Val, ArrayRef<Register> VRegs,
FunctionLoweringInfo &FLI) const {

assert(!Val == VRegs.empty() && "Return value without a vreg");
MachineInstrBuilder Ret = MIRBuilder.buildInstrNoInsert(RISCV::PseudoRET);

if (Val != nullptr) {
if (!lowerReturnVal(MIRBuilder, Val, VRegs, Ret))
return false;
}

MIRBuilder.insertInstr(Ret);
return true;
}
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVCallLowering.h
Expand Up @@ -16,10 +16,11 @@

#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/GlobalISel/CallLowering.h"
#include "llvm/CodeGen/ValueTypes.h"

namespace llvm {

class MachineInstrBuilder;
class MachineIRBuilder;
class RISCVTargetLowering;

class RISCVCallLowering : public CallLowering {
Expand All @@ -37,6 +38,10 @@ class RISCVCallLowering : public CallLowering {

bool lowerCall(MachineIRBuilder &MIRBuilder,
CallLoweringInfo &Info) const override;

private:
bool lowerReturnVal(MachineIRBuilder &MIRBuilder, const Value *Val,
ArrayRef<Register> VRegs, MachineInstrBuilder &Ret) const;
};

} // end namespace llvm
Expand Down
24 changes: 12 additions & 12 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -13722,7 +13722,7 @@ static unsigned allocateRVVReg(MVT ValVT, unsigned ValNo,
}

// Implements the RISC-V calling convention. Returns true upon failure.
static bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
Expand Down Expand Up @@ -14175,7 +14175,7 @@ static SDValue unpackF64OnRV32DSoftABI(SelectionDAG &DAG, SDValue Chain,

// FastCC has less than 1% performance improvement for some particular
// benchmark. But theoretically, it may has benenfit for some cases.
static bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State,
Expand Down Expand Up @@ -14277,7 +14277,7 @@ static bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
return true; // CC didn't match.
}

static bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
bool RISCV::CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State) {

Expand Down Expand Up @@ -14371,11 +14371,11 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());

if (CallConv == CallingConv::GHC)
CCInfo.AnalyzeFormalArguments(Ins, CC_RISCV_GHC);
CCInfo.AnalyzeFormalArguments(Ins, RISCV::CC_RISCV_GHC);
else
analyzeInputArgs(MF, CCInfo, Ins, /*IsRet=*/false,
CallConv == CallingConv::Fast ? CC_RISCV_FastCC
: CC_RISCV);
CallConv == CallingConv::Fast ? RISCV::CC_RISCV_FastCC
: RISCV::CC_RISCV);

for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
CCValAssign &VA = ArgLocs[i];
Expand Down Expand Up @@ -14576,11 +14576,11 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());

if (CallConv == CallingConv::GHC)
ArgCCInfo.AnalyzeCallOperands(Outs, CC_RISCV_GHC);
ArgCCInfo.AnalyzeCallOperands(Outs, RISCV::CC_RISCV_GHC);
else
analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI,
CallConv == CallingConv::Fast ? CC_RISCV_FastCC
: CC_RISCV);
CallConv == CallingConv::Fast ? RISCV::CC_RISCV_FastCC
: RISCV::CC_RISCV);

// Check if it's really possible to do a tail call.
if (IsTailCall)
Expand Down Expand Up @@ -14824,7 +14824,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
// Assign locations to each value returned by this call.
SmallVector<CCValAssign, 16> RVLocs;
CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext());
analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, CC_RISCV);
analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, RISCV::CC_RISCV);

// Copy all of the result registers out of their specified physreg.
for (auto &VA : RVLocs) {
Expand Down Expand Up @@ -14867,7 +14867,7 @@ bool RISCVTargetLowering::CanLowerReturn(
MVT VT = Outs[i].VT;
ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
RISCVABI::ABI ABI = MF.getSubtarget<RISCVSubtarget>().getTargetABI();
if (CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
if (RISCV::CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full,
ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr,
*this, FirstMaskArgument))
return false;
Expand All @@ -14892,7 +14892,7 @@ RISCVTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
*DAG.getContext());

analyzeOutputArgs(DAG.getMachineFunction(), CCInfo, Outs, /*IsRet=*/true,
nullptr, CC_RISCV);
nullptr, RISCV::CC_RISCV);

if (CallConv == CallingConv::GHC && !RVLocs.empty())
report_fatal_error("GHC functions return void only");
Expand Down
22 changes: 21 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -735,7 +735,6 @@ class RISCVTargetLowering : public TargetLowering {
bool lowerInterleavedStore(StoreInst *SI, ShuffleVectorInst *SVI,
unsigned Factor) const override;

private:
/// RISCVCCAssignFn - This target-specific function extends the default
/// CCValAssign with additional information used to lower RISC-V calling
/// conventions.
Expand All @@ -747,6 +746,7 @@ class RISCVTargetLowering : public TargetLowering {
const RISCVTargetLowering &TLI,
std::optional<unsigned> FirstMaskArgument);

private:
void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
RISCVCCAssignFn Fn) const;
Expand Down Expand Up @@ -874,6 +874,26 @@ class RISCVTargetLowering : public TargetLowering {
/// faster than two FDIVs.
unsigned combineRepeatedFPDivisors() const override;
};

namespace RISCV {

bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
std::optional<unsigned> FirstMaskArgument);

bool CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed,
bool IsRet, Type *OrigTy, const RISCVTargetLowering &TLI,
std::optional<unsigned> FirstMaskArgument);

bool CC_RISCV_GHC(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags,
CCState &State);
} // end namespace RISCV

namespace RISCVVIntrinsicsTable {

struct RISCVVIntrinsicInfo {
Expand Down

0 comments on commit 991ecfb

Please sign in to comment.