Skip to content

Commit

Permalink
[SVE][CodeGen] Improve codegen of scalable masked scatters
Browse files Browse the repository at this point in the history
If the scatter store is able to perform the sign/zero extend of
its index, this is folded into the instruction with refineIndexType().
Additionally, refineUniformBase() will return the base pointer and index
from an add + splat_vector.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D90942
  • Loading branch information
kmclaughlin-arm committed Nov 13, 2020
1 parent 08016ac commit 306c8ab
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 370 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Expand Up @@ -1318,6 +1318,10 @@ class TargetLoweringBase {
getIndexedMaskedStoreAction(IdxMode, VT.getSimpleVT()) == Custom);
}

// Returns true if VT is a legal index type for masked gathers/scatters
// on this target
virtual bool shouldRemoveExtendFromGSIndex(EVT VT) const { return false; }

/// Return how the condition code should be treated: either it is legal, needs
/// to be expanded to some other code sequence, or the target has a custom
/// expander for it.
Expand Down
58 changes: 58 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -9399,16 +9399,74 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1));
}

bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) {
if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
return false;

// For now we check only the LHS of the add.
SDValue LHS = Index.getOperand(0);
SDValue SplatVal = DAG.getSplatValue(LHS);
if (!SplatVal)
return false;

BasePtr = SplatVal;
Index = Index.getOperand(1);
return true;
}

// Fold sext/zext of index into index type.
bool refineIndexType(MaskedScatterSDNode *MSC, SDValue &Index, bool Scaled,
SelectionDAG &DAG) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDValue Op = Index.getOperand(0);

if (Index.getOpcode() == ISD::ZERO_EXTEND) {
MSC->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED);
if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
Index = Op;
return true;
}
}

if (Index.getOpcode() == ISD::SIGN_EXTEND) {
MSC->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED);
if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
Index = Op;
return true;
}
}

return false;
}

SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
SDValue Mask = MSC->getMask();
SDValue Chain = MSC->getChain();
SDValue Index = MSC->getIndex();
SDValue Scale = MSC->getScale();
SDValue StoreVal = MSC->getValue();
SDValue BasePtr = MSC->getBasePtr();
SDLoc DL(N);

// Zap scatters with a zero mask.
if (ISD::isBuildVectorAllZeros(Mask.getNode()))
return Chain;

if (refineUniformBase(BasePtr, Index, DAG)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(
DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops,
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
}

if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(
DAG.getVTList(MVT::Other), StoreVal.getValueType(), DL, Ops,
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
}

return SDValue();
}

Expand Down
13 changes: 9 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -3705,6 +3705,14 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
}

bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
if (VT.getVectorElementType() == MVT::i32 &&
VT.getVectorElementCount().getKnownMinValue() >= 4)
return true;

return false;
}

bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
return ExtVal.getValueType().isScalableVector();
}
Expand Down Expand Up @@ -3792,11 +3800,8 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
}

if (getScatterIndexIsExtended(Index)) {
if (Index.getOpcode() == ISD::AND)
IsSigned = false;
if (getScatterIndexIsExtended(Index))
Index = Index.getOperand(0);
}

SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
return DAG.getNode(getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend), DL,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -980,6 +980,7 @@ class AArch64TargetLowering : public TargetLowering {
return TargetLowering::getInlineAsmMemConstraint(ConstraintCode);
}

bool shouldRemoveExtendFromGSIndex(EVT VT) const override;
bool isVectorLoadExtDesirable(SDValue ExtVal) const override;
bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override;
bool mayBeEmittedAsTailCall(const CallInst *CI) const override;
Expand Down
100 changes: 10 additions & 90 deletions llvm/test/CodeGen/AArch64/sve-masked-scatter-32b-scaled.ll
Expand Up @@ -166,15 +166,7 @@ define void @masked_scatter_nxv2f64_zext(<vscale x 2 x double> %data, double* %b
define void @masked_scatter_nxv4i16_sext(<vscale x 4 x i16> %data, i16* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind {
; CHECK-LABEL: masked_scatter_nxv4i16_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1h { z3.d }, p2, [x0, z1.d, sxtw #1]
; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, sxtw #1]
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, sxtw #1]
; CHECK-NEXT: ret
%ext = sext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr i16, i16* %base, <vscale x 4 x i64> %ext
Expand All @@ -185,15 +177,7 @@ define void @masked_scatter_nxv4i16_sext(<vscale x 4 x i16> %data, i16* %base, <
define void @masked_scatter_nxv4i32_sext(<vscale x 4 x i32> %data, i32* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind {
; CHECK-LABEL: masked_scatter_nxv4i32_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1w { z3.d }, p2, [x0, z1.d, sxtw #2]
; CHECK-NEXT: st1w { z0.d }, p0, [x0, z2.d, sxtw #2]
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
; CHECK-NEXT: ret
%ext = sext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr i32, i32* %base, <vscale x 4 x i64> %ext
Expand All @@ -204,15 +188,7 @@ define void @masked_scatter_nxv4i32_sext(<vscale x 4 x i32> %data, i32* %base, <
define void @masked_scatter_nxv4f16_sext(<vscale x 4 x half> %data, half* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind {
; CHECK-LABEL: masked_scatter_nxv4f16_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1h { z3.d }, p2, [x0, z1.d, sxtw #1]
; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, sxtw #1]
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, sxtw #1]
; CHECK-NEXT: ret
%ext = sext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr half, half* %base, <vscale x 4 x i64> %ext
Expand All @@ -223,15 +199,7 @@ define void @masked_scatter_nxv4f16_sext(<vscale x 4 x half> %data, half* %base,
define void @masked_scatter_nxv4bf16_sext(<vscale x 4 x bfloat> %data, bfloat* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind #0 {
; CHECK-LABEL: masked_scatter_nxv4bf16_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1h { z3.d }, p2, [x0, z1.d, sxtw #1]
; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, sxtw #1]
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, sxtw #1]
; CHECK-NEXT: ret
%ext = sext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr bfloat, bfloat* %base, <vscale x 4 x i64> %ext
Expand All @@ -242,15 +210,7 @@ define void @masked_scatter_nxv4bf16_sext(<vscale x 4 x bfloat> %data, bfloat* %
define void @masked_scatter_nxv4f32_sext(<vscale x 4 x float> %data, float* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind #0 {
; CHECK-LABEL: masked_scatter_nxv4f32_sext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1w { z3.d }, p2, [x0, z1.d, sxtw #2]
; CHECK-NEXT: st1w { z0.d }, p0, [x0, z2.d, sxtw #2]
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
; CHECK-NEXT: ret
%ext = sext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr float, float* %base, <vscale x 4 x i64> %ext
Expand All @@ -261,15 +221,7 @@ define void @masked_scatter_nxv4f32_sext(<vscale x 4 x float> %data, float* %bas
define void @masked_scatter_nxv4i16_zext(<vscale x 4 x i16> %data, i16* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind {
; CHECK-LABEL: masked_scatter_nxv4i16_zext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1h { z3.d }, p2, [x0, z1.d, uxtw #1]
; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, uxtw #1]
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, uxtw #1]
; CHECK-NEXT: ret
%ext = zext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr i16, i16* %base, <vscale x 4 x i64> %ext
Expand All @@ -280,15 +232,7 @@ define void @masked_scatter_nxv4i16_zext(<vscale x 4 x i16> %data, i16* %base, <
define void @masked_scatter_nxv4i32_zext(<vscale x 4 x i32> %data, i32* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind {
; CHECK-LABEL: masked_scatter_nxv4i32_zext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1w { z3.d }, p2, [x0, z1.d, uxtw #2]
; CHECK-NEXT: st1w { z0.d }, p0, [x0, z2.d, uxtw #2]
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, uxtw #2]
; CHECK-NEXT: ret
%ext = zext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr i32, i32* %base, <vscale x 4 x i64> %ext
Expand All @@ -299,15 +243,7 @@ define void @masked_scatter_nxv4i32_zext(<vscale x 4 x i32> %data, i32* %base, <
define void @masked_scatter_nxv4f16_zext(<vscale x 4 x half> %data, half* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind {
; CHECK-LABEL: masked_scatter_nxv4f16_zext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1h { z3.d }, p2, [x0, z1.d, uxtw #1]
; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, uxtw #1]
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, uxtw #1]
; CHECK-NEXT: ret
%ext = zext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr half, half* %base, <vscale x 4 x i64> %ext
Expand All @@ -318,15 +254,7 @@ define void @masked_scatter_nxv4f16_zext(<vscale x 4 x half> %data, half* %base,
define void @masked_scatter_nxv4bf16_zext(<vscale x 4 x bfloat> %data, bfloat* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind #0 {
; CHECK-LABEL: masked_scatter_nxv4bf16_zext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1h { z3.d }, p2, [x0, z1.d, uxtw #1]
; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, uxtw #1]
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, uxtw #1]
; CHECK-NEXT: ret
%ext = zext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr bfloat, bfloat* %base, <vscale x 4 x i64> %ext
Expand All @@ -337,15 +265,7 @@ define void @masked_scatter_nxv4bf16_zext(<vscale x 4 x bfloat> %data, bfloat* %
define void @masked_scatter_nxv4f32_zext(<vscale x 4 x float> %data, float* %base, <vscale x 4 x i32> %indexes, <vscale x 4 x i1> %masks) nounwind #0 {
; CHECK-LABEL: masked_scatter_nxv4f32_zext:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: uunpkhi z2.d, z1.s
; CHECK-NEXT: uunpklo z1.d, z1.s
; CHECK-NEXT: uunpklo z3.d, z0.s
; CHECK-NEXT: uunpkhi z0.d, z0.s
; CHECK-NEXT: zip1 p2.s, p0.s, p1.s
; CHECK-NEXT: zip2 p0.s, p0.s, p1.s
; CHECK-NEXT: st1w { z3.d }, p2, [x0, z1.d, uxtw #2]
; CHECK-NEXT: st1w { z0.d }, p0, [x0, z2.d, uxtw #2]
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, uxtw #2]
; CHECK-NEXT: ret
%ext = zext <vscale x 4 x i32> %indexes to <vscale x 4 x i64>
%ptrs = getelementptr float, float* %base, <vscale x 4 x i64> %ext
Expand Down

0 comments on commit 306c8ab

Please sign in to comment.