Skip to content

Commit

Permalink
[AArch64][SVE] Add more folds to make use of gather/scatter with 32-b…
Browse files Browse the repository at this point in the history
…it indices

In AArch64ISelLowering.cpp this patch implements this fold:

1) GEP (%ptr, SHL ((stepvector(A) + splat(%offset))) << splat(B)))
into GEP (%ptr + (%offset << B), step_vector (A << B))

The above transform simplifies the index operand so that it can be expressed
as i32 elements.
This allows using only one gather/scatter assembly instruction instead of two.

Patch by Paul Walker (@paulwalker-arm).

Depends on D117900

Differential Revision: https://reviews.llvm.org/D118345
  • Loading branch information
CarolineConcatto committed Feb 3, 2022
1 parent 8ada962 commit 961e954
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
23 changes: 23 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16387,6 +16387,29 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
}
}

// Index = shl((step(const) + splat(offset))), splat(shift))
if (Index.getOpcode() == ISD::SHL &&
Index.getOperand(0).getOpcode() == ISD::ADD &&
Index.getOperand(0).getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
SDValue Add = Index.getOperand(0);
SDValue ShiftOp = Index.getOperand(1);
SDValue StepOp = Add.getOperand(0);
SDValue OffsetOp = Add.getOperand(1);
if (auto *Shift =
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(ShiftOp)))
if (auto Offset = DAG.getSplatValue(OffsetOp)) {
int64_t Step =
cast<ConstantSDNode>(StepOp.getOperand(0))->getSExtValue();
// Stride does not scale explicitly by 'Scale', because it happens in
// the gather/scatter addressing mode.
Stride = Step << Shift->getSExtValue();
// BasePtr = BasePtr + ((Offset * Scale) << Shift)
Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, SDValue(Shift, 0));
BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
}
}

// Return early because no supported pattern is found.
if (Stride == 0)
return false;
Expand Down
83 changes: 83 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,92 @@ define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4
ret void
}

; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the
; impression the gather must be split due to it's <vscale x 4 x i64> offset.
; gather_f32(base, index(offset, 8 * sizeof(float))
define <vscale x 4 x i8> @gather_8i8_index_offset_8([8 x i8]* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
; CHECK-LABEL: gather_8i8_index_offset_8:
; CHECK: // %bb.0:
; CHECK-NEXT: add x8, x0, x1, lsl #3
; CHECK-NEXT: index z0.s, #0, #8
; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw]
; CHECK-NEXT: ret
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
%t2 = add <vscale x 4 x i64> %t1, %step
%t3 = getelementptr [8 x i8], [8 x i8]* %base, <vscale x 4 x i64> %t2
%t4 = bitcast <vscale x 4 x [8 x i8]*> %t3 to <vscale x 4 x i8*>
%load = call <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*> %t4, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x i8> undef)
ret <vscale x 4 x i8> %load
}

; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the
; impression the gather must be split due to it's <vscale x 4 x i64> offset.
; gather_f32(base, index(offset, 8 * sizeof(float))
define <vscale x 4 x float> @gather_f32_index_offset_8([8 x float]* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
; CHECK-LABEL: gather_f32_index_offset_8:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #32
; CHECK-NEXT: add x9, x0, x1, lsl #5
; CHECK-NEXT: index z0.s, #0, w8
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x9, z0.s, sxtw]
; CHECK-NEXT: ret
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
%t2 = add <vscale x 4 x i64> %t1, %step
%t3 = getelementptr [8 x float], [8 x float]* %base, <vscale x 4 x i64> %t2
%t4 = bitcast <vscale x 4 x [8 x float]*> %t3 to <vscale x 4 x float*>
%load = call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*> %t4, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x float> undef)
ret <vscale x 4 x float> %load
}

; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the
; impression the scatter must be split due to it's <vscale x 4 x i64> offset.
; scatter_f16(base, index(offset, 8 * sizeof(i8))
define void @scatter_i8_index_offset_8([8 x i8]* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
; CHECK-LABEL: scatter_i8_index_offset_8:
; CHECK: // %bb.0:
; CHECK-NEXT: add x8, x0, x1, lsl #3
; CHECK-NEXT: index z1.s, #0, #8
; CHECK-NEXT: st1b { z0.s }, p0, [x8, z1.s, sxtw]
; CHECK-NEXT: ret
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
%t2 = add <vscale x 4 x i64> %t1, %step
%t3 = getelementptr [8 x i8], [8 x i8]* %base, <vscale x 4 x i64> %t2
%t4 = bitcast <vscale x 4 x [8 x i8]*> %t3 to <vscale x 4 x i8*>
call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t4, i32 2, <vscale x 4 x i1> %pg)
ret void
}

; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the
; impression the scatter must be split due to it's <vscale x 4 x i64> offset.
; scatter_f16(base, index(offset, 8 * sizeof(half))
define void @scatter_f16_index_offset_8([8 x half]* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
; CHECK-LABEL: scatter_f16_index_offset_8:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #16
; CHECK-NEXT: add x9, x0, x1, lsl #4
; CHECK-NEXT: index z1.s, #0, w8
; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw]
; CHECK-NEXT: ret
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
%t2 = add <vscale x 4 x i64> %t1, %step
%t3 = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %t2
%t4 = bitcast <vscale x 4 x [8 x half]*> %t3 to <vscale x 4 x half*>
call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %t4, i32 2, <vscale x 4 x i1> %pg)
ret void
}


attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }

declare <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*>, i32, <vscale x 4 x i1>, <vscale x 4 x float>)

declare <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*>, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
declare void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8>, <vscale x 4 x i8*>, i32, <vscale x 4 x i1>)
Expand Down

0 comments on commit 961e954

Please sign in to comment.