Skip to content

Commit

Permalink
Recommit r344877 "[X86] Stop promoting integer loads to vXi64"
Browse files Browse the repository at this point in the history
I've included a fix to DAGCombiner::ForwardStoreValueToDirectLoad that I believe will prevent the previous miscompile.

Original commit message:

Theoretically this was done to simplify the amount of isel patterns that were needed. But it also meant a substantial number of our isel patterns have to match an explicit bitcast. By making the vXi32/vXi16/vXi8 types legal for loads, DAG combiner should be able to change the load type to rem

I had to add some additional plain load instruction patterns and a few other special cases, but overall the isel table has reduced in size by ~12000 bytes. So it looks like this promotion was hurting us more than helping.

I still have one crash in vector-trunc.ll that I'm hoping @RKSimon can help with. It seems to relate to using getTargetConstantFromNode on a load that was shrunk due to an extract_subvector combine after the constant pool entry was created. So we end up decoding more mask elements than the lo

I'm hoping this patch will simplify the number of patterns needed to remove the and/or/xor promotion.

Reviewers: RKSimon, spatel

Reviewed By: RKSimon

Subscribers: llvm-commits, RKSimon

Differential Revision: https://reviews.llvm.org/D53306

llvm-svn: 344965
  • Loading branch information
topperc committed Oct 22, 2018
1 parent 21a62e2 commit c8e183f
Show file tree
Hide file tree
Showing 16 changed files with 712 additions and 590 deletions.
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -12897,7 +12897,8 @@ SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
if (!isTypeLegal(LDMemType))
continue;
if (STMemType != LDMemType) {
if (numVectorEltsOrZero(STMemType) == numVectorEltsOrZero(LDMemType) &&
// TODO: Support vectors? This requires extract_subvector/bitcast.
if (!STMemType.isVector() && !LDMemType.isVector() &&
STMemType.isInteger() && LDMemType.isInteger())
Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
else
Expand Down
28 changes: 10 additions & 18 deletions llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Expand Up @@ -2890,21 +2890,17 @@ MachineSDNode *X86DAGToDAGISel::emitPCMPISTR(unsigned ROpc, unsigned MOpc,
const ConstantInt *Val = cast<ConstantSDNode>(Imm)->getConstantIntValue();
Imm = CurDAG->getTargetConstant(*Val, SDLoc(Node), Imm.getValueType());

// If there is a load, it will be behind a bitcast. We don't need to check
// alignment on this load.
// Try to fold a load. No need to check alignment.
SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
if (MayFoldLoad && N1->getOpcode() == ISD::BITCAST && N1->hasOneUse() &&
tryFoldLoad(Node, N1.getNode(), N1.getOperand(0), Tmp0, Tmp1, Tmp2,
Tmp3, Tmp4)) {
SDValue Load = N1.getOperand(0);
if (MayFoldLoad && tryFoldLoad(Node, N1, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) {
SDValue Ops[] = { N0, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4, Imm,
Load.getOperand(0) };
N1.getOperand(0) };
SDVTList VTs = CurDAG->getVTList(VT, MVT::i32, MVT::Other);
MachineSDNode *CNode = CurDAG->getMachineNode(MOpc, dl, VTs, Ops);
// Update the chain.
ReplaceUses(Load.getValue(1), SDValue(CNode, 2));
ReplaceUses(N1.getValue(1), SDValue(CNode, 2));
// Record the mem-refs
CurDAG->setNodeMemRefs(CNode, {cast<LoadSDNode>(Load)->getMemOperand()});
CurDAG->setNodeMemRefs(CNode, {cast<LoadSDNode>(N1)->getMemOperand()});
return CNode;
}

Expand All @@ -2927,22 +2923,18 @@ MachineSDNode *X86DAGToDAGISel::emitPCMPESTR(unsigned ROpc, unsigned MOpc,
const ConstantInt *Val = cast<ConstantSDNode>(Imm)->getConstantIntValue();
Imm = CurDAG->getTargetConstant(*Val, SDLoc(Node), Imm.getValueType());

// If there is a load, it will be behind a bitcast. We don't need to check
// alignment on this load.
// Try to fold a load. No need to check alignment.
SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
if (MayFoldLoad && N2->getOpcode() == ISD::BITCAST && N2->hasOneUse() &&
tryFoldLoad(Node, N2.getNode(), N2.getOperand(0), Tmp0, Tmp1, Tmp2,
Tmp3, Tmp4)) {
SDValue Load = N2.getOperand(0);
if (MayFoldLoad && tryFoldLoad(Node, N2, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) {
SDValue Ops[] = { N0, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4, Imm,
Load.getOperand(0), InFlag };
N2.getOperand(0), InFlag };
SDVTList VTs = CurDAG->getVTList(VT, MVT::i32, MVT::Other, MVT::Glue);
MachineSDNode *CNode = CurDAG->getMachineNode(MOpc, dl, VTs, Ops);
InFlag = SDValue(CNode, 3);
// Update the chain.
ReplaceUses(Load.getValue(1), SDValue(CNode, 2));
ReplaceUses(N2.getValue(1), SDValue(CNode, 2));
// Record the mem-refs
CurDAG->setNodeMemRefs(CNode, {cast<LoadSDNode>(Load)->getMemOperand()});
CurDAG->setNodeMemRefs(CNode, {cast<LoadSDNode>(N2)->getMemOperand()});
return CNode;
}

Expand Down
28 changes: 7 additions & 21 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -869,11 +869,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
}

// Promote v16i8, v8i16, v4i32 load, select, and, or, xor to v2i64.
for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
setOperationPromotedToType(ISD::LOAD, VT, MVT::v2i64);
}

// Custom lower v2i64 and v2f64 selects.
setOperationAction(ISD::SELECT, MVT::v2f64, Custom);
setOperationAction(ISD::SELECT, MVT::v2i64, Custom);
Expand Down Expand Up @@ -1178,11 +1173,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
if (HasInt256)
setOperationAction(ISD::VSELECT, MVT::v32i8, Legal);

// Promote v32i8, v16i16, v8i32 select, and, or, xor to v4i64.
for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32 }) {
setOperationPromotedToType(ISD::LOAD, VT, MVT::v4i64);
}

if (HasInt256) {
// Custom legalize 2x32 to get a little better code.
setOperationAction(ISD::MGATHER, MVT::v2f32, Custom);
Expand Down Expand Up @@ -1419,10 +1409,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
}
for (auto VT : { MVT::v64i8, MVT::v32i16, MVT::v16i32 }) {
setOperationPromotedToType(ISD::LOAD, VT, MVT::v8i64);
}

// Need to custom split v32i16/v64i8 bitcasts.
if (!Subtarget.hasBWI()) {
setOperationAction(ISD::BITCAST, MVT::v32i16, Custom);
Expand Down Expand Up @@ -5539,7 +5525,7 @@ static const Constant *getTargetConstantFromNode(SDValue Op) {
if (!CNode || CNode->isMachineConstantPoolEntry() || CNode->getOffset() != 0)
return nullptr;

return dyn_cast<Constant>(CNode->getConstVal());
return CNode->getConstVal();
}

// Extract raw constant bits from constant pools.
Expand Down Expand Up @@ -6045,7 +6031,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
break;
}
if (auto *C = getTargetConstantFromNode(MaskNode)) {
DecodeVPERMILPMask(C, MaskEltSize, Mask);
DecodeVPERMILPMask(C, MaskEltSize, VT.getSizeInBits(), Mask);
break;
}
return false;
Expand All @@ -6062,7 +6048,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
break;
}
if (auto *C = getTargetConstantFromNode(MaskNode)) {
DecodePSHUFBMask(C, Mask);
DecodePSHUFBMask(C, VT.getSizeInBits(), Mask);
break;
}
return false;
Expand Down Expand Up @@ -6124,7 +6110,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
break;
}
if (auto *C = getTargetConstantFromNode(MaskNode)) {
DecodeVPERMIL2PMask(C, CtrlImm, MaskEltSize, Mask);
DecodeVPERMIL2PMask(C, CtrlImm, MaskEltSize, VT.getSizeInBits(), Mask);
break;
}
}
Expand All @@ -6141,7 +6127,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
break;
}
if (auto *C = getTargetConstantFromNode(MaskNode)) {
DecodeVPPERMMask(C, Mask);
DecodeVPPERMMask(C, VT.getSizeInBits(), Mask);
break;
}
return false;
Expand All @@ -6158,7 +6144,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
break;
}
if (auto *C = getTargetConstantFromNode(MaskNode)) {
DecodeVPERMVMask(C, MaskEltSize, Mask);
DecodeVPERMVMask(C, MaskEltSize, VT.getSizeInBits(), Mask);
break;
}
return false;
Expand All @@ -6172,7 +6158,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
Ops.push_back(N->getOperand(2));
SDValue MaskNode = N->getOperand(1);
if (auto *C = getTargetConstantFromNode(MaskNode)) {
DecodeVPERMV3Mask(C, MaskEltSize, Mask);
DecodeVPERMV3Mask(C, MaskEltSize, VT.getSizeInBits(), Mask);
break;
}
return false;
Expand Down

0 comments on commit c8e183f

Please sign in to comment.