Skip to content

Commit

Permalink
[SelectionDAG] Unify scalarizeVectorLoad and VectorLegalizer::ExpandLoad
Browse files Browse the repository at this point in the history
The two code paths have the same goal, legalizing a load of a non-byte-sized vector by loading the "flattened" representation in memory, slicing off each single element and then building a vector out of those pieces.

The technique employed by `ExpandLoad`  is slightly more convoluted and produces slightly better codegen on ARM, AMDGPU and x86 but suffers from some bugs (D78480) and is wrong for BE machines.

Differential Revision: https://reviews.llvm.org/D79096
  • Loading branch information
LemonBoy authored and topperc committed May 2, 2020
1 parent 3542384 commit 6d103ca
Show file tree
Hide file tree
Showing 12 changed files with 972 additions and 1,031 deletions.
126 changes: 1 addition & 125 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Expand Up @@ -702,131 +702,7 @@ void VectorLegalizer::PromoteFP_TO_INT(SDNode *Node,

std::pair<SDValue, SDValue> VectorLegalizer::ExpandLoad(SDNode *N) {
LoadSDNode *LD = cast<LoadSDNode>(N);

EVT SrcVT = LD->getMemoryVT();
EVT SrcEltVT = SrcVT.getScalarType();
unsigned NumElem = SrcVT.getVectorNumElements();

SDValue NewChain;
SDValue Value;
if (SrcVT.getVectorNumElements() > 1 && !SrcEltVT.isByteSized()) {
SDLoc dl(N);

SmallVector<SDValue, 8> Vals;
SmallVector<SDValue, 8> LoadChains;

EVT DstEltVT = LD->getValueType(0).getScalarType();
SDValue Chain = LD->getChain();
SDValue BasePTR = LD->getBasePtr();
ISD::LoadExtType ExtType = LD->getExtensionType();

// When elements in a vector is not byte-addressable, we cannot directly
// load each element by advancing pointer, which could only address bytes.
// Instead, we load all significant words, mask bits off, and concatenate
// them to form each element. Finally, they are extended to destination
// scalar type to build the destination vector.
EVT WideVT = TLI.getPointerTy(DAG.getDataLayout());

assert(WideVT.isRound() &&
"Could not handle the sophisticated case when the widest integer is"
" not power of 2.");
assert(WideVT.bitsGE(SrcEltVT) &&
"Type is not legalized?");

unsigned WideBytes = WideVT.getStoreSize();
unsigned Offset = 0;
unsigned RemainingBytes = SrcVT.getStoreSize();
SmallVector<SDValue, 8> LoadVals;
while (RemainingBytes > 0) {
SDValue ScalarLoad;
unsigned LoadBytes = WideBytes;

if (RemainingBytes >= LoadBytes) {
ScalarLoad = DAG.getLoad(
WideVT, dl, Chain, BasePTR,
LD->getPointerInfo().getWithOffset(Offset), LD->getOriginalAlign(),
LD->getMemOperand()->getFlags(), LD->getAAInfo());
} else {
EVT LoadVT = WideVT;
while (RemainingBytes < LoadBytes) {
LoadBytes >>= 1; // Reduce the load size by half.
LoadVT = EVT::getIntegerVT(*DAG.getContext(), LoadBytes << 3);
}
ScalarLoad =
DAG.getExtLoad(ISD::EXTLOAD, dl, WideVT, Chain, BasePTR,
LD->getPointerInfo().getWithOffset(Offset), LoadVT,
LD->getOriginalAlign(),
LD->getMemOperand()->getFlags(), LD->getAAInfo());
}

RemainingBytes -= LoadBytes;
Offset += LoadBytes;

BasePTR = DAG.getObjectPtrOffset(dl, BasePTR, LoadBytes);

LoadVals.push_back(ScalarLoad.getValue(0));
LoadChains.push_back(ScalarLoad.getValue(1));
}

unsigned BitOffset = 0;
unsigned WideIdx = 0;
unsigned WideBits = WideVT.getSizeInBits();

// Extract bits, pack and extend/trunc them into destination type.
unsigned SrcEltBits = SrcEltVT.getSizeInBits();
SDValue SrcEltBitMask = DAG.getConstant(
APInt::getLowBitsSet(WideBits, SrcEltBits), dl, WideVT);

for (unsigned Idx = 0; Idx != NumElem; ++Idx) {
assert(BitOffset < WideBits && "Unexpected offset!");

SDValue ShAmt = DAG.getConstant(
BitOffset, dl, TLI.getShiftAmountTy(WideVT, DAG.getDataLayout()));
SDValue Lo = DAG.getNode(ISD::SRL, dl, WideVT, LoadVals[WideIdx], ShAmt);

BitOffset += SrcEltBits;
if (BitOffset >= WideBits) {
WideIdx++;
BitOffset -= WideBits;
if (BitOffset > 0) {
ShAmt = DAG.getConstant(
SrcEltBits - BitOffset, dl,
TLI.getShiftAmountTy(WideVT, DAG.getDataLayout()));
SDValue Hi =
DAG.getNode(ISD::SHL, dl, WideVT, LoadVals[WideIdx], ShAmt);
Lo = DAG.getNode(ISD::OR, dl, WideVT, Lo, Hi);
}
}

Lo = DAG.getNode(ISD::AND, dl, WideVT, Lo, SrcEltBitMask);

switch (ExtType) {
default: llvm_unreachable("Unknown extended-load op!");
case ISD::EXTLOAD:
Lo = DAG.getAnyExtOrTrunc(Lo, dl, DstEltVT);
break;
case ISD::ZEXTLOAD:
Lo = DAG.getZExtOrTrunc(Lo, dl, DstEltVT);
break;
case ISD::SEXTLOAD:
ShAmt =
DAG.getConstant(WideBits - SrcEltBits, dl,
TLI.getShiftAmountTy(WideVT, DAG.getDataLayout()));
Lo = DAG.getNode(ISD::SHL, dl, WideVT, Lo, ShAmt);
Lo = DAG.getNode(ISD::SRA, dl, WideVT, Lo, ShAmt);
Lo = DAG.getSExtOrTrunc(Lo, dl, DstEltVT);
break;
}
Vals.push_back(Lo);
}

NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, LoadChains);
Value = DAG.getBuildVector(N->getValueType(0), dl, Vals);
} else {
std::tie(Value, NewChain) = TLI.scalarizeVectorLoad(LD, DAG);
}

return std::make_pair(Value, NewChain);
return TLI.scalarizeVectorLoad(LD, DAG);
}

SDValue VectorLegalizer::ExpandStore(SDNode *N) {
Expand Down
33 changes: 23 additions & 10 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Expand Up @@ -6620,27 +6620,40 @@ TargetLowering::scalarizeVectorLoad(LoadSDNode *LD,
// elements that are byte-sized must therefore be stored as an integer
// built out of the extracted vector elements.
if (!SrcEltVT.isByteSized()) {
unsigned NumBits = SrcVT.getSizeInBits();
EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), NumBits);
unsigned NumLoadBits = SrcVT.getStoreSizeInBits();
EVT LoadVT = EVT::getIntegerVT(*DAG.getContext(), NumLoadBits);

unsigned NumSrcBits = SrcVT.getSizeInBits();
EVT SrcIntVT = EVT::getIntegerVT(*DAG.getContext(), NumSrcBits);

SDValue Load = DAG.getLoad(IntVT, SL, Chain, BasePTR, LD->getPointerInfo(),
LD->getAlignment(),
LD->getMemOperand()->getFlags(),
LD->getAAInfo());
unsigned SrcEltBits = SrcEltVT.getSizeInBits();
SDValue SrcEltBitMask = DAG.getConstant(
APInt::getLowBitsSet(NumLoadBits, SrcEltBits), SL, LoadVT);

// Load the whole vector and avoid masking off the top bits as it makes
// the codegen worse.
SDValue Load =
DAG.getExtLoad(ISD::EXTLOAD, SL, LoadVT, Chain, BasePTR,
LD->getPointerInfo(), SrcIntVT, LD->getAlignment(),
LD->getMemOperand()->getFlags(), LD->getAAInfo());

SmallVector<SDValue, 8> Vals;
for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
unsigned ShiftIntoIdx =
(DAG.getDataLayout().isBigEndian() ? (NumElem - 1) - Idx : Idx);
SDValue ShiftAmount =
DAG.getConstant(ShiftIntoIdx * SrcEltVT.getSizeInBits(), SL, IntVT);
SDValue ShiftedElt =
DAG.getNode(ISD::SRL, SL, IntVT, Load, ShiftAmount);
SDValue Scalar = DAG.getNode(ISD::TRUNCATE, SL, SrcEltVT, ShiftedElt);
DAG.getConstant(ShiftIntoIdx * SrcEltVT.getSizeInBits(), SL,
getShiftAmountTy(LoadVT, DAG.getDataLayout()));
SDValue ShiftedElt = DAG.getNode(ISD::SRL, SL, LoadVT, Load, ShiftAmount);
SDValue Elt =
DAG.getNode(ISD::AND, SL, LoadVT, ShiftedElt, SrcEltBitMask);
SDValue Scalar = DAG.getNode(ISD::TRUNCATE, SL, SrcEltVT, Elt);

if (ExtType != ISD::NON_EXTLOAD) {
unsigned ExtendOp = ISD::getExtForLoadExtType(false, ExtType);
Scalar = DAG.getNode(ExtendOp, SL, DstEltVT, Scalar);
}

Vals.push_back(Scalar);
}

Expand Down

0 comments on commit 6d103ca

Please sign in to comment.