Skip to content

Commit

Permalink
[X86] Directly emit VBROADCAST_LOAD from constant pool in lowerBuildV…
Browse files Browse the repository at this point in the history
…ectorAsBroadcast

Also add a DAG combine to combine different sized broadcasts from
constant pool to avoid a regression.

Differential Revision: https://reviews.llvm.org/D75412
  • Loading branch information
topperc committed Mar 3, 2020
1 parent 22dd235 commit 56cd3bc
Showing 1 changed file with 58 additions and 17 deletions.
75 changes: 58 additions & 17 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -8558,12 +8558,14 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
unsigned Repeat = VT.getSizeInBits() / SplatBitSize;

unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
Ld = DAG.getLoad(
CVT, dl, DAG.getEntryNode(), CP,
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
Alignment);
SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
MVT::getVectorVT(CVT, Repeat), Ld);
SDVTList Tys =
DAG.getVTList(MVT::getVectorVT(CVT, Repeat), MVT::Other);
SDValue Ops[] = {DAG.getEntryNode(), CP};
MachinePointerInfo MPI =
MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
SDValue Brdcst = DAG.getMemIntrinsicNode(
X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT, MPI, Alignment,
MachineMemOperand::MOLoad);
return DAG.getBitcast(VT, Brdcst);
} else if (SplatBitSize == 32 || SplatBitSize == 64) {
// Splatted value can fit in one FLOAT constant in constant pool.
Expand All @@ -8582,12 +8584,14 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
unsigned Repeat = VT.getSizeInBits() / SplatBitSize;

unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
Ld = DAG.getLoad(
CVT, dl, DAG.getEntryNode(), CP,
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
Alignment);
SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
MVT::getVectorVT(CVT, Repeat), Ld);
SDVTList Tys =
DAG.getVTList(MVT::getVectorVT(CVT, Repeat), MVT::Other);
SDValue Ops[] = {DAG.getEntryNode(), CP};
MachinePointerInfo MPI =
MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
SDValue Brdcst = DAG.getMemIntrinsicNode(
X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT, MPI, Alignment,
MachineMemOperand::MOLoad);
return DAG.getBitcast(VT, Brdcst);
} else if (SplatBitSize > 64) {
// Load the vector of constants and broadcast it.
Expand Down Expand Up @@ -8667,12 +8671,13 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
SDValue CP =
DAG.getConstantPool(C, TLI.getPointerTy(DAG.getDataLayout()));
unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
Ld = DAG.getLoad(
CVT, dl, DAG.getEntryNode(), CP,
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
Alignment);

return DAG.getNode(X86ISD::VBROADCAST, dl, VT, Ld);
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
SDValue Ops[] = {DAG.getEntryNode(), CP};
MachinePointerInfo MPI =
MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
return DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT,
MPI, Alignment, MachineMemOperand::MOLoad);
}
}

Expand Down Expand Up @@ -46828,6 +46833,41 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::FP_EXTEND, dl, VT, Cvt);
}

// Try to find a larger VBROADCAST_LOAD that we can extract from. Limit this to
// cases where the loads have the same input chain and the output chains are
// unused. This avoids any memory ordering issues.
static SDValue combineVBROADCAST_LOAD(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
// Only do this if the chain result is unused.
if (N->hasAnyUseOfValue(1))
return SDValue();

auto *MemIntrin = cast<MemIntrinsicSDNode>(N);

SDValue Ptr = MemIntrin->getBasePtr();
SDValue Chain = MemIntrin->getChain();
EVT VT = N->getSimpleValueType(0);
EVT MemVT = MemIntrin->getMemoryVT();

// Look at other users of our base pointer and try to find a wider broadcast.
// The input chain and the size of the memory VT must match.
for (SDNode *User : Ptr->uses())
if (User != N && User->getOpcode() == X86ISD::VBROADCAST_LOAD &&
cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
cast<MemIntrinsicSDNode>(User)->getChain() == Chain &&
cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
MemVT.getSizeInBits() &&
!User->hasAnyUseOfValue(1) &&
User->getValueSizeInBits(0) > VT.getSizeInBits()) {
SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
VT.getSizeInBits());
Extract = DAG.getBitcast(VT, Extract);
return DCI.CombineTo(N, Extract, SDValue(User, 1));
}

return SDValue();
}

static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
Expand Down Expand Up @@ -47027,6 +47067,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::STRICT_FP_EXTEND:
case ISD::FP_EXTEND: return combineFP_EXTEND(N, DAG, Subtarget);
case ISD::FP_ROUND: return combineFP_ROUND(N, DAG, Subtarget);
case X86ISD::VBROADCAST_LOAD: return combineVBROADCAST_LOAD(N, DAG, DCI);
}

return SDValue();
Expand Down

0 comments on commit 56cd3bc

Please sign in to comment.