diff --git a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h index cd2ae61334d0f..682ae877f7ae3 100644 --- a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h +++ b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h @@ -100,20 +100,11 @@ template class CFGMST { // i8 0, label %await.ready // i8 1, label %exit // ] - const BasicBlock *EdgeTarget = E->DestBB; - if (!EdgeTarget) + if (!E->DestBB) return; assert(E->SrcBB); - const Function *F = EdgeTarget->getParent(); - if (!F->isPresplitCoroutine()) - return; - - const Instruction *TI = E->SrcBB->getTerminator(); - if (auto *SWInst = dyn_cast(TI)) - if (auto *Intrinsic = dyn_cast(SWInst->getCondition())) - if (Intrinsic->getIntrinsicID() == Intrinsic::coro_suspend && - SWInst->getDefaultDest() == EdgeTarget) - E->Removed = true; + if (llvm::isPresplitCoroSuspendExitEdge(*E->SrcBB, *E->DestBB)) + E->Removed = true; } // Traverse the CFG using a stack. Find all the edges and assign the weight. diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h index e6dde450b7df9..e650ac80efbcd 100644 --- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h +++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h @@ -705,6 +705,25 @@ void InvertBranch(BranchInst *PBI, IRBuilderBase &Builder); // Check whether the function only has simple terminator: // br/brcond/unreachable/ret bool hasOnlySimpleTerminator(const Function &F); + +// Returns true if these basic blocks belong to a presplit coroutine and the +// edge corresponds to the 'default' case in the switch statement in the +// pattern: +// +// %0 = call i8 @llvm.coro.suspend(token none, i1 false) +// switch i8 %0, label %suspend [i8 0, label %resume +// i8 1, label %cleanup] +// +// i.e. the edge to the `%suspend` BB. This edge is special in that it will +// be elided by coroutine lowering (coro-split), and the `%suspend` BB needs +// to be kept as-is. It's not a real CFG edge - post-lowering, it will end +// up being a `ret`, and it must be thus lowerable to support symmetric +// transfer. For example: +// - this edge is not a loop exit edge if encountered in a loop (and should +// be ignored) +// - must not be split for PGO instrumentation, for example. +bool isPresplitCoroSuspendExitEdge(const BasicBlock &Src, + const BasicBlock &Dest); } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_BASICBLOCKUTILS_H diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index 480817a23d2c2..5033aa8f046f5 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Instrumentation/InstrProfiling.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" @@ -23,6 +24,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" @@ -48,6 +50,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include @@ -243,7 +246,10 @@ class PGOCounterPromoter { return; for (BasicBlock *ExitBlock : LoopExitBlocks) { - if (BlockSet.insert(ExitBlock).second) { + if (BlockSet.insert(ExitBlock).second && + llvm::none_of(predecessors(ExitBlock), [&](const BasicBlock *Pred) { + return llvm::isPresplitCoroSuspendExitEdge(*Pred, *ExitBlock); + })) { ExitBlocks.push_back(ExitBlock); InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); } diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index 05ff4efb7b944..c4c5e415166e5 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -2086,3 +2086,15 @@ bool llvm::hasOnlySimpleTerminator(const Function &F) { } return true; } + +bool llvm::isPresplitCoroSuspendExitEdge(const BasicBlock &Src, + const BasicBlock &Dest) { + assert(Src.getParent() == Dest.getParent()); + if (!Src.getParent()->isPresplitCoroutine()) + return false; + if (auto *SW = dyn_cast(Src.getTerminator())) + if (auto *Intr = dyn_cast(SW->getCondition())) + return Intr->getIntrinsicID() == Intrinsic::coro_suspend && + SW->getDefaultDest() == &Dest; + return false; +} diff --git a/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll new file mode 100644 index 0000000000000..ddd293eed2409 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll @@ -0,0 +1,175 @@ +; REQUIRES: x86-registered-target +; RUN: opt -passes='pgo-instr-gen,instrprof,coro-split' -do-counter-promotion=true -S < %s | FileCheck %s + +; CHECK-LABEL: define internal fastcc void @f.resume +; CHECK: musttail call fastcc void +; CHECK-NEXT: ret void +; CHECK: musttail call fastcc void +; CHECK-NEXT: ret void +; CHECK-LABEL: define internal fastcc void @f.destroy +target triple = "x86_64-grtev4-linux-gnu" + +%CoroutinePromise = type { ptr, i64, [8 x i8], ptr} +%Awaitable.1 = type { ptr } +%Awaitable.2 = type { ptr, ptr } + +declare void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1), ptr) local_unnamed_addr +declare ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16), ptr) local_unnamed_addr +declare void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32)) local_unnamed_addr +declare ptr @other_coro(); +declare void @heap_delete(ptr noundef, i64 noundef, i64 noundef) local_unnamed_addr +declare noundef nonnull ptr @heap_allocate(i64 noundef, i64 noundef) local_unnamed_addr + +declare void @llvm.assume(i1 noundef) +declare i64 @llvm.coro.align.i64() +declare i1 @llvm.coro.alloc(token) +declare ptr @llvm.coro.begin(token, ptr writeonly) +declare i1 @llvm.coro.end(ptr, i1, token) +declare ptr @llvm.coro.free(token, ptr nocapture readonly) +declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr) +declare token @llvm.coro.save(ptr) +declare i64 @llvm.coro.size.i64() +declare ptr @llvm.coro.subfn.addr(ptr nocapture readonly, i8) +declare i8 @llvm.coro.suspend(token, i1) +declare void @llvm.instrprof.increment(ptr, i64, i32, i32) +declare void @llvm.instrprof.value.profile(ptr, i64, i64, i32, i32) +declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) +declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) + +; Function Attrs: noinline nounwind presplitcoroutine uwtable +define ptr @f(i32 %0) presplitcoroutine align 32 { + %2 = alloca i32, align 8 + %3 = alloca %CoroutinePromise, align 16 + %4 = alloca %Awaitable.1, align 8 + %5 = alloca %Awaitable.2, align 8 + %6 = call token @llvm.coro.id(i32 8, ptr nonnull %3, ptr nonnull @f, ptr null) + %7 = call i1 @llvm.coro.alloc(token %6) + br i1 %7, label %8, label %12 + +8: ; preds = %1 + %9 = call i64 @llvm.coro.size.i64() + %10 = call i64 @llvm.coro.align.i64() + %11 = call noalias noundef nonnull ptr @heap_allocate(i64 noundef %9, i64 noundef %10) #27 + call void @llvm.assume(i1 true) [ "align"(ptr %11, i64 %10) ] + br label %12 + +12: ; preds = %8, %1 + %13 = phi ptr [ null, %1 ], [ %11, %8 ] + %14 = call ptr @llvm.coro.begin(token %6, ptr %13) #28 + call void @llvm.lifetime.start.p0(i64 32, ptr nonnull %3) #9 + store ptr null, ptr %3, align 16 + %15 = getelementptr inbounds {ptr, i64}, ptr %3, i64 0, i32 1 + store i64 0, ptr %15, align 8 + call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %4) #9 + store ptr %3, ptr %4, align 8 + %16 = call token @llvm.coro.save(ptr null) + call void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1) %4, ptr %14) #9 + %17 = call i8 @llvm.coro.suspend(token %16, i1 false) + switch i8 %17, label %61 [ + i8 0, label %18 + i8 1, label %21 + ] + +18: ; preds = %12 + call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9 + %19 = icmp slt i32 0, %0 + br i1 %19, label %20, label %36 + +20: ; preds = %18 + br label %22 + +21: ; preds = %12 + call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9 + br label %54 + +22: ; preds = %20, %31 + %23 = phi i32 [ 0, %20 ], [ %32, %31 ] + call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %5) #9 + %24 = call ptr @other_coro() + store ptr %3, ptr %5, align 8 + %25 = getelementptr inbounds { ptr, ptr }, ptr %5, i64 0, i32 1 + store ptr %24, ptr %25, align 8 + %26 = call token @llvm.coro.save(ptr null) + %27 = call ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16) %5, ptr %14) + %28 = call ptr @llvm.coro.subfn.addr(ptr %27, i8 0) + %29 = ptrtoint ptr %28 to i64 + call fastcc void %28(ptr %27) #9 + %30 = call i8 @llvm.coro.suspend(token %26, i1 false) + switch i8 %30, label %60 [ + i8 0, label %31 + i8 1, label %34 + ] + +31: ; preds = %22 + call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9 + %32 = add nuw nsw i32 %23, 1 + %33 = icmp slt i32 %32, %0 + br i1 %33, label %22, label %35, !llvm.loop !0 + +34: ; preds = %22 + call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9 + br label %54 + +35: ; preds = %31 + br label %36 + +36: ; preds = %35, %18 + %37 = call token @llvm.coro.save(ptr null) + %38 = getelementptr inbounds i8, ptr %14, i64 16 + %39 = getelementptr inbounds i8, ptr %14, i64 32 + %40 = load i64, ptr %39, align 8 + %41 = load ptr, ptr %38, align 16 + %42 = icmp eq ptr %41, null + br i1 %42, label %43, label %46 + +43: ; preds = %36 + %44 = call ptr @llvm.coro.subfn.addr(ptr nonnull %14, i8 1) + %45 = ptrtoint ptr %44 to i64 + call fastcc void %44(ptr nonnull %14) #9 + br label %47 + +46: ; preds = %36 + call void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32) %38) #9 + br label %47 + +47: ; preds = %43, %46 + %48 = inttoptr i64 %40 to ptr + %49 = call ptr @llvm.coro.subfn.addr(ptr %48, i8 0) + %50 = ptrtoint ptr %49 to i64 + call fastcc void %49(ptr %48) #9 + %51 = call i8 @llvm.coro.suspend(token %37, i1 true) #28 + switch i8 %51, label %61 [ + i8 0, label %53 + i8 1, label %52 + ] + +52: ; preds = %47 + br label %54 + +53: ; preds = %47 + call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %2) #9 + unreachable + +54: ; preds = %52, %34, %21 + call void @llvm.lifetime.end.p0(i64 32, ptr nonnull %3) #9 + %55 = call ptr @llvm.coro.free(token %6, ptr %14) + %56 = icmp eq ptr %55, null + br i1 %56, label %61, label %57 + +57: ; preds = %54 + %58 = call i64 @llvm.coro.size.i64() + %59 = call i64 @llvm.coro.align.i64() + call void @heap_delete(ptr noundef nonnull %55, i64 noundef %58, i64 noundef %59) #9 + br label %61 + +60: ; preds = %22 + br label %61 + +61: ; preds = %60, %57, %54, %47, %12 + %62 = getelementptr inbounds i8, ptr %3, i64 -16 + %63 = call i1 @llvm.coro.end(ptr null, i1 false, token none) #28 + ret ptr %62 +} + +!0 = distinct !{!0, !1} +!1 = !{!"llvm.loop.mustprogress"} diff --git a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp index 2da38c60044be..5152cbe19c215 100644 --- a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp @@ -611,3 +611,64 @@ switch i32 %0, label %LD [ EXPECT_EQ(BranchProbability::getRaw(1), BPI.getEdgeProbability(EntryBB, UnreachableBB)); } + +TEST(BasicBlockUtils, IsPresplitCoroSuspendExitTest) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"IR( +define void @positive_case(i32 %0) #0 { +entry: + %save = call token @llvm.coro.save(ptr null) + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %resume + i8 1, label %destroy + ] +resume: + ret void +destroy: + ret void +exit: + call i1 @llvm.coro.end(ptr null, i1 false, token none) + ret void +} + +define void @notpresplit(i32 %0) { +entry: + %save = call token @llvm.coro.save(ptr null) + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %resume + i8 1, label %destroy + ] +resume: + ret void +destroy: + ret void +exit: + call i1 @llvm.coro.end(ptr null, i1 false, token none) + ret void +} + +declare token @llvm.coro.save(ptr) +declare i8 @llvm.coro.suspend(token, i1) +declare i1 @llvm.coro.end(ptr, i1, token) + +attributes #0 = { presplitcoroutine } +)IR"); + + auto FindExit = [](const Function &F) -> const BasicBlock * { + for (const auto &BB : F) + if (BB.getName() == "exit") + return &BB; + return nullptr; + }; + Function *P = M->getFunction("positive_case"); + const auto &ExitP = *FindExit(*P); + EXPECT_TRUE(llvm::isPresplitCoroSuspendExitEdge(*ExitP.getSinglePredecessor(), + ExitP)); + + Function *N = M->getFunction("notpresplit"); + const auto &ExitN = *FindExit(*N); + EXPECT_FALSE(llvm::isPresplitCoroSuspendExitEdge( + *ExitN.getSinglePredecessor(), ExitN)); +}