Skip to content

Commit

Permalink
[DAGCombine] DAGTypeLegalizer::GenWidenVectorLoads(): make use of der…
Browse files Browse the repository at this point in the history
…eferenceability knowledge
  • Loading branch information
jayfoad committed Jul 29, 2021
1 parent 00b0f1e commit ed4a53d
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 143 deletions.
63 changes: 47 additions & 16 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Expand Up @@ -5144,17 +5144,19 @@ SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
// TLI: Target lowering used to determine legal types.
// Width: Width left need to load/store.
// WidenVT: The widen vector type to load to/store from
// Align: If 0, don't allow use of a wider type
// WidenEx: If Align is not 0, the amount additional we can load/store from.
// NumDereferenceableBytes: If 0, don't allow use of a wider type
// WidenEx: If NumDereferenceableBytes is not 0,
// the additional amount we have to load/store.

static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
static EVT FindMemType(SelectionDAG &DAG, const TargetLowering &TLI,
unsigned Width, EVT WidenVT,
unsigned Align = 0, unsigned WidenEx = 0) {
unsigned NumDereferenceableBytes = 0,
unsigned WidenEx = 0) {
EVT WidenEltVT = WidenVT.getVectorElementType();
const bool Scalable = WidenVT.isScalableVector();
unsigned WidenWidth = WidenVT.getSizeInBits().getKnownMinSize();
unsigned WidenEltWidth = WidenEltVT.getSizeInBits();
unsigned AlignInBits = Align*8;
unsigned NumDereferenceableBits = NumDereferenceableBytes * 8;

// If we have one element to load/store, return it.
EVT RetVT = WidenEltVT;
Expand All @@ -5174,8 +5176,9 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
Action == TargetLowering::TypePromoteInteger) &&
(WidenWidth % MemVTWidth) == 0 &&
isPowerOf2_32(WidenWidth / MemVTWidth) &&
(MemVTWidth <= Width ||
(Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) {
(MemVTWidth <= Width || (NumDereferenceableBytes != 0 &&
MemVTWidth <= NumDereferenceableBits &&
MemVTWidth <= Width + WidenEx))) {
if (MemVTWidth == WidenWidth)
return MemVT;
RetVT = MemVT;
Expand All @@ -5197,8 +5200,9 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
WidenEltVT == MemVT.getVectorElementType() &&
(WidenWidth % MemVTWidth) == 0 &&
isPowerOf2_32(WidenWidth / MemVTWidth) &&
(MemVTWidth <= Width ||
(Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) {
(MemVTWidth <= Width || (NumDereferenceableBytes != 0 &&
MemVTWidth <= NumDereferenceableBits &&
MemVTWidth <= Width + WidenEx))) {
if (RetVT.getFixedSizeInBits() < MemVTWidth || MemVT == WidenVT)
return MemVT;
}
Expand Down Expand Up @@ -5264,13 +5268,24 @@ SDValue DAGTypeLegalizer::GenWidenVectorLoads(SmallVectorImpl<SDValue> &LdChain,
TypeSize LdWidth = LdVT.getSizeInBits();
TypeSize WidenWidth = WidenVT.getSizeInBits();
TypeSize WidthDiff = WidenWidth - LdWidth;
// Allow wider loads if they are sufficiently aligned to avoid memory faults
// and if the original load is simple.
unsigned LdAlign = (!LD->isSimple()) ? 0 : LD->getAlignment();
unsigned NumDereferenceableBytes = 0;
// Allow wider loads if the original load is simple and we can dereference
// padding bytes.
if (LD->isSimple()) {
NumDereferenceableBytes = LD->getAlignment();
if (!LdWidth.isScalable())
NumDereferenceableBytes =
std::max<unsigned>(NumDereferenceableBytes, LdWidth / 8);
if (!WidenWidth.isScalable() && NumDereferenceableBytes < WidenWidth / 8 &&
LD->getPointerInfo().isDereferenceable(
WidenWidth / 8, *DAG.getContext(), DAG.getDataLayout()))
NumDereferenceableBytes =
std::max<unsigned>(NumDereferenceableBytes, WidenWidth / 8);
}

// Find the vector type that can load from.
EVT NewVT = FindMemType(DAG, TLI, LdWidth.getKnownMinSize(), WidenVT, LdAlign,
WidthDiff.getKnownMinSize());
EVT NewVT = FindMemType(DAG, TLI, LdWidth.getKnownMinSize(), WidenVT,
NumDereferenceableBytes, WidthDiff.getKnownMinSize());
TypeSize NewVTWidth = NewVT.getSizeInBits();
SDValue LdOp = DAG.getLoad(NewVT, dl, Chain, BasePtr, LD->getPointerInfo(),
LD->getOriginalAlign(), MMOFlags, AAInfo);
Expand Down Expand Up @@ -5306,13 +5321,29 @@ SDValue DAGTypeLegalizer::GenWidenVectorLoads(SmallVectorImpl<SDValue> &LdChain,
MachinePointerInfo MPI = LD->getPointerInfo();
do {
LdWidth -= NewVTWidth;
if (LD->isSimple()) {
if (!NewVTWidth.isScalable()) {
if (NumDereferenceableBytes > NewVTWidth / 8)
NumDereferenceableBytes -= NewVTWidth / 8;
else
NumDereferenceableBytes = 0;
NumDereferenceableBytes = std::max<unsigned>(
NumDereferenceableBytes,
commonAlignment(cast<LoadSDNode>(LdOp)->getOriginalAlign(),
cast<LoadSDNode>(LdOp)->getSrcValueOffset() +
NewVTWidth)
.value());
} else
NumDereferenceableBytes = 0; // FIXME
}

IncrementPointer(cast<LoadSDNode>(LdOp), NewVT, MPI, BasePtr,
&ScaledOffset);

if (TypeSize::isKnownLT(LdWidth, NewVTWidth)) {
// The current type we are using is too large. Find a better size.
NewVT = FindMemType(DAG, TLI, LdWidth.getKnownMinSize(), WidenVT, LdAlign,
WidthDiff.getKnownMinSize());
NewVT = FindMemType(DAG, TLI, LdWidth.getKnownMinSize(), WidenVT,
NumDereferenceableBytes, WidthDiff.getKnownMinSize());
NewVTWidth = NewVT.getSizeInBits();
}

Expand Down

0 comments on commit ed4a53d

Please sign in to comment.