diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h index 4cec8115709d6..4893c0c9934d2 100644 --- a/llvm/include/llvm/Analysis/Loads.h +++ b/llvm/include/llvm/Analysis/Loads.h @@ -13,7 +13,9 @@ #ifndef LLVM_ANALYSIS_LOADS_H #define LLVM_ANALYSIS_LOADS_H +#include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/GEPNoWrapFlags.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" @@ -193,6 +195,26 @@ LLVM_ABI bool canReplacePointersIfEqual(const Value *From, const Value *To, const DataLayout &DL); LLVM_ABI bool canReplacePointersInUseIfEqual(const Use &U, const Value *To, const DataLayout &DL); + +/// Linear expression BasePtr + Index * Scale + Offset. +/// Index, Scale and Offset all have the same bit width, which matches the +/// pointer index size of BasePtr. +/// Index may be nullptr if Scale is 0. +struct LinearExpression { + Value *BasePtr; + Value *Index = nullptr; + APInt Scale; + APInt Offset; + GEPNoWrapFlags Flags = GEPNoWrapFlags::all(); + + LinearExpression(Value *BasePtr, unsigned BitWidth) + : BasePtr(BasePtr), Scale(BitWidth, 0), Offset(BitWidth, 0) {} +}; + +/// Decompose a pointer into a linear expression. This may look through +/// multiple GEPs. +LLVM_ABI LinearExpression decomposeLinearExpression(const DataLayout &DL, + Value *Ptr); } #endif diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index 47fdb41e8b445..0c4e3a2e3b233 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -21,6 +21,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Operator.h" @@ -876,3 +877,66 @@ bool llvm::isReadOnlyLoop( } return true; } + +LinearExpression llvm::decomposeLinearExpression(const DataLayout &DL, + Value *Ptr) { + assert(Ptr->getType()->isPointerTy() && "Must be called with pointer arg"); + + unsigned BitWidth = DL.getIndexTypeSizeInBits(Ptr->getType()); + LinearExpression Expr(Ptr, BitWidth); + + while (true) { + auto *GEP = dyn_cast(Expr.BasePtr); + if (!GEP || GEP->getSourceElementType()->isScalableTy()) + return Expr; + + Value *VarIndex = nullptr; + for (Value *Index : GEP->indices()) { + if (isa(Index)) + continue; + // Only allow a single variable index. We do not bother to handle the + // case of the same variable index appearing multiple times. + if (Expr.Index || VarIndex) + return Expr; + VarIndex = Index; + } + + // Don't return non-canonical indexes. + if (VarIndex && !VarIndex->getType()->isIntegerTy(BitWidth)) + return Expr; + + // We have verified that we can fully handle this GEP, so we can update Expr + // members past this point. + Expr.BasePtr = GEP->getPointerOperand(); + Expr.Flags = Expr.Flags.intersectForOffsetAdd(GEP->getNoWrapFlags()); + for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP); + GTI != GTE; ++GTI) { + Value *Index = GTI.getOperand(); + if (auto *ConstOffset = dyn_cast(Index)) { + if (ConstOffset->isZero()) + continue; + if (StructType *STy = GTI.getStructTypeOrNull()) { + unsigned ElementIdx = ConstOffset->getZExtValue(); + const StructLayout *SL = DL.getStructLayout(STy); + Expr.Offset += SL->getElementOffset(ElementIdx); + continue; + } + // Truncate if type size exceeds index space. + APInt IndexedSize(BitWidth, GTI.getSequentialElementStride(DL), + /*isSigned=*/false, + /*implcitTrunc=*/true); + Expr.Offset += ConstOffset->getValue() * IndexedSize; + continue; + } + + // FIXME: Also look through a mul/shl in the index. + assert(Expr.Index == nullptr && "Shouldn't have index yet"); + Expr.Index = Index; + // Truncate if type size exceeds index space. + Expr.Scale = APInt(BitWidth, GTI.getSequentialElementStride(DL), + /*isSigned=*/false, /*implicitTrunc=*/true); + } + } + + return Expr; +} diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 07da12a3ab2a4..99ea04816681c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/Utils/Local.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -110,30 +111,27 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { /// If AndCst is non-null, then the loaded value is masked with that constant /// before doing the comparison. This handles cases like "A[i]&4 == 0". Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( - LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, - ConstantInt *AndCst) { - if (LI->isVolatile() || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + LoadInst *LI, GetElementPtrInst *GEP, CmpInst &ICI, ConstantInt *AndCst) { + auto *GV = dyn_cast(getUnderlyingObject(GEP)); + if (LI->isVolatile() || !GV || !GV->isConstant() || + !GV->hasDefinitiveInitializer()) return nullptr; - Constant *Init = GV->getInitializer(); - TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType()); Type *EltTy = LI->getType(); TypeSize EltSize = DL.getTypeStoreSize(EltTy); if (EltSize.isScalable()) return nullptr; - unsigned IndexBW = DL.getIndexTypeSizeInBits(GEP->getType()); - SmallMapVector VarOffsets; - APInt ConstOffset(IndexBW, 0); - if (!GEP->collectOffset(DL, IndexBW, VarOffsets, ConstOffset) || - VarOffsets.size() != 1 || IndexBW > 64) + LinearExpression Expr = decomposeLinearExpression(DL, GEP); + if (!Expr.Index || Expr.BasePtr != GV || Expr.Offset.getBitWidth() > 64) return nullptr; - Value *Idx = VarOffsets.front().first; - const APInt &Stride = VarOffsets.front().second; - // If the index type is non-canonical, wait for it to be canonicalized. - if (Idx->getType()->getScalarSizeInBits() != IndexBW) - return nullptr; + Constant *Init = GV->getInitializer(); + TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType()); + + Value *Idx = Expr.Index; + const APInt &Stride = Expr.Scale; + const APInt &ConstOffset = Expr.Offset; // Allow an additional context offset, but only within the stride. if (!ConstOffset.ult(Stride)) @@ -280,7 +278,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( // comparison is false if Idx was 0x80..00. // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx. auto MaskIdx = [&](Value *Idx) { - if (!GEP->isInBounds() && Stride.countr_zero() != 0) { + if (!Expr.Flags.isInBounds() && Stride.countr_zero() != 0) { Value *Mask = Constant::getAllOnesValue(Idx->getType()); Mask = Builder.CreateLShr(Mask, Stride.countr_zero()); Idx = Builder.CreateAnd(Idx, Mask); @@ -1958,10 +1956,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, if (auto *C2 = dyn_cast(Y)) if (auto *LI = dyn_cast(X)) if (auto *GEP = dyn_cast(LI->getOperand(0))) - if (auto *GV = dyn_cast(GEP->getOperand(0))) - if (Instruction *Res = - foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2)) - return Res; + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(LI, GEP, Cmp, C2)) + return Res; if (!Cmp.isEquality()) return nullptr; @@ -4314,10 +4310,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { // Try to optimize things like "A[i] > 4" to index computations. if (GetElementPtrInst *GEP = dyn_cast(LHSI->getOperand(0))) - if (GlobalVariable *GV = dyn_cast(GEP->getOperand(0))) - if (Instruction *Res = - foldCmpLoadFromIndexedGlobal(cast(LHSI), GEP, GV, I)) - return Res; + if (Instruction *Res = + foldCmpLoadFromIndexedGlobal(cast(LHSI), GEP, I)) + return Res; break; } @@ -8798,10 +8793,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { break; case Instruction::Load: if (auto *GEP = dyn_cast(LHSI->getOperand(0))) - if (auto *GV = dyn_cast(GEP->getOperand(0))) - if (Instruction *Res = foldCmpLoadFromIndexedGlobal( - cast(LHSI), GEP, GV, I)) - return Res; + if (Instruction *Res = + foldCmpLoadFromIndexedGlobal(cast(LHSI), GEP, I)) + return Res; break; case Instruction::FPTrunc: if (Instruction *NV = foldFCmpFpTrunc(I, *LHSI, *RHSC)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index d3d23130b6fc4..7a979c16da501 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -693,7 +693,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final bool foldAllocaCmp(AllocaInst *Alloca); Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI, GetElementPtrInst *GEP, - GlobalVariable *GV, CmpInst &ICI, + CmpInst &ICI, ConstantInt *AndCst = nullptr); Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC); diff --git a/llvm/test/Transforms/InstCombine/load-cmp.ll b/llvm/test/Transforms/InstCombine/load-cmp.ll index 79faefbd5df56..f7d788110b1bc 100644 --- a/llvm/test/Transforms/InstCombine/load-cmp.ll +++ b/llvm/test/Transforms/InstCombine/load-cmp.ll @@ -419,6 +419,32 @@ define i1 @load_vs_array_type_mismatch_offset1(i32 %idx) { ret i1 %cmp } +define i1 @load_vs_array_type_mismatch_offset1_separate_gep(i32 %idx) { +; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[CMP]] +; + %gep1 = getelementptr inbounds {i16, i16}, ptr @g_i16_1, i32 %idx + %gep2 = getelementptr inbounds i8, ptr %gep1, i32 2 + %load = load i16, ptr %gep2 + %cmp = icmp eq i16 %load, 0 + ret i1 %cmp +} + +define i1 @load_vs_array_type_mismatch_offset1_separate_gep_swapped(i32 %idx) { +; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep_swapped( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[CMP]] +; + %gep1 = getelementptr inbounds i8, ptr @g_i16_1, i32 2 + %gep2 = getelementptr inbounds {i16, i16}, ptr %gep1, i32 %idx + %load = load i16, ptr %gep2 + %cmp = icmp eq i16 %load, 0 + ret i1 %cmp +} + @g_i16_2 = internal constant [8 x i16] [i16 1, i16 0, i16 0, i16 1, i16 1, i16 0, i16 0, i16 1] ; idx == 0 || idx == 2 @@ -554,3 +580,36 @@ entry: %cond = icmp ult i32 %isOK, 5 ret i1 %cond } + +define i1 @cmp_load_multiple_indices(i32 %idx, i32 %idx2) { +; CHECK-LABEL: @cmp_load_multiple_indices( +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr @g_i16_1, i32 [[IDX:%.*]] +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, ptr [[GEP1]], i32 [[IDX2:%.*]] +; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP2]], i32 2 +; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP3]], align 2 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0 +; CHECK-NEXT: ret i1 [[CMP]] +; + %gep1 = getelementptr inbounds i16, ptr @g_i16_1, i32 %idx + %gep2 = getelementptr inbounds i16, ptr %gep1, i32 %idx2 + %gep3 = getelementptr inbounds i8, ptr %gep2, i32 2 + %load = load i16, ptr %gep3 + %cmp = icmp eq i16 %load, 0 + ret i1 %cmp +} + +define i1 @cmp_load_multiple_indices2(i32 %idx, i32 %idx2) { +; CHECK-LABEL: @cmp_load_multiple_indices2( +; CHECK-NEXT: [[GEP1_SPLIT:%.*]] = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 [[IDX:%.*]] +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr [[GEP1_SPLIT]], i32 [[IDX2:%.*]] +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP1]], i32 2 +; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP2]], align 2 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0 +; CHECK-NEXT: ret i1 [[CMP]] +; + %gep1 = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 %idx, i32 %idx2 + %gep2 = getelementptr inbounds i8, ptr %gep1, i32 2 + %load = load i16, ptr %gep2 + %cmp = icmp eq i16 %load, 0 + ret i1 %cmp +}