Skip to content

Commit

Permalink
[SVE][CodeGen] Add the isTruncatingStore flag to MSCATTER
Browse files Browse the repository at this point in the history
This patch adds the IsTruncatingStore flag to MaskedScatterSDNode, set by getMaskedScatter().
Updated SelectionDAGDumper::print_details for MaskedScatterSDNode to print
the details of masked scatters (is truncating, signed or scaled).

This is the first in a series of patches which adds support for scalable masked scatters

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D90939
  • Loading branch information
kmclaughlin-arm committed Nov 11, 2020
1 parent 9ff7011 commit ffbbfc7
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 16 deletions.
3 changes: 2 additions & 1 deletion llvm/include/llvm/CodeGen/SelectionDAG.h
Expand Up @@ -1361,7 +1361,8 @@ class SelectionDAG {
ISD::MemIndexType IndexType);
SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType);
ISD::MemIndexType IndexType,
bool IsTruncating = false);

/// Construct a node to track a Value* through the backend.
SDValue getSrcValue(const Value *v);
Expand Down
12 changes: 10 additions & 2 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Expand Up @@ -523,6 +523,7 @@ BEGIN_TWO_BYTE_PACK()
class StoreSDNodeBitfields {
friend class StoreSDNode;
friend class MaskedStoreSDNode;
friend class MaskedScatterSDNode;

uint16_t : NumLSBaseSDNodeBits;

Expand Down Expand Up @@ -2441,9 +2442,16 @@ class MaskedScatterSDNode : public MaskedGatherScatterSDNode {

MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
EVT MemVT, MachineMemOperand *MMO,
ISD::MemIndexType IndexType)
ISD::MemIndexType IndexType, bool IsTrunc)
: MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, VTs, MemVT, MMO,
IndexType) {}
IndexType) {
StoreSDNodeBits.IsTruncating = IsTrunc;
}

/// Return true if the op does a truncation before store.
/// For integers this is the same as doing a TRUNCATE and storing the result.
/// For floats, it is the same as doing an FP_ROUND and storing the result.
bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; }

const SDValue &getValue() const { return getOperand(1); }

Expand Down
11 changes: 9 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Expand Up @@ -1851,6 +1851,7 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MGATHER(MaskedGatherSDNode *N,

SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N,
unsigned OpNo) {
bool TruncateStore = N->isTruncatingStore();
SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
if (OpNo == 2) {
// The Mask
Expand All @@ -1863,9 +1864,15 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N,
NewOps[OpNo] = SExtPromotedInteger(N->getOperand(OpNo));
else
NewOps[OpNo] = ZExtPromotedInteger(N->getOperand(OpNo));
} else

} else {
NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
TruncateStore = true;
}

return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(),
SDLoc(N), NewOps, N->getMemOperand(),
N->getIndexType(), TruncateStore);
}

SDValue DAGTypeLegalizer::PromoteIntOp_TRUNCATE(SDNode *N) {
Expand Down
17 changes: 12 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Expand Up @@ -2496,11 +2496,15 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
SDValue Index = N->getIndex();
SDValue Scale = N->getScale();
SDValue Data = N->getValue();
EVT MemoryVT = N->getMemoryVT();
Align Alignment = N->getOriginalAlign();
SDLoc DL(N);

// Split all operands

EVT LoMemVT, HiMemVT;
std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);

SDValue DataLo, DataHi;
if (getTypeAction(Data.getValueType()) == TargetLowering::TypeSplitVector)
// Split Data operand
Expand Down Expand Up @@ -2531,15 +2535,17 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
MemoryLocation::UnknownSize, Alignment, N->getAAInfo(), N->getRanges());

SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale};
Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
DL, OpsLo, MMO, N->getIndexType());
Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), LoMemVT,
DL, OpsLo, MMO, N->getIndexType(),
N->isTruncatingStore());

// The order of the Scatter operation after split is well defined. The "Hi"
// part comes after the "Lo". So these two operations should be chained one
// after another.
SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale};
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
DL, OpsHi, MMO, N->getIndexType());
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), HiMemVT,
DL, OpsHi, MMO, N->getIndexType(),
N->isTruncatingStore());
}

SDValue DAGTypeLegalizer::SplitVecOp_STORE(StoreSDNode *N, unsigned OpNo) {
Expand Down Expand Up @@ -4717,7 +4723,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) {
Scale};
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
MSC->getMemoryVT(), SDLoc(N), Ops,
MSC->getMemOperand(), MSC->getIndexType());
MSC->getMemOperand(), MSC->getIndexType(),
MSC->isTruncatingStore());
}

SDValue DAGTypeLegalizer::WidenVecOp_SETCC(SDNode *N) {
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -7340,22 +7340,24 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO,
ISD::MemIndexType IndexType) {
ISD::MemIndexType IndexType,
bool IsTrunc) {
assert(Ops.size() == 6 && "Incompatible number of operands");

FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>(
dl.getIROrder(), VTs, VT, MMO, IndexType));
dl.getIROrder(), VTs, VT, MMO, IndexType, IsTrunc));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);
return SDValue(E, 0);
}

auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(),
VTs, VT, MMO, IndexType);
VTs, VT, MMO, IndexType, IsTrunc);
createOperands(N, Ops);

assert(N->getMask().getValueType().getVectorNumElements() ==
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Expand Up @@ -4302,7 +4302,7 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
}
SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale };
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
Ops, MMO, IndexType);
Ops, MMO, IndexType, false);
DAG.setRoot(Scatter);
setValue(&I, Scatter);
}
Expand Down
14 changes: 13 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Expand Up @@ -735,7 +735,19 @@ void SDNode::print_details(raw_ostream &OS, const SelectionDAG *G) const {
OS << ", compressing";

OS << ">";
} else if (const MemSDNode* M = dyn_cast<MemSDNode>(this)) {
} else if (const auto *MScatter = dyn_cast<MaskedScatterSDNode>(this)) {
OS << "<";
printMemOperand(OS, *MScatter->getMemOperand(), G);

if (MScatter->isTruncatingStore())
OS << ", trunc to " << MScatter->getMemoryVT().getEVTString();

auto Signed = MScatter->isIndexSigned() ? "signed" : "unsigned";
auto Scaled = MScatter->isIndexScaled() ? "scaled" : "unscaled";
OS << ", " << Signed << " " << Scaled << " offset";

OS << ">";
} else if (const MemSDNode *M = dyn_cast<MemSDNode>(this)) {
OS << "<";
printMemOperand(OS, *M->getMemOperand(), G);
OS << ">";
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -47560,7 +47560,8 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType());
Scatter->getIndexType(),
Scatter->isTruncatingStore());
}

static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Expand Down

0 comments on commit ffbbfc7

Please sign in to comment.