Skip to content

Commit

Permalink
[Polly][Ast] Partial refactoring of IslAst and IslAstInfo to use isl+…
Browse files Browse the repository at this point in the history
…+. NFC.

Polly use algorithms from the Integer Set Library (isl), which is a library written in C and which is incompatible with the rest of the LLVM as it is written in C++.

Changes made:
 - Refactoring the following methods of class `IslAst`
  - `getAst()` `getRunCondition()` `buildRunCondition()`
  - Removed the destructor in favor of the default one
 - Change the type of the attribute `IslAst.RunCondition` to `isl::ast_expr`
 - Change the type of the attribute `IslAst.Root` to `isl::ast_node`
 - Change the order of attributes in class `IslAst` to reflect the data dependencies so that the destructor won't complain
 - Refactoring the following methods of class `IslAstInfo`
  - `getAst()` `getRunCondition()`

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D100265
  • Loading branch information
patacca authored and pull[bot] committed Feb 16, 2024
1 parent 778a337 commit 1464953
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 67 deletions.
16 changes: 7 additions & 9 deletions polly/include/polly/CodeGen/IslAst.h
Expand Up @@ -37,34 +37,32 @@ class IslAst {
IslAst &operator=(const IslAst &) = delete;
IslAst(IslAst &&);
IslAst &operator=(IslAst &&) = delete;
~IslAst();

static IslAst create(Scop &Scop, const Dependences &D);

/// Print a source code representation of the program.
void pprint(raw_ostream &OS);

__isl_give isl_ast_node *getAst();
isl::ast_node getAst();

const std::shared_ptr<isl_ctx> getSharedIslCtx() const { return Ctx; }

/// Get the run-time conditions for the Scop.
__isl_give isl_ast_expr *getRunCondition();
isl::ast_expr getRunCondition();

/// Build run-time condition for scop.
///
/// @param S The scop to build the condition for.
/// @param Build The isl_build object to use to build the condition.
///
/// @returns An ast expression that describes the necessary run-time check.
static isl_ast_expr *buildRunCondition(Scop &S,
__isl_keep isl_ast_build *Build);
static isl::ast_expr buildRunCondition(Scop &S, const isl::ast_build &Build);

private:
Scop &S;
isl_ast_node *Root = nullptr;
isl_ast_expr *RunCondition = nullptr;
std::shared_ptr<isl_ctx> Ctx;
isl::ast_expr RunCondition;
isl::ast_node Root;

IslAst(Scop &Scop);

Expand Down Expand Up @@ -120,15 +118,15 @@ class IslAstInfo {
IslAst &getIslAst() { return Ast; }

/// Return a copy of the AST root node.
__isl_give isl_ast_node *getAst();
isl::ast_node getAst();

/// Get the run condition.
///
/// Only if the run condition evaluates at run-time to a non-zero value, the
/// assumptions that have been taken hold. If the run condition evaluates to
/// zero/false some assumptions do not hold and the original code needs to
/// be executed.
__isl_give isl_ast_expr *getRunCondition();
isl::ast_expr getRunCondition();

void print(raw_ostream &O);

Expand Down
10 changes: 4 additions & 6 deletions polly/lib/CodeGen/CodeGeneration.cpp
Expand Up @@ -188,8 +188,8 @@ static bool CodeGen(Scop &S, IslAstInfo &AI, LoopInfo &LI, DominatorTree &DT,
}

// Check if we created an isl_ast root node, otherwise exit.
isl_ast_node *AstRoot = Ast.getAst();
if (!AstRoot)
isl::ast_node AstRoot = Ast.getAst();
if (AstRoot.is_null())
return false;

// Collect statistics. Do it before we modify the IR to avoid having it any
Expand Down Expand Up @@ -266,11 +266,9 @@ static bool CodeGen(Scop &S, IslAstInfo &AI, LoopInfo &LI, DominatorTree &DT,
assert(ExitingBB);
DT.changeImmediateDominator(MergeBlock, ExitingBB);
DT.eraseNode(ExitingBlock);

isl_ast_node_free(AstRoot);
} else {
NodeBuilder.addParameters(S.getContext().release());
Value *RTC = NodeBuilder.createRTC(AI.getRunCondition());
Value *RTC = NodeBuilder.createRTC(AI.getRunCondition().release());

Builder.GetInsertBlock()->getTerminator()->setOperand(0, RTC);

Expand All @@ -282,7 +280,7 @@ static bool CodeGen(Scop &S, IslAstInfo &AI, LoopInfo &LI, DominatorTree &DT,
// between polly.start and polly.exiting (at this point).
Builder.SetInsertPoint(StartBlock->getTerminator());

NodeBuilder.create(AstRoot);
NodeBuilder.create(AstRoot.release());
NodeBuilder.finalize();
fixRegionInfo(*EnteringBB->getParent(), *R->getParent(), RI);

Expand Down
81 changes: 33 additions & 48 deletions polly/lib/CodeGen/IslAst.cpp
Expand Up @@ -398,23 +398,22 @@ static isl::ast_expr buildCondition(Scop &S, isl::ast_build Build,
return NonAliasGroup;
}

__isl_give isl_ast_expr *
IslAst::buildRunCondition(Scop &S, __isl_keep isl_ast_build *Build) {
isl_ast_expr *RunCondition;
isl::ast_expr IslAst::buildRunCondition(Scop &S, const isl::ast_build &Build) {
isl::ast_expr RunCondition;

// The conditions that need to be checked at run-time for this scop are
// available as an isl_set in the runtime check context from which we can
// directly derive a run-time condition.
auto *PosCond =
isl_ast_build_expr_from_set(Build, S.getAssumedContext().release());
auto PosCond = Build.expr_from(S.getAssumedContext());
if (S.hasTrivialInvalidContext()) {
RunCondition = PosCond;
RunCondition = std::move(PosCond);
} else {
auto *ZeroV = isl_val_zero(isl_ast_build_get_ctx(Build));
auto *NegCond =
isl_ast_build_expr_from_set(Build, S.getInvalidContext().release());
auto *NotNegCond = isl_ast_expr_eq(isl_ast_expr_from_val(ZeroV), NegCond);
RunCondition = isl_ast_expr_and(PosCond, NotNegCond);
auto ZeroV = isl::val::zero(Build.get_ctx());
auto NegCond = Build.expr_from(S.getInvalidContext());
auto NotNegCond =
isl::ast_expr::from_val(std::move(ZeroV)).eq(std::move(NegCond));
RunCondition =
isl::manage(isl_ast_expr_and(PosCond.release(), NotNegCond.release()));
}

// Create the alias checks from the minimal/maximal accesses in each alias
Expand All @@ -429,15 +428,13 @@ IslAst::buildRunCondition(Scop &S, __isl_keep isl_ast_build *Build) {
for (auto RWAccIt0 = MinMaxReadWrite.begin(); RWAccIt0 != RWAccEnd;
++RWAccIt0) {
for (auto RWAccIt1 = RWAccIt0 + 1; RWAccIt1 != RWAccEnd; ++RWAccIt1)
RunCondition = isl_ast_expr_and(
RunCondition,
buildCondition(S, isl::manage_copy(Build), RWAccIt0, RWAccIt1)
.release());
RunCondition = isl::manage(isl_ast_expr_and(
RunCondition.release(),
buildCondition(S, Build, RWAccIt0, RWAccIt1).release()));
for (const Scop::MinMaxAccessTy &ROAccIt : MinMaxReadOnly)
RunCondition = isl_ast_expr_and(
RunCondition,
buildCondition(S, isl::manage_copy(Build), RWAccIt0, &ROAccIt)
.release());
RunCondition = isl::manage(isl_ast_expr_and(
RunCondition.release(),
buildCondition(S, Build, RWAccIt0, &ROAccIt).release()));
}
}

Expand Down Expand Up @@ -465,10 +462,10 @@ static bool benefitsFromPolly(Scop &Scop, bool PerformParallelTest) {
}

/// Collect statistics for the syntax tree rooted at @p Ast.
static void walkAstForStatistics(__isl_keep isl_ast_node *Ast) {
assert(Ast);
static void walkAstForStatistics(const isl::ast_node &Ast) {
assert(!Ast.is_null());
isl_ast_node_foreach_descendant_top_down(
Ast,
Ast.get(),
[](__isl_keep isl_ast_node *Node, void *User) -> isl_bool {
switch (isl_ast_node_get_type(Node)) {
case isl_ast_node_for:
Expand Down Expand Up @@ -502,15 +499,8 @@ static void walkAstForStatistics(__isl_keep isl_ast_node *Ast) {
IslAst::IslAst(Scop &Scop) : S(Scop), Ctx(Scop.getSharedIslCtx()) {}

IslAst::IslAst(IslAst &&O)
: S(O.S), Root(O.Root), RunCondition(O.RunCondition), Ctx(O.Ctx) {
O.Root = nullptr;
O.RunCondition = nullptr;
}

IslAst::~IslAst() {
isl_ast_node_free(Root);
isl_ast_expr_free(RunCondition);
}
: S(O.S), Ctx(O.Ctx), RunCondition(std::move(O.RunCondition)),
Root(std::move(O.Root)) {}

void IslAst::init(const Dependences &D) {
bool PerformParallelTest = PollyParallel || DetectParallel ||
Expand Down Expand Up @@ -557,9 +547,10 @@ void IslAst::init(const Dependences &D) {
&BuildInfo);
}

RunCondition = buildRunCondition(S, Build);
RunCondition = buildRunCondition(S, isl::manage_copy(Build));

Root = isl_ast_build_node_from_schedule(Build, S.getScheduleTree().release());
Root = isl::manage(
isl_ast_build_node_from_schedule(Build, S.getScheduleTree().release()));
walkAstForStatistics(Root);

isl_ast_build_free(Build);
Expand All @@ -571,15 +562,11 @@ IslAst IslAst::create(Scop &Scop, const Dependences &D) {
return Ast;
}

__isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); }
__isl_give isl_ast_expr *IslAst::getRunCondition() {
return isl_ast_expr_copy(RunCondition);
}
isl::ast_node IslAst::getAst() { return Root; }
isl::ast_expr IslAst::getRunCondition() { return RunCondition; }

__isl_give isl_ast_node *IslAstInfo::getAst() { return Ast.getAst(); }
__isl_give isl_ast_expr *IslAstInfo::getRunCondition() {
return Ast.getRunCondition();
}
isl::ast_node IslAstInfo::getAst() { return Ast.getAst(); }
isl::ast_expr IslAstInfo::getRunCondition() { return Ast.getRunCondition(); }

IslAstUserPayload *IslAstInfo::getNodePayload(const isl::ast_node &Node) {
isl::id Id = Node.get_annotation();
Expand Down Expand Up @@ -745,12 +732,12 @@ static __isl_give isl_printer *cbPrintUser(__isl_take isl_printer *P,

void IslAstInfo::print(raw_ostream &OS) {
isl_ast_print_options *Options;
isl_ast_node *RootNode = Ast.getAst();
isl::ast_node RootNode = Ast.getAst();
Function &F = S.getFunction();

OS << ":: isl ast :: " << F.getName() << " :: " << S.getNameStr() << "\n";

if (!RootNode) {
if (RootNode.is_null()) {
OS << ":: isl ast generation and code generation was skipped!\n\n";
OS << ":: This is either because no useful optimizations could be applied "
"(use -polly-process-unprofitable to enforce code generation) or "
Expand All @@ -760,7 +747,7 @@ void IslAstInfo::print(raw_ostream &OS) {
return;
}

isl_ast_expr *RunCondition = Ast.getRunCondition();
isl::ast_expr RunCondition = Ast.getRunCondition();
char *RtCStr, *AstStr;

Options = isl_ast_print_options_alloc(S.getIslCtx().get());
Expand All @@ -772,11 +759,11 @@ void IslAstInfo::print(raw_ostream &OS) {

isl_printer *P = isl_printer_to_str(S.getIslCtx().get());
P = isl_printer_set_output_format(P, ISL_FORMAT_C);
P = isl_printer_print_ast_expr(P, RunCondition);
P = isl_printer_print_ast_expr(P, RunCondition.get());
RtCStr = isl_printer_get_str(P);
P = isl_printer_flush(P);
P = isl_printer_indent(P, 4);
P = isl_ast_node_print(RootNode, P, Options);
P = isl_ast_node_print(RootNode.get(), P, Options);
AstStr = isl_printer_get_str(P);

auto *Schedule = S.getScheduleTree().release();
Expand All @@ -793,9 +780,7 @@ void IslAstInfo::print(raw_ostream &OS) {
free(RtCStr);
free(AstStr);

isl_ast_expr_free(RunCondition);
isl_schedule_free(Schedule);
isl_ast_node_free(RootNode);
isl_printer_free(P);
}

Expand Down
9 changes: 5 additions & 4 deletions polly/lib/CodeGen/PPCGCodeGeneration.cpp
Expand Up @@ -3506,9 +3506,11 @@ class PPCGCodeGeneration : public ScopPass {
Builder.SetInsertPoint(SplitBlock->getTerminator());

isl_ast_build *Build = isl_ast_build_alloc(S->getIslCtx().get());
isl_ast_expr *Condition = IslAst::buildRunCondition(*S, Build);
isl::ast_expr Condition =
IslAst::buildRunCondition(*S, isl::manage_copy(Build));
isl_ast_expr *SufficientCompute = createSufficientComputeCheck(*S, Build);
Condition = isl_ast_expr_and(Condition, SufficientCompute);
Condition =
isl::manage(isl_ast_expr_and(Condition.release(), SufficientCompute));
isl_ast_build_free(Build);

// preload invariant loads. Note: This should happen before the RTC
Expand All @@ -3535,7 +3537,6 @@ class PPCGCodeGeneration : public ScopPass {

DT->changeImmediateDominator(MergeBlock, ExitingBB);
DT->eraseNode(ExitingBlock);
isl_ast_expr_free(Condition);
isl_ast_node_free(Root);
} else {

Expand All @@ -3556,7 +3557,7 @@ class PPCGCodeGeneration : public ScopPass {
}

NodeBuilder.addParameters(S->getContext().release());
Value *RTC = NodeBuilder.createRTC(Condition);
Value *RTC = NodeBuilder.createRTC(Condition.release());
Builder.GetInsertBlock()->getTerminator()->setOperand(0, RTC);

Builder.SetInsertPoint(&*StartBlock->begin());
Expand Down

0 comments on commit 1464953

Please sign in to comment.