Skip to content

Commit 39b0cbe

Browse files
authored
[IndVarSimplify] Allow predicateLoopExit on some loops with thread-local writes (#155901)
This is important to optimize patterns that frequently appear with bounds checks: ``` for (int i = 0; i < N; ++i) { bar[i] = foo[i] + 123; } ``` which gets roughly turned into ``` for (int i = 0; i < N; ++i) { if (i >= size of foo) ubsan.trap(); if (i >= size of bar) ubsan.trap(); bar[i] = foo[i] + 123; } ``` Motivating example: https://github.com/google/boringssl/blob/main/crypto/fipsmodule/hmac/hmac.cc.inc#L138 I hand-verified the assembly and confirmed that this optimization removes the check in the loop. This also allowed the loop to be vectorized. Alive2: https://alive2.llvm.org/ce/z/3qMdLF I did a `stage2-check-all` for both normal and `-DBOOTSTRAP_CMAKE_C[XX]_FLAGS="-fsanitize=array-bounds -fsanitize-trap=all"`. I also ran some Google-internal tests with `fsanitize=array-bounds`. Everything passes.
1 parent 7fe0691 commit 39b0cbe

File tree

3 files changed

+794
-14
lines changed

3 files changed

+794
-14
lines changed

llvm/lib/Transforms/Scalar/IndVarSimplify.cpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "llvm/IR/InstrTypes.h"
5454
#include "llvm/IR/Instruction.h"
5555
#include "llvm/IR/Instructions.h"
56+
#include "llvm/IR/IntrinsicInst.h"
5657
#include "llvm/IR/Intrinsics.h"
5758
#include "llvm/IR/PassManager.h"
5859
#include "llvm/IR/PatternMatch.h"
@@ -117,6 +118,10 @@ static cl::opt<bool>
117118
LoopPredication("indvars-predicate-loops", cl::Hidden, cl::init(true),
118119
cl::desc("Predicate conditions in read only loops"));
119120

121+
static cl::opt<bool> LoopPredicationTraps(
122+
"indvars-predicate-loop-traps", cl::Hidden, cl::init(true),
123+
cl::desc("Predicate conditions that trap in loops with only local writes"));
124+
120125
static cl::opt<bool>
121126
AllowIVWidening("indvars-widen-indvars", cl::Hidden, cl::init(true),
122127
cl::desc("Allow widening of indvars to eliminate s/zext"));
@@ -1704,6 +1709,24 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {
17041709
return Changed;
17051710
}
17061711

1712+
static bool crashingBBWithoutEffect(const BasicBlock &BB) {
1713+
return llvm::all_of(BB, [](const Instruction &I) {
1714+
// TODO: for now this is overly restrictive, to make sure nothing in this
1715+
// BB can depend on the loop body.
1716+
// It's not enough to check for !I.mayHaveSideEffects(), because e.g. a
1717+
// load does not have a side effect, but we could have
1718+
// %a = load ptr, ptr %ptr
1719+
// %b = load i32, ptr %a
1720+
// Now if the loop stored a non-nullptr to %a, we could cause a nullptr
1721+
// dereference by skipping over loop iterations.
1722+
if (const auto *CB = dyn_cast<CallBase>(&I)) {
1723+
if (CB->onlyAccessesInaccessibleMemory())
1724+
return true;
1725+
}
1726+
return isa<UnreachableInst>(I);
1727+
});
1728+
}
1729+
17071730
bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
17081731
SmallVector<BasicBlock*, 16> ExitingBlocks;
17091732
L->getExitingBlocks(ExitingBlocks);
@@ -1816,11 +1839,25 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
18161839
// suggestions on how to improve this? I can obviously bail out for outer
18171840
// loops, but that seems less than ideal. MemorySSA can find memory writes,
18181841
// is that enough for *all* side effects?
1842+
bool HasThreadLocalSideEffects = false;
18191843
for (BasicBlock *BB : L->blocks())
18201844
for (auto &I : *BB)
18211845
// TODO:isGuaranteedToTransfer
1822-
if (I.mayHaveSideEffects())
1823-
return false;
1846+
if (I.mayHaveSideEffects()) {
1847+
if (!LoopPredicationTraps)
1848+
return false;
1849+
HasThreadLocalSideEffects = true;
1850+
if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
1851+
// Simple stores cannot be observed by other threads.
1852+
// If HasThreadLocalSideEffects is set, we check
1853+
// crashingBBWithoutEffect to make sure that the crashing BB cannot
1854+
// observe them either.
1855+
if (!SI->isSimple())
1856+
return false;
1857+
} else {
1858+
return false;
1859+
}
1860+
}
18241861

18251862
bool Changed = false;
18261863
// Finally, do the actual predication for all predicatable blocks. A couple
@@ -1840,6 +1877,19 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
18401877
const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
18411878

18421879
auto *BI = cast<BranchInst>(ExitingBB->getTerminator());
1880+
if (HasThreadLocalSideEffects) {
1881+
const BasicBlock *Unreachable = nullptr;
1882+
for (const BasicBlock *Succ : BI->successors()) {
1883+
if (isa<UnreachableInst>(Succ->getTerminator()))
1884+
Unreachable = Succ;
1885+
}
1886+
// Exit BB which have one branch back into the loop and another one to
1887+
// a trap can still be optimized, because local side effects cannot
1888+
// be observed in the exit case (the trap). We could be smarter about
1889+
// this, but for now lets pattern match common cases that directly trap.
1890+
if (Unreachable == nullptr || !crashingBBWithoutEffect(*Unreachable))
1891+
return Changed;
1892+
}
18431893
Value *NewCond;
18441894
if (ExitCount == ExactBTC) {
18451895
NewCond = L->contains(BI->getSuccessor(0)) ?

llvm/test/Transforms/IndVarSimplify/X86/overflow-intrinsics.ll

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ define void @f_sadd_overflow(ptr %a) {
6060
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[CONT:.*]] ], [ 2147483645, %[[ENTRY]] ]
6161
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDVARS_IV]]
6262
; CHECK-NEXT: store i8 0, ptr [[ARRAYIDX]], align 1
63-
; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[INDVARS_IV]], 2147483647
64-
; CHECK-NEXT: br i1 [[EXITCOND]], label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
63+
; CHECK-NEXT: br i1 true, label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
6564
; CHECK: [[TRAP]]:
6665
; CHECK-NEXT: tail call void @llvm.trap(), !nosanitize [[META0]]
6766
; CHECK-NEXT: unreachable, !nosanitize [[META0]]
@@ -150,8 +149,7 @@ define void @f_uadd_overflow(ptr %a) {
150149
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[CONT:.*]] ], [ -6, %[[ENTRY]] ]
151150
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDVARS_IV]]
152151
; CHECK-NEXT: store i8 0, ptr [[ARRAYIDX]], align 1
153-
; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[INDVARS_IV]], -1
154-
; CHECK-NEXT: br i1 [[EXITCOND]], label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
152+
; CHECK-NEXT: br i1 true, label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
155153
; CHECK: [[TRAP]]:
156154
; CHECK-NEXT: tail call void @llvm.trap(), !nosanitize [[META0]]
157155
; CHECK-NEXT: unreachable, !nosanitize [[META0]]
@@ -243,10 +241,7 @@ define void @f_ssub_overflow(ptr nocapture %a) {
243241
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[CONT:.*]] ], [ -2147483642, %[[ENTRY]] ]
244242
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDVARS_IV]]
245243
; CHECK-NEXT: store i8 0, ptr [[ARRAYIDX]], align 1
246-
; CHECK-NEXT: [[TMP0:%.*]] = trunc nsw i64 [[INDVARS_IV]] to i32
247-
; CHECK-NEXT: [[TMP1:%.*]] = tail call { i32, i1 } @llvm.ssub.with.overflow.i32(i32 [[TMP0]], i32 1)
248-
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { i32, i1 } [[TMP1]], 1
249-
; CHECK-NEXT: br i1 [[TMP2]], label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
244+
; CHECK-NEXT: br i1 true, label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
250245
; CHECK: [[TRAP]]:
251246
; CHECK-NEXT: tail call void @llvm.trap(), !nosanitize [[META0]]
252247
; CHECK-NEXT: unreachable, !nosanitize [[META0]]
@@ -339,10 +334,7 @@ define void @f_usub_overflow(ptr nocapture %a) {
339334
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[CONT:.*]] ], [ 15, %[[ENTRY]] ]
340335
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDVARS_IV]]
341336
; CHECK-NEXT: store i8 0, ptr [[ARRAYIDX]], align 1
342-
; CHECK-NEXT: [[TMP0:%.*]] = trunc nuw nsw i64 [[INDVARS_IV]] to i32
343-
; CHECK-NEXT: [[TMP1:%.*]] = tail call { i32, i1 } @llvm.usub.with.overflow.i32(i32 [[TMP0]], i32 1)
344-
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { i32, i1 } [[TMP1]], 1
345-
; CHECK-NEXT: br i1 [[TMP2]], label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
337+
; CHECK-NEXT: br i1 true, label %[[TRAP:.*]], label %[[CONT]], !nosanitize [[META0]]
346338
; CHECK: [[TRAP]]:
347339
; CHECK-NEXT: tail call void @llvm.trap(), !nosanitize [[META0]]
348340
; CHECK-NEXT: unreachable, !nosanitize [[META0]]

0 commit comments

Comments
 (0)