Skip to content

Commit

Permalink
[SelectionDAG][X86] Explicitly store the scale in the gather/scatter …
Browse files Browse the repository at this point in the history
…ISD nodes

Currently we infer the scale at isel time by analyzing whether the base is a constant 0 or not. If it is we assume scale is 1, else we take it from the element size of the pass thru or stored value. This seems a little weird and I think it makes more sense to make it explicit in the DAG rather than doing tricky things in the backend.

Most of this patch is just making sure we copy the scale around everywhere.

Differential Revision: https://reviews.llvm.org/D40055

llvm-svn: 322210
  • Loading branch information
topperc committed Jan 10, 2018
1 parent 4c3ea80 commit af4eb17
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 42 deletions.
5 changes: 3 additions & 2 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Expand Up @@ -2120,13 +2120,14 @@ class MaskedGatherScatterSDNode : public MemSDNode {
: MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {}

// In the both nodes address is Op1, mask is Op2:
// MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value
// MaskedScatterSDNode (Chain, value, mask, base, index)
// MaskedGatherSDNode (Chain, passthru, mask, base, index, scale)
// MaskedScatterSDNode (Chain, value, mask, base, index, scale)
// Mask is a vector of i1 elements
const SDValue &getBasePtr() const { return getOperand(3); }
const SDValue &getIndex() const { return getOperand(4); }
const SDValue &getMask() const { return getOperand(2); }
const SDValue &getValue() const { return getOperand(1); }
const SDValue &getScale() const { return getOperand(5); }

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MGATHER ||
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -6726,6 +6726,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
SDValue DataLo, DataHi;
std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL);

SDValue Scale = MSC->getScale();
SDValue BasePtr = MSC->getBasePtr();
SDValue IndexLo, IndexHi;
std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL);
Expand All @@ -6735,11 +6736,11 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
MachineMemOperand::MOStore, LoMemVT.getStoreSize(),
Alignment, MSC->getAAInfo(), MSC->getRanges());

SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo };
SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale };
Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
DL, OpsLo, MMO);

SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi};
SDValue OpsHi[] = { Chain, DataHi, MaskHi, BasePtr, IndexHi, Scale };
Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
DL, OpsHi, MMO);

Expand Down Expand Up @@ -6859,6 +6860,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
EVT LoMemVT, HiMemVT;
std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);

SDValue Scale = MGT->getScale();
SDValue BasePtr = MGT->getBasePtr();
SDValue Index = MGT->getIndex();
SDValue IndexLo, IndexHi;
Expand All @@ -6869,13 +6871,13 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
MachineMemOperand::MOLoad, LoMemVT.getStoreSize(),
Alignment, MGT->getAAInfo(), MGT->getRanges());

SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo };
SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo, Scale };
Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo,
MMO);
MMO);

SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi};
SDValue OpsHi[] = { Chain, Src0Hi, MaskHi, BasePtr, IndexHi, Scale };
Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi,
MMO);
MMO);

AddToWorklist(Lo.getNode());
AddToWorklist(Hi.getNode());
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Expand Up @@ -501,7 +501,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_MGATHER(MaskedGatherSDNode *N) {

SDLoc dl(N);
SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(),
N->getIndex()};
N->getIndex(), N->getScale() };
SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other),
N->getMemoryVT(), dl, Ops,
N->getMemOperand());
Expand Down
22 changes: 14 additions & 8 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Expand Up @@ -1238,6 +1238,7 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
SDValue Mask = MGT->getMask();
SDValue Src0 = MGT->getValue();
SDValue Index = MGT->getIndex();
SDValue Scale = MGT->getScale();
unsigned Alignment = MGT->getOriginalAlignment();

// Split Mask operand
Expand Down Expand Up @@ -1269,11 +1270,11 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
MachineMemOperand::MOLoad, LoMemVT.getStoreSize(),
Alignment, MGT->getAAInfo(), MGT->getRanges());

SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo};
SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale};
Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo,
MMO);

SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi};
SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale};
Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi,
MMO);

Expand Down Expand Up @@ -1816,6 +1817,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
SDValue Ch = MGT->getChain();
SDValue Ptr = MGT->getBasePtr();
SDValue Index = MGT->getIndex();
SDValue Scale = MGT->getScale();
SDValue Mask = MGT->getMask();
SDValue Src0 = MGT->getValue();
unsigned Alignment = MGT->getOriginalAlignment();
Expand Down Expand Up @@ -1848,7 +1850,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
MachineMemOperand::MOLoad, LoMemVT.getStoreSize(),
Alignment, MGT->getAAInfo(), MGT->getRanges());

SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo};
SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale};
SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl,
OpsLo, MMO);

Expand All @@ -1858,7 +1860,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
Alignment, MGT->getAAInfo(),
MGT->getRanges());

SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi};
SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale};
SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl,
OpsHi, MMO);

Expand Down Expand Up @@ -1941,6 +1943,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
SDValue Ptr = N->getBasePtr();
SDValue Mask = N->getMask();
SDValue Index = N->getIndex();
SDValue Scale = N->getScale();
SDValue Data = N->getValue();
EVT MemoryVT = N->getMemoryVT();
unsigned Alignment = N->getOriginalAlignment();
Expand Down Expand Up @@ -1976,7 +1979,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
MachineMemOperand::MOStore, LoMemVT.getStoreSize(),
Alignment, N->getAAInfo(), N->getRanges());

SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo};
SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale};
Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
DL, OpsLo, MMO);

Expand All @@ -1988,7 +1991,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
// The order of the Scatter operation after split is well defined. The "Hi"
// part comes after the "Lo". So these two operations should be chained one
// after another.
SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi};
SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale};
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
DL, OpsHi, MMO);
}
Expand Down Expand Up @@ -2954,6 +2957,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_MGATHER(MaskedGatherSDNode *N) {
SDValue Mask = N->getMask();
EVT MaskVT = Mask.getValueType();
SDValue Src0 = GetWidenedVector(N->getValue());
SDValue Scale = N->getScale();
unsigned NumElts = WideVT.getVectorNumElements();
SDLoc dl(N);

Expand All @@ -2969,7 +2973,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_MGATHER(MaskedGatherSDNode *N) {
Index.getValueType().getScalarType(),
NumElts);
Index = ModifyToType(Index, WideIndexVT);
SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale };
SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other),
N->getMemoryVT(), dl, Ops,
N->getMemOperand());
Expand Down Expand Up @@ -3593,6 +3597,7 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) {
SDValue DataOp = MSC->getValue();
SDValue Mask = MSC->getMask();
EVT MaskVT = Mask.getValueType();
SDValue Scale = MSC->getScale();

// Widen the value.
SDValue WideVal = GetWidenedVector(DataOp);
Expand All @@ -3612,7 +3617,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) {
NumElts);
Index = ModifyToType(Index, WideIndexVT);

SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index};
SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index,
Scale};
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
MSC->getMemoryVT(), dl, Ops,
MSC->getMemOperand());
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -6208,7 +6208,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO) {
assert(Ops.size() == 5 && "Incompatible number of operands");
assert(Ops.size() == 6 && "Incompatible number of operands");

FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
Expand All @@ -6234,6 +6234,9 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
assert(N->getIndex().getValueType().getVectorNumElements() ==
N->getValueType(0).getVectorNumElements() &&
"Vector width mismatch between index and data");
assert(isa<ConstantSDNode>(N->getScale()) &&
cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
"Scale should be a constant power of 2");

CSEMap.InsertNode(N, IP);
InsertNode(N);
Expand All @@ -6245,7 +6248,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO) {
assert(Ops.size() == 5 && "Incompatible number of operands");
assert(Ops.size() == 6 && "Incompatible number of operands");

FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
Expand All @@ -6268,6 +6271,9 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
assert(N->getIndex().getValueType().getVectorNumElements() ==
N->getValue().getValueType().getVectorNumElements() &&
"Vector width mismatch between index and data");
assert(isa<ConstantSDNode>(N->getScale()) &&
cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
"Scale should be a constant power of 2");

CSEMap.InsertNode(N, IP);
InsertNode(N);
Expand Down
22 changes: 15 additions & 7 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Expand Up @@ -3867,7 +3867,7 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
// extract the splat value and use it as a uniform base.
// In all other cases the function returns 'false'.
static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index,
SelectionDAGBuilder* SDB) {
SDValue &Scale, SelectionDAGBuilder* SDB) {
SelectionDAG& DAG = SDB->DAG;
LLVMContext &Context = *DAG.getContext();

Expand Down Expand Up @@ -3897,6 +3897,10 @@ static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index,
if (!SDB->findValue(Ptr) || !SDB->findValue(IndexVal))
return false;

const TargetLowering &TLI = DAG.getTargetLoweringInfo();
const DataLayout &DL = DAG.getDataLayout();
Scale = DAG.getTargetConstant(DL.getTypeAllocSize(GEP->getResultElementType()),
SDB->getCurSDLoc(), TLI.getPointerTy(DL));
Base = SDB->getValue(Ptr);
Index = SDB->getValue(IndexVal);

Expand Down Expand Up @@ -3926,19 +3930,21 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {

SDValue Base;
SDValue Index;
SDValue Scale;
const Value *BasePtr = Ptr;
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this);

const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr;
MachineMemOperand *MMO = DAG.getMachineFunction().
getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
MachineMemOperand::MOStore, VT.getStoreSize(),
Alignment, AAInfo);
if (!UniformBase) {
Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
Index = getValue(Ptr);
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index, Scale };
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
Ops, MMO);
DAG.setRoot(Scatter);
Expand Down Expand Up @@ -4025,8 +4031,9 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
SDValue Root = DAG.getRoot();
SDValue Base;
SDValue Index;
SDValue Scale;
const Value *BasePtr = Ptr;
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this);
bool ConstantMemory = false;
if (UniformBase &&
AA && AA->pointsToConstantMemory(MemoryLocation(
Expand All @@ -4044,10 +4051,11 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
Alignment, AAInfo, Ranges);

if (!UniformBase) {
Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
Index = getValue(Ptr);
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
SDValue Ops[] = { Root, Src0, Mask, Base, Index };
SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
Ops, MMO);

Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Expand Up @@ -1508,6 +1508,12 @@ bool X86DAGToDAGISel::matchAddressBase(SDValue N, X86ISelAddressMode &AM) {
bool X86DAGToDAGISel::matchVectorAddress(SDValue N, X86ISelAddressMode &AM) {
// TODO: Support other operations.
switch (N.getOpcode()) {
case ISD::Constant: {
uint64_t Val = cast<ConstantSDNode>(N)->getSExtValue();
if (!foldOffsetIntoAddress(Val, AM))
return false;
break;
}
case X86ISD::Wrapper:
if (!matchWrapper(N, AM))
return false;
Expand All @@ -1523,7 +1529,7 @@ bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base,
X86ISelAddressMode AM;
auto *Mgs = cast<X86MaskedGatherScatterSDNode>(Parent);
AM.IndexReg = Mgs->getIndex();
AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
AM.Scale = cast<ConstantSDNode>(Mgs->getScale())->getZExtValue();

unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace();
// AddrSpace 256 -> GS, 257 -> FS, 258 -> SS.
Expand All @@ -1534,14 +1540,8 @@ bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base,
if (AddrSpace == 258)
AM.Segment = CurDAG->getRegister(X86::SS, MVT::i16);

// If Base is 0, the whole address is in index and the Scale is 1
if (isa<ConstantSDNode>(N)) {
assert(cast<ConstantSDNode>(N)->isNullValue() &&
"Unexpected base in gather/scatter");
AM.Scale = 1;
}
// Otherwise, try to match into the base and displacement fields.
else if (matchVectorAddress(N, AM))
// Try to match into the base and displacement fields.
if (matchVectorAddress(N, AM))
return false;

MVT VT = N.getSimpleValueType();
Expand Down

0 comments on commit af4eb17

Please sign in to comment.