Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LSR][AArch64] Optimize chain generation based on legal addressing modes #94453

Merged
merged 1 commit into from
Jun 10, 2024

Conversation

davemgreen
Copy link
Collaborator

LSR will generate chains of related instructions with a known increment between them. With SVE, in the case of the test case, this can include increments like 'vscale * 16 + 8'. The idea of this patch is if we have a '+8' increment already calculated in the chain, we can generate a (legal) '+ vscale*16' addressing mode from it, allowing us to use the '[x16, #1, mul vl]' addressing mode instructions.

In order to do this we keep track of the known 'bases' when generating chains in GenerateIVChain, checking for each if the accumulated increment expression from the base neatly folds into a legal addressing mode. If they do not we fall back to the existing LeftOverExpr, whether it is legal or not.

This is mostly orthogonal to #88124, dealing with the generation of chains as opposed to rest of LSR. The existing vscale addressing mode work has greatly helped compared to the last time I looked at this, allowing us to check that the addressing modes are indeed legal.

…des.

LSR will generate chains of related instructions with a known increment between
them. With SVE, in the case of the test case, this can include increments like
'vscale * 16 + 8'.  The idea of this patch is if we have a '+8' increment
already calculated in the chain, we can generate a (legal) '+ vscale*16'
addressing mode from it, allowing us to use the '[x16, llvm#1, mul vl]' addressing
mode instructions.

In order to do this we keep track of the known 'bases' when generating chains
in GenerateIVChain, checking for each if the accumulated increment expression
neatly folds into a legal addressing mode. If they do not we fall back to the
existing LeftOverExpr, whether it is legal or not.

This is mostly orthogonal to llvm#88124, dealing with the generation of chains as
opposed to rest of LSR. The existing vscale addressing mode work has greatly
helped compared to the last time I looked at this, allowing us to check that
the addressing modes are indeed legal.
@llvmbot
Copy link
Collaborator

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-transforms

Author: David Green (davemgreen)

Changes

LSR will generate chains of related instructions with a known increment between them. With SVE, in the case of the test case, this can include increments like 'vscale * 16 + 8'. The idea of this patch is if we have a '+8' increment already calculated in the chain, we can generate a (legal) '+ vscale*16' addressing mode from it, allowing us to use the '[x16, #1, mul vl]' addressing mode instructions.

In order to do this we keep track of the known 'bases' when generating chains in GenerateIVChain, checking for each if the accumulated increment expression from the base neatly folds into a legal addressing mode. If they do not we fall back to the existing LeftOverExpr, whether it is legal or not.

This is mostly orthogonal to #88124, dealing with the generation of chains as opposed to rest of LSR. The existing vscale addressing mode work has greatly helped compared to the last time I looked at this, allowing us to check that the addressing modes are indeed legal.


Full diff: https://github.com/llvm/llvm-project/pull/94453.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp (+58-14)
  • (modified) llvm/test/CodeGen/AArch64/sve-lsrchain.ll (+40-46)
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 35a17d6060c94..2ff3e5de7656d 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
                                  LSRUse::KindType Kind, MemAccessTy AccessTy,
                                  GlobalValue *BaseGV, int64_t BaseOffset,
                                  bool HasBaseReg, int64_t Scale,
-                                 Instruction *Fixup = nullptr);
+                                 Instruction *Fixup = nullptr,
+                                 int64_t ScalableOffset = 0);
 
 static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) {
   if (isa<SCEVUnknown>(Reg) || isa<SCEVConstant>(Reg))
@@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
                                  LSRUse::KindType Kind, MemAccessTy AccessTy,
                                  GlobalValue *BaseGV, int64_t BaseOffset,
                                  bool HasBaseReg, int64_t Scale,
-                                 Instruction *Fixup/*= nullptr*/) {
+                                 Instruction *Fixup /* = nullptr */,
+                                 int64_t ScalableOffset) {
   switch (Kind) {
   case LSRUse::Address:
     return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset,
-                                     HasBaseReg, Scale, AccessTy.AddrSpace, Fixup);
+                                     HasBaseReg, Scale, AccessTy.AddrSpace,
+                                     Fixup, ScalableOffset);
 
   case LSRUse::ICmpZero:
     // There's not even a target hook for querying whether it would be legal to
     // fold a GV into an ICmp.
-    if (BaseGV)
+    if (BaseGV || ScalableOffset != 0)
       return false;
 
     // ICmp only has two operands; don't allow more than two non-trivial parts.
@@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
 
   case LSRUse::Basic:
     // Only handle single-register values.
-    return !BaseGV && Scale == 0 && BaseOffset == 0;
+    return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0;
 
   case LSRUse::Special:
     // Special case Basic to handle -1 scales.
-    return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0;
+    return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 &&
+           ScalableOffset == 0;
   }
 
   llvm_unreachable("Invalid LSRUse Kind!");
@@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
 static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
                              LSRUse::KindType Kind, MemAccessTy AccessTy,
                              GlobalValue *BaseGV, int64_t BaseOffset,
-                             bool HasBaseReg) {
+                             bool HasBaseReg, int64_t ScalableOffset = 0) {
   // Fast-path: zero is always foldable.
   if (BaseOffset == 0 && !BaseGV) return true;
 
@@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
   }
 
   return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset,
-                              HasBaseReg, Scale);
+                              HasBaseReg, Scale, nullptr, ScalableOffset);
 }
 
 static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
@@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) {
 static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
                              Value *Operand, const TargetTransformInfo &TTI) {
   const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr);
-  if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
-    return false;
+  int64_t IncOffset = 0;
+  int64_t ScalableOffset = 0;
+  if (IncConst) {
+    if (IncConst && IncConst->getAPInt().getSignificantBits() > 64)
+      return false;
+    IncOffset = IncConst->getValue()->getSExtValue();
+  } else {
+    // Look for mul(vscale, constant), to detect ScalableOffset.
+    auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
+    if (!IncVScale || IncVScale->getNumOperands() != 2 ||
+        !isa<SCEVVScale>(IncVScale->getOperand(1)))
+      return false;
+    auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
+    if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
+      return false;
+    ScalableOffset = Scale->getValue()->getSExtValue();
+  }
 
-  if (IncConst->getAPInt().getSignificantBits() > 64)
+  if (!isAddressUse(TTI, UserInst, Operand))
     return false;
 
   MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
-  int64_t IncOffset = IncConst->getValue()->getSExtValue();
   if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr,
-                        IncOffset, /*HasBaseReg=*/false))
+                        IncOffset, /*HasBaseReg=*/false, ScalableOffset))
     return false;
 
   return true;
@@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
   Type *IVTy = IVSrc->getType();
   Type *IntTy = SE.getEffectiveSCEVType(IVTy);
   const SCEV *LeftOverExpr = nullptr;
+  const SCEV *Accum = SE.getZero(IntTy);
+  SmallVector<std::pair<const SCEV *, Value *>> Bases;
+  Bases.emplace_back(Accum, IVSrc);
+
   for (const IVInc &Inc : Chain) {
     Instruction *InsertPt = Inc.UserInst;
     if (isa<PHINode>(InsertPt))
@@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
       // IncExpr was the result of subtraction of two narrow values, so must
       // be signed.
       const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy);
+      Accum = SE.getAddExpr(Accum, IncExpr);
       LeftOverExpr = LeftOverExpr ?
         SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr;
     }
-    if (LeftOverExpr && !LeftOverExpr->isZero()) {
+
+    // Look through each base to see if any can produce a nice addressing mode.
+    bool FoundBase = false;
+    for (auto [MapScev, MapIVOper] : reverse(Bases)) {
+      const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev);
+      if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) {
+        if (!Remainder->isZero()) {
+          Rewriter.clearPostInc();
+          Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt);
+          const SCEV *IVOperExpr =
+              SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV));
+          IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt);
+        } else {
+          IVOper = MapIVOper;
+        }
+
+        FoundBase = true;
+        break;
+      }
+    }
+    if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) {
       // Expand the IV increment.
       Rewriter.clearPostInc();
       Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt);
@@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
       // If an IV increment can't be folded, use it as the next IV value.
       if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) {
         assert(IVTy == IVOper->getType() && "inconsistent IV increment type");
+        Bases.emplace_back(Accum, IVOper);
         IVSrc = IVOper;
         LeftOverExpr = nullptr;
       }
diff --git a/llvm/test/CodeGen/AArch64/sve-lsrchain.ll b/llvm/test/CodeGen/AArch64/sve-lsrchain.ll
index 9c7bffb921ce2..1931cfc2ef51d 100644
--- a/llvm/test/CodeGen/AArch64/sve-lsrchain.ll
+++ b/llvm/test/CodeGen/AArch64/sve-lsrchain.ll
@@ -14,24 +14,22 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
 ; CHECK-NEXT:  // %bb.2: // %for.body.us.preheader
 ; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    add x11, x2, x11, lsl #1
-; CHECK-NEXT:    mov x12, #-16 // =0xfffffffffffffff0
-; CHECK-NEXT:    ptrue p1.b
 ; CHECK-NEXT:    mov w8, wzr
+; CHECK-NEXT:    ptrue p1.b
 ; CHECK-NEXT:    mov x9, xzr
 ; CHECK-NEXT:    mov w10, wzr
-; CHECK-NEXT:    addvl x12, x12, #1
-; CHECK-NEXT:    mov x13, #4 // =0x4
-; CHECK-NEXT:    mov x14, #8 // =0x8
+; CHECK-NEXT:    mov x12, #4 // =0x4
+; CHECK-NEXT:    mov x13, #8 // =0x8
 ; CHECK-NEXT:  .LBB0_3: // %for.body.us
 ; CHECK-NEXT:    // =>This Loop Header: Depth=1
 ; CHECK-NEXT:    // Child Loop BB0_4 Depth 2
-; CHECK-NEXT:    add x15, x0, x9, lsl #2
-; CHECK-NEXT:    sbfiz x16, x8, #1, #32
-; CHECK-NEXT:    mov x17, x2
-; CHECK-NEXT:    ldp s0, s1, [x15]
-; CHECK-NEXT:    add x16, x16, #8
-; CHECK-NEXT:    ldp s2, s3, [x15, #8]
-; CHECK-NEXT:    ubfiz x15, x8, #1, #32
+; CHECK-NEXT:    add x14, x0, x9, lsl #2
+; CHECK-NEXT:    sbfiz x15, x8, #1, #32
+; CHECK-NEXT:    mov x16, x2
+; CHECK-NEXT:    ldp s0, s1, [x14]
+; CHECK-NEXT:    add x15, x15, #8
+; CHECK-NEXT:    ldp s2, s3, [x14, #8]
+; CHECK-NEXT:    ubfiz x14, x8, #1, #32
 ; CHECK-NEXT:    fcvt h0, s0
 ; CHECK-NEXT:    fcvt h1, s1
 ; CHECK-NEXT:    fcvt h2, s2
@@ -43,56 +41,52 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
 ; CHECK-NEXT:  .LBB0_4: // %for.cond.i.preheader.us
 ; CHECK-NEXT:    // Parent Loop BB0_3 Depth=1
 ; CHECK-NEXT:    // => This Inner Loop Header: Depth=2
-; CHECK-NEXT:    ld1b { z4.b }, p1/z, [x17, x15]
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17]
-; CHECK-NEXT:    add x18, x17, x16
-; CHECK-NEXT:    add x3, x17, x15
+; CHECK-NEXT:    ld1b { z4.b }, p1/z, [x16, x14]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x16]
+; CHECK-NEXT:    add x17, x16, x15
+; CHECK-NEXT:    add x18, x16, x14
+; CHECK-NEXT:    add x3, x17, #8
+; CHECK-NEXT:    add x4, x17, #16
 ; CHECK-NEXT:    fmad z4.h, p0/m, z0.h, z5.h
-; CHECK-NEXT:    ld1b { z5.b }, p1/z, [x17, x16]
+; CHECK-NEXT:    ld1b { z5.b }, p1/z, [x16, x15]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z1.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, x12, lsl #1]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z2.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
-; CHECK-NEXT:    add x18, x18, #16
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, x13, lsl #1]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z3.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, #1, mul vl]
-; CHECK-NEXT:    st1h { z4.h }, p0, [x17]
-; CHECK-NEXT:    ld1h { z4.h }, p0/z, [x3, #1, mul vl]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x16, #1, mul vl]
+; CHECK-NEXT:    st1h { z4.h }, p0, [x16]
+; CHECK-NEXT:    ld1h { z4.h }, p0/z, [x18, #1, mul vl]
 ; CHECK-NEXT:    fmad z4.h, p0/m, z0.h, z5.h
-; CHECK-NEXT:    ld1b { z5.b }, p1/z, [x18, x12]
-; CHECK-NEXT:    add x18, x18, x12
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, #1, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z1.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x3, #1, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z2.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
-; CHECK-NEXT:    add x18, x18, #16
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x4, #1, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z3.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, #2, mul vl]
-; CHECK-NEXT:    st1h { z4.h }, p0, [x17, #1, mul vl]
-; CHECK-NEXT:    ld1h { z4.h }, p0/z, [x3, #2, mul vl]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x16, #2, mul vl]
+; CHECK-NEXT:    st1h { z4.h }, p0, [x16, #1, mul vl]
+; CHECK-NEXT:    ld1h { z4.h }, p0/z, [x18, #2, mul vl]
 ; CHECK-NEXT:    fmad z4.h, p0/m, z0.h, z5.h
-; CHECK-NEXT:    ld1b { z5.b }, p1/z, [x18, x12]
-; CHECK-NEXT:    add x18, x18, x12
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, #2, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z1.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x3, #2, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z2.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
-; CHECK-NEXT:    add x18, x18, #16
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x4, #2, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z3.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, #3, mul vl]
-; CHECK-NEXT:    st1h { z4.h }, p0, [x17, #2, mul vl]
-; CHECK-NEXT:    ld1h { z4.h }, p0/z, [x3, #3, mul vl]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x16, #3, mul vl]
+; CHECK-NEXT:    st1h { z4.h }, p0, [x16, #2, mul vl]
+; CHECK-NEXT:    ld1h { z4.h }, p0/z, [x18, #3, mul vl]
 ; CHECK-NEXT:    fmad z4.h, p0/m, z0.h, z5.h
-; CHECK-NEXT:    ld1b { z5.b }, p1/z, [x18, x12]
-; CHECK-NEXT:    add x18, x18, x12
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x17, #3, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z1.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x3, #3, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z2.h
-; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
+; CHECK-NEXT:    ld1h { z5.h }, p0/z, [x4, #3, mul vl]
 ; CHECK-NEXT:    fmla z4.h, p0/m, z5.h, z3.h
-; CHECK-NEXT:    st1h { z4.h }, p0, [x17, #3, mul vl]
-; CHECK-NEXT:    addvl x17, x17, #4
-; CHECK-NEXT:    cmp x17, x11
+; CHECK-NEXT:    st1h { z4.h }, p0, [x16, #3, mul vl]
+; CHECK-NEXT:    addvl x16, x16, #4
+; CHECK-NEXT:    cmp x16, x11
 ; CHECK-NEXT:    b.lo .LBB0_4
 ; CHECK-NEXT:  // %bb.5: // %while.cond.i..exit_crit_edge.us
 ; CHECK-NEXT:    // in Loop: Header=BB0_3 Depth=1

@davemgreen davemgreen merged commit c7308d4 into llvm:main Jun 10, 2024
10 checks passed
@davemgreen davemgreen deleted the gh-lsr-svechain branch June 10, 2024 19:35
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Jun 12, 2024
…des (llvm#94453)

LSR will generate chains of related instructions with a known increment
between them. With SVE, in the case of the test case, this can include
increments like 'vscale * 16 + 8'. The idea of this patch is if we have
a '+8' increment already calculated in the chain, we can generate a
(legal) '+ vscale*16' addressing mode from it, allowing us to use the
'[x16, llvm#1, mul vl]' addressing mode instructions.

In order to do this we keep track of the known 'bases' when generating
chains in GenerateIVChain, checking for each if the accumulated
increment expression from the base neatly folds into a legal addressing
mode. If they do not we fall back to the existing LeftOverExpr, whether
it is legal or not.

This is mostly orthogonal to llvm#88124, dealing with the generation of
chains as opposed to rest of LSR. The existing vscale addressing mode
work has greatly helped compared to the last time I looked at this,
allowing us to check that the addressing modes are indeed legal.
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants