Skip to content

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Sep 8, 2025

Currently this fold only supports a single GEP. However, in ptradd representation, it may be split across multiple GEPs. In particular, PR #151333 will split off constant offset GEPs.

To support this, add a new helper decomposeLinearExpression(), which decomposes a pointer into a linear expression of the form BasePtr + Index * Scale + Offset.

I plan to also extend this helper to look through mul/shl on the index and use it in more places that currently use collectOffset() to extract a single index * scale. This will make sure such optimizations are not affected by the ptradd migration.

Currently this fold only supports a single GEP. However, in ptradd
representation, it may be split across multiple GEPs. In particular,
PR llvm#151333 will split off constant offset GEPs.

To support this, add a new helper decomposeLinearExpression(),
which decomposes a pointer into a linear expression of the form
BasePtr + Index * Scale + Offset.

I plan to also extend this helper to look through mul/shl on the
index and use it in more places that currently use collectOffset()
to extract a single index * scale. This will make sure such
optimizations are not affected by the ptradd migration.
@nikic nikic requested review from davemgreen, dtcxzyw and XChy September 8, 2025 13:19
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Sep 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 8, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

Changes

Currently this fold only supports a single GEP. However, in ptradd representation, it may be split across multiple GEPs. In particular, PR #151333 will split off constant offset GEPs.

To support this, add a new helper decomposeLinearExpression(), which decomposes a pointer into a linear expression of the form BasePtr + Index * Scale + Offset.

I plan to also extend this helper to look through mul/shl on the index and use it in more places that currently use collectOffset() to extract a single index * scale. This will make sure such optimizations are not affected by the ptradd migration.


Full diff: https://github.com/llvm/llvm-project/pull/157447.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/Loads.h (+22)
  • (modified) llvm/lib/Analysis/Loads.cpp (+64)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+22-28)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1-1)
  • (modified) llvm/test/Transforms/InstCombine/load-cmp.ll (+59)
diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h
index 4cec8115709d6..4893c0c9934d2 100644
--- a/llvm/include/llvm/Analysis/Loads.h
+++ b/llvm/include/llvm/Analysis/Loads.h
@@ -13,7 +13,9 @@
 #ifndef LLVM_ANALYSIS_LOADS_H
 #define LLVM_ANALYSIS_LOADS_H
 
+#include "llvm/ADT/APInt.h"
 #include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/GEPNoWrapFlags.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
 
@@ -193,6 +195,26 @@ LLVM_ABI bool canReplacePointersIfEqual(const Value *From, const Value *To,
                                         const DataLayout &DL);
 LLVM_ABI bool canReplacePointersInUseIfEqual(const Use &U, const Value *To,
                                              const DataLayout &DL);
+
+/// Linear expression BasePtr + Index * Scale + Offset.
+/// Index, Scale and Offset all have the same bit width, which matches the
+/// pointer index size of BasePtr.
+/// Index may be nullptr if Scale is 0.
+struct LinearExpression {
+  Value *BasePtr;
+  Value *Index = nullptr;
+  APInt Scale;
+  APInt Offset;
+  GEPNoWrapFlags Flags = GEPNoWrapFlags::all();
+
+  LinearExpression(Value *BasePtr, unsigned BitWidth)
+      : BasePtr(BasePtr), Scale(BitWidth, 0), Offset(BitWidth, 0) {}
+};
+
+/// Decompose a pointer into a linear expression. This may look through
+/// multiple GEPs.
+LLVM_ABI LinearExpression decomposeLinearExpression(const DataLayout &DL,
+                                                    Value *Ptr);
 }
 
 #endif
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 47fdb41e8b445..0c4e3a2e3b233 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -21,6 +21,7 @@
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Operator.h"
 
@@ -876,3 +877,66 @@ bool llvm::isReadOnlyLoop(
   }
   return true;
 }
+
+LinearExpression llvm::decomposeLinearExpression(const DataLayout &DL,
+                                                 Value *Ptr) {
+  assert(Ptr->getType()->isPointerTy() && "Must be called with pointer arg");
+
+  unsigned BitWidth = DL.getIndexTypeSizeInBits(Ptr->getType());
+  LinearExpression Expr(Ptr, BitWidth);
+
+  while (true) {
+    auto *GEP = dyn_cast<GEPOperator>(Expr.BasePtr);
+    if (!GEP || GEP->getSourceElementType()->isScalableTy())
+      return Expr;
+
+    Value *VarIndex = nullptr;
+    for (Value *Index : GEP->indices()) {
+      if (isa<ConstantInt>(Index))
+        continue;
+      // Only allow a single variable index. We do not bother to handle the
+      // case of the same variable index appearing multiple times.
+      if (Expr.Index || VarIndex)
+        return Expr;
+      VarIndex = Index;
+    }
+
+    // Don't return non-canonical indexes.
+    if (VarIndex && !VarIndex->getType()->isIntegerTy(BitWidth))
+      return Expr;
+
+    // We have verified that we can fully handle this GEP, so we can update Expr
+    // members past this point.
+    Expr.BasePtr = GEP->getPointerOperand();
+    Expr.Flags = Expr.Flags.intersectForOffsetAdd(GEP->getNoWrapFlags());
+    for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
+         GTI != GTE; ++GTI) {
+      Value *Index = GTI.getOperand();
+      if (auto *ConstOffset = dyn_cast<ConstantInt>(Index)) {
+        if (ConstOffset->isZero())
+          continue;
+        if (StructType *STy = GTI.getStructTypeOrNull()) {
+          unsigned ElementIdx = ConstOffset->getZExtValue();
+          const StructLayout *SL = DL.getStructLayout(STy);
+          Expr.Offset += SL->getElementOffset(ElementIdx);
+          continue;
+        }
+        // Truncate if type size exceeds index space.
+        APInt IndexedSize(BitWidth, GTI.getSequentialElementStride(DL),
+                          /*isSigned=*/false,
+                          /*implcitTrunc=*/true);
+        Expr.Offset += ConstOffset->getValue() * IndexedSize;
+        continue;
+      }
+
+      // FIXME: Also look through a mul/shl in the index.
+      assert(Expr.Index == nullptr && "Shouldn't have index yet");
+      Expr.Index = Index;
+      // Truncate if type size exceeds index space.
+      Expr.Scale = APInt(BitWidth, GTI.getSequentialElementStride(DL),
+                         /*isSigned=*/false, /*implicitTrunc=*/true);
+    }
+  }
+
+  return Expr;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 01b0da3469c18..6a778726df729 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -19,6 +19,7 @@
 #include "llvm/Analysis/CmpInstAnalysis.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/Utils/Local.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/ConstantRange.h"
@@ -110,30 +111,27 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) {
 /// If AndCst is non-null, then the loaded value is masked with that constant
 /// before doing the comparison. This handles cases like "A[i]&4 == 0".
 Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
-    LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI,
-    ConstantInt *AndCst) {
-  if (LI->isVolatile() || !GV->isConstant() || !GV->hasDefinitiveInitializer())
+    LoadInst *LI, GetElementPtrInst *GEP, CmpInst &ICI, ConstantInt *AndCst) {
+  auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(GEP));
+  if (LI->isVolatile() || !GV || !GV->isConstant() ||
+      !GV->hasDefinitiveInitializer())
     return nullptr;
 
-  Constant *Init = GV->getInitializer();
-  TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());
   Type *EltTy = LI->getType();
   TypeSize EltSize = DL.getTypeStoreSize(EltTy);
   if (EltSize.isScalable())
     return nullptr;
 
-  unsigned IndexBW = DL.getIndexTypeSizeInBits(GEP->getType());
-  SmallMapVector<Value *, APInt, 4> VarOffsets;
-  APInt ConstOffset(IndexBW, 0);
-  if (!GEP->collectOffset(DL, IndexBW, VarOffsets, ConstOffset) ||
-      VarOffsets.size() != 1 || IndexBW > 64)
+  LinearExpression Expr = decomposeLinearExpression(DL, GEP);
+  if (!Expr.Index || Expr.BasePtr != GV || Expr.Offset.getBitWidth() > 64)
     return nullptr;
 
-  Value *Idx = VarOffsets.front().first;
-  const APInt &Stride = VarOffsets.front().second;
-  // If the index type is non-canonical, wait for it to be canonicalized.
-  if (Idx->getType()->getScalarSizeInBits() != IndexBW)
-    return nullptr;
+  Constant *Init = GV->getInitializer();
+  TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());
+
+  Value *Idx = Expr.Index;
+  const APInt &Stride = Expr.Scale;
+  const APInt &ConstOffset = Expr.Offset;
 
   // Allow an additional context offset, but only within the stride.
   if (!ConstOffset.ult(Stride))
@@ -280,7 +278,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
   // comparison is false if Idx was 0x80..00.
   // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx.
   auto MaskIdx = [&](Value *Idx) {
-    if (!GEP->isInBounds() && Stride.countr_zero() != 0) {
+    if (!Expr.Flags.isInBounds() && Stride.countr_zero() != 0) {
       Value *Mask = Constant::getAllOnesValue(Idx->getType());
       Mask = Builder.CreateLShr(Mask, Stride.countr_zero());
       Idx = Builder.CreateAnd(Idx, Mask);
@@ -1958,10 +1956,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
   if (auto *C2 = dyn_cast<ConstantInt>(Y))
     if (auto *LI = dyn_cast<LoadInst>(X))
       if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)))
-        if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
-          if (Instruction *Res =
-                  foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2))
-            return Res;
+        if (Instruction *Res = foldCmpLoadFromIndexedGlobal(LI, GEP, Cmp, C2))
+          return Res;
 
   if (!Cmp.isEquality())
     return nullptr;
@@ -4314,10 +4310,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
     // Try to optimize things like "A[i] > 4" to index computations.
     if (GetElementPtrInst *GEP =
             dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
-      if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
-        if (Instruction *Res =
-                foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I))
-          return Res;
+      if (Instruction *Res =
+              foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
+        return Res;
     break;
   }
 
@@ -8798,10 +8793,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
       break;
     case Instruction::Load:
       if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
-        if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
-          if (Instruction *Res = foldCmpLoadFromIndexedGlobal(
-                  cast<LoadInst>(LHSI), GEP, GV, I))
-            return Res;
+        if (Instruction *Res =
+                foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
+          return Res;
       break;
     case Instruction::FPTrunc:
       if (Instruction *NV = foldFCmpFpTrunc(I, *LHSI, *RHSC))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 2340028ce93dc..8c64d6398eca9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -710,7 +710,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   bool foldAllocaCmp(AllocaInst *Alloca);
   Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI,
                                             GetElementPtrInst *GEP,
-                                            GlobalVariable *GV, CmpInst &ICI,
+                                            CmpInst &ICI,
                                             ConstantInt *AndCst = nullptr);
   Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
                                     Constant *RHSC);
diff --git a/llvm/test/Transforms/InstCombine/load-cmp.ll b/llvm/test/Transforms/InstCombine/load-cmp.ll
index 79faefbd5df56..f7d788110b1bc 100644
--- a/llvm/test/Transforms/InstCombine/load-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/load-cmp.ll
@@ -419,6 +419,32 @@ define i1 @load_vs_array_type_mismatch_offset1(i32 %idx) {
   ret i1 %cmp
 }
 
+define i1 @load_vs_array_type_mismatch_offset1_separate_gep(i32 %idx) {
+; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %gep1 = getelementptr inbounds {i16, i16}, ptr @g_i16_1, i32 %idx
+  %gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
+  %load = load i16, ptr %gep2
+  %cmp = icmp eq i16 %load, 0
+  ret i1 %cmp
+}
+
+define i1 @load_vs_array_type_mismatch_offset1_separate_gep_swapped(i32 %idx) {
+; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep_swapped(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %gep1 = getelementptr inbounds i8, ptr @g_i16_1, i32 2
+  %gep2 = getelementptr inbounds {i16, i16}, ptr %gep1, i32 %idx
+  %load = load i16, ptr %gep2
+  %cmp = icmp eq i16 %load, 0
+  ret i1 %cmp
+}
+
 @g_i16_2 = internal constant [8 x i16] [i16 1, i16 0, i16 0, i16 1, i16 1, i16 0, i16 0, i16 1]
 
 ; idx == 0 || idx == 2
@@ -554,3 +580,36 @@ entry:
   %cond = icmp ult i32 %isOK, 5
   ret i1 %cond
 }
+
+define i1 @cmp_load_multiple_indices(i32 %idx, i32 %idx2) {
+; CHECK-LABEL: @cmp_load_multiple_indices(
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i16, ptr @g_i16_1, i32 [[IDX:%.*]]
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds i16, ptr [[GEP1]], i32 [[IDX2:%.*]]
+; CHECK-NEXT:    [[GEP3:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP2]], i32 2
+; CHECK-NEXT:    [[LOAD:%.*]] = load i16, ptr [[GEP3]], align 2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %gep1 = getelementptr inbounds i16, ptr @g_i16_1, i32 %idx
+  %gep2 = getelementptr inbounds i16, ptr %gep1, i32 %idx2
+  %gep3 = getelementptr inbounds i8, ptr %gep2, i32 2
+  %load = load i16, ptr %gep3
+  %cmp = icmp eq i16 %load, 0
+  ret i1 %cmp
+}
+
+define i1 @cmp_load_multiple_indices2(i32 %idx, i32 %idx2) {
+; CHECK-LABEL: @cmp_load_multiple_indices2(
+; CHECK-NEXT:    [[GEP1_SPLIT:%.*]] = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 [[IDX:%.*]]
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i16, ptr [[GEP1_SPLIT]], i32 [[IDX2:%.*]]
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP1]], i32 2
+; CHECK-NEXT:    [[LOAD:%.*]] = load i16, ptr [[GEP2]], align 2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %gep1 = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 %idx, i32 %idx2
+  %gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
+  %load = load i16, ptr %gep2
+  %cmp = icmp eq i16 %load, 0
+  ret i1 %cmp
+}

@XChy
Copy link
Member

XChy commented Sep 9, 2025

@zyw-bot mfuzz

Copy link
Member

@XChy XChy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nikic nikic merged commit 1cbb35e into llvm:main Sep 9, 2025
13 checks passed
@nikic nikic deleted the indexed-global-multi-gep branch September 9, 2025 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants