Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,21 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16,
NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64,
NVPTX::StoreRetvalF32, NVPTX::StoreRetvalF64);
if (Opcode == NVPTX::StoreRetvalI8) {
// Fine tune the opcode depending on the size of the operand.
// This helps to avoid creating redundant COPY instructions in
// InstrEmitter::AddRegisterOperand().
switch (Ops[0].getSimpleValueType().SimpleTy) {
default:
break;
case MVT::i32:
Opcode = NVPTX::StoreRetvalI8TruncI32;
break;
case MVT::i64:
Opcode = NVPTX::StoreRetvalI8TruncI64;
break;
}
}
break;
case 2:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
Expand Down Expand Up @@ -2211,6 +2226,21 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
NVPTX::StoreParamI8, NVPTX::StoreParamI16,
NVPTX::StoreParamI32, NVPTX::StoreParamI64,
NVPTX::StoreParamF32, NVPTX::StoreParamF64);
if (Opcode == NVPTX::StoreParamI8) {
// Fine tune the opcode depending on the size of the operand.
// This helps to avoid creating redundant COPY instructions in
// InstrEmitter::AddRegisterOperand().
switch (Ops[0].getSimpleValueType().SimpleTy) {
default:
break;
case MVT::i32:
Opcode = NVPTX::StoreParamI8TruncI32;
break;
case MVT::i64:
Opcode = NVPTX::StoreParamI8TruncI64;
break;
}
}
break;
case 2:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
Expand Down
257 changes: 237 additions & 20 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
Expand All @@ -59,6 +60,7 @@
#include <cmath>
#include <cstdint>
#include <iterator>
#include <optional>
#include <sstream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -1529,6 +1531,105 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
return DL.getABITypeAlign(Ty);
}

static bool adjustElementType(EVT &ElementType) {
switch (ElementType.getSimpleVT().SimpleTy) {
default:
return false;
case MVT::f16:
case MVT::bf16:
ElementType = MVT::i16;
return true;
case MVT::f32:
case MVT::v2f16:
case MVT::v2bf16:
ElementType = MVT::i32;
return true;
case MVT::f64:
ElementType = MVT::i64;
return true;
}
}

// Use byte-store when the param address of the argument value is unaligned.
// This may happen when the return value is a field of a packed structure.
//
// This is called in LowerCall() when passing the param values.
static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
uint64_t Offset, EVT ElementType,
SDValue StVal, SDValue &InGlue,
unsigned ArgID, const SDLoc &dl) {
// Bit logic only works on integer types
if (adjustElementType(ElementType))
StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);

// Store each byte
SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
// Shift the byte to the last byte position
SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
DAG.getConstant(i * 8, dl, MVT::i32));
SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
DAG.getConstant(Offset + i, dl, MVT::i32),
ShiftVal, InGlue};
// Trunc store only the last byte by using
// st.param.b8
// The register type can be larger than b8.
Chain = DAG.getMemIntrinsicNode(
NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
InGlue = Chain.getValue(1);
}
return Chain;
}

// Use byte-load when the param adress of the returned value is unaligned.
// This may happen when the returned value is a field of a packed structure.
static SDValue
LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
EVT ElementType, SDValue &InGlue,
SmallVectorImpl<SDValue> &TempProxyRegOps,
const SDLoc &dl) {
// Bit logic only works on integer types
EVT MergedType = ElementType;
adjustElementType(MergedType);

// Load each byte and construct the whole value. Initial value to 0
SDValue RetVal = DAG.getConstant(0, dl, MergedType);
// LoadParamMemI8 loads into i16 register only
SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
DAG.getConstant(Offset + i, dl, MVT::i32),
InGlue};
// This will be selected to LoadParamMemI8
SDValue LdVal =
DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
MVT::i8, MachinePointerInfo(), Align(1));
SDValue TmpLdVal = LdVal.getValue(0);
Chain = LdVal.getValue(1);
InGlue = LdVal.getValue(2);

TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
TmpLdVal.getSimpleValueType(), TmpLdVal);
TempProxyRegOps.push_back(TmpLdVal);

SDValue CMask = DAG.getConstant(255, dl, MergedType);
SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
// Need to extend the i16 register to the whole width.
TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
// Mask off the high bits. Leave only the lower 8bits.
// Do this because we are using loadparam.b8.
TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
// Shift and merge
TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
}
if (ElementType != MergedType)
RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);

return RetVal;
}

SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {

Expand Down Expand Up @@ -1680,17 +1781,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (NeedAlign)
PartAlign = commonAlignment(ArgAlign, CurOffset);

// New store.
if (VectorInfo[j] & PVF_FIRST) {
assert(StoreOperands.empty() && "Unfinished preceding store.");
StoreOperands.push_back(Chain);
StoreOperands.push_back(
DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));
StoreOperands.push_back(DAG.getConstant(
IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
dl, MVT::i32));
}

SDValue StVal = OutVals[OIdx];

MVT PromotedVT;
Expand Down Expand Up @@ -1723,6 +1813,35 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
}

// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar store. In such cases, fall back to byte stores.
if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
PartAlign.value() <
DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
assert(StoreOperands.empty() && "Unfinished preceeding store.");
Chain = LowerUnalignedStoreParam(
DAG, Chain, IsByVal ? CurOffset + VAOffset : CurOffset, EltVT,
StVal, InGlue, ParamCount, dl);

// LowerUnalignedStoreParam took care of inserting the necessary nodes
// into the SDAG, so just move on to the next element.
if (!IsByVal)
++OIdx;
continue;
}

// New store.
if (VectorInfo[j] & PVF_FIRST) {
assert(StoreOperands.empty() && "Unfinished preceding store.");
StoreOperands.push_back(Chain);
StoreOperands.push_back(
DAG.getConstant(IsVAArg ? FirstVAArg : ParamCount, dl, MVT::i32));

StoreOperands.push_back(DAG.getConstant(
IsByVal ? CurOffset + VAOffset : (IsVAArg ? VAOffset : CurOffset),
dl, MVT::i32));
}

// Record the value to store.
StoreOperands.push_back(StVal);

Expand Down Expand Up @@ -1923,6 +2042,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,

SmallVector<SDValue, 16> ProxyRegOps;
SmallVector<std::optional<MVT>, 16> ProxyRegTruncates;
// An item of the vector is filled if the element does not need a ProxyReg
// operation on it and should be added to InVals as is. ProxyRegOps and
// ProxyRegTruncates contain empty/none items at the same index.
SmallVector<SDValue, 16> RetElts;
// A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
// to use the values of `LoadParam`s and to be replaced later then
// `CALLSEQ_END` is added.
SmallVector<SDValue, 16> TempProxyRegOps;

// Generate loads from param memory/moves from registers for result
if (Ins.size() > 0) {
Expand Down Expand Up @@ -1966,6 +2093,22 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
EltType = MVT::i16;
}

// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar load. In such cases, fall back to byte loads.
if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType() &&
EltAlign < DL.getABITypeAlign(
TheLoadType.getTypeForEVT(*DAG.getContext()))) {
assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
SDValue Ret = LowerUnalignedLoadRetParam(
DAG, Chain, Offsets[i], TheLoadType, InGlue, TempProxyRegOps, dl);
ProxyRegOps.push_back(SDValue());
ProxyRegTruncates.push_back(std::optional<MVT>());
RetElts.resize(i);
RetElts.push_back(Ret);

continue;
}

// Record index of the very first element of the vector.
if (VectorInfo[i] & PVF_FIRST) {
assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
Expand Down Expand Up @@ -2028,6 +2171,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// will not get lost. Otherwise, during libcalls expansion, the nodes can become
// dangling.
for (unsigned i = 0; i < ProxyRegOps.size(); ++i) {
if (i < RetElts.size() && RetElts[i]) {
InVals.push_back(RetElts[i]);
continue;
}

SDValue Ret = DAG.getNode(
NVPTXISD::ProxyReg, dl,
DAG.getVTList(ProxyRegOps[i].getSimpleValueType(), MVT::Other, MVT::Glue),
Expand All @@ -2044,6 +2192,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
InVals.push_back(Ret);
}

for (SDValue &T : TempProxyRegOps) {
SDValue Repl = DAG.getNode(
NVPTXISD::ProxyReg, dl,
DAG.getVTList(T.getSimpleValueType(), MVT::Other, MVT::Glue),
{Chain, T.getOperand(0), InGlue});
DAG.ReplaceAllUsesWith(T, Repl);
DAG.RemoveDeadNode(T.getNode());

Chain = Repl.getValue(1);
InGlue = Repl.getValue(2);
}

// set isTailCall to false for now, until we figure out how to express
// tail call optimization in PTX
isTailCall = false;
Expand Down Expand Up @@ -3045,9 +3205,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
Value *srcValue = Constant::getNullValue(PointerType::get(
EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));

const MaybeAlign PartAlign = [&]() -> MaybeAlign {
if (aggregateIsPacked)
return Align(1);
if (NumElts != 1)
return std::nullopt;
Align PartAlign =
(Offsets[parti] == 0 && PAL.getParamAlignment(i))
? PAL.getParamAlignment(i).value()
: DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
return commonAlignment(PartAlign, Offsets[parti]);
}();
SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
MachinePointerInfo(srcValue),
MaybeAlign(aggregateIsPacked ? 1 : 0),
MachinePointerInfo(srcValue), PartAlign,
MachineMemOperand::MODereferenceable |
MachineMemOperand::MOInvariant);
if (P.getNode())
Expand Down Expand Up @@ -3113,6 +3284,33 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
return Chain;
}

// Use byte-store when the param adress of the return value is unaligned.
// This may happen when the return value is a field of a packed structure.
static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain,
uint64_t Offset, EVT ElementType,
SDValue RetVal, const SDLoc &dl) {
// Bit logic only works on integer types
if (adjustElementType(ElementType))
RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);

// Store each byte
for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
// Shift the byte to the last byte position
SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
DAG.getConstant(i * 8, dl, MVT::i32));
SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
ShiftVal};
// Trunc store only the last byte by using
// st.param.b8
// The register type can be larger than b8.
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
DAG.getVTList(MVT::Other), StoreOperands,
MVT::i8, MachinePointerInfo(), std::nullopt,
MachineMemOperand::MOStore);
}
return Chain;
}

SDValue
NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
bool isVarArg,
Expand Down Expand Up @@ -3162,13 +3360,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,

SmallVector<SDValue, 6> StoreOperands;
for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
// New load/store. Record chain and offset operands.
if (VectorInfo[i] & PVF_FIRST) {
assert(StoreOperands.empty() && "Orphaned operand list.");
StoreOperands.push_back(Chain);
StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
}

SDValue OutVal = OutVals[i];
SDValue RetVal = PromotedOutVals[i];

Expand All @@ -3182,6 +3373,32 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
}

// If we have a PVF_SCALAR entry, it may not even be sufficiently aligned
// for a scalar store. In such cases, fall back to byte stores.
if (VectorInfo[i] == PVF_SCALAR && RetTy->isAggregateType()) {
EVT ElementType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
Align ElementTypeAlign =
DL.getABITypeAlign(ElementType.getTypeForEVT(RetTy->getContext()));
Align ElementAlign =
commonAlignment(DL.getABITypeAlign(RetTy), Offsets[i]);
if (ElementAlign < ElementTypeAlign) {
assert(StoreOperands.empty() && "Orphaned operand list.");
Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[i], ElementType,
RetVal, dl);

// The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
// into the graph, so just move on to the next element.
continue;
}
}

// New load/store. Record chain and offset operands.
if (VectorInfo[i] & PVF_FIRST) {
assert(StoreOperands.empty() && "Orphaned operand list.");
StoreOperands.push_back(Chain);
StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
}

// Record the value to return.
StoreOperands.push_back(RetVal);

Expand Down
Loading