diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index cef768e95dda2e..4ca7f803132f84 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -8558,12 +8558,14 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, unsigned Repeat = VT.getSizeInBits() / SplatBitSize; unsigned Alignment = cast(CP)->getAlignment(); - Ld = DAG.getLoad( - CVT, dl, DAG.getEntryNode(), CP, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), - Alignment); - SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl, - MVT::getVectorVT(CVT, Repeat), Ld); + SDVTList Tys = + DAG.getVTList(MVT::getVectorVT(CVT, Repeat), MVT::Other); + SDValue Ops[] = {DAG.getEntryNode(), CP}; + MachinePointerInfo MPI = + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()); + SDValue Brdcst = DAG.getMemIntrinsicNode( + X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT, MPI, Alignment, + MachineMemOperand::MOLoad); return DAG.getBitcast(VT, Brdcst); } else if (SplatBitSize == 32 || SplatBitSize == 64) { // Splatted value can fit in one FLOAT constant in constant pool. @@ -8582,12 +8584,14 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, unsigned Repeat = VT.getSizeInBits() / SplatBitSize; unsigned Alignment = cast(CP)->getAlignment(); - Ld = DAG.getLoad( - CVT, dl, DAG.getEntryNode(), CP, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), - Alignment); - SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl, - MVT::getVectorVT(CVT, Repeat), Ld); + SDVTList Tys = + DAG.getVTList(MVT::getVectorVT(CVT, Repeat), MVT::Other); + SDValue Ops[] = {DAG.getEntryNode(), CP}; + MachinePointerInfo MPI = + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()); + SDValue Brdcst = DAG.getMemIntrinsicNode( + X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT, MPI, Alignment, + MachineMemOperand::MOLoad); return DAG.getBitcast(VT, Brdcst); } else if (SplatBitSize > 64) { // Load the vector of constants and broadcast it. @@ -8667,12 +8671,13 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, SDValue CP = DAG.getConstantPool(C, TLI.getPointerTy(DAG.getDataLayout())); unsigned Alignment = cast(CP)->getAlignment(); - Ld = DAG.getLoad( - CVT, dl, DAG.getEntryNode(), CP, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), - Alignment); - return DAG.getNode(X86ISD::VBROADCAST, dl, VT, Ld); + SDVTList Tys = DAG.getVTList(VT, MVT::Other); + SDValue Ops[] = {DAG.getEntryNode(), CP}; + MachinePointerInfo MPI = + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()); + return DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT, + MPI, Alignment, MachineMemOperand::MOLoad); } } @@ -46828,6 +46833,41 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::FP_EXTEND, dl, VT, Cvt); } +// Try to find a larger VBROADCAST_LOAD that we can extract from. Limit this to +// cases where the loads have the same input chain and the output chains are +// unused. This avoids any memory ordering issues. +static SDValue combineVBROADCAST_LOAD(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + // Only do this if the chain result is unused. + if (N->hasAnyUseOfValue(1)) + return SDValue(); + + auto *MemIntrin = cast(N); + + SDValue Ptr = MemIntrin->getBasePtr(); + SDValue Chain = MemIntrin->getChain(); + EVT VT = N->getSimpleValueType(0); + EVT MemVT = MemIntrin->getMemoryVT(); + + // Look at other users of our base pointer and try to find a wider broadcast. + // The input chain and the size of the memory VT must match. + for (SDNode *User : Ptr->uses()) + if (User != N && User->getOpcode() == X86ISD::VBROADCAST_LOAD && + cast(User)->getBasePtr() == Ptr && + cast(User)->getChain() == Chain && + cast(User)->getMemoryVT().getSizeInBits() == + MemVT.getSizeInBits() && + !User->hasAnyUseOfValue(1) && + User->getValueSizeInBits(0) > VT.getSizeInBits()) { + SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N), + VT.getSizeInBits()); + Extract = DAG.getBitcast(VT, Extract); + return DCI.CombineTo(N, Extract, SDValue(User, 1)); + } + + return SDValue(); +} + static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { if (!Subtarget.hasF16C() || Subtarget.useSoftFloat()) @@ -47027,6 +47067,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::STRICT_FP_EXTEND: case ISD::FP_EXTEND: return combineFP_EXTEND(N, DAG, Subtarget); case ISD::FP_ROUND: return combineFP_ROUND(N, DAG, Subtarget); + case X86ISD::VBROADCAST_LOAD: return combineVBROADCAST_LOAD(N, DAG, DCI); } return SDValue();