Skip to content

Commit

Permalink
[RISCV] Lower the tail pseudoinstruction
Browse files Browse the repository at this point in the history
This patch lowers the tail pseudoinstruction. This has been modeled after ARM's
tail call opt.

llvm-svn: 333137
  • Loading branch information
Mandeep Singh Grang committed May 23, 2018
1 parent 767d92e commit ddcb956
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 17 deletions.
14 changes: 7 additions & 7 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ MCCodeEmitter *llvm::createRISCVMCCodeEmitter(const MCInstrInfo &MCII,
return new RISCVMCCodeEmitter(Ctx, MCII);
}

// Expand PseudoCALL to AUIPC and JALR with relocation types.
// We expand PseudoCALL while encoding, meaning AUIPC and JALR won't go through
// RISCV MC to MC compressed instruction transformation. This is acceptable
// because AUIPC has no 16-bit form and C_JALR have no immediate operand field.
// We let linker relaxation deal with it. When linker relaxation enabled,
// AUIPC and JALR have chance relax to JAL. If C extension is enabled,
// JAL has chance relax to C_JAL.
// Expand PseudoCALL and PseudoTAIL to AUIPC and JALR with relocation types.
// We expand PseudoCALL and PseudoTAIL while encoding, meaning AUIPC and JALR
// won't go through RISCV MC to MC compressed instruction transformation. This
// is acceptable because AUIPC has no 16-bit form and C_JALR have no immediate
// operand field. We let linker relaxation deal with it. When linker
// relaxation enabled, AUIPC and JALR have chance relax to JAL. If C extension
// is enabled, JAL has chance relax to C_JAL.
void RISCVMCCodeEmitter::expandFunctionCall(const MCInst &MI, raw_ostream &OS,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const {
Expand Down
125 changes: 117 additions & 8 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "RISCVRegisterInfo.h"
#include "RISCVSubtarget.h"
#include "RISCVTargetMachine.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
Expand All @@ -36,6 +37,8 @@ using namespace llvm;

#define DEBUG_TYPE "riscv-lower"

STATISTIC(NumTailCalls, "Number of tail calls");

RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
const RISCVSubtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
Expand Down Expand Up @@ -1076,6 +1079,88 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
return Chain;
}

/// IsEligibleForTailCallOptimization - Check whether the call is eligible
/// for tail call optimization.
/// Note: This is modelled after ARM's IsEligibleForTailCallOptimization.
bool RISCVTargetLowering::IsEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
const SmallVector<CCValAssign, 16> &ArgLocs) const {

auto &Callee = CLI.Callee;
auto CalleeCC = CLI.CallConv;
auto IsVarArg = CLI.IsVarArg;
auto &Outs = CLI.Outs;
auto &Caller = MF.getFunction();
auto CallerCC = Caller.getCallingConv();

// Do not tail call opt functions with "disable-tail-calls" attribute.
if (Caller.getFnAttribute("disable-tail-calls").getValueAsString() == "true")
return false;

// Exception-handling functions need a special set of instructions to
// indicate a return to the hardware. Tail-calling another function would
// probably break this.
// TODO: The "interrupt" attribute isn't currently defined by RISC-V. This
// should be expanded as new function attributes are introduced.
if (Caller.hasFnAttribute("interrupt"))
return false;

// Do not tail call opt functions with varargs.
if (IsVarArg)
return false;

// Do not tail call opt if the stack is used to pass parameters.
if (CCInfo.getNextStackOffset() != 0)
return false;

// Do not tail call opt if any parameters need to be passed indirectly.
// Since long doubles (fp128) and i128 are larger than 2*XLEN, they are
// passed indirectly. So the address of the value will be passed in a
// register, or if not available, then the address is put on the stack. In
// order to pass indirectly, space on the stack often needs to be allocated
// in order to store the value. In this case the CCInfo.getNextStackOffset()
// != 0 check is not enough and we need to check if any CCValAssign ArgsLocs
// are passed CCValAssign::Indirect.
for (auto &VA : ArgLocs)
if (VA.getLocInfo() == CCValAssign::Indirect)
return false;

// Do not tail call opt if either caller or callee uses struct return
// semantics.
auto IsCallerStructRet = Caller.hasStructRetAttr();
auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
if (IsCallerStructRet || IsCalleeStructRet)
return false;

// Externally-defined functions with weak linkage should not be
// tail-called. The behaviour of branch instructions in this situation (as
// used for tail calls) is implementation-defined, so we cannot rely on the
// linker replacing the tail call with a return.
if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
const GlobalValue *GV = G->getGlobal();
if (GV->hasExternalWeakLinkage())
return false;
}

// The callee has to preserve all registers the caller needs to preserve.
const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
if (CalleeCC != CallerCC) {
const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
return false;
}

// Byval parameters hand the function a pointer directly into the stack area
// we want to reuse during a tail call. Working around this *is* possible
// but less efficient and uglier in LowerCall.
for (auto &Arg : Outs)
if (Arg.Flags.isByVal())
return false;

return true;
}

// Lower a call to a callseq_start + CALL + callseq_end chain, and add input
// and output parameter nodes.
SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
Expand All @@ -1087,7 +1172,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
SDValue Chain = CLI.Chain;
SDValue Callee = CLI.Callee;
CLI.IsTailCall = false;
bool &IsTailCall = CLI.IsTailCall;
CallingConv::ID CallConv = CLI.CallConv;
bool IsVarArg = CLI.IsVarArg;
EVT PtrVT = getPointerTy(DAG.getDataLayout());
Expand All @@ -1100,6 +1185,17 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI);

// Check if it's really possible to do a tail call.
if (IsTailCall)
IsTailCall = IsEligibleForTailCallOptimization(ArgCCInfo, CLI, MF,
ArgLocs);

if (IsTailCall)
++NumTailCalls;
else if (CLI.CS && CLI.CS.isMustTailCall())
report_fatal_error("failed to perform tail call elimination on a call "
"site marked musttail");

// Get a count of how many bytes are to be pushed on the stack.
unsigned NumBytes = ArgCCInfo.getNextStackOffset();

Expand All @@ -1121,12 +1217,13 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Align,
/*IsVolatile=*/false,
/*AlwaysInline=*/false,
/*isTailCall=*/false, MachinePointerInfo(),
IsTailCall, MachinePointerInfo(),
MachinePointerInfo());
ByValArgs.push_back(FIPtr);
}

Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
if (!IsTailCall)
Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);

// Copy argument values to their designated locations.
SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
Expand Down Expand Up @@ -1213,6 +1310,8 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue));
} else {
assert(VA.isMemLoc() && "Argument not register or memory");
assert(!IsTailCall && "Tail call not allowed if stack is used "
"for passing parameters");

// Work out the address of the stack slot.
if (!StackPtr.getNode())
Expand Down Expand Up @@ -1258,18 +1357,26 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
for (auto &Reg : RegsToPass)
Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType()));

// Add a register mask operand representing the call-preserved registers.
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
assert(Mask && "Missing call preserved mask for calling convention");
Ops.push_back(DAG.getRegisterMask(Mask));
if (!IsTailCall) {
// Add a register mask operand representing the call-preserved registers.
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
assert(Mask && "Missing call preserved mask for calling convention");
Ops.push_back(DAG.getRegisterMask(Mask));
}

// Glue the call to the argument copies, if any.
if (Glue.getNode())
Ops.push_back(Glue);

// Emit the call.
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);

if (IsTailCall) {
MF.getFrameInfo().setHasTailCall();
return DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
}

Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
Glue = Chain.getValue(1);

Expand Down Expand Up @@ -1425,6 +1532,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "RISCVISD::BuildPairF64";
case RISCVISD::SplitF64:
return "RISCVISD::SplitF64";
case RISCVISD::TAIL:
return "RISCVISD::TAIL";
}
return nullptr;
}
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ enum NodeType : unsigned {
CALL,
SELECT_CC,
BuildPairF64,
SplitF64
SplitF64,
TAIL
};
}

Expand Down Expand Up @@ -100,6 +101,10 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerVASTART(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const;

bool IsEligibleForTailCallOptimization(CCState &CCInfo,
CallLoweringInfo &CLI, MachineFunction &MF,
const SmallVector<CCValAssign, 16> &ArgLocs) const;
};
}

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ unsigned RISCVInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
case TargetOpcode::DBG_VALUE:
return 0;
case RISCV::PseudoCALL:
case RISCV::PseudoTAIL:
return 8;
case TargetOpcode::INLINEASM: {
const MachineFunction &MF = *MI.getParent()->getParent();
Expand Down
14 changes: 13 additions & 1 deletion llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def RetFlag : SDNode<"RISCVISD::RET_FLAG", SDTNone,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
def SelectCC : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC,
[SDNPInGlue]>;
def Tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;

//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
Expand Down Expand Up @@ -665,11 +668,20 @@ def PseudoRET : Pseudo<(outs), (ins), [(RetFlag)]>,
// expand to auipc and jalr while encoding.
// Define AsmString to print "tail" when compile with -S flag.
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [X2],
hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCodeGenOnly = 0 in
isCodeGenOnly = 0 in
def PseudoTAIL : Pseudo<(outs), (ins bare_symbol:$dst), []> {
let AsmString = "tail\t$dst";
}

let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [X2] in
def PseudoTAILIndirect : Pseudo<(outs), (ins GPRTC:$rs1), [(Tail GPRTC:$rs1)]>,
PseudoInstExpansion<(JALR X0, GPR:$rs1, 0)>;

def : Pat<(Tail (iPTR tglobaladdr:$dst)),
(PseudoTAIL texternalsym:$dst)>;
def : Pat<(Tail (iPTR texternalsym:$dst)),
(PseudoTAIL texternalsym:$dst)>;

/// Loads

multiclass LdPat<PatFrag LoadOp, RVInst Inst> {
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ def GPRC : RegisterClass<"RISCV", [XLenVT], 32, (add
[RegInfo<32,32,32>, RegInfo<64,64,64>, RegInfo<32,32,32>]>;
}

// For indirect tail calls, we can't use callee-saved registers, as they are
// restored to the saved value before the tail call, which would clobber a call
// address.
def GPRTC : RegisterClass<"RISCV", [XLenVT], 32, (add
(sequence "X%u", 5, 7),
(sequence "X%u", 10, 17),
(sequence "X%u", 28, 31)
)> {
let RegInfos = RegInfoByHwMode<
[RV32, RV64, DefaultMode],
[RegInfo<32,32,32>, RegInfo<64,64,64>, RegInfo<32,32,32>]>;
}

def SP : RegisterClass<"RISCV", [XLenVT], 32, (add X2)> {
let RegInfos = RegInfoByHwMode<
[RV32, RV64, DefaultMode],
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/RISCV/disable-tail-calls.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; Check that command line option "-disable-tail-calls" overrides function
; attribute "disable-tail-calls".

; RUN: llc < %s -mtriple=riscv32-unknown-elf \
; RUN: | FileCheck %s --check-prefixes=CALLER1,NOTAIL
; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \
; RUN: | FileCheck %s --check-prefixes=CALLER1,NOTAIL
; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \
; RUN: | FileCheck %s --check-prefixes=CALLER1,TAIL

; RUN: llc < %s -mtriple=riscv32-unknown-elf \
; RUN: | FileCheck %s --check-prefixes=CALLER2,TAIL
; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \
; RUN: | FileCheck %s --check-prefixes=CALLER2,NOTAIL
; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \
; RUN: | FileCheck %s --check-prefixes=CALLER2,TAIL

; RUN: llc < %s -mtriple=riscv32-unknown-elf \
; RUN: | FileCheck %s --check-prefixes=CALLER3,TAIL
; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \
; RUN: | FileCheck %s --check-prefixes=CALLER3,NOTAIL
; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \
; RUN: | FileCheck %s --check-prefixes=CALLER3,TAIL

; CALLER1-LABEL: {{\_?}}caller1
; CALLER2-LABEL: {{\_?}}caller2
; CALLER3-LABEL: {{\_?}}caller3
; NOTAIL-NOT: tail callee
; NOTAIL: call callee
; TAIL: tail callee
; TAIL-NOT: call callee

; Function with attribute #0 = { "disable-tail-calls"="true" }
define i32 @caller1(i32 %a) #0 {
entry:
%call = tail call i32 @callee(i32 %a)
ret i32 %call
}

; Function with attribute #1 = { "disable-tail-calls"="false" }
define i32 @caller2(i32 %a) #0 {
entry:
%call = tail call i32 @callee(i32 %a)
ret i32 %call
}

define i32 @caller3(i32 %a) {
entry:
%call = tail call i32 @callee(i32 %a)
ret i32 %call
}

declare i32 @callee(i32)

attributes #0 = { "disable-tail-calls"="true" }
attributes #1 = { "disable-tail-calls"="false" }
20 changes: 20 additions & 0 deletions llvm/test/CodeGen/RISCV/musttail-call.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; Check that we error out if tail is not possible but call is marked as mustail.

; RUN: not llc -mtriple riscv32-unknown-linux-gnu -o - %s \
; RUN: 2>&1 | FileCheck %s
; RUN: not llc -mtriple riscv32-unknown-elf -o - %s \
; RUN: 2>&1 | FileCheck %s
; RUN: not llc -mtriple riscv64-unknown-linux-gnu -o - %s \
; RUN: 2>&1 | FileCheck %s
; RUN: not llc -mtriple riscv64-unknown-elf -o - %s \
; RUN: 2>&1 | FileCheck %s

%struct.A = type { i32 }

declare void @callee_musttail(%struct.A* sret %a)
define void @caller_musttail(%struct.A* sret %a) {
; CHECK: LLVM ERROR: failed to perform tail call elimination on a call site marked musttail
entry:
musttail call void @callee_musttail(%struct.A* sret %a)
ret void
}
Loading

0 comments on commit ddcb956

Please sign in to comment.