Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LSR][AArch64] Optimize chain generation based on legal addressing modes #94453

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 58 additions & 14 deletions llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
Instruction *Fixup = nullptr);
Instruction *Fixup = nullptr,
int64_t ScalableOffset = 0);

static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) {
if (isa<SCEVUnknown>(Reg) || isa<SCEVConstant>(Reg))
Expand Down Expand Up @@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
Instruction *Fixup/*= nullptr*/) {
Instruction *Fixup /* = nullptr */,
int64_t ScalableOffset) {
switch (Kind) {
case LSRUse::Address:
return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset,
HasBaseReg, Scale, AccessTy.AddrSpace, Fixup);
HasBaseReg, Scale, AccessTy.AddrSpace,
Fixup, ScalableOffset);

case LSRUse::ICmpZero:
// There's not even a target hook for querying whether it would be legal to
// fold a GV into an ICmp.
if (BaseGV)
if (BaseGV || ScalableOffset != 0)
return false;

// ICmp only has two operands; don't allow more than two non-trivial parts.
Expand Down Expand Up @@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,

case LSRUse::Basic:
// Only handle single-register values.
return !BaseGV && Scale == 0 && BaseOffset == 0;
return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0;

case LSRUse::Special:
// Special case Basic to handle -1 scales.
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0;
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 &&
ScalableOffset == 0;
}

llvm_unreachable("Invalid LSRUse Kind!");
Expand Down Expand Up @@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg) {
bool HasBaseReg, int64_t ScalableOffset = 0) {
// Fast-path: zero is always foldable.
if (BaseOffset == 0 && !BaseGV) return true;

Expand All @@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
}

return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset,
HasBaseReg, Scale);
HasBaseReg, Scale, nullptr, ScalableOffset);
}

static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
Expand Down Expand Up @@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) {
static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
Value *Operand, const TargetTransformInfo &TTI) {
const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr);
if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
return false;
int64_t IncOffset = 0;
int64_t ScalableOffset = 0;
if (IncConst) {
if (IncConst && IncConst->getAPInt().getSignificantBits() > 64)
return false;
IncOffset = IncConst->getValue()->getSExtValue();
} else {
// Look for mul(vscale, constant), to detect ScalableOffset.
auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
if (!IncVScale || IncVScale->getNumOperands() != 2 ||
!isa<SCEVVScale>(IncVScale->getOperand(1)))
return false;
auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
return false;
ScalableOffset = Scale->getValue()->getSExtValue();
}

if (IncConst->getAPInt().getSignificantBits() > 64)
if (!isAddressUse(TTI, UserInst, Operand))
return false;

MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
int64_t IncOffset = IncConst->getValue()->getSExtValue();
if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr,
IncOffset, /*HasBaseReg=*/false))
IncOffset, /*HasBaseReg=*/false, ScalableOffset))
return false;

return true;
Expand Down Expand Up @@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
Type *IVTy = IVSrc->getType();
Type *IntTy = SE.getEffectiveSCEVType(IVTy);
const SCEV *LeftOverExpr = nullptr;
const SCEV *Accum = SE.getZero(IntTy);
SmallVector<std::pair<const SCEV *, Value *>> Bases;
Bases.emplace_back(Accum, IVSrc);

for (const IVInc &Inc : Chain) {
Instruction *InsertPt = Inc.UserInst;
if (isa<PHINode>(InsertPt))
Expand All @@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// IncExpr was the result of subtraction of two narrow values, so must
// be signed.
const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy);
Accum = SE.getAddExpr(Accum, IncExpr);
LeftOverExpr = LeftOverExpr ?
SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr;
}
if (LeftOverExpr && !LeftOverExpr->isZero()) {

// Look through each base to see if any can produce a nice addressing mode.
bool FoundBase = false;
for (auto [MapScev, MapIVOper] : reverse(Bases)) {
const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev);
if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) {
if (!Remainder->isZero()) {
Rewriter.clearPostInc();
Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt);
const SCEV *IVOperExpr =
SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV));
IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt);
} else {
IVOper = MapIVOper;
}

FoundBase = true;
break;
}
}
if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) {
// Expand the IV increment.
Rewriter.clearPostInc();
Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt);
Expand All @@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// If an IV increment can't be folded, use it as the next IV value.
if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) {
assert(IVTy == IVOper->getType() && "inconsistent IV increment type");
Bases.emplace_back(Accum, IVOper);
IVSrc = IVOper;
LeftOverExpr = nullptr;
}
Expand Down
86 changes: 40 additions & 46 deletions llvm/test/CodeGen/AArch64/sve-lsrchain.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,22 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
; CHECK-NEXT: // %bb.2: // %for.body.us.preheader
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: add x11, x2, x11, lsl #1
; CHECK-NEXT: mov x12, #-16 // =0xfffffffffffffff0
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: mov w8, wzr
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: mov x9, xzr
; CHECK-NEXT: mov w10, wzr
; CHECK-NEXT: addvl x12, x12, #1
; CHECK-NEXT: mov x13, #4 // =0x4
; CHECK-NEXT: mov x14, #8 // =0x8
; CHECK-NEXT: mov x12, #4 // =0x4
; CHECK-NEXT: mov x13, #8 // =0x8
; CHECK-NEXT: .LBB0_3: // %for.body.us
; CHECK-NEXT: // =>This Loop Header: Depth=1
; CHECK-NEXT: // Child Loop BB0_4 Depth 2
; CHECK-NEXT: add x15, x0, x9, lsl #2
; CHECK-NEXT: sbfiz x16, x8, #1, #32
; CHECK-NEXT: mov x17, x2
; CHECK-NEXT: ldp s0, s1, [x15]
; CHECK-NEXT: add x16, x16, #8
; CHECK-NEXT: ldp s2, s3, [x15, #8]
; CHECK-NEXT: ubfiz x15, x8, #1, #32
; CHECK-NEXT: add x14, x0, x9, lsl #2
; CHECK-NEXT: sbfiz x15, x8, #1, #32
; CHECK-NEXT: mov x16, x2
; CHECK-NEXT: ldp s0, s1, [x14]
; CHECK-NEXT: add x15, x15, #8
; CHECK-NEXT: ldp s2, s3, [x14, #8]
; CHECK-NEXT: ubfiz x14, x8, #1, #32
; CHECK-NEXT: fcvt h0, s0
; CHECK-NEXT: fcvt h1, s1
; CHECK-NEXT: fcvt h2, s2
Expand All @@ -43,56 +41,52 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
; CHECK-NEXT: .LBB0_4: // %for.cond.i.preheader.us
; CHECK-NEXT: // Parent Loop BB0_3 Depth=1
; CHECK-NEXT: // => This Inner Loop Header: Depth=2
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x17, x15]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17]
; CHECK-NEXT: add x18, x17, x16
; CHECK-NEXT: add x3, x17, x15
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x16, x14]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16]
; CHECK-NEXT: add x17, x16, x15
; CHECK-NEXT: add x18, x16, x14
; CHECK-NEXT: add x3, x17, #8
; CHECK-NEXT: add x4, x17, #16
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x17, x16]
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x16, x15]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x12, lsl #1]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x13, lsl #1]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #1, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #1, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #1, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #1, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #2, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #2, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #1, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #2, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #2, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #3, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #3, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #2, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #3, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #3, mul vl]
; CHECK-NEXT: addvl x17, x17, #4
; CHECK-NEXT: cmp x17, x11
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #3, mul vl]
; CHECK-NEXT: addvl x16, x16, #4
; CHECK-NEXT: cmp x16, x11
; CHECK-NEXT: b.lo .LBB0_4
; CHECK-NEXT: // %bb.5: // %while.cond.i..exit_crit_edge.us
; CHECK-NEXT: // in Loop: Header=BB0_3 Depth=1
Expand Down
Loading