Skip to content

Commit

Permalink
[X86] Add X86ISD::SUBV_BROADCAST_LOAD and begin removing X86ISD::SUBV…
Browse files Browse the repository at this point in the history
…_BROADCAST (PR38969)

Subvector broadcasts are only load instructions, yet X86ISD::SUBV_BROADCAST treats them more generally, requiring a lot of fallback tablegen patterns.

This initial patch replaces constant vector lowering inside lowerBuildVectorAsBroadcast with direct X86ISD::SUBV_BROADCAST_LOAD loads which helps us merge a number of equivalent loads/broadcasts.

As well as general plumbing/analysis additions for SUBV_BROADCAST_LOAD, I needed to wrap SelectionDAG::makeEquivalentMemoryOrdering so it can handle result chains from non generic LoadSDNode nodes.

Later patches will continue to replace X86ISD::SUBV_BROADCAST usage.

Differential Revision: https://reviews.llvm.org/D92645
  • Loading branch information
RKSimon committed Dec 17, 2020
1 parent 352cba2 commit cdb692e
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 150 deletions.
9 changes: 8 additions & 1 deletion llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,14 @@ class SelectionDAG {
/// chain to the token factor. This ensures that the new memory node will have
/// the same relative memory dependency position as the old load. Returns the
/// new merged load chain.
SDValue makeEquivalentMemoryOrdering(LoadSDNode *Old, SDValue New);
SDValue makeEquivalentMemoryOrdering(SDValue OldChain, SDValue NewMemOpChain);

/// If an existing load has uses of its chain, create a token factor node with
/// that chain and the new memory node's chain and update users of the old
/// chain to the token factor. This ensures that the new memory node will have
/// the same relative memory dependency position as the old load. Returns the
/// new merged load chain.
SDValue makeEquivalentMemoryOrdering(LoadSDNode *OldLoad, SDValue NewMemOp);

/// Topological-sort the AllNodes list and a
/// assign a unique node id for each node in the DAG based on their
Expand Down
27 changes: 17 additions & 10 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8966,25 +8966,32 @@ void SelectionDAG::AddDbgLabel(SDDbgLabel *DB) {
DbgInfo->add(DB);
}

SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad,
SDValue NewMemOp) {
assert(isa<MemSDNode>(NewMemOp.getNode()) && "Expected a memop node");
SDValue SelectionDAG::makeEquivalentMemoryOrdering(SDValue OldChain,
SDValue NewMemOpChain) {
assert(isa<MemSDNode>(NewMemOpChain) && "Expected a memop node");
assert(NewMemOpChain.getValueType() == MVT::Other && "Expected a token VT");
// The new memory operation must have the same position as the old load in
// terms of memory dependency. Create a TokenFactor for the old load and new
// memory operation and update uses of the old load's output chain to use that
// TokenFactor.
SDValue OldChain = SDValue(OldLoad, 1);
SDValue NewChain = SDValue(NewMemOp.getNode(), 1);
if (OldChain == NewChain || !OldLoad->hasAnyUseOfValue(1))
return NewChain;
if (OldChain == NewMemOpChain || OldChain.use_empty())
return NewMemOpChain;

SDValue TokenFactor =
getNode(ISD::TokenFactor, SDLoc(OldLoad), MVT::Other, OldChain, NewChain);
SDValue TokenFactor = getNode(ISD::TokenFactor, SDLoc(OldChain), MVT::Other,
OldChain, NewMemOpChain);
ReplaceAllUsesOfValueWith(OldChain, TokenFactor);
UpdateNodeOperands(TokenFactor.getNode(), OldChain, NewChain);
UpdateNodeOperands(TokenFactor.getNode(), OldChain, NewMemOpChain);
return TokenFactor;
}

SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad,
SDValue NewMemOp) {
assert(isa<MemSDNode>(NewMemOp.getNode()) && "Expected a memop node");
SDValue OldChain = SDValue(OldLoad, 1);
SDValue NewMemOpChain = NewMemOp.getValue(1);
return makeEquivalentMemoryOrdering(OldChain, NewMemOpChain);
}

SDValue SelectionDAG::getSymbolFunctionGlobalAddress(SDValue Op,
Function **OutFunction) {
assert(isa<ExternalSymbolSDNode>(Op) && "Node should be an ExternalSymbol");
Expand Down
121 changes: 98 additions & 23 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6644,15 +6644,30 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
}

// Extract constant bits from a subvector broadcast.
if (Op.getOpcode() == X86ISD::SUBV_BROADCAST) {
SmallVector<APInt, 16> SubEltBits;
if (getTargetConstantBitsFromNode(Op.getOperand(0), EltSizeInBits,
UndefElts, SubEltBits, AllowWholeUndefs,
AllowPartialUndefs)) {
UndefElts = APInt::getSplat(NumElts, UndefElts);
while (EltBits.size() < NumElts)
EltBits.append(SubEltBits.begin(), SubEltBits.end());
return true;
if (Op.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) {
auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
SDValue Ptr = MemIntr->getBasePtr();
if (const Constant *Cst = getTargetConstantFromBasePtr(Ptr)) {
Type *CstTy = Cst->getType();
unsigned CstSizeInBits = CstTy->getPrimitiveSizeInBits();
if (!CstTy->isVectorTy() || (SizeInBits % CstSizeInBits) != 0)
return false;
unsigned SubEltSizeInBits = CstTy->getScalarSizeInBits();
unsigned NumSubElts = CstSizeInBits / SubEltSizeInBits;
unsigned NumSubVecs = SizeInBits / CstSizeInBits;
APInt UndefSubElts(NumSubElts, 0);
SmallVector<APInt, 64> SubEltBits(NumSubElts * NumSubVecs,
APInt(SubEltSizeInBits, 0));
for (unsigned i = 0; i != NumSubElts; ++i) {
if (!CollectConstantBits(Cst->getAggregateElement(i), SubEltBits[i],
UndefSubElts, i))
return false;
for (unsigned j = 1; j != NumSubVecs; ++j)
SubEltBits[i + (j * NumSubElts)] = SubEltBits[i];
}
UndefSubElts = APInt::getSplat(NumSubVecs * UndefSubElts.getBitWidth(),
UndefSubElts);
return CastBitData(UndefSubElts, SubEltBits);
}
}

Expand Down Expand Up @@ -8802,17 +8817,19 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
}
if (SplatBitSize > 64) {
// Load the vector of constants and broadcast it.
MVT CVT = VT.getScalarType();
Constant *VecC = getConstantVector(VT, SplatValue, SplatBitSize,
*Ctx);
SDValue VCP = DAG.getConstantPool(VecC, PVT);
unsigned NumElm = SplatBitSize / VT.getScalarSizeInBits();
MVT VVT = MVT::getVectorVT(VT.getScalarType(), NumElm);
Align Alignment = cast<ConstantPoolSDNode>(VCP)->getAlign();
Ld = DAG.getLoad(
MVT::getVectorVT(CVT, NumElm), dl, DAG.getEntryNode(), VCP,
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
Alignment);
return DAG.getNode(X86ISD::SUBV_BROADCAST, dl, VT, Ld);
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
SDValue Ops[] = {DAG.getEntryNode(), VCP};
MachinePointerInfo MPI =
MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
return DAG.getMemIntrinsicNode(
X86ISD::SUBV_BROADCAST_LOAD, dl, Tys, Ops, VVT, MPI, Alignment,
MachineMemOperand::MOLoad);
}
}
}
Expand Down Expand Up @@ -30929,6 +30946,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VBROADCAST_LOAD)
NODE_NAME_CASE(VBROADCASTM)
NODE_NAME_CASE(SUBV_BROADCAST)
NODE_NAME_CASE(SUBV_BROADCAST_LOAD)
NODE_NAME_CASE(VPERMILPV)
NODE_NAME_CASE(VPERMILPI)
NODE_NAME_CASE(VPERM2X128)
Expand Down Expand Up @@ -38056,6 +38074,34 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
}
return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Src, 0,
TLO.DAG, DL, ExtSizeInBits));
}
case X86ISD::SUBV_BROADCAST_LOAD: {
auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
EVT MemVT = MemIntr->getMemoryVT();
if (ExtSizeInBits == MemVT.getStoreSizeInBits()) {
SDLoc DL(Op);
SDValue Ld =
TLO.DAG.getLoad(MemVT, DL, MemIntr->getChain(),
MemIntr->getBasePtr(), MemIntr->getMemOperand());
TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
Ld.getValue(1));
return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Ld, 0,
TLO.DAG, DL, ExtSizeInBits));
} else if ((ExtSizeInBits % MemVT.getStoreSizeInBits()) == 0) {
SDLoc DL(Op);
EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
ExtSizeInBits / VT.getScalarSizeInBits());
SDVTList Tys = TLO.DAG.getVTList(BcstVT, MVT::Other);
SDValue Ops[] = {MemIntr->getOperand(0), MemIntr->getOperand(1)};
SDValue Bcst =
TLO.DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys,
Ops, MemVT, MemIntr->getMemOperand());
TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
Bcst.getValue(1));
return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0,
TLO.DAG, DL, ExtSizeInBits));
}
break;
}
// Byte shifts by immediate.
case X86ISD::VSHLDQ:
Expand Down Expand Up @@ -44606,6 +44652,29 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
}
}

// If we also broadcast this as a subvector to a wider type, then just extract
// the lowest subvector.
if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
(RegVT.is128BitVector() || RegVT.is256BitVector())) {
SDValue Ptr = Ld->getBasePtr();
SDValue Chain = Ld->getChain();
for (SDNode *User : Ptr->uses()) {
if (User != N && User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
cast<MemIntrinsicSDNode>(User)->getChain() == Chain &&
cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
MemVT.getSizeInBits() &&
!User->hasAnyUseOfValue(1) &&
User->getValueSizeInBits(0).getFixedSize() >
RegVT.getFixedSizeInBits()) {
SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
RegVT.getSizeInBits());
Extract = DAG.getBitcast(RegVT, Extract);
return DCI.CombineTo(N, Extract, SDValue(User, 1));
}
}
}

// Cast ptr32 and ptr64 pointers to the default address space before a load.
unsigned AddrSpace = Ld->getAddressSpace();
if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||
Expand Down Expand Up @@ -49321,7 +49390,8 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG,
// extract the lowest subvector instead which should allow
// SimplifyDemandedVectorElts do more simplifications.
if (IdxVal != 0 && (InVec.getOpcode() == X86ISD::VBROADCAST ||
InVec.getOpcode() == X86ISD::VBROADCAST_LOAD))
InVec.getOpcode() == X86ISD::VBROADCAST_LOAD ||
InVec.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD))
return extractSubVector(InVec, 0, DAG, SDLoc(N), SizeInBits);

// If we're extracting a broadcasted subvector, just use the source.
Expand Down Expand Up @@ -49687,11 +49757,15 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::FP_EXTEND, dl, VT, Cvt);
}

// Try to find a larger VBROADCAST_LOAD that we can extract from. Limit this to
// cases where the loads have the same input chain and the output chains are
// unused. This avoids any memory ordering issues.
static SDValue combineVBROADCAST_LOAD(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
// Try to find a larger VBROADCAST_LOAD/SUBV_BROADCAST_LOAD that we can extract
// from. Limit this to cases where the loads have the same input chain and the
// output chains are unused. This avoids any memory ordering issues.
static SDValue combineBROADCAST_LOAD(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
assert((N->getOpcode() == X86ISD::VBROADCAST_LOAD ||
N->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) &&
"Unknown broadcast load type");

// Only do this if the chain result is unused.
if (N->hasAnyUseOfValue(1))
return SDValue();
Expand All @@ -49706,7 +49780,7 @@ static SDValue combineVBROADCAST_LOAD(SDNode *N, SelectionDAG &DAG,
// Look at other users of our base pointer and try to find a wider broadcast.
// The input chain and the size of the memory VT must match.
for (SDNode *User : Ptr->uses())
if (User != N && User->getOpcode() == X86ISD::VBROADCAST_LOAD &&
if (User != N && User->getOpcode() == N->getOpcode() &&
cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
cast<MemIntrinsicSDNode>(User)->getChain() == Chain &&
cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
Expand Down Expand Up @@ -49963,7 +50037,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::STRICT_FP_EXTEND:
case ISD::FP_EXTEND: return combineFP_EXTEND(N, DAG, Subtarget);
case ISD::FP_ROUND: return combineFP_ROUND(N, DAG, Subtarget);
case X86ISD::VBROADCAST_LOAD: return combineVBROADCAST_LOAD(N, DAG, DCI);
case X86ISD::VBROADCAST_LOAD:
case X86ISD::SUBV_BROADCAST_LOAD: return combineBROADCAST_LOAD(N, DAG, DCI);
case X86ISD::MOVDQ2Q: return combineMOVDQ2Q(N, DAG);
case X86ISD::PDEP: return combinePDEP(N, DAG, DCI);
}
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,12 @@ namespace llvm {
// extract_vector_elt, store.
VEXTRACT_STORE,

// scalar broadcast from memory
// scalar broadcast from memory.
VBROADCAST_LOAD,

// subvector broadcast from memory.
SUBV_BROADCAST_LOAD,

// Store FP control world into i16 memory.
FNSTCW16m,

Expand Down
39 changes: 39 additions & 0 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,32 @@ defm VBROADCASTF64X4 : avx512_subvec_broadcast_rm<0x1b, "vbroadcastf64x4",
EVEX_V512, EVEX_CD8<64, CD8VT4>;

let Predicates = [HasAVX512] in {
def : Pat<(v8f64 (X86SubVBroadcastld256 addr:$src)),
(VBROADCASTF64X4rm addr:$src)>;
def : Pat<(v16f32 (X86SubVBroadcastld256 addr:$src)),
(VBROADCASTF64X4rm addr:$src)>;
def : Pat<(v8i64 (X86SubVBroadcastld256 addr:$src)),
(VBROADCASTI64X4rm addr:$src)>;
def : Pat<(v16i32 (X86SubVBroadcastld256 addr:$src)),
(VBROADCASTI64X4rm addr:$src)>;
def : Pat<(v32i16 (X86SubVBroadcastld256 addr:$src)),
(VBROADCASTI64X4rm addr:$src)>;
def : Pat<(v64i8 (X86SubVBroadcastld256 addr:$src)),
(VBROADCASTI64X4rm addr:$src)>;

def : Pat<(v8f64 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF32X4rm addr:$src)>;
def : Pat<(v16f32 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF32X4rm addr:$src)>;
def : Pat<(v8i64 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4rm addr:$src)>;
def : Pat<(v16i32 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4rm addr:$src)>;
def : Pat<(v32i16 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4rm addr:$src)>;
def : Pat<(v64i8 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4rm addr:$src)>;

def : Pat<(v16f32 (X86SubVBroadcast (loadv8f32 addr:$src))),
(VBROADCASTF64X4rm addr:$src)>;
def : Pat<(v16i32 (X86SubVBroadcast (loadv8i32 addr:$src))),
Expand Down Expand Up @@ -1539,6 +1565,19 @@ defm VBROADCASTF32X4Z256 : avx512_subvec_broadcast_rm<0x1a, "vbroadcastf32x4",
v8f32x_info, v4f32x_info>,
EVEX_V256, EVEX_CD8<32, CD8VT4>;

def : Pat<(v4f64 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF32X4Z256rm addr:$src)>;
def : Pat<(v8f32 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF32X4Z256rm addr:$src)>;
def : Pat<(v4i64 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4Z256rm addr:$src)>;
def : Pat<(v8i32 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4Z256rm addr:$src)>;
def : Pat<(v16i16 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4Z256rm addr:$src)>;
def : Pat<(v32i8 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTI32X4Z256rm addr:$src)>;

def : Pat<(v4f64 (X86SubVBroadcast (loadv2f64 addr:$src))),
(VBROADCASTF32X4Z256rm addr:$src)>;
def : Pat<(v4i64 (X86SubVBroadcast (loadv2i64 addr:$src))),
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def X86vextractst : SDNode<"X86ISD::VEXTRACT_STORE", SDTStore,
[SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
def X86VBroadcastld : SDNode<"X86ISD::VBROADCAST_LOAD", SDTLoad,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
def X86SubVBroadcastld : SDNode<"X86ISD::SUBV_BROADCAST_LOAD", SDTLoad,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;

def SDTVtrunc : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>,
SDTCisInt<0>, SDTCisInt<1>,
Expand Down Expand Up @@ -965,6 +967,16 @@ def X86VBroadcastld64 : PatFrag<(ops node:$src),
return cast<MemIntrinsicSDNode>(N)->getMemoryVT().getStoreSize() == 8;
}]>;

def X86SubVBroadcastld128 : PatFrag<(ops node:$src),
(X86SubVBroadcastld node:$src), [{
return cast<MemIntrinsicSDNode>(N)->getMemoryVT().getStoreSize() == 16;
}]>;

def X86SubVBroadcastld256 : PatFrag<(ops node:$src),
(X86SubVBroadcastld node:$src), [{
return cast<MemIntrinsicSDNode>(N)->getMemoryVT().getStoreSize() == 32;
}]>;

// Scalar SSE intrinsic fragments to match several different types of loads.
// Used by scalar SSE intrinsic instructions which have 128 bit types, but
// only load a single element.
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/X86/X86InstrSSE.td
Original file line number Diff line number Diff line change
Expand Up @@ -7016,6 +7016,11 @@ def VBROADCASTF128 : AVX8I<0x1A, MRMSrcMem, (outs VR256:$dst),
Sched<[SchedWriteFShuffle.XMM.Folded]>, VEX, VEX_L;

let Predicates = [HasAVX, NoVLX] in {
def : Pat<(v4f64 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF128 addr:$src)>;
def : Pat<(v8f32 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF128 addr:$src)>;

def : Pat<(v4f64 (X86SubVBroadcast (loadv2f64 addr:$src))),
(VBROADCASTF128 addr:$src)>;
def : Pat<(v8f32 (X86SubVBroadcast (loadv4f32 addr:$src))),
Expand All @@ -7025,6 +7030,15 @@ def : Pat<(v8f32 (X86SubVBroadcast (loadv4f32 addr:$src))),
// NOTE: We're using FP instructions here, but execution domain fixing can
// convert to integer when profitable.
let Predicates = [HasAVX, NoVLX] in {
def : Pat<(v4i64 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF128 addr:$src)>;
def : Pat<(v8i32 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF128 addr:$src)>;
def : Pat<(v16i16 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF128 addr:$src)>;
def : Pat<(v32i8 (X86SubVBroadcastld128 addr:$src)),
(VBROADCASTF128 addr:$src)>;

def : Pat<(v4i64 (X86SubVBroadcast (loadv2i64 addr:$src))),
(VBROADCASTF128 addr:$src)>;
def : Pat<(v8i32 (X86SubVBroadcast (loadv4i32 addr:$src))),
Expand Down
Loading

0 comments on commit cdb692e

Please sign in to comment.