Skip to content

Commit

Permalink
[AArch64][SVE] Handle more cases in findMoreOptimalIndexType.
Browse files Browse the repository at this point in the history
This patch addresses @paulwalker-arm's comment on D117900 to
only update/write the by-ref operands iff the function returns
true. It also handles a few more cases where a series of added
offsets can be folded into the base pointer, rather than just looking
at a single offset.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D119728
  • Loading branch information
sdesmalen-arm committed Feb 28, 2022
1 parent ee95fe5 commit 201e368
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 35 deletions.
108 changes: 73 additions & 35 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -16476,55 +16476,90 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}

// Analyse the specified address returning true if a more optimal addressing
// mode is available. When returning true all parameters are updated to reflect
// their recommended values.
static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
SDValue &BasePtr, SDValue &Index,
ISD::MemIndexType &IndexType,
SelectionDAG &DAG) {
// Only consider element types that are pointer sized as smaller types can
// be easily promoted.
/// \return true if part of the index was folded into the Base.
static bool foldIndexIntoBase(SDValue &BasePtr, SDValue &Index, SDValue Scale,
SDLoc DL, SelectionDAG &DAG) {
// This function assumes a vector of i64 indices.
EVT IndexVT = Index.getValueType();
if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
if (!IndexVT.isVector() || IndexVT.getVectorElementType() != MVT::i64)
return false;

int64_t Stride = 0;
SDLoc DL(N);
// Index = step(const) + splat(offset)
if (Index.getOpcode() == ISD::ADD &&
Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
SDValue StepVector = Index.getOperand(0);
// Simplify:
// BasePtr = Ptr
// Index = X + splat(Offset)
// ->
// BasePtr = Ptr + Offset * scale.
// Index = X
if (Index.getOpcode() == ISD::ADD) {
if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
Stride = cast<ConstantSDNode>(StepVector.getOperand(0))->getSExtValue();
Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
Index = Index.getOperand(0);
return true;
}
}

// Index = shl((step(const) + splat(offset))), splat(shift))
// Simplify:
// BasePtr = Ptr
// Index = (X + splat(Offset)) << splat(Shift)
// ->
// BasePtr = Ptr + (Offset << Shift) * scale)
// Index = X << splat(shift)
if (Index.getOpcode() == ISD::SHL &&
Index.getOperand(0).getOpcode() == ISD::ADD &&
Index.getOperand(0).getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
Index.getOperand(0).getOpcode() == ISD::ADD) {
SDValue Add = Index.getOperand(0);
SDValue ShiftOp = Index.getOperand(1);
SDValue StepOp = Add.getOperand(0);
SDValue OffsetOp = Add.getOperand(1);
if (auto *Shift =
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(ShiftOp)))
if (auto Shift = DAG.getSplatValue(ShiftOp))
if (auto Offset = DAG.getSplatValue(OffsetOp)) {
int64_t Step =
cast<ConstantSDNode>(StepOp.getOperand(0))->getSExtValue();
// Stride does not scale explicitly by 'Scale', because it happens in
// the gather/scatter addressing mode.
Stride = Step << Shift->getSExtValue();
// BasePtr = BasePtr + ((Offset * Scale) << Shift)
Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, SDValue(Shift, 0));
Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, Shift);
Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
Index = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
Add.getOperand(0), ShiftOp);
return true;
}
}

return false;
}

// Analyse the specified address returning true if a more optimal addressing
// mode is available. When returning true all parameters are updated to reflect
// their recommended values.
static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
SDValue &BasePtr, SDValue &Index,
SelectionDAG &DAG) {
// Only consider element types that are pointer sized as smaller types can
// be easily promoted.
EVT IndexVT = Index.getValueType();
if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
return false;

// Try to iteratively fold parts of the index into the base pointer to
// simplify the index as much as possible.
SDValue NewBasePtr = BasePtr, NewIndex = Index;
while (foldIndexIntoBase(NewBasePtr, NewIndex, N->getScale(), SDLoc(N), DAG))
;

// Match:
// Index = step(const)
int64_t Stride = 0;
if (NewIndex.getOpcode() == ISD::STEP_VECTOR)
Stride = cast<ConstantSDNode>(NewIndex.getOperand(0))->getSExtValue();

// Match:
// Index = step(const) << shift(const)
else if (NewIndex.getOpcode() == ISD::SHL &&
NewIndex.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
SDValue RHS = NewIndex.getOperand(1);
if (auto *Shift =
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(RHS))) {
int64_t Step = (int64_t)NewIndex.getOperand(0).getConstantOperandVal(1);
Stride = Step << Shift->getZExtValue();
}
}

// Return early because no supported pattern is found.
if (Stride == 0)
return false;
Expand All @@ -16545,8 +16580,11 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
return false;

EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT,
DAG.getTargetConstant(Stride, DL, MVT::i32));
// Stride does not scale explicitly by 'Scale', because it happens in
// the gather/scatter addressing mode.
Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT,
DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32));
BasePtr = NewBasePtr;
return true;
}

Expand All @@ -16566,7 +16604,7 @@ static SDValue performMaskedGatherScatterCombine(
SDValue BasePtr = MGS->getBasePtr();
ISD::MemIndexType IndexType = MGS->getIndexType();

if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG))
if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG))
return SDValue();

// Here we catch such cases early and change MGATHER's IndexType to allow
Expand Down
47 changes: 47 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
Expand Up @@ -283,7 +283,54 @@ define void @scatter_f16_index_offset_8([8 x half]* %base, i64 %offset, <vscale
ret void
}

; stepvector is hidden further behind GEP and two adds.
define void @scatter_f16_index_add_add([8 x half]* %base, i64 %offset, i64 %offset2, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
; CHECK-LABEL: scatter_f16_index_add_add:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #16
; CHECK-NEXT: add x9, x0, x2, lsl #4
; CHECK-NEXT: add x9, x9, x1, lsl #4
; CHECK-NEXT: index z1.s, #0, w8
; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw]
; CHECK-NEXT: ret
%splat.offset.ins = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%splat.offset = shufflevector <vscale x 4 x i64> %splat.offset.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%splat.offset2.ins = insertelement <vscale x 4 x i64> undef, i64 %offset2, i32 0
%splat.offset2 = shufflevector <vscale x 4 x i64> %splat.offset2.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
%add1 = add <vscale x 4 x i64> %splat.offset, %step
%add2 = add <vscale x 4 x i64> %add1, %splat.offset2
%gep = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %add2
%gep.bc = bitcast <vscale x 4 x [8 x half]*> %gep to <vscale x 4 x half*>
call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %gep.bc, i32 2, <vscale x 4 x i1> %pg)
ret void
}

; stepvector is hidden further behind GEP two adds and a shift.
define void @scatter_f16_index_add_add_mul([8 x half]* %base, i64 %offset, i64 %offset2, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
; CHECK-LABEL: scatter_f16_index_add_add_mul:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #128
; CHECK-NEXT: add x9, x0, x2, lsl #7
; CHECK-NEXT: add x9, x9, x1, lsl #7
; CHECK-NEXT: index z1.s, #0, w8
; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw]
; CHECK-NEXT: ret
%splat.offset.ins = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%splat.offset = shufflevector <vscale x 4 x i64> %splat.offset.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%splat.offset2.ins = insertelement <vscale x 4 x i64> undef, i64 %offset2, i32 0
%splat.offset2 = shufflevector <vscale x 4 x i64> %splat.offset2.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
%add1 = add <vscale x 4 x i64> %splat.offset, %step
%add2 = add <vscale x 4 x i64> %add1, %splat.offset2
%splat.const8.ins = insertelement <vscale x 4 x i64> undef, i64 8, i32 0
%splat.const8 = shufflevector <vscale x 4 x i64> %splat.const8.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%mul = mul <vscale x 4 x i64> %add2, %splat.const8
%gep = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %mul
%gep.bc = bitcast <vscale x 4 x [8 x half]*> %gep to <vscale x 4 x half*>
call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %gep.bc, i32 2, <vscale x 4 x i1> %pg)
ret void
}
attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }

declare <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*>, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
Expand Down

0 comments on commit 201e368

Please sign in to comment.