Skip to content

Commit

Permalink
[LoopUnroll] Add check to Latch's terminator in UnrollRuntimeLoopRema…
Browse files Browse the repository at this point in the history
…inder

In this patch, I'm adding an extra check to the Latch's terminator in llvm::UnrollRuntimeLoopRemainder,
similar to how it is already done in the llvm::UnrollLoop.

The compiler would crash if this function is called with a malformed loop.

Patch by Rodrigo Caetano Rocha!

Differential Revision: https://reviews.llvm.org/D51486

llvm-svn: 342958
  • Loading branch information
davemgreen committed Sep 25, 2018
1 parent 029aa8e commit 9108c2b
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 5 deletions.
24 changes: 19 additions & 5 deletions llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
Expand Up @@ -545,13 +545,27 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,
BasicBlock *Header = L->getHeader();

BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator());

if (!LatchBR || LatchBR->isUnconditional()) {
// The loop-rotate pass can be helpful to avoid this in many cases.
LLVM_DEBUG(
dbgs()
<< "Loop latch not terminated by a conditional branch.\n");
return false;
}

unsigned ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0;
BasicBlock *LatchExit = LatchBR->getSuccessor(ExitIndex);
// Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the
// targets of the Latch be an exit block out of the loop. This needs
// to be guaranteed by the callers of UnrollRuntimeLoopRemainder.
assert(!L->contains(LatchExit) &&
"one of the loop latch successors should be the exit block!");

if (L->contains(LatchExit)) {
// Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the
// targets of the Latch be an exit block out of the loop.
LLVM_DEBUG(
dbgs()
<< "One of the loop latch successors must be the exit block.\n");
return false;
}

// These are exit blocks other than the target of the latch exiting block.
SmallVector<BasicBlock *, 4> OtherExits;
bool isMultiExitUnrollingEnabled =
Expand Down
27 changes: 27 additions & 0 deletions llvm/test/Transforms/LoopUnroll/runtime-loop-non-exiting-latch.ll
@@ -0,0 +1,27 @@
; REQUIRES: asserts
; RUN: opt < %s -S -loop-unroll -unroll-runtime=true -unroll-allow-remainder=true -unroll-count=4

; Make sure that the runtime unroll does not break with a non-exiting latch.
define i32 @test(i32* %a, i32* %b, i32* %c, i64 %n) {
entry:
br label %while.cond

while.cond: ; preds = %while.body, %entry
%i.0 = phi i64 [ 0, %entry ], [ %inc, %while.body ]
%cmp = icmp slt i64 %i.0, %n
br i1 %cmp, label %while.body, label %while.end

while.body: ; preds = %while.cond
%arrayidx = getelementptr inbounds i32, i32* %b, i64 %i.0
%0 = load i32, i32* %arrayidx
%arrayidx1 = getelementptr inbounds i32, i32* %c, i64 %i.0
%1 = load i32, i32* %arrayidx1
%mul = mul nsw i32 %0, %1
%arrayidx2 = getelementptr inbounds i32, i32* %a, i64 %i.0
store i32 %mul, i32* %arrayidx2
%inc = add nsw i64 %i.0, 1
br label %while.cond

while.end: ; preds = %while.cond
ret i32 0
}
1 change: 1 addition & 0 deletions llvm/unittests/Transforms/Utils/CMakeLists.txt
Expand Up @@ -15,5 +15,6 @@ add_llvm_unittest(UtilsTests
IntegerDivisionTest.cpp
LocalTest.cpp
SSAUpdaterBulkTest.cpp
UnrollLoopTest.cpp
ValueMapperTest.cpp
)
76 changes: 76 additions & 0 deletions llvm/unittests/Transforms/Utils/UnrollLoopTest.cpp
@@ -0,0 +1,76 @@
//===- UnrollLoopTest.cpp - Unit tests for UnrollLoop ---------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/UnrollLoop.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"

using namespace llvm;

static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
SMDiagnostic Err;
std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
if (!Mod)
Err.print("UnrollLoopTests", errs());
return Mod;
}

TEST(LoopUnrollRuntime, Latch) {
LLVMContext C;

std::unique_ptr<Module> M = parseIR(
C,
R"(define i32 @test(i32* %a, i32* %b, i32* %c, i64 %n) {
entry:
br label %while.cond
while.cond: ; preds = %while.body, %entry
%i.0 = phi i64 [ 0, %entry ], [ %inc, %while.body ]
%cmp = icmp slt i64 %i.0, %n
br i1 %cmp, label %while.body, label %while.end
while.body: ; preds = %while.cond
%arrayidx = getelementptr inbounds i32, i32* %b, i64 %i.0
%0 = load i32, i32* %arrayidx
%arrayidx1 = getelementptr inbounds i32, i32* %c, i64 %i.0
%1 = load i32, i32* %arrayidx1
%mul = mul nsw i32 %0, %1
%arrayidx2 = getelementptr inbounds i32, i32* %a, i64 %i.0
store i32 %mul, i32* %arrayidx2
%inc = add nsw i64 %i.0, 1
br label %while.cond
while.end: ; preds = %while.cond
ret i32 0
})"
);

auto *F = M->getFunction("test");
DominatorTree DT(*F);
LoopInfo LI(DT);
AssumptionCache AC(*F);
TargetLibraryInfoImpl TLII;
TargetLibraryInfo TLI(TLII);
ScalarEvolution SE(*F, TLI, AC, DT, LI);

Loop *L = *LI.begin();

bool PreserveLCSSA = L->isRecursivelyLCSSAForm(DT,LI);

bool ret = UnrollRuntimeLoopRemainder(L, 4, true, false, false, &LI, &SE, &DT, &AC, PreserveLCSSA);
EXPECT_FALSE(ret);
}

0 comments on commit 9108c2b

Please sign in to comment.