Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3346,6 +3346,14 @@ namespace ISD {
Ld->getAddressingMode() == ISD::UNINDEXED;
}

/// Returns true if the specified node is a non-extending and unindexed
/// masked store.
inline bool isNormalMaskedStore(const SDNode *N) {
auto *St = dyn_cast<MaskedStoreSDNode>(N);
return St && !St->isTruncatingStore() &&
St->getAddressingMode() == ISD::UNINDEXED;
}

/// Attempt to match a unary predicate against a scalar/splat constant or
/// every element of a constant BUILD_VECTOR.
/// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
Expand Down
142 changes: 106 additions & 36 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24632,6 +24632,105 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}

static bool
isSequentialConcatOfVectorInterleave(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
if (N->getOpcode() != ISD::CONCAT_VECTORS)
return false;

unsigned NumParts = N->getNumOperands();

// We should be concatenating each sequential result from a
// VECTOR_INTERLEAVE.
SDNode *InterleaveOp = N->getOperand(0).getNode();
if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
InterleaveOp->getNumOperands() != NumParts)
return false;

for (unsigned I = 0; I < NumParts; I++)
if (N->getOperand(I) != SDValue(InterleaveOp, I))
return false;

Ops.append(InterleaveOp->op_begin(), InterleaveOp->op_end());
return true;
}

static SDValue getNarrowMaskForInterleavedOps(SelectionDAG &DAG, SDLoc &DL,
SDValue WideMask,
unsigned RequiredNumParts) {
if (WideMask->getOpcode() == ISD::CONCAT_VECTORS) {
SmallVector<SDValue, 4> MaskInterleaveOps;
if (!isSequentialConcatOfVectorInterleave(WideMask.getNode(),
MaskInterleaveOps))
return SDValue();

if (MaskInterleaveOps.size() != RequiredNumParts)
return SDValue();

// Make sure the inputs to the vector interleave are identical.
if (!llvm::all_equal(MaskInterleaveOps))
return SDValue();

return MaskInterleaveOps[0];
}

if (WideMask->getOpcode() != ISD::SPLAT_VECTOR)
return SDValue();

ElementCount EC = WideMask.getValueType().getVectorElementCount();
assert(EC.isKnownMultipleOf(RequiredNumParts) &&
"Expected element count divisible by number of parts");
EC = EC.divideCoefficientBy(RequiredNumParts);
return DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
WideMask->getOperand(0));
}

static SDValue performInterleavedMaskedStoreCombine(
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
if (!DCI.isBeforeLegalize())
return SDValue();

MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
SDValue WideValue = MST->getValue();

// Bail out if the stored value has an unexpected number of uses, since we'll
// have to perform manual interleaving and may as well just use normal masked
// stores. Also, discard masked stores that are truncating or indexed.
if (!WideValue.hasOneUse() || !ISD::isNormalMaskedStore(MST) ||
!MST->isSimple() || !MST->getOffset().isUndef())
return SDValue();

SmallVector<SDValue, 4> ValueInterleaveOps;
if (!isSequentialConcatOfVectorInterleave(WideValue.getNode(),
ValueInterleaveOps))
return SDValue();

unsigned NumParts = ValueInterleaveOps.size();
if (NumParts != 2 && NumParts != 4)
return SDValue();

// At the moment we're unlikely to see a fixed-width vector interleave as
// we usually generate shuffles instead.
EVT SubVecTy = ValueInterleaveOps[0].getValueType();
if (!SubVecTy.isScalableVT() ||
SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
!DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
return SDValue();

SDLoc DL(N);
SDValue NarrowMask =
getNarrowMaskForInterleavedOps(DAG, DL, MST->getMask(), NumParts);
if (!NarrowMask)
return SDValue();

const Intrinsic::ID IID =
NumParts == 2 ? Intrinsic::aarch64_sve_st2 : Intrinsic::aarch64_sve_st4;
SmallVector<SDValue, 8> NewStOps;
NewStOps.append({MST->getChain(), DAG.getConstant(IID, DL, MVT::i32)});
NewStOps.append(ValueInterleaveOps);
NewStOps.append({NarrowMask, MST->getBasePtr()});
return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, NewStOps);
}

static SDValue performMSTORECombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG,
Expand All @@ -24641,6 +24740,9 @@ static SDValue performMSTORECombine(SDNode *N,
SDValue Mask = MST->getMask();
SDLoc DL(N);

if (SDValue Res = performInterleavedMaskedStoreCombine(N, DCI, DAG))
return Res;

// If this is a UZP1 followed by a masked store, fold this into a masked
// truncating store. We can do this even if this is already a masked
// truncstore.
Expand Down Expand Up @@ -27274,43 +27376,11 @@ static SDValue performVectorDeinterleaveCombine(
return SDValue();

// Now prove that the mask is an interleave of identical masks.
SDValue Mask = MaskedLoad->getMask();
if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
Mask->getOpcode() != ISD::CONCAT_VECTORS)
return SDValue();

SDValue NarrowMask;
SDLoc DL(N);
if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
if (Mask->getNumOperands() != NumParts)
return SDValue();

// We should be concatenating each sequential result from a
// VECTOR_INTERLEAVE.
SDNode *InterleaveOp = Mask->getOperand(0).getNode();
if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
InterleaveOp->getNumOperands() != NumParts)
return SDValue();

for (unsigned I = 0; I < NumParts; I++) {
if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
return SDValue();
}

// Make sure the inputs to the vector interleave are identical.
if (!llvm::all_equal(InterleaveOp->op_values()))
return SDValue();

NarrowMask = InterleaveOp->getOperand(0);
} else { // ISD::SPLAT_VECTOR
ElementCount EC = Mask.getValueType().getVectorElementCount();
assert(EC.isKnownMultipleOf(NumParts) &&
"Expected element count divisible by number of parts");
EC = EC.divideCoefficientBy(NumParts);
NarrowMask =
DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
Mask->getOperand(0));
}
SDValue NarrowMask =
getNarrowMaskForInterleavedOps(DAG, DL, MaskedLoad->getMask(), NumParts);
if (!NarrowMask)
return SDValue();

const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
: Intrinsic::aarch64_sve_ld4_sret;
Expand Down
Loading
Loading