Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][CallPromotionUtils]Extract a helper function versionCallSiteWithCond from versionCallSite #81181

Merged
merged 7 commits into from
May 14, 2024
39 changes: 24 additions & 15 deletions llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,9 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// Predicate and clone the given call site.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The "if" condition compares the call site's called value to the given
/// callee. The original call site is moved into the "else" block, and a clone
/// of the call site is placed in the "then" block. The cloned instruction is
/// returned.
/// site. The "if" condition is specified by `Cond`. The original call site is
/// moved into the "else" block, and a clone of the call site is placed in the
/// "then" block. The cloned instruction is returned.
///
/// For example, the call instruction below:
///
Expand All @@ -202,7 +201,7 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// Is replace by the following:
///
/// orig_bb:
/// %cond = icmp eq i32 ()* %ptr, @func
/// %cond = Cond
/// br i1 %cond, %then_bb, %else_bb
///
/// then_bb:
Expand Down Expand Up @@ -232,7 +231,7 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// Is replace by the following:
///
/// orig_bb:
/// %cond = icmp eq i32 ()* %ptr, @func
/// %cond = Cond
/// br i1 %cond, %then_bb, %else_bb
///
/// then_bb:
Expand Down Expand Up @@ -267,7 +266,7 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// Is replaced by the following:
///
/// cond_bb:
/// %cond = icmp eq i32 ()* %ptr, @func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure these conditions should be completely removed. Perhaps just show something like %cond = Cond.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

/// %cond = Cond
/// br i1 %cond, %then_bb, %orig_bb
///
/// then_bb:
Expand All @@ -280,19 +279,13 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// ; The original call instruction stays in its original block.
/// %t0 = musttail call i32 %ptr()
/// ret %t0
CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee,
MDNode *BranchWeights) {
static CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond,
MDNode *BranchWeights) {

IRBuilder<> Builder(&CB);
CallBase *OrigInst = &CB;
BasicBlock *OrigBlock = OrigInst->getParent();

// Create the compare. The called value and callee must have the same type to
// be compared.
if (CB.getCalledOperand()->getType() != Callee->getType())
Callee = Builder.CreateBitCast(Callee, CB.getCalledOperand()->getType());
auto *Cond = Builder.CreateICmpEQ(CB.getCalledOperand(), Callee);

if (OrigInst->isMustTailCall()) {
// Create an if-then structure. The original instruction stays in its block,
// and a clone of the original instruction is placed in the "then" block.
Expand Down Expand Up @@ -380,6 +373,22 @@ CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee,
return *NewInst;
}

// Predicate and clone the given call site using condition `CB.callee ==
// Callee`. See the comment `versionCallSiteWithCond` for the transformation.
CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee,
MDNode *BranchWeights) {

IRBuilder<> Builder(&CB);

// Create the compare. The called value and callee must have the same type to
// be compared.
if (CB.getCalledOperand()->getType() != Callee->getType())
Callee = Builder.CreateBitCast(Callee, CB.getCalledOperand()->getType());
auto *Cond = Builder.CreateICmpEQ(CB.getCalledOperand(), Callee);

return versionCallSiteWithCond(CB, Cond, BranchWeights);
}

bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee,
const char **FailureReason) {
assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
Expand Down
Loading