Skip to content

Commit

Permalink
[attrs] Handle convergent CallSites.
Browse files Browse the repository at this point in the history
Summary:
Previously we had a notion of convergent functions but not of convergent
calls.  This is insufficient to correctly analyze calls where the target
is unknown, e.g. indirect calls.

Now a call is convergent if it targets a known-convergent function, or
if it's explicitly marked as convergent.  As usual, we can remove
convergent where we can prove that no convergent operations are
performed in the call.

Reviewers: chandlerc, jingyue

Subscribers: hfinkel, jhen, tra, llvm-commits

Differential Revision: http://reviews.llvm.org/D17317

llvm-svn: 261544
  • Loading branch information
Justin Lebar committed Feb 22, 2016
1 parent f62b165 commit 7bf9187
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 56 deletions.
62 changes: 25 additions & 37 deletions llvm/lib/Transforms/IPO/FunctionAttrs.cpp
Expand Up @@ -903,49 +903,37 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes,
return MadeChange;
}

/// Removes convergent attributes where we can prove that none of the SCC's
/// callees are themselves convergent. Returns true if successful at removing
/// the attribute.
/// Remove the convergent attribute from all functions in the SCC if every
/// callsite within the SCC is not convergent (except for calls to functions
/// within the SCC). Returns true if changes were made.
static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) {
// Determines whether a function can be made non-convergent, ignoring all
// other functions in SCC. (A function can *actually* be made non-convergent
// only if all functions in its SCC can be made convergent.)
auto CanRemoveConvergent = [&](Function *F) {
if (!F->isConvergent())
return true;

// Can't remove convergent from declarations.
if (F->isDeclaration())
return false;

for (Instruction &I : instructions(*F))
if (auto CS = CallSite(&I)) {
// Can't remove convergent if any of F's callees -- ignoring functions
// in the SCC itself -- are convergent. This needs to consider both
// function calls and intrinsic calls. We also assume indirect calls
// might call a convergent function.
// FIXME: We should revisit this when we put convergent onto calls
// instead of functions so that indirect calls which should be
// convergent are required to be marked as such.
Function *Callee = CS.getCalledFunction();
if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent()))
return false;
}

return true;
};
// No point checking if none of SCCNodes is convergent.
if (!llvm::any_of(SCCNodes, [](Function *F) { return F->isConvergent(); }))
return false;

// We can remove the convergent attr from functions in the SCC if they all
// can be made non-convergent (because they call only non-convergent
// functions, other than each other).
if (!llvm::all_of(SCCNodes, CanRemoveConvergent))
// Can't remove convergent from function declarations.
if (llvm::any_of(SCCNodes, [](Function *F) { return F->isDeclaration(); }))
return false;

// If we got here, all of the SCC's callees are non-convergent. Therefore all
// of the SCC's functions can be marked as non-convergent.
// Can't remove convergent if any of our functions has a convergent call to a
// function not in the SCC.
for (Function *F : SCCNodes)
for (Instruction &I : instructions(*F)) {
CallSite CS(&I);
// Bail if is CS a convergent call to a function not in the SCC.
if (CS && CS.isConvergent() &&
SCCNodes.count(CS.getCalledFunction()) == 0)
return false;
}

// If we got here, all of the calls the SCC makes to functions not in the SCC
// are non-convergent. Therefore all of the SCC's functions can also be made
// non-convergent. We'll remove the attr from the callsites in
// InstCombineCalls.
for (Function *F : SCCNodes) {
if (F->isConvergent())
DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n");
DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName()
<< "\n");
F->setNotConvergent();
}
return true;
Expand Down
11 changes: 10 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Expand Up @@ -2070,7 +2070,15 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
if (!isa<Function>(Callee) && transformConstExprCastCall(CS))
return nullptr;

if (Function *CalleeF = dyn_cast<Function>(Callee))
if (Function *CalleeF = dyn_cast<Function>(Callee)) {
// Remove the convergent attr on calls when the callee is not convergent.
if (CS.isConvergent() && !CalleeF->isConvergent()) {
DEBUG(dbgs() << "Removing convergent attr from instr "
<< CS.getInstruction() << "\n");
CS.setNotConvergent();
return CS.getInstruction();
}

// If the call and callee calling conventions don't match, this call must
// be unreachable, as the call is undefined.
if (CalleeF->getCallingConv() != CS.getCallingConv() &&
Expand All @@ -2095,6 +2103,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
Constant::getNullValue(CalleeF->getType()));
return nullptr;
}
}

if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) {
// If CS does not return void then replaceAllUsesWith undef.
Expand Down
48 changes: 30 additions & 18 deletions llvm/test/Transforms/FunctionAttrs/convergent.ll
@@ -1,4 +1,4 @@
; RUN: opt < %s -basicaa -functionattrs -rpo-functionattrs -S | FileCheck %s
; RUN: opt -functionattrs -S < %s | FileCheck %s

; CHECK: Function Attrs
; CHECK-NOT: convergent
Expand All @@ -24,18 +24,39 @@ declare i32 @k() convergent
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @extern()
define i32 @extern() convergent {
%a = call i32 @k()
%a = call i32 @k() convergent
ret i32 %a
}

; Convergent should not be removed on the function here. Although the call is
; not explicitly convergent, it picks up the convergent attr from the callee.
;
; CHECK: Function Attrs
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @call_extern()
define i32 @call_extern() convergent {
%a = call i32 @extern()
; CHECK-NEXT: define i32 @extern_non_convergent_call()
define i32 @extern_non_convergent_call() convergent {
%a = call i32 @k()
ret i32 %a
}

; CHECK: Function Attrs
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @indirect_convergent_call(
define i32 @indirect_convergent_call(i32 ()* %f) convergent {
%a = call i32 %f() convergent
ret i32 %a
}
; Give indirect_non_convergent_call the norecurse attribute so we get a
; "Function Attrs" comment in the output.
;
; CHECK: Function Attrs
; CHECK-NOT: convergent
; CHECK-NEXT: define i32 @indirect_non_convergent_call(
define i32 @indirect_non_convergent_call(i32 ()* %f) convergent norecurse {
%a = call i32 %f()
ret i32 %a
}

; CHECK: Function Attrs
; CHECK-SAME: convergent
; CHECK-NEXT: declare void @llvm.cuda.syncthreads()
Expand All @@ -45,41 +66,32 @@ declare void @llvm.cuda.syncthreads() convergent
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @intrinsic()
define i32 @intrinsic() convergent {
; Implicitly convergent, because the intrinsic is convergent.
call void @llvm.cuda.syncthreads()
ret i32 0
}

@xyz = global i32 ()* null
; CHECK: Function Attrs
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @functionptr()
define i32 @functionptr() convergent {
%1 = load i32 ()*, i32 ()** @xyz
%2 = call i32 %1()
ret i32 %2
}

; CHECK: Function Attrs
; CHECK-NOT: convergent
; CHECK-NEXT: define i32 @recursive1()
define i32 @recursive1() convergent {
%a = call i32 @recursive2()
%a = call i32 @recursive2() convergent
ret i32 %a
}

; CHECK: Function Attrs
; CHECK-NOT: convergent
; CHECK-NEXT: define i32 @recursive2()
define i32 @recursive2() convergent {
%a = call i32 @recursive1()
%a = call i32 @recursive1() convergent
ret i32 %a
}

; CHECK: Function Attrs
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @noopt()
define i32 @noopt() convergent optnone noinline {
%a = call i32 @noopt_friend()
%a = call i32 @noopt_friend() convergent
ret i32 0
}

Expand Down
33 changes: 33 additions & 0 deletions llvm/test/Transforms/InstCombine/convergent.ll
@@ -0,0 +1,33 @@
; RUN: opt -instcombine -S < %s | FileCheck %s

declare i32 @k() convergent
declare i32 @f()

define i32 @extern() {
; Convergent attr shouldn't be removed here; k is convergent.
; CHECK: call i32 @k() [[CONVERGENT_ATTR:#[0-9]+]]
%a = call i32 @k() convergent
ret i32 %a
}

define i32 @extern_no_attr() {
; Convergent attr shouldn't be added here, even though k is convergent.
; CHECK: call i32 @k(){{$}}
%a = call i32 @k()
ret i32 %a
}

define i32 @no_extern() {
; Convergent should be removed here, as the target is convergent.
; CHECK: call i32 @f(){{$}}
%a = call i32 @f() convergent
ret i32 %a
}

define i32 @indirect_call(i32 ()* %f) {
; CHECK call i32 %f() [[CONVERGENT_ATTR]]
%a = call i32 %f() convergent
ret i32 %a
}

; CHECK: [[CONVERGENT_ATTR]] = { convergent }

0 comments on commit 7bf9187

Please sign in to comment.