Skip to content

Commit

Permalink
[InstCombine] Fold strncmp of constant arrays and variable size
Browse files Browse the repository at this point in the history
Extend the solution accepted in D127766 to strncmp and simplify
strncmp(A, B, N) calls with constant A and B and variable N to
the equivalent of

  N <= Pos ? 0 : (A < B ? -1 : B < A ? +1 : 0)

where Pos is the offset of either the first mismatch between A
and B or the terminating null character if both A and B are equal
strings.

Reviewed By: courbet

Differential Revision: https://reviews.llvm.org/D128089
  • Loading branch information
msebor committed Jun 28, 2022
1 parent e263a76 commit 8827679
Show file tree
Hide file tree
Showing 4 changed files with 507 additions and 41 deletions.
57 changes: 31 additions & 26 deletions llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
Expand Up @@ -428,6 +428,12 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
return nullptr;
}

// Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant
// arrays LHS and RHS and nonconstant Size.
static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS,
Value *Size, bool StrNCmp,
IRBuilderBase &B, const DataLayout &DL);

Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
Value *Str1P = CI->getArgOperand(0);
Value *Str2P = CI->getArgOperand(1);
Expand All @@ -442,7 +448,7 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size))
Length = LengthArg->getZExtValue();
else
return nullptr;
return optimizeMemCmpVarSize(CI, Str1P, Str2P, Size, true, B, DL);

if (Length == 0) // strncmp(x,y,0) -> 0
return ConstantInt::get(CI->getType(), 0);
Expand Down Expand Up @@ -1155,11 +1161,11 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
CI->getType());
}

// Optimize a memcmp call CI with constant arrays LHS and RHS and either
// nonconstant Size or constant size known to be in bounds.
// Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant
// arrays LHS and RHS and nonconstant Size.
static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS,
Value *Size, IRBuilderBase &B,
const DataLayout &DL) {
Value *Size, bool StrNCmp,
IRBuilderBase &B, const DataLayout &DL) {
if (LHS == RHS) // memcmp(s,s,x) -> 0
return Constant::getNullValue(CI->getType());

Expand All @@ -1173,30 +1179,29 @@ static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS,
// N <= Pos ? 0 : (A < B ? -1 : B < A ? +1 : 0)
// where Pos is the first mismatch between A and B, determined below.

uint64_t Pos = 0;
Value *Zero = ConstantInt::get(CI->getType(), 0);

uint64_t MinSize = std::min(LStr.size(), RStr.size());
for (uint64_t Pos = 0; Pos < MinSize; ++Pos) {
if (LStr[Pos] != RStr[Pos]) {
Value *MaxSize = ConstantInt::get(Size->getType(), Pos);
Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize);
typedef unsigned char UChar;
int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1;
Value *Res = ConstantInt::get(CI->getType(), IRes);
return B.CreateSelect(Cmp, Zero, Res);
for (uint64_t MinSize = std::min(LStr.size(), RStr.size()); ; ++Pos) {
if (Pos == MinSize ||
(StrNCmp && (LStr[Pos] == '\0' && RStr[Pos] == '\0'))) {
// One array is a leading part of the other of equal or greater
// size, or for strncmp, the arrays are equal strings.
// Fold the result to zero. Size is assumed to be in bounds, since
// otherwise the call would be undefined.
return Zero;
}
}

if (auto *SizeC = dyn_cast<ConstantInt>(Size))
if (MinSize < SizeC->getZExtValue())
// Fail if the bound happens to be constant and excessive and
// let sanitizers catch it.
return nullptr;
if (LStr[Pos] != RStr[Pos])
break;
}

// One array is a leading part of the other of equal or greater size.
// Fold the result to zero. Nonconstant size is assumed to be in bounds,
// since otherwise the call would be undefined.
return Zero;
// Normalize the result.
typedef unsigned char UChar;
int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1;
Value *MaxSize = ConstantInt::get(Size->getType(), Pos);
Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize);
Value *Res = ConstantInt::get(CI->getType(), IRes);
return B.CreateSelect(Cmp, Zero, Res);
}

// Optimize a memcmp call CI with constant size Len.
Expand Down Expand Up @@ -1265,7 +1270,7 @@ Value *LibCallSimplifier::optimizeMemCmpBCmpCommon(CallInst *CI,

annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL);

if (Value *Res = optimizeMemCmpVarSize(CI, LHS, RHS, Size, B, DL))
if (Value *Res = optimizeMemCmpVarSize(CI, LHS, RHS, Size, false, B, DL))
return Res;

// Handle constant Size.
Expand Down
28 changes: 13 additions & 15 deletions llvm/test/Transforms/InstCombine/memcmp-4.ll
Expand Up @@ -17,14 +17,14 @@ declare i32 @memcmp(i8*, i8*, i64)
; value (analogous to strncmp) is safer than letting a SIMD library
; implementation return a bogus value.

define void @fold_memcmp_too_big(i32* %pcmp) {
; BE-LABEL: @fold_memcmp_too_big(
define void @fold_memcmp_mismatch_too_big(i32* %pcmp) {
; BE-LABEL: @fold_memcmp_mismatch_too_big(
; BE-NEXT: store i32 -1, i32* [[PCMP:%.*]], align 4
; BE-NEXT: [[PSTOR_CB:%.*]] = getelementptr i32, i32* [[PCMP]], i64 1
; BE-NEXT: store i32 1, i32* [[PSTOR_CB]], align 4
; BE-NEXT: ret void
;
; LE-LABEL: @fold_memcmp_too_big(
; LE-LABEL: @fold_memcmp_mismatch_too_big(
; LE-NEXT: store i32 -1, i32* [[PCMP:%.*]], align 4
; LE-NEXT: [[PSTOR_CB:%.*]] = getelementptr i32, i32* [[PCMP]], i64 1
; LE-NEXT: store i32 1, i32* [[PSTOR_CB]], align 4
Expand All @@ -47,23 +47,21 @@ define void @fold_memcmp_too_big(i32* %pcmp) {
}


; Don't fold calls with excessive byte counts of arrays with the same bytes.
; Fold even calls with excessive byte counts of arrays with matching bytes.
; Like in the instances above, this is preferable to letting the undefined
; calls take place, although it does prevent sanitizers from detecting them.

define void @call_memcmp_too_big(i32* %pcmp) {
; BE-LABEL: @call_memcmp_too_big(
; BE-NEXT: [[CMP_AB_9:%.*]] = call i32 @memcmp(i8* noundef nonnull dereferenceable(9) bitcast ([4 x i16]* @ia16a to i8*), i8* noundef nonnull dereferenceable(9) bitcast ([5 x i16]* @ia16b to i8*), i64 9)
; BE-NEXT: store i32 [[CMP_AB_9]], i32* [[PCMP:%.*]], align 4
; BE-NEXT: [[CMP_AB_M1:%.*]] = call i32 @memcmp(i8* noundef nonnull dereferenceable(18446744073709551615) bitcast ([4 x i16]* @ia16a to i8*), i8* noundef nonnull dereferenceable(18446744073709551615) bitcast ([5 x i16]* @ia16b to i8*), i64 -1)
define void @fold_memcmp_match_too_big(i32* %pcmp) {
; BE-LABEL: @fold_memcmp_match_too_big(
; BE-NEXT: store i32 0, i32* [[PCMP:%.*]], align 4
; BE-NEXT: [[PSTOR_AB_M1:%.*]] = getelementptr i32, i32* [[PCMP]], i64 1
; BE-NEXT: store i32 [[CMP_AB_M1]], i32* [[PSTOR_AB_M1]], align 4
; BE-NEXT: store i32 0, i32* [[PSTOR_AB_M1]], align 4
; BE-NEXT: ret void
;
; LE-LABEL: @call_memcmp_too_big(
; LE-NEXT: [[CMP_AB_9:%.*]] = call i32 @memcmp(i8* noundef nonnull dereferenceable(9) bitcast ([4 x i16]* @ia16a to i8*), i8* noundef nonnull dereferenceable(9) bitcast ([5 x i16]* @ia16b to i8*), i64 9)
; LE-NEXT: store i32 [[CMP_AB_9]], i32* [[PCMP:%.*]], align 4
; LE-NEXT: [[CMP_AB_M1:%.*]] = call i32 @memcmp(i8* noundef nonnull dereferenceable(18446744073709551615) bitcast ([4 x i16]* @ia16a to i8*), i8* noundef nonnull dereferenceable(18446744073709551615) bitcast ([5 x i16]* @ia16b to i8*), i64 -1)
; LE-LABEL: @fold_memcmp_match_too_big(
; LE-NEXT: store i32 0, i32* [[PCMP:%.*]], align 4
; LE-NEXT: [[PSTOR_AB_M1:%.*]] = getelementptr i32, i32* [[PCMP]], i64 1
; LE-NEXT: store i32 [[CMP_AB_M1]], i32* [[PSTOR_AB_M1]], align 4
; LE-NEXT: store i32 0, i32* [[PSTOR_AB_M1]], align 4
; LE-NEXT: ret void
;
%p0 = getelementptr [4 x i16], [4 x i16]* @ia16a, i64 0, i64 0
Expand Down

0 comments on commit 8827679

Please sign in to comment.