Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions llvm/include/llvm/Analysis/Loads.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#ifndef LLVM_ANALYSIS_LOADS_H
#define LLVM_ANALYSIS_LOADS_H

#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/GEPNoWrapFlags.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"

Expand Down Expand Up @@ -193,6 +195,26 @@ LLVM_ABI bool canReplacePointersIfEqual(const Value *From, const Value *To,
const DataLayout &DL);
LLVM_ABI bool canReplacePointersInUseIfEqual(const Use &U, const Value *To,
const DataLayout &DL);

/// Linear expression BasePtr + Index * Scale + Offset.
/// Index, Scale and Offset all have the same bit width, which matches the
/// pointer index size of BasePtr.
/// Index may be nullptr if Scale is 0.
struct LinearExpression {
Value *BasePtr;
Value *Index = nullptr;
APInt Scale;
APInt Offset;
GEPNoWrapFlags Flags = GEPNoWrapFlags::all();

LinearExpression(Value *BasePtr, unsigned BitWidth)
: BasePtr(BasePtr), Scale(BitWidth, 0), Offset(BitWidth, 0) {}
};

/// Decompose a pointer into a linear expression. This may look through
/// multiple GEPs.
LLVM_ABI LinearExpression decomposeLinearExpression(const DataLayout &DL,
Value *Ptr);
}

#endif
64 changes: 64 additions & 0 deletions llvm/lib/Analysis/Loads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Operator.h"

Expand Down Expand Up @@ -876,3 +877,66 @@ bool llvm::isReadOnlyLoop(
}
return true;
}

LinearExpression llvm::decomposeLinearExpression(const DataLayout &DL,
Value *Ptr) {
assert(Ptr->getType()->isPointerTy() && "Must be called with pointer arg");

unsigned BitWidth = DL.getIndexTypeSizeInBits(Ptr->getType());
LinearExpression Expr(Ptr, BitWidth);

while (true) {
auto *GEP = dyn_cast<GEPOperator>(Expr.BasePtr);
if (!GEP || GEP->getSourceElementType()->isScalableTy())
return Expr;

Value *VarIndex = nullptr;
for (Value *Index : GEP->indices()) {
if (isa<ConstantInt>(Index))
continue;
// Only allow a single variable index. We do not bother to handle the
// case of the same variable index appearing multiple times.
if (Expr.Index || VarIndex)
return Expr;
VarIndex = Index;
}

// Don't return non-canonical indexes.
if (VarIndex && !VarIndex->getType()->isIntegerTy(BitWidth))
return Expr;

// We have verified that we can fully handle this GEP, so we can update Expr
// members past this point.
Expr.BasePtr = GEP->getPointerOperand();
Expr.Flags = Expr.Flags.intersectForOffsetAdd(GEP->getNoWrapFlags());
for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
GTI != GTE; ++GTI) {
Value *Index = GTI.getOperand();
if (auto *ConstOffset = dyn_cast<ConstantInt>(Index)) {
if (ConstOffset->isZero())
continue;
if (StructType *STy = GTI.getStructTypeOrNull()) {
unsigned ElementIdx = ConstOffset->getZExtValue();
const StructLayout *SL = DL.getStructLayout(STy);
Expr.Offset += SL->getElementOffset(ElementIdx);
continue;
}
// Truncate if type size exceeds index space.
APInt IndexedSize(BitWidth, GTI.getSequentialElementStride(DL),
/*isSigned=*/false,
/*implcitTrunc=*/true);
Expr.Offset += ConstOffset->getValue() * IndexedSize;
continue;
}

// FIXME: Also look through a mul/shl in the index.
assert(Expr.Index == nullptr && "Shouldn't have index yet");
Expr.Index = Index;
// Truncate if type size exceeds index space.
Expr.Scale = APInt(BitWidth, GTI.getSequentialElementStride(DL),
/*isSigned=*/false, /*implicitTrunc=*/true);
}
}

return Expr;
}
50 changes: 22 additions & 28 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/Utils/Local.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/ConstantRange.h"
Expand Down Expand Up @@ -110,30 +111,27 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) {
/// If AndCst is non-null, then the loaded value is masked with that constant
/// before doing the comparison. This handles cases like "A[i]&4 == 0".
Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI,
ConstantInt *AndCst) {
if (LI->isVolatile() || !GV->isConstant() || !GV->hasDefinitiveInitializer())
LoadInst *LI, GetElementPtrInst *GEP, CmpInst &ICI, ConstantInt *AndCst) {
auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(GEP));
if (LI->isVolatile() || !GV || !GV->isConstant() ||
!GV->hasDefinitiveInitializer())
return nullptr;

Constant *Init = GV->getInitializer();
TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());
Type *EltTy = LI->getType();
TypeSize EltSize = DL.getTypeStoreSize(EltTy);
if (EltSize.isScalable())
return nullptr;

unsigned IndexBW = DL.getIndexTypeSizeInBits(GEP->getType());
SmallMapVector<Value *, APInt, 4> VarOffsets;
APInt ConstOffset(IndexBW, 0);
if (!GEP->collectOffset(DL, IndexBW, VarOffsets, ConstOffset) ||
VarOffsets.size() != 1 || IndexBW > 64)
LinearExpression Expr = decomposeLinearExpression(DL, GEP);
if (!Expr.Index || Expr.BasePtr != GV || Expr.Offset.getBitWidth() > 64)
return nullptr;

Value *Idx = VarOffsets.front().first;
const APInt &Stride = VarOffsets.front().second;
// If the index type is non-canonical, wait for it to be canonicalized.
if (Idx->getType()->getScalarSizeInBits() != IndexBW)
return nullptr;
Constant *Init = GV->getInitializer();
TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());

Value *Idx = Expr.Index;
const APInt &Stride = Expr.Scale;
const APInt &ConstOffset = Expr.Offset;

// Allow an additional context offset, but only within the stride.
if (!ConstOffset.ult(Stride))
Expand Down Expand Up @@ -280,7 +278,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
// comparison is false if Idx was 0x80..00.
// We need to erase the highest countTrailingZeros(ElementSize) bits of Idx.
auto MaskIdx = [&](Value *Idx) {
if (!GEP->isInBounds() && Stride.countr_zero() != 0) {
if (!Expr.Flags.isInBounds() && Stride.countr_zero() != 0) {
Value *Mask = Constant::getAllOnesValue(Idx->getType());
Mask = Builder.CreateLShr(Mask, Stride.countr_zero());
Idx = Builder.CreateAnd(Idx, Mask);
Expand Down Expand Up @@ -1958,10 +1956,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
if (auto *C2 = dyn_cast<ConstantInt>(Y))
if (auto *LI = dyn_cast<LoadInst>(X))
if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)))
if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
if (Instruction *Res =
foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2))
return Res;
if (Instruction *Res = foldCmpLoadFromIndexedGlobal(LI, GEP, Cmp, C2))
return Res;

if (!Cmp.isEquality())
return nullptr;
Expand Down Expand Up @@ -4314,10 +4310,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
// Try to optimize things like "A[i] > 4" to index computations.
if (GetElementPtrInst *GEP =
dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
if (Instruction *Res =
foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I))
return Res;
if (Instruction *Res =
foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
return Res;
break;
}

Expand Down Expand Up @@ -8798,10 +8793,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
break;
case Instruction::Load:
if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
if (Instruction *Res = foldCmpLoadFromIndexedGlobal(
cast<LoadInst>(LHSI), GEP, GV, I))
return Res;
if (Instruction *Res =
foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
return Res;
break;
case Instruction::FPTrunc:
if (Instruction *NV = foldFCmpFpTrunc(I, *LHSI, *RHSC))
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
bool foldAllocaCmp(AllocaInst *Alloca);
Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI,
GetElementPtrInst *GEP,
GlobalVariable *GV, CmpInst &ICI,
CmpInst &ICI,
ConstantInt *AndCst = nullptr);
Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
Constant *RHSC);
Expand Down
59 changes: 59 additions & 0 deletions llvm/test/Transforms/InstCombine/load-cmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,32 @@ define i1 @load_vs_array_type_mismatch_offset1(i32 %idx) {
ret i1 %cmp
}

define i1 @load_vs_array_type_mismatch_offset1_separate_gep(i32 %idx) {
; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
; CHECK-NEXT: ret i1 [[CMP]]
;
%gep1 = getelementptr inbounds {i16, i16}, ptr @g_i16_1, i32 %idx
%gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
%load = load i16, ptr %gep2
%cmp = icmp eq i16 %load, 0
ret i1 %cmp
}

define i1 @load_vs_array_type_mismatch_offset1_separate_gep_swapped(i32 %idx) {
; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep_swapped(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
; CHECK-NEXT: ret i1 [[CMP]]
;
%gep1 = getelementptr inbounds i8, ptr @g_i16_1, i32 2
%gep2 = getelementptr inbounds {i16, i16}, ptr %gep1, i32 %idx
%load = load i16, ptr %gep2
%cmp = icmp eq i16 %load, 0
ret i1 %cmp
}

@g_i16_2 = internal constant [8 x i16] [i16 1, i16 0, i16 0, i16 1, i16 1, i16 0, i16 0, i16 1]

; idx == 0 || idx == 2
Expand Down Expand Up @@ -554,3 +580,36 @@ entry:
%cond = icmp ult i32 %isOK, 5
ret i1 %cond
}

define i1 @cmp_load_multiple_indices(i32 %idx, i32 %idx2) {
; CHECK-LABEL: @cmp_load_multiple_indices(
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr @g_i16_1, i32 [[IDX:%.*]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, ptr [[GEP1]], i32 [[IDX2:%.*]]
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP2]], i32 2
; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP3]], align 2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
%gep1 = getelementptr inbounds i16, ptr @g_i16_1, i32 %idx
%gep2 = getelementptr inbounds i16, ptr %gep1, i32 %idx2
%gep3 = getelementptr inbounds i8, ptr %gep2, i32 2
%load = load i16, ptr %gep3
%cmp = icmp eq i16 %load, 0
ret i1 %cmp
}

define i1 @cmp_load_multiple_indices2(i32 %idx, i32 %idx2) {
; CHECK-LABEL: @cmp_load_multiple_indices2(
; CHECK-NEXT: [[GEP1_SPLIT:%.*]] = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 [[IDX:%.*]]
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr [[GEP1_SPLIT]], i32 [[IDX2:%.*]]
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP1]], i32 2
; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP2]], align 2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
%gep1 = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 %idx, i32 %idx2
%gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
%load = load i16, ptr %gep2
%cmp = icmp eq i16 %load, 0
ret i1 %cmp
}