Skip to content

Commit

Permalink
[DAGCombine] Refactor DAGCombiner::ReduceLoadWidth. NFCI
Browse files Browse the repository at this point in the history
Update code comments in DAGCombiner::ReduceLoadWidth and refactor
the handling of SRL a bit. The refactoring is done with the intent
of adding support for folding away SRA by using SEXTLOAD in a
follow-up patch.

The function is also renamed as DAGCombiner::reduceLoadWidth.

Differential Revision: https://reviews.llvm.org/D117104
  • Loading branch information
bjope committed Jan 16, 2022
1 parent 37e6496 commit 9f237c9
Showing 1 changed file with 104 additions and 81 deletions.
185 changes: 104 additions & 81 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -593,7 +593,7 @@ namespace {
SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
SDValue MatchLoadCombine(SDNode *N);
SDValue mergeTruncStores(StoreSDNode *N);
SDValue ReduceLoadWidth(SDNode *N);
SDValue reduceLoadWidth(SDNode *N);
SDValue ReduceLoadOpStoreWidth(SDNode *N);
SDValue splitMergedValStore(StoreSDNode *ST);
SDValue TransformFPLoadStorePair(SDNode *N);
Expand Down Expand Up @@ -5624,7 +5624,7 @@ bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
if (And.getOpcode() == ISD ::AND)
And = SDValue(
DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
SDValue NewLoad = ReduceLoadWidth(And.getNode());
SDValue NewLoad = reduceLoadWidth(And.getNode());
assert(NewLoad &&
"Shouldn't be masking the load if it can't be narrowed");
CombineTo(Load, NewLoad, NewLoad.getValue(1));
Expand Down Expand Up @@ -6024,7 +6024,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
(N0.getOpcode() == ISD::ANY_EXTEND &&
N0.getOperand(0).getOpcode() == ISD::LOAD))) {
if (SDValue Res = ReduceLoadWidth(N)) {
if (SDValue Res = reduceLoadWidth(N)) {
LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
AddToWorklist(N);
Expand Down Expand Up @@ -9140,7 +9140,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
return NewSRL;

// Attempt to convert a srl of a load into a narrower zero-extending load.
if (SDValue NarrowLoad = ReduceLoadWidth(N))
if (SDValue NarrowLoad = reduceLoadWidth(N))
return NarrowLoad;

// Here is a common situation. We want to optimize:
Expand Down Expand Up @@ -11357,7 +11357,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
if (N0.getOpcode() == ISD::TRUNCATE) {
// fold (sext (truncate (load x))) -> (sext (smaller load x))
// fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
SDNode *oye = N0.getOperand(0).getNode();
if (NarrowLoad.getNode() != N0.getNode()) {
CombineTo(N0.getNode(), NarrowLoad);
Expand Down Expand Up @@ -11621,7 +11621,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
if (N0.getOpcode() == ISD::TRUNCATE) {
// fold (zext (truncate (load x))) -> (zext (smaller load x))
// fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
SDNode *oye = N0.getOperand(0).getNode();
if (NarrowLoad.getNode() != N0.getNode()) {
CombineTo(N0.getNode(), NarrowLoad);
Expand Down Expand Up @@ -11864,7 +11864,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
// fold (aext (truncate (load x))) -> (aext (smaller load x))
// fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
if (N0.getOpcode() == ISD::TRUNCATE) {
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
SDNode *oye = N0.getOperand(0).getNode();
if (NarrowLoad.getNode() != N0.getNode()) {
CombineTo(N0.getNode(), NarrowLoad);
Expand Down Expand Up @@ -12095,13 +12095,10 @@ SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
return SDValue();
}

/// If the result of a wider load is shifted to right of N bits and then
/// truncated to a narrower type and where N is a multiple of number of bits of
/// the narrower type, transform it to a narrower load from address + N / num of
/// bits of new type. Also narrow the load if the result is masked with an AND
/// to effectively produce a smaller type. If the result is to be extended, also
/// fold the extension to form a extending load.
SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
/// If the result of a load is shifted/masked/truncated to an effectively
/// narrower type, try to transform the load to a narrower type and/or
/// use an extending load.
SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
unsigned Opc = N->getOpcode();

ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
Expand All @@ -12113,7 +12110,14 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
if (VT.isVector())
return SDValue();

// The ShAmt variable is used to indicate that we've consumed a right
// shift. I.e. we want to narrow the width of the load by skipping to load the
// ShAmt least significant bits.
unsigned ShAmt = 0;
// A special case is when the least significant bits from the load are masked
// away, but using an AND rather than a right shift. HasShiftedOffset is used
// to indicate that the narrowed load should be left-shifted ShAmt bits to get
// the result.
bool HasShiftedOffset = false;
// Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
// extended to VT.
Expand All @@ -12122,23 +12126,29 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
} else if (Opc == ISD::SRL) {
// Another special-case: SRL is basically zero-extending a narrower value,
// or it maybe shifting a higher subword, half or byte into the lowest
// or it may be shifting a higher subword, half or byte into the lowest
// bits.
ExtType = ISD::ZEXTLOAD;
N0 = SDValue(N, 0);

auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
if (!N01 || !LN0)
// Only handle shift with constant shift amount, and the shiftee must be a
// load.
auto *LN = dyn_cast<LoadSDNode>(N0);
auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!N1C || !LN)
return SDValue();
// If the shift amount is larger than the memory type then we're not
// accessing any of the loaded bytes.
ShAmt = N1C->getZExtValue();
uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
if (MemoryWidth <= ShAmt)
return SDValue();
// Attempt to fold away the SRL by using ZEXTLOAD.
ExtType = ISD::ZEXTLOAD;
ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
// If original load is a SEXTLOAD then we can't simply replace it by a
// ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
// followed by a ZEXT, but that is not handled at the moment).
if (LN->getExtensionType() == ISD::SEXTLOAD)
return SDValue();

uint64_t ShiftAmt = N01->getZExtValue();
uint64_t MemoryWidth = LN0->getMemoryVT().getScalarSizeInBits();
if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
else
ExtVT = EVT::getIntegerVT(*DAG.getContext(),
VT.getScalarSizeInBits() - ShiftAmt);
} else if (Opc == ISD::AND) {
// An AND with a constant mask is the same as a truncate + zero-extend.
auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
Expand All @@ -12161,55 +12171,73 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
}

if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
SDValue SRL = N0;
if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
ShAmt = ConstShift->getZExtValue();
unsigned EVTBits = ExtVT.getScalarSizeInBits();
// Is the shift amount a multiple of size of VT?
if ((ShAmt & (EVTBits-1)) == 0) {
N0 = N0.getOperand(0);
// Is the load width a multiple of size of VT?
if ((N0.getScalarValueSizeInBits() & (EVTBits - 1)) != 0)
return SDValue();
}
// In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
// a right shift. Here we redo some of those checks, to possibly adjust the
// ExtVT even further based on "a masking AND". We could also end up here for
// other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
// need to be done here as well.
if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
// Bail out when the SRL has more than one use. This is done for historical
// (undocumented) reasons. Maybe intent was to guard the AND-masking below
// check below? And maybe it could be non-profitable to do the transform in
// case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
// FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
if (!SRL.hasOneUse())
return SDValue();

// Only handle shift with constant shift amount, and the shiftee must be a
// load.
auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
if (!SRL1C || !LN)
return SDValue();

// At this point, we must have a load or else we can't do the transform.
auto *LN0 = dyn_cast<LoadSDNode>(N0);
if (!LN0) return SDValue();
// If the shift amount is larger than the input type then we're not
// accessing any of the loaded bytes. If the load was a zextload/extload
// then the result of the shift+trunc is zero/undef (handled elsewhere).
ShAmt = SRL1C->getZExtValue();
if (ShAmt >= LN->getMemoryVT().getSizeInBits())
return SDValue();

// Because a SRL must be assumed to *need* to zero-extend the high bits
// (as opposed to anyext the high bits), we can't combine the zextload
// lowering of SRL and an sextload.
if (LN0->getExtensionType() == ISD::SEXTLOAD)
return SDValue();
// Because a SRL must be assumed to *need* to zero-extend the high bits
// (as opposed to anyext the high bits), we can't combine the zextload
// lowering of SRL and an sextload.
if (LN->getExtensionType() == ISD::SEXTLOAD)
return SDValue();

// If the shift amount is larger than the input type then we're not
// accessing any of the loaded bytes. If the load was a zextload/extload
// then the result of the shift+trunc is zero/undef (handled elsewhere).
if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
return SDValue();
unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
// Is the shift amount a multiple of size of ExtVT?
if ((ShAmt & (ExtVTBits - 1)) != 0)
return SDValue();
// Is the load width a multiple of size of ExtVT?
if ((SRL.getScalarValueSizeInBits() & (ExtVTBits - 1)) != 0)
return SDValue();

// If the SRL is only used by a masking AND, we may be able to adjust
// the ExtVT to make the AND redundant.
SDNode *Mask = *(SRL->use_begin());
if (Mask->getOpcode() == ISD::AND &&
isa<ConstantSDNode>(Mask->getOperand(1))) {
const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
if (ShiftMask.isMask()) {
EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
ShiftMask.countTrailingOnes());
// If the mask is smaller, recompute the type.
if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
ExtVT = MaskedVT;
}
// If the SRL is only used by a masking AND, we may be able to adjust
// the ExtVT to make the AND redundant.
SDNode *Mask = *(SRL->use_begin());
if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
isa<ConstantSDNode>(Mask->getOperand(1))) {
const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
if (ShiftMask.isMask()) {
EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
ShiftMask.countTrailingOnes());
// If the mask is smaller, recompute the type.
if ((ExtVTBits > MaskedVT.getScalarSizeInBits()) &&
TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
ExtVT = MaskedVT;
}
}

N0 = SRL.getOperand(0);
}

// If the load is shifted left (and the result isn't shifted back right),
// we can fold the truncate through the shift.
// If the load is shifted left (and the result isn't shifted back right), we
// can fold a truncate through the shift. The typical scenario is that N
// points at a TRUNCATE here so the attempted fold is:
// (truncate (shl (load x), c))) -> (shl (narrow load x), c)
// ShLeftAmt will indicate how much a narrowed load should be shifted left.
unsigned ShLeftAmt = 0;
if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
Expand Down Expand Up @@ -12237,12 +12265,12 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
return LVTStoreBits - EVTStoreBits - ShAmt;
};

// For big endian targets, we need to adjust the offset to the pointer to
// load the correct bytes.
if (DAG.getDataLayout().isBigEndian())
ShAmt = AdjustBigEndianShift(ShAmt);
// We need to adjust the pointer to the load by ShAmt bits in order to load
// the correct bytes.
unsigned PtrAdjustmentInBits =
DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;

uint64_t PtrOff = ShAmt / 8;
uint64_t PtrOff = PtrAdjustmentInBits / 8;
Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
SDLoc DL(LN0);
// The original load itself didn't wrap, so an offset within it doesn't.
Expand Down Expand Up @@ -12285,11 +12313,6 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
}

if (HasShiftedOffset) {
// Recalculate the shift amount after it has been altered to calculate
// the offset.
if (DAG.getDataLayout().isBigEndian())
ShAmt = AdjustBigEndianShift(ShAmt);

// We're using a shifted mask, so the load now has an offset. This means
// that data has been loaded into the lower bytes than it would have been
// before, so we need to shl the loaded data into the correct position in the
Expand Down Expand Up @@ -12382,7 +12405,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {

// fold (sext_in_reg (load x)) -> (smaller sextload x)
// fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
if (SDValue NarrowLoad = ReduceLoadWidth(N))
if (SDValue NarrowLoad = reduceLoadWidth(N))
return NarrowLoad;

// fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
Expand Down Expand Up @@ -12669,7 +12692,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
// fold (truncate (load x)) -> (smaller load x)
// fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
if (SDValue Reduced = ReduceLoadWidth(N))
if (SDValue Reduced = reduceLoadWidth(N))
return Reduced;

// Handle the case where the load remains an extending load even
Expand Down

0 comments on commit 9f237c9

Please sign in to comment.