Skip to content

Commit

Permalink
[SVE][CodeGen] Lower scalable masked gathers
Browse files Browse the repository at this point in the history
Lowers the llvm.masked.gather intrinsics (scalar plus vector addressing mode only)

Changes in this patch:
- Add custom lowering for MGATHER, using getGatherVecOpcode() to choose the appropriate
  gather load opcode to use.
- Improve codegen with refineIndexType/refineUniformBase, added in D90942
- Tests added for gather loads with 32 & 64-bit scaled & unscaled offsets.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D91092
  • Loading branch information
kmclaughlin-arm committed Dec 7, 2020
1 parent 9806181 commit f6dd32f
Show file tree
Hide file tree
Showing 12 changed files with 1,516 additions and 12 deletions.
13 changes: 9 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Expand Up @@ -1746,6 +1746,7 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
SDValue PassThru = MGT->getPassThru();
SDValue Index = MGT->getIndex();
SDValue Scale = MGT->getScale();
EVT MemoryVT = MGT->getMemoryVT();
Align Alignment = MGT->getOriginalAlign();

// Split Mask operand
Expand All @@ -1759,6 +1760,10 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, dl);
}

EVT LoMemVT, HiMemVT;
// Split MemoryVT
std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);

SDValue PassThruLo, PassThruHi;
if (getTypeAction(PassThru.getValueType()) == TargetLowering::TypeSplitVector)
GetSplitVector(PassThru, PassThruLo, PassThruHi);
Expand All @@ -1777,11 +1782,11 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
MGT->getRanges());

SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale};
Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo,
Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, OpsLo,
MMO, MGT->getIndexType());

SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale};
Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi,
Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, OpsHi,
MMO, MGT->getIndexType());

// Build a factor node to remember that this load is independent of the
Expand Down Expand Up @@ -2421,11 +2426,11 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
MGT->getRanges());

SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale};
SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl,
SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl,
OpsLo, MMO, MGT->getIndexType());

SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale};
SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl,
SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl,
OpsHi, MMO, MGT->getIndexType());

// Build a factor node to remember that this load is independent of the
Expand Down
13 changes: 9 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -7310,17 +7310,22 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
return SDValue(E, 0);
}

IndexType = TLI->getCanonicalIndexType(IndexType, VT, Ops[4]);
auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(),
VTs, VT, MMO, IndexType);
createOperands(N, Ops);

assert(N->getPassThru().getValueType() == N->getValueType(0) &&
"Incompatible type of the PassThru value in MaskedGatherSDNode");
assert(N->getMask().getValueType().getVectorNumElements() ==
N->getValueType(0).getVectorNumElements() &&
assert(N->getMask().getValueType().getVectorElementCount() ==
N->getValueType(0).getVectorElementCount() &&
"Vector width mismatch between mask and data");
assert(N->getIndex().getValueType().getVectorNumElements() >=
N->getValueType(0).getVectorNumElements() &&
assert(N->getIndex().getValueType().getVectorElementCount().isScalable() ==
N->getValueType(0).getVectorElementCount().isScalable() &&
"Scalable flags of index and data do not match");
assert(ElementCount::isKnownGE(
N->getIndex().getValueType().getVectorElementCount(),
N->getValueType(0).getVectorElementCount()) &&
"Vector width mismatch between index and data");
assert(isa<ConstantSDNode>(N->getScale()) &&
cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Expand Up @@ -4416,7 +4416,7 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
if (!UniformBase) {
Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
Index = getValue(Ptr);
IndexType = ISD::SIGNED_SCALED;
IndexType = ISD::SIGNED_UNSCALED;
Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
}
SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
Expand Down
97 changes: 94 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -113,6 +113,16 @@ EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden,
"optimization"),
cl::init(true));

// Temporary option added for the purpose of testing functionality added
// to DAGCombiner.cpp in D92230. It is expected that this can be removed
// in future when both implementations will be based off MGATHER rather
// than the GLD1 nodes added for the SVE gather load intrinsics.
static cl::opt<bool>
EnableCombineMGatherIntrinsics("aarch64-enable-mgather-combine", cl::Hidden,
cl::desc("Combine extends of AArch64 masked "
"gather intrinsics"),
cl::init(true));

/// Value type used for condition codes.
static const MVT MVT_CC = MVT::i32;

Expand Down Expand Up @@ -1059,6 +1069,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SINT_TO_FP, VT, Custom);
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::MUL, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
Expand Down Expand Up @@ -1111,6 +1122,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::nxv4f32, MVT::nxv2f64}) {
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
Expand Down Expand Up @@ -3775,6 +3787,29 @@ bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
return ExtVal.getValueType().isScalableVector();
}

unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
{std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
AArch64ISD::GLD1_MERGE_ZERO},
{std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
AArch64ISD::GLD1_UXTW_MERGE_ZERO},
{std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
AArch64ISD::GLD1_MERGE_ZERO},
{std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
AArch64ISD::GLD1_SXTW_MERGE_ZERO},
{std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
AArch64ISD::GLD1_SCALED_MERGE_ZERO},
{std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO},
{std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
AArch64ISD::GLD1_SCALED_MERGE_ZERO},
{std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO},
};
auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
return AddrModes.find(Key)->second;
}

unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
{std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
Expand All @@ -3798,7 +3833,7 @@ unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
return AddrModes.find(Key)->second;
}

bool getScatterIndexIsExtended(SDValue Index) {
bool getGatherScatterIndexIsExtended(SDValue Index) {
unsigned Opcode = Index.getOpcode();
if (Opcode == ISD::SIGN_EXTEND_INREG)
return true;
Expand All @@ -3816,6 +3851,54 @@ bool getScatterIndexIsExtended(SDValue Index) {
return false;
}

SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
assert(MGT && "Can only custom lower gather load nodes");

SDValue Index = MGT->getIndex();
SDValue Chain = MGT->getChain();
SDValue PassThru = MGT->getPassThru();
SDValue Mask = MGT->getMask();
SDValue BasePtr = MGT->getBasePtr();

ISD::MemIndexType IndexType = MGT->getIndexType();
bool IsScaled =
IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
bool IsSigned =
IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
bool IdxNeedsExtend =
getGatherScatterIndexIsExtended(Index) ||
Index.getSimpleValueType().getVectorElementType() == MVT::i32;

EVT VT = PassThru.getSimpleValueType();
EVT MemVT = MGT->getMemoryVT();
SDValue InputVT = DAG.getValueType(MemVT);

if (VT.getVectorElementType() == MVT::bf16 &&
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
return SDValue();

// Handle FP data
if (VT.isFloatingPoint()) {
VT = VT.changeVectorElementTypeToInteger();
ElementCount EC = VT.getVectorElementCount();
auto ScalarIntVT =
MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
PassThru = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
MVT::getVectorVT(ScalarIntVT, EC), PassThru);

InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
}

SDVTList VTs = DAG.getVTList(PassThru.getSimpleValueType(), MVT::Other);

SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT, PassThru};
return DAG.getNode(getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend), DL,
VTs, Ops);
}

SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
Expand All @@ -3834,7 +3917,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
bool IsSigned =
IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
bool NeedsExtend =
getScatterIndexIsExtended(Index) ||
getGatherScatterIndexIsExtended(Index) ||
Index.getSimpleValueType().getVectorElementType() == MVT::i32;

EVT VT = StoreVal.getSimpleValueType();
Expand All @@ -3858,7 +3941,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
}

if (getScatterIndexIsExtended(Index))
if (getGatherScatterIndexIsExtended(Index))
Index = Index.getOperand(0);

SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
Expand Down Expand Up @@ -4159,6 +4242,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerINTRINSIC_WO_CHAIN(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
case ISD::MGATHER:
return LowerMGATHER(Op, DAG);
case ISD::MSCATTER:
return LowerMSCATTER(Op, DAG);
case ISD::VECREDUCE_SEQ_FADD:
Expand Down Expand Up @@ -12019,6 +12104,9 @@ static SDValue performSVEAndCombine(SDNode *N,
return DAG.getNode(Opc, DL, N->getValueType(0), And);
}

if (!EnableCombineMGatherIntrinsics)
return SDValue();

SDValue Mask = N->getOperand(1);

if (!Src.hasOneUse())
Expand Down Expand Up @@ -14982,6 +15070,9 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
}

if (!EnableCombineMGatherIntrinsics)
return SDValue();

// SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
// for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
unsigned NewOpc;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -805,6 +805,7 @@ class AArch64TargetLowering : public TargetLowering {

SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
Expand Down

0 comments on commit f6dd32f

Please sign in to comment.