diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index cd56529bfa0fd..e0679f5f27d8c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -49959,18 +49959,17 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG, SDValue Ptr = Ld->getBasePtr(); SDValue Chain = Ld->getChain(); for (SDNode *User : Chain->uses()) { - if (User != N && + auto *UserLd = dyn_cast(User); + if (User != N && UserLd && (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD || User->getOpcode() == X86ISD::VBROADCAST_LOAD || ISD::isNormalLoad(User)) && - cast(User)->getChain() == Chain && - !User->hasAnyUseOfValue(1) && + UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) && User->getValueSizeInBits(0).getFixedValue() > RegVT.getFixedSizeInBits()) { if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD && - cast(User)->getBasePtr() == Ptr && - cast(User)->getMemoryVT().getSizeInBits() == - MemVT.getSizeInBits()) { + UserLd->getBasePtr() == Ptr && + UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) { SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits()); Extract = DAG.getBitcast(RegVT, Extract); @@ -49989,7 +49988,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG, // See if we are loading a constant that matches in the lower // bits of a longer constant (but from a different constant pool ptr). EVT UserVT = User->getValueType(0); - SDValue UserPtr = cast(User)->getBasePtr(); + SDValue UserPtr = UserLd->getBasePtr(); const Constant *LdC = getTargetConstantFromBasePtr(Ptr); const Constant *UserC = getTargetConstantFromBasePtr(UserPtr); if (LdC && UserC && UserPtr != Ptr) {