Skip to content

Commit

Permalink
[Coroutines] Modify CoroFrame materializable into a callback
Browse files Browse the repository at this point in the history
This change makes it possible to optionally provide a different callback to
determine if an instruction is materializable.

By default the behaviour is unchanged.

Differential Revision: https://reviews.llvm.org/D142621
  • Loading branch information
dstutt committed Feb 13, 2023
1 parent 3e51af9 commit c4f7cc8
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 21 deletions.
8 changes: 7 additions & 1 deletion llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
Expand Up @@ -22,7 +22,13 @@
namespace llvm {

struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
CoroSplitPass(bool OptimizeFrame = false) : OptimizeFrame(OptimizeFrame) {}
const std::function<bool(Instruction &)> MaterializableCallback;

CoroSplitPass(bool OptimizeFrame = false);
CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
bool OptimizeFrame = false)
: MaterializableCallback(MaterializableCallback),
OptimizeFrame(OptimizeFrame) {}

PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR);
Expand Down
31 changes: 19 additions & 12 deletions llvm/lib/Transforms/Coroutines/CoroFrame.cpp
Expand Up @@ -318,8 +318,6 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
LLVM_DEBUG(dump());
}

static bool materializable(Instruction &V);

namespace {

// RematGraph is used to construct a DAG for rematerializable instructions
Expand All @@ -342,9 +340,12 @@ struct RematGraph {
using RematNodeMap =
SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
RematNodeMap Remats;
const std::function<bool(Instruction &)> &MaterializableCallback;
SuspendCrossingInfo &Checker;

RematGraph(Instruction *I, SuspendCrossingInfo &Checker) : Checker(Checker) {
RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
Instruction *I, SuspendCrossingInfo &Checker)
: MaterializableCallback(MaterializableCallback), Checker(Checker) {
std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
EntryNode = FirstNode.get();
std::deque<std::unique_ptr<RematNode>> WorkList;
Expand All @@ -367,7 +368,7 @@ struct RematGraph {
Remats[N->Node] = std::move(NUPtr);
for (auto &Def : N->Node->operands()) {
Instruction *D = dyn_cast<Instruction>(Def.get());
if (!D || !materializable(*D) ||
if (!D || !MaterializableCallback(*D) ||
!Checker.isDefinitionAcrossSuspend(*D, FirstUse))
continue;

Expand Down Expand Up @@ -2211,11 +2212,12 @@ static void rewritePHIs(Function &F) {
rewritePHIs(*BB);
}

/// Default materializable callback
// Check for instructions that we can recreate on resume as opposed to spill
// the result into a coroutine frame.
static bool materializable(Instruction &V) {
return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V);
bool coro::defaultMaterializable(Instruction &V) {
return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
}

// Check for structural coroutine intrinsics that should not be spilled into
Expand Down Expand Up @@ -2887,14 +2889,16 @@ void coro::salvageDebugInfo(
}
}

static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
static void doRematerializations(
Function &F, SuspendCrossingInfo &Checker,
const std::function<bool(Instruction &)> &MaterializableCallback) {
SpillInfo Spills;

// See if there are materializable instructions across suspend points
// We record these as the starting point to also identify materializable
// defs of uses in these operations
for (Instruction &I : instructions(F)) {
if (!materializable(I))
if (!MaterializableCallback(I))
continue;
for (User *U : I.users())
if (Checker.isDefinitionAcrossSuspend(I, U))
Expand Down Expand Up @@ -2925,7 +2929,8 @@ static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
continue;

// Constructor creates the whole RematGraph for the given Use
auto RematUPtr = std::make_unique<RematGraph>(U, Checker);
auto RematUPtr =
std::make_unique<RematGraph>(MaterializableCallback, U, Checker);

LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
Expand All @@ -2943,7 +2948,9 @@ static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
rewriteMaterializableInstructions(AllRemats);
}

void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
void coro::buildCoroutineFrame(
Function &F, Shape &Shape,
const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
eliminateSwiftError(F, Shape);
Expand Down Expand Up @@ -2994,7 +3001,7 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
// Build suspend crossing info.
SuspendCrossingInfo Checker(F, Shape);

doRematerializations(F, Checker);
doRematerializations(F, Checker, MaterializableCallback);

FrameDataInfo FrameData;
SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas;
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Transforms/Coroutines/CoroInternal.h
Expand Up @@ -261,7 +261,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
void buildFrom(Function &F);
};

void buildCoroutineFrame(Function &F, Shape &Shape);
bool defaultMaterializable(Instruction &V);
void buildCoroutineFrame(
Function &F, Shape &Shape,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
Expand Down
19 changes: 12 additions & 7 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Expand Up @@ -1929,10 +1929,10 @@ namespace {
};
}

static coro::Shape splitCoroutine(Function &F,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI,
bool OptimizeFrame) {
static coro::Shape
splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI, bool OptimizeFrame,
std::function<bool(Instruction &)> MaterializableCallback) {
PrettyStackTraceFunction prettyStackTrace(F);

// The suspend-crossing algorithm in buildCoroutineFrame get tripped
Expand All @@ -1944,7 +1944,7 @@ static coro::Shape splitCoroutine(Function &F,
return Shape;

simplifySuspendPoints(Shape);
buildCoroutineFrame(F, Shape);
buildCoroutineFrame(F, Shape, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);

// If there are no suspend points, no split required, just remove
Expand Down Expand Up @@ -2104,6 +2104,10 @@ static void addPrepareFunction(const Module &M,
Fns.push_back(PrepareFn);
}

CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
: MaterializableCallback(coro::defaultMaterializable),
OptimizeFrame(OptimizeFrame) {}

PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR) {
Expand Down Expand Up @@ -2142,8 +2146,9 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
F.setSplittedCoroutine();

SmallVector<Function *, 4> Clones;
const coro::Shape Shape = splitCoroutine(
F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame);
const coro::Shape Shape =
splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
OptimizeFrame, MaterializableCallback);
updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM);

if (!Shape.CoroSuspends.empty()) {
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Transforms/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Coroutines)
add_subdirectory(IPO)
add_subdirectory(Scalar)
add_subdirectory(Utils)
Expand Down
18 changes: 18 additions & 0 deletions llvm/unittests/Transforms/Coroutines/CMakeLists.txt
@@ -0,0 +1,18 @@
set(LLVM_LINK_COMPONENTS
Analysis
AsmParser
Core
Coroutines
Passes
Support
TargetParser
TransformUtils
)

add_llvm_unittest(CoroTests
ExtraRematTest.cpp
)

target_link_libraries(CoroTests PRIVATE LLVMTestingSupport)

set_property(TARGET CoroTests PROPERTY FOLDER "Tests/UnitTests/TransformTests")
184 changes: 184 additions & 0 deletions llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
@@ -0,0 +1,184 @@
//===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Module.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Testing/Support/Error.h"
#include "llvm/Transforms/Coroutines/CoroSplit.h"
#include "gtest/gtest.h"

using namespace llvm;

namespace {

struct ExtraRematTest : public testing::Test {
LLVMContext Ctx;
ModulePassManager MPM;
PassBuilder PB;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
LLVMContext Context;
std::unique_ptr<Module> M;

ExtraRematTest() {
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
}

BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
for (BasicBlock &BB : *F) {
if (BB.getName() == Name)
return &BB;
}
return nullptr;
}

CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
for (Instruction &I : *BB) {
if (CallInst *CI = dyn_cast<CallInst>(&I))
if (CI->getCalledFunction()->getName() == Name)
return CI;
}
return nullptr;
}

void ParseAssembly(const StringRef IR) {
SMDiagnostic Error;
M = parseAssemblyString(IR, Error, Context);
std::string errMsg;
raw_string_ostream os(errMsg);
Error.print("", os);

// A failure here means that the test itself is buggy.
if (!M)
report_fatal_error(os.str().c_str());
}
};

StringRef Text = R"(
define ptr @f(i32 %n) presplitcoroutine {
entry:
%id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
%size = call i32 @llvm.coro.size.i32()
%alloc = call ptr @malloc(i32 %size)
%hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
%inc1 = add i32 %n, 1
%val2 = call i32 @should.remat(i32 %inc1)
%sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %sp1, label %suspend [i8 0, label %resume1
i8 1, label %cleanup]
resume1:
%inc2 = add i32 %val2, 1
%sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %sp1, label %suspend [i8 0, label %resume2
i8 1, label %cleanup]
resume2:
call void @print(i32 %val2)
call void @print(i32 %inc2)
br label %cleanup
cleanup:
%mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
call void @free(ptr %mem)
br label %suspend
suspend:
call i1 @llvm.coro.end(ptr %hdl, i1 0)
ret ptr %hdl
}
declare ptr @llvm.coro.free(token, ptr)
declare i32 @llvm.coro.size.i32()
declare i8 @llvm.coro.suspend(token, i1)
declare void @llvm.coro.resume(ptr)
declare void @llvm.coro.destroy(ptr)
declare token @llvm.coro.id(i32, ptr, ptr, ptr)
declare i1 @llvm.coro.alloc(token)
declare ptr @llvm.coro.begin(token, ptr)
declare i1 @llvm.coro.end(ptr, i1)
declare i32 @should.remat(i32)
declare noalias ptr @malloc(i32)
declare void @print(i32)
declare void @free(ptr)
)";

// Materializable callback with extra rematerialization
bool ExtraMaterializable(Instruction &I) {
if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) ||
isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I))
return true;

if (auto *CI = dyn_cast<CallInst>(&I)) {
auto *CalledFunc = CI->getCalledFunction();
if (CalledFunc && CalledFunc->getName().startswith("should.remat"))
return true;
}

return false;
}

TEST_F(ExtraRematTest, TestCoroRematDefault) {
ParseAssembly(Text);

ASSERT_TRUE(M);

CGSCCPassManager CGPM;
CGPM.addPass(CoroSplitPass());
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
MPM.run(*M, MAM);

// Verify that extra rematerializable instruction has been rematerialized
Function *F = M->getFunction("f.resume");
ASSERT_TRUE(F) << "could not find split function f.resume";

BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
ASSERT_TRUE(Resume1)
<< "could not find expected BB resume1 in split function";

// With default materialization the intrinsic should not have been
// rematerialized
CallInst *CI = getCallByName(Resume1, "should.remat");
ASSERT_FALSE(CI);
}

TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
ParseAssembly(Text);

ASSERT_TRUE(M);

CGSCCPassManager CGPM;
CGPM.addPass(
CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
MPM.run(*M, MAM);

// Verify that extra rematerializable instruction has been rematerialized
Function *F = M->getFunction("f.resume");
ASSERT_TRUE(F) << "could not find split function f.resume";

BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
ASSERT_TRUE(Resume1)
<< "could not find expected BB resume1 in split function";

// With callback the extra rematerialization of the function should have
// happened
CallInst *CI = getCallByName(Resume1, "should.remat");
ASSERT_TRUE(CI);
}
} // namespace

0 comments on commit c4f7cc8

Please sign in to comment.