Skip to content

Commit

Permalink
[SVE][LSR] Teach LSR to enable simple scaled-index addressing mode ge…
Browse files Browse the repository at this point in the history
…neration for SVE.

Currently, Loop strengh reduce is not handling loops with scalable stride very well.

Take loop vectorized with scalable vector type <vscale x 8 x i16> for instance,
(refer to test/CodeGen/AArch64/sve-lsr-scaled-index-addressing-mode.ll added).

Memory accesses are incremented by "16*vscale", while induction variable is incremented
by "8*vscale". The scaling factor "2" needs to be extracted to build candidate formula
i.e., "reg(%in) + 2*reg({0,+,(8 * %vscale)}". So that addrec register reg({0,+,(8*vscale)})
can be reused among Address and ICmpZero LSRUses to enable optimal solution selection.

This patch allow LSR getExactSDiv to recognize special cases like "C1*X*Y /s C2*X*Y",
and pull out "C1 /s C2" as scaling factor whenever possible. Without this change, LSR
is missing candidate formula with proper scaled factor to leverage target scaled-index
addressing mode.

Note: This patch doesn't fully fix AArch64 isLegalAddressingMode for scalable
vector. But allow simple valid scale to pass through.

Reviewed By: sdesmalen

Differential Revision: https://reviews.llvm.org/D103939
  • Loading branch information
huihzhang committed Jun 14, 2021
1 parent 7a7c007 commit 1c096bf
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 6 deletions.
8 changes: 6 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -11808,8 +11808,12 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
return false;

// FIXME: Update this method to support scalable addressing modes.
if (isa<ScalableVectorType>(Ty))
return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
if (isa<ScalableVectorType>(Ty)) {
uint64_t VecElemNumBytes =
DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
return AM.HasBaseReg && !AM.BaseOffs &&
(AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes);
}

// check reg + imm case:
// i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
Expand Down
17 changes: 16 additions & 1 deletion llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Expand Up @@ -665,7 +665,7 @@ static bool isMulSExtable(const SCEVMulExpr *M, ScalarEvolution &SE) {

/// Return an expression for LHS /s RHS, if it can be determined and if the
/// remainder is known to be zero, or null otherwise. If IgnoreSignificantBits
/// is true, expressions like (X * Y) /s Y are simplified to Y, ignoring that
/// is true, expressions like (X * Y) /s Y are simplified to X, ignoring that
/// the multiplication may overflow, which is useful when the result will be
/// used in a context where the most significant bits are ignored.
static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
Expand Down Expand Up @@ -733,6 +733,21 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
// Check for a multiply operand that we can pull RHS out of.
if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS)) {
if (IgnoreSignificantBits || isMulSExtable(Mul, SE)) {
// Handle special case C1*X*Y /s C2*X*Y.
if (const SCEVMulExpr *MulRHS = dyn_cast<SCEVMulExpr>(RHS)) {
if (IgnoreSignificantBits || isMulSExtable(MulRHS, SE)) {
const SCEVConstant *LC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
const SCEVConstant *RC =
dyn_cast<SCEVConstant>(MulRHS->getOperand(0));
if (LC && RC) {
SmallVector<const SCEV *, 4> LOps(drop_begin(Mul->operands()));
SmallVector<const SCEV *, 4> ROps(drop_begin(MulRHS->operands()));
if (LOps == ROps)
return getExactSDiv(LC, RC, SE, IgnoreSignificantBits);
}
}
}

SmallVector<const SCEV *, 4> Ops;
bool Found = false;
for (const SCEV *S : Mul->operands()) {
Expand Down
5 changes: 2 additions & 3 deletions llvm/test/CodeGen/AArch64/sve-fold-vscale.ll
@@ -1,9 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -disable-lsr < %s | FileCheck %s
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s

; Check that vscale call is recognised by load/store reg/reg pattern and
; partially folded, with the rest pulled out of the loop. This requires LSR to
; be disabled, which is something that will be addressed at a later date.
; partially folded, with the rest pulled out of the loop.

define void @ld1w_reg_loop([32000 x i32]* %addr) {
; CHECK-LABEL: ld1w_reg_loop:
Expand Down
165 changes: 165 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-lsr-scaled-index-addressing-mode.ll
@@ -0,0 +1,165 @@
; RUN: opt -S -loop-reduce < %s | FileCheck %s --check-prefix=IR
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s --check-prefix=ASM
; Note: To update this test, please run utils/update_test_checks.py and utils/update_llc_test_checks.py separately on opt/llc run line.

target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
target triple = "aarch64-linux-gnu"

; These tests check that the IR coming out of LSR does not cast input/output pointer from i16* to i8* type.
; And scaled-index addressing mode is leveraged in the generated assembly, i.e. ld1h { z1.h }, p0/z, [x0, x8, lsl #1].

define void @ld_st_nxv8i16(i16* %in, i16* %out) {
; IR-LABEL: @ld_st_nxv8i16(
; IR-NEXT: entry:
; IR-NEXT: br label [[LOOP_PH:%.*]]
; IR: loop.ph:
; IR-NEXT: [[P_VEC_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i16> undef, i16 3, i32 0
; IR-NEXT: [[P_VEC_SPLAT:%.*]] = shufflevector <vscale x 8 x i16> [[P_VEC_SPLATINSERT]], <vscale x 8 x i16> undef, <vscale x 8 x i32> zeroinitializer
; IR-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64()
; IR-NEXT: [[SCALED_VF:%.*]] = shl i64 [[VSCALE]], 3
; IR-NEXT: br label [[LOOP:%.*]]
; IR: loop:
; IR-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[LOOP_PH]] ], [ [[INDVAR_NEXT:%.*]], [[LOOP]] ]
; IR-NEXT: [[SCEVGEP2:%.*]] = getelementptr i16, i16* [[IN:%.*]], i64 [[INDVAR]]
; IR-NEXT: [[SCEVGEP23:%.*]] = bitcast i16* [[SCEVGEP2]] to <vscale x 8 x i16>*
; IR-NEXT: [[SCEVGEP:%.*]] = getelementptr i16, i16* [[OUT:%.*]], i64 [[INDVAR]]
; IR-NEXT: [[SCEVGEP1:%.*]] = bitcast i16* [[SCEVGEP]] to <vscale x 8 x i16>*
; IR-NEXT: [[VAL:%.*]] = load <vscale x 8 x i16>, <vscale x 8 x i16>* [[SCEVGEP23]], align 16
; IR-NEXT: [[ADDP_VEC:%.*]] = add <vscale x 8 x i16> [[VAL]], [[P_VEC_SPLAT]]
; IR-NEXT: store <vscale x 8 x i16> [[ADDP_VEC]], <vscale x 8 x i16>* [[SCEVGEP1]], align 16
; IR-NEXT: [[INDVAR_NEXT]] = add nsw i64 [[INDVAR]], [[SCALED_VF]]
; IR-NEXT: [[EXIT_COND:%.*]] = icmp eq i64 [[INDVAR_NEXT]], 1024
; IR-NEXT: br i1 [[EXIT_COND]], label [[LOOP_EXIT:%.*]], label [[LOOP]]
; IR: loop.exit:
; IR-NEXT: br label [[EXIT:%.*]]
; IR: exit:
; IR-NEXT: ret void
;
; ASM-LABEL: ld_st_nxv8i16:
; ASM: // %bb.0: // %entry
; ASM-NEXT: mov x8, xzr
; ASM-NEXT: mov z0.h, #3 // =0x3
; ASM-NEXT: cnth x9
; ASM-NEXT: ptrue p0.h
; ASM-NEXT: .LBB0_1: // %loop
; ASM-NEXT: // =>This Inner Loop Header: Depth=1
; ASM-NEXT: ld1h { z1.h }, p0/z, [x0, x8, lsl #1]
; ASM-NEXT: add z1.h, z1.h, z0.h
; ASM-NEXT: st1h { z1.h }, p0, [x1, x8, lsl #1]
; ASM-NEXT: add x8, x8, x9
; ASM-NEXT: cmp x8, #1024 // =1024
; ASM-NEXT: b.ne .LBB0_1
; ASM-NEXT: // %bb.2: // %exit
; ASM-NEXT: ret
entry:
br label %loop.ph

loop.ph:
%p_vec.splatinsert = insertelement <vscale x 8 x i16> undef, i16 3, i32 0
%p_vec.splat = shufflevector <vscale x 8 x i16> %p_vec.splatinsert, <vscale x 8 x i16> undef, <vscale x 8 x i32> zeroinitializer
%vscale = call i64 @llvm.vscale.i64()
%scaled_vf = shl i64 %vscale, 3
br label %loop

loop: ; preds = %loop, %loop.ph
%indvar = phi i64 [ 0, %loop.ph ], [ %indvar.next, %loop ]
%ptr.in = getelementptr inbounds i16, i16* %in, i64 %indvar
%ptr.out = getelementptr inbounds i16, i16* %out, i64 %indvar
%in.ptrcast = bitcast i16* %ptr.in to <vscale x 8 x i16>*
%out.ptrcast = bitcast i16* %ptr.out to <vscale x 8 x i16>*
%val = load <vscale x 8 x i16>, <vscale x 8 x i16>* %in.ptrcast, align 16
%addp_vec = add <vscale x 8 x i16> %val, %p_vec.splat
store <vscale x 8 x i16> %addp_vec, <vscale x 8 x i16>* %out.ptrcast, align 16
%indvar.next = add nsw i64 %indvar, %scaled_vf
%exit.cond = icmp eq i64 %indvar.next, 1024
br i1 %exit.cond, label %loop.exit, label %loop

loop.exit: ; preds = %loop
br label %exit

exit:
ret void
}

define void @masked_ld_st_nxv8i16(i16* %in, i16* %out, i64 %n) {
; IR-LABEL: @masked_ld_st_nxv8i16(
; IR-NEXT: entry:
; IR-NEXT: br label [[LOOP_PH:%.*]]
; IR: loop.ph:
; IR-NEXT: [[P_VEC_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i16> undef, i16 3, i32 0
; IR-NEXT: [[P_VEC_SPLAT:%.*]] = shufflevector <vscale x 8 x i16> [[P_VEC_SPLATINSERT]], <vscale x 8 x i16> undef, <vscale x 8 x i32> zeroinitializer
; IR-NEXT: [[PTRUE_VEC_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i1> undef, i1 true, i32 0
; IR-NEXT: [[PTRUE_VEC_SPLAT:%.*]] = shufflevector <vscale x 8 x i1> [[PTRUE_VEC_SPLATINSERT]], <vscale x 8 x i1> undef, <vscale x 8 x i32> zeroinitializer
; IR-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64()
; IR-NEXT: [[SCALED_VF:%.*]] = shl i64 [[VSCALE]], 3
; IR-NEXT: br label [[LOOP:%.*]]
; IR: loop:
; IR-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[LOOP_PH]] ], [ [[INDVAR_NEXT:%.*]], [[LOOP]] ]
; IR-NEXT: [[SCEVGEP2:%.*]] = getelementptr i16, i16* [[IN:%.*]], i64 [[INDVAR]]
; IR-NEXT: [[SCEVGEP23:%.*]] = bitcast i16* [[SCEVGEP2]] to <vscale x 8 x i16>*
; IR-NEXT: [[SCEVGEP:%.*]] = getelementptr i16, i16* [[OUT:%.*]], i64 [[INDVAR]]
; IR-NEXT: [[SCEVGEP1:%.*]] = bitcast i16* [[SCEVGEP]] to <vscale x 8 x i16>*
; IR-NEXT: [[VAL:%.*]] = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0nxv8i16(<vscale x 8 x i16>* [[SCEVGEP23]], i32 4, <vscale x 8 x i1> [[PTRUE_VEC_SPLAT]], <vscale x 8 x i16> undef)
; IR-NEXT: [[ADDP_VEC:%.*]] = add <vscale x 8 x i16> [[VAL]], [[P_VEC_SPLAT]]
; IR-NEXT: call void @llvm.masked.store.nxv8i16.p0nxv8i16(<vscale x 8 x i16> [[ADDP_VEC]], <vscale x 8 x i16>* [[SCEVGEP1]], i32 4, <vscale x 8 x i1> [[PTRUE_VEC_SPLAT]])
; IR-NEXT: [[INDVAR_NEXT]] = add nsw i64 [[INDVAR]], [[SCALED_VF]]
; IR-NEXT: [[EXIT_COND:%.*]] = icmp eq i64 [[N:%.*]], [[INDVAR_NEXT]]
; IR-NEXT: br i1 [[EXIT_COND]], label [[LOOP_EXIT:%.*]], label [[LOOP]]
; IR: loop.exit:
; IR-NEXT: br label [[EXIT:%.*]]
; IR: exit:
; IR-NEXT: ret void
;
; ASM-LABEL: masked_ld_st_nxv8i16:
; ASM: // %bb.0: // %entry
; ASM-NEXT: mov x8, xzr
; ASM-NEXT: mov z0.h, #3 // =0x3
; ASM-NEXT: ptrue p0.h
; ASM-NEXT: cnth x9
; ASM-NEXT: .LBB1_1: // %loop
; ASM-NEXT: // =>This Inner Loop Header: Depth=1
; ASM-NEXT: ld1h { z1.h }, p0/z, [x0, x8, lsl #1]
; ASM-NEXT: add z1.h, z1.h, z0.h
; ASM-NEXT: st1h { z1.h }, p0, [x1, x8, lsl #1]
; ASM-NEXT: add x8, x8, x9
; ASM-NEXT: cmp x2, x8
; ASM-NEXT: b.ne .LBB1_1
; ASM-NEXT: // %bb.2: // %exit
; ASM-NEXT: ret
entry:
br label %loop.ph

loop.ph:
%p_vec.splatinsert = insertelement <vscale x 8 x i16> undef, i16 3, i32 0
%p_vec.splat = shufflevector <vscale x 8 x i16> %p_vec.splatinsert, <vscale x 8 x i16> undef, <vscale x 8 x i32> zeroinitializer
%ptrue_vec.splatinsert = insertelement <vscale x 8 x i1> undef, i1 true, i32 0
%ptrue_vec.splat = shufflevector <vscale x 8 x i1> %ptrue_vec.splatinsert, <vscale x 8 x i1> undef, <vscale x 8 x i32> zeroinitializer
%vscale = call i64 @llvm.vscale.i64()
%scaled_vf = shl i64 %vscale, 3
br label %loop

loop: ; preds = %loop, %loop.ph
%indvar = phi i64 [ 0, %loop.ph ], [ %indvar.next, %loop ]
%ptr.in = getelementptr inbounds i16, i16* %in, i64 %indvar
%ptr.out = getelementptr inbounds i16, i16* %out, i64 %indvar
%in.ptrcast = bitcast i16* %ptr.in to <vscale x 8 x i16>*
%out.ptrcast = bitcast i16* %ptr.out to <vscale x 8 x i16>*
%val = call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0nxv8i16(<vscale x 8 x i16>* %in.ptrcast, i32 4, <vscale x 8 x i1> %ptrue_vec.splat, <vscale x 8 x i16> undef)
%addp_vec = add <vscale x 8 x i16> %val, %p_vec.splat
call void @llvm.masked.store.nxv8i16.p0nxv8i16(<vscale x 8 x i16> %addp_vec, <vscale x 8 x i16>* %out.ptrcast, i32 4, <vscale x 8 x i1> %ptrue_vec.splat)
%indvar.next = add nsw i64 %indvar, %scaled_vf
%exit.cond = icmp eq i64 %indvar.next, %n
br i1 %exit.cond, label %loop.exit, label %loop

loop.exit: ; preds = %loop
br label %exit

exit:
ret void
}

declare i64 @llvm.vscale.i64()

declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0nxv8i16(<vscale x 8 x i16>*, i32 immarg, <vscale x 8 x i1>, <vscale x 8 x i16>)

declare void @llvm.masked.store.nxv8i16.p0nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>*, i32 immarg, <vscale x 8 x i1>)

0 comments on commit 1c096bf

Please sign in to comment.