-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[InstCombine] Support GEP chains in foldCmpLoadFromIndexedGlobal() #157447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Currently this fold only supports a single GEP. However, in ptradd representation, it may be split across multiple GEPs. In particular, PR llvm#151333 will split off constant offset GEPs. To support this, add a new helper decomposeLinearExpression(), which decomposes a pointer into a linear expression of the form BasePtr + Index * Scale + Offset. I plan to also extend this helper to look through mul/shl on the index and use it in more places that currently use collectOffset() to extract a single index * scale. This will make sure such optimizations are not affected by the ptradd migration.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: Nikita Popov (nikic) ChangesCurrently this fold only supports a single GEP. However, in ptradd representation, it may be split across multiple GEPs. In particular, PR #151333 will split off constant offset GEPs. To support this, add a new helper decomposeLinearExpression(), which decomposes a pointer into a linear expression of the form BasePtr + Index * Scale + Offset. I plan to also extend this helper to look through mul/shl on the index and use it in more places that currently use collectOffset() to extract a single index * scale. This will make sure such optimizations are not affected by the ptradd migration. Full diff: https://github.com/llvm/llvm-project/pull/157447.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h
index 4cec8115709d6..4893c0c9934d2 100644
--- a/llvm/include/llvm/Analysis/Loads.h
+++ b/llvm/include/llvm/Analysis/Loads.h
@@ -13,7 +13,9 @@
#ifndef LLVM_ANALYSIS_LOADS_H
#define LLVM_ANALYSIS_LOADS_H
+#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/GEPNoWrapFlags.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
@@ -193,6 +195,26 @@ LLVM_ABI bool canReplacePointersIfEqual(const Value *From, const Value *To,
const DataLayout &DL);
LLVM_ABI bool canReplacePointersInUseIfEqual(const Use &U, const Value *To,
const DataLayout &DL);
+
+/// Linear expression BasePtr + Index * Scale + Offset.
+/// Index, Scale and Offset all have the same bit width, which matches the
+/// pointer index size of BasePtr.
+/// Index may be nullptr if Scale is 0.
+struct LinearExpression {
+ Value *BasePtr;
+ Value *Index = nullptr;
+ APInt Scale;
+ APInt Offset;
+ GEPNoWrapFlags Flags = GEPNoWrapFlags::all();
+
+ LinearExpression(Value *BasePtr, unsigned BitWidth)
+ : BasePtr(BasePtr), Scale(BitWidth, 0), Offset(BitWidth, 0) {}
+};
+
+/// Decompose a pointer into a linear expression. This may look through
+/// multiple GEPs.
+LLVM_ABI LinearExpression decomposeLinearExpression(const DataLayout &DL,
+ Value *Ptr);
}
#endif
diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index 47fdb41e8b445..0c4e3a2e3b233 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -21,6 +21,7 @@
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Operator.h"
@@ -876,3 +877,66 @@ bool llvm::isReadOnlyLoop(
}
return true;
}
+
+LinearExpression llvm::decomposeLinearExpression(const DataLayout &DL,
+ Value *Ptr) {
+ assert(Ptr->getType()->isPointerTy() && "Must be called with pointer arg");
+
+ unsigned BitWidth = DL.getIndexTypeSizeInBits(Ptr->getType());
+ LinearExpression Expr(Ptr, BitWidth);
+
+ while (true) {
+ auto *GEP = dyn_cast<GEPOperator>(Expr.BasePtr);
+ if (!GEP || GEP->getSourceElementType()->isScalableTy())
+ return Expr;
+
+ Value *VarIndex = nullptr;
+ for (Value *Index : GEP->indices()) {
+ if (isa<ConstantInt>(Index))
+ continue;
+ // Only allow a single variable index. We do not bother to handle the
+ // case of the same variable index appearing multiple times.
+ if (Expr.Index || VarIndex)
+ return Expr;
+ VarIndex = Index;
+ }
+
+ // Don't return non-canonical indexes.
+ if (VarIndex && !VarIndex->getType()->isIntegerTy(BitWidth))
+ return Expr;
+
+ // We have verified that we can fully handle this GEP, so we can update Expr
+ // members past this point.
+ Expr.BasePtr = GEP->getPointerOperand();
+ Expr.Flags = Expr.Flags.intersectForOffsetAdd(GEP->getNoWrapFlags());
+ for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
+ GTI != GTE; ++GTI) {
+ Value *Index = GTI.getOperand();
+ if (auto *ConstOffset = dyn_cast<ConstantInt>(Index)) {
+ if (ConstOffset->isZero())
+ continue;
+ if (StructType *STy = GTI.getStructTypeOrNull()) {
+ unsigned ElementIdx = ConstOffset->getZExtValue();
+ const StructLayout *SL = DL.getStructLayout(STy);
+ Expr.Offset += SL->getElementOffset(ElementIdx);
+ continue;
+ }
+ // Truncate if type size exceeds index space.
+ APInt IndexedSize(BitWidth, GTI.getSequentialElementStride(DL),
+ /*isSigned=*/false,
+ /*implcitTrunc=*/true);
+ Expr.Offset += ConstOffset->getValue() * IndexedSize;
+ continue;
+ }
+
+ // FIXME: Also look through a mul/shl in the index.
+ assert(Expr.Index == nullptr && "Shouldn't have index yet");
+ Expr.Index = Index;
+ // Truncate if type size exceeds index space.
+ Expr.Scale = APInt(BitWidth, GTI.getSequentialElementStride(DL),
+ /*isSigned=*/false, /*implicitTrunc=*/true);
+ }
+ }
+
+ return Expr;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 01b0da3469c18..6a778726df729 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -19,6 +19,7 @@
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/Utils/Local.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/ConstantRange.h"
@@ -110,30 +111,27 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) {
/// If AndCst is non-null, then the loaded value is masked with that constant
/// before doing the comparison. This handles cases like "A[i]&4 == 0".
Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
- LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI,
- ConstantInt *AndCst) {
- if (LI->isVolatile() || !GV->isConstant() || !GV->hasDefinitiveInitializer())
+ LoadInst *LI, GetElementPtrInst *GEP, CmpInst &ICI, ConstantInt *AndCst) {
+ auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(GEP));
+ if (LI->isVolatile() || !GV || !GV->isConstant() ||
+ !GV->hasDefinitiveInitializer())
return nullptr;
- Constant *Init = GV->getInitializer();
- TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());
Type *EltTy = LI->getType();
TypeSize EltSize = DL.getTypeStoreSize(EltTy);
if (EltSize.isScalable())
return nullptr;
- unsigned IndexBW = DL.getIndexTypeSizeInBits(GEP->getType());
- SmallMapVector<Value *, APInt, 4> VarOffsets;
- APInt ConstOffset(IndexBW, 0);
- if (!GEP->collectOffset(DL, IndexBW, VarOffsets, ConstOffset) ||
- VarOffsets.size() != 1 || IndexBW > 64)
+ LinearExpression Expr = decomposeLinearExpression(DL, GEP);
+ if (!Expr.Index || Expr.BasePtr != GV || Expr.Offset.getBitWidth() > 64)
return nullptr;
- Value *Idx = VarOffsets.front().first;
- const APInt &Stride = VarOffsets.front().second;
- // If the index type is non-canonical, wait for it to be canonicalized.
- if (Idx->getType()->getScalarSizeInBits() != IndexBW)
- return nullptr;
+ Constant *Init = GV->getInitializer();
+ TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());
+
+ Value *Idx = Expr.Index;
+ const APInt &Stride = Expr.Scale;
+ const APInt &ConstOffset = Expr.Offset;
// Allow an additional context offset, but only within the stride.
if (!ConstOffset.ult(Stride))
@@ -280,7 +278,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
// comparison is false if Idx was 0x80..00.
// We need to erase the highest countTrailingZeros(ElementSize) bits of Idx.
auto MaskIdx = [&](Value *Idx) {
- if (!GEP->isInBounds() && Stride.countr_zero() != 0) {
+ if (!Expr.Flags.isInBounds() && Stride.countr_zero() != 0) {
Value *Mask = Constant::getAllOnesValue(Idx->getType());
Mask = Builder.CreateLShr(Mask, Stride.countr_zero());
Idx = Builder.CreateAnd(Idx, Mask);
@@ -1958,10 +1956,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
if (auto *C2 = dyn_cast<ConstantInt>(Y))
if (auto *LI = dyn_cast<LoadInst>(X))
if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)))
- if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
- if (Instruction *Res =
- foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2))
- return Res;
+ if (Instruction *Res = foldCmpLoadFromIndexedGlobal(LI, GEP, Cmp, C2))
+ return Res;
if (!Cmp.isEquality())
return nullptr;
@@ -4314,10 +4310,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
// Try to optimize things like "A[i] > 4" to index computations.
if (GetElementPtrInst *GEP =
dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
- if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
- if (Instruction *Res =
- foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I))
- return Res;
+ if (Instruction *Res =
+ foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
+ return Res;
break;
}
@@ -8798,10 +8793,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
break;
case Instruction::Load:
if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
- if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
- if (Instruction *Res = foldCmpLoadFromIndexedGlobal(
- cast<LoadInst>(LHSI), GEP, GV, I))
- return Res;
+ if (Instruction *Res =
+ foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
+ return Res;
break;
case Instruction::FPTrunc:
if (Instruction *NV = foldFCmpFpTrunc(I, *LHSI, *RHSC))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 2340028ce93dc..8c64d6398eca9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -710,7 +710,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
bool foldAllocaCmp(AllocaInst *Alloca);
Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI,
GetElementPtrInst *GEP,
- GlobalVariable *GV, CmpInst &ICI,
+ CmpInst &ICI,
ConstantInt *AndCst = nullptr);
Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
Constant *RHSC);
diff --git a/llvm/test/Transforms/InstCombine/load-cmp.ll b/llvm/test/Transforms/InstCombine/load-cmp.ll
index 79faefbd5df56..f7d788110b1bc 100644
--- a/llvm/test/Transforms/InstCombine/load-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/load-cmp.ll
@@ -419,6 +419,32 @@ define i1 @load_vs_array_type_mismatch_offset1(i32 %idx) {
ret i1 %cmp
}
+define i1 @load_vs_array_type_mismatch_offset1_separate_gep(i32 %idx) {
+; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %gep1 = getelementptr inbounds {i16, i16}, ptr @g_i16_1, i32 %idx
+ %gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
+ %load = load i16, ptr %gep2
+ %cmp = icmp eq i16 %load, 0
+ ret i1 %cmp
+}
+
+define i1 @load_vs_array_type_mismatch_offset1_separate_gep_swapped(i32 %idx) {
+; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep_swapped(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %gep1 = getelementptr inbounds i8, ptr @g_i16_1, i32 2
+ %gep2 = getelementptr inbounds {i16, i16}, ptr %gep1, i32 %idx
+ %load = load i16, ptr %gep2
+ %cmp = icmp eq i16 %load, 0
+ ret i1 %cmp
+}
+
@g_i16_2 = internal constant [8 x i16] [i16 1, i16 0, i16 0, i16 1, i16 1, i16 0, i16 0, i16 1]
; idx == 0 || idx == 2
@@ -554,3 +580,36 @@ entry:
%cond = icmp ult i32 %isOK, 5
ret i1 %cond
}
+
+define i1 @cmp_load_multiple_indices(i32 %idx, i32 %idx2) {
+; CHECK-LABEL: @cmp_load_multiple_indices(
+; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr @g_i16_1, i32 [[IDX:%.*]]
+; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, ptr [[GEP1]], i32 [[IDX2:%.*]]
+; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP2]], i32 2
+; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP3]], align 2
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %gep1 = getelementptr inbounds i16, ptr @g_i16_1, i32 %idx
+ %gep2 = getelementptr inbounds i16, ptr %gep1, i32 %idx2
+ %gep3 = getelementptr inbounds i8, ptr %gep2, i32 2
+ %load = load i16, ptr %gep3
+ %cmp = icmp eq i16 %load, 0
+ ret i1 %cmp
+}
+
+define i1 @cmp_load_multiple_indices2(i32 %idx, i32 %idx2) {
+; CHECK-LABEL: @cmp_load_multiple_indices2(
+; CHECK-NEXT: [[GEP1_SPLIT:%.*]] = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 [[IDX:%.*]]
+; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr [[GEP1_SPLIT]], i32 [[IDX2:%.*]]
+; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP1]], i32 2
+; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP2]], align 2
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %gep1 = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 %idx, i32 %idx2
+ %gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
+ %load = load i16, ptr %gep2
+ %cmp = icmp eq i16 %load, 0
+ ret i1 %cmp
+}
|
@zyw-bot mfuzz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Currently this fold only supports a single GEP. However, in ptradd representation, it may be split across multiple GEPs. In particular, PR #151333 will split off constant offset GEPs.
To support this, add a new helper decomposeLinearExpression(), which decomposes a pointer into a linear expression of the form BasePtr + Index * Scale + Offset.
I plan to also extend this helper to look through mul/shl on the index and use it in more places that currently use collectOffset() to extract a single index * scale. This will make sure such optimizations are not affected by the ptradd migration.