Skip to content

Commit

Permalink
implement loops
Browse files Browse the repository at this point in the history
  • Loading branch information
fbs committed Mar 25, 2020
1 parent a62f94f commit cdb5a9b
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 40 deletions.
8 changes: 8 additions & 0 deletions src/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ class While : public Statement {
location loc;

void accept(Visitor &v) override;

~While() {
delete cond;
if (stmts)
for (Statement *s : *stmts)
delete s;
delete stmts;
}
};

class AttachPoint : public Node {
Expand Down
97 changes: 79 additions & 18 deletions src/ast/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1393,39 +1393,67 @@ void CodegenLLVM::visit(AssignVarStatement &assignment)

void CodegenLLVM::visit(If &if_block)
{
bool has_jump = false;
Function *parent = b_.GetInsertBlock()->getParent();
BasicBlock *if_true = BasicBlock::Create(module_->getContext(), "if_stmt", parent);
BasicBlock *if_false = BasicBlock::Create(module_->getContext(), "else_stmt", parent);
BasicBlock *if_true = BasicBlock::Create(module_->getContext(), "if_body", parent);
BasicBlock *if_end = BasicBlock::Create(module_->getContext(), "if_end", parent);
BasicBlock *if_else = nullptr;

if_block.cond->accept(*this);
Value *cond = expr_;
Value *cond = b_.CreateICmpNE(expr_, b_.getInt64(0), "true_cond");

b_.CreateCondBr(b_.CreateICmpNE(cond, b_.getInt64(0), "true_cond"), if_true, if_false);
// 3 possible flows:
//
// if condition is true
// parent -> if_body -> if_end
//
// if condition is false, no else
// parent -> if_end
//
// if condition is false, with else
// parent -> if_else -> if_end
//
if (if_block.else_stmts)
{
// LLVM doesn't accept empty labels, only create when needed
if_else = BasicBlock::Create(module_->getContext(), "else_body", parent);
b_.CreateCondBr(cond, if_true, if_else);
}
else
{
b_.CreateCondBr(cond, if_true, if_end);
}

b_.SetInsertPoint(if_true);
for (Statement *stmt : *if_block.stmts)
{
stmt->accept(*this);
if (dynamic_cast<Jump *>(stmt)) {
has_jump = true;
// Instructions after a jump can be skipped, they're dead
break;
}
}
// If a Jump statement was found a branch has already been
// taken. Otherwise it must happen now.
if (!has_jump)
b_.CreateBr(if_end);

b_.SetInsertPoint(if_end);

if (if_block.else_stmts)
{
BasicBlock *done = BasicBlock::Create(module_->getContext(), "done", parent);
b_.CreateBr(done);

b_.SetInsertPoint(if_false);
b_.SetInsertPoint(if_else);
for (Statement *stmt : *if_block.else_stmts)
{
stmt->accept(*this);
if (dynamic_cast<Jump *>(stmt)) {
b_.SetInsertPoint(if_end);
return;
}
}
b_.CreateBr(done);

b_.SetInsertPoint(done);
}
else
{
b_.CreateBr(if_false);
b_.SetInsertPoint(if_false);
b_.CreateBr(if_end);
b_.SetInsertPoint(if_end);
}
}

Expand All @@ -1441,12 +1469,45 @@ void CodegenLLVM::visit(Unroll &unroll)

void CodegenLLVM::visit(Jump &jump)
{
return;
if (jump.ident == "continue") {
b_.CreateBr(std::get<0>(loops_.back()));
}
else if (jump.ident == "break") {
b_.CreateBr(std::get<1>(loops_.back()));
}
else {
throw new std::runtime_error("Unknown jump: " + jump.ident);
}
}

void CodegenLLVM::visit(While &while_block)
{
return;
Function *parent = b_.GetInsertBlock()->getParent();
BasicBlock *while_cond = BasicBlock::Create(module_->getContext(), "while_cond", parent);
BasicBlock *while_body = BasicBlock::Create(module_->getContext(), "while_body", parent);
BasicBlock *while_end = BasicBlock::Create(module_->getContext(), "while_end", parent);

loops_.push_back(std::make_tuple(while_cond, while_end));

b_.CreateBr(while_cond);

b_.SetInsertPoint(while_cond);
while_block.cond->accept(*this);
auto *cond = b_.CreateICmpNE(expr_, b_.getInt64(0), "true_cond");
b_.CreateCondBr(cond, while_body, while_end);

b_.SetInsertPoint(while_body);
for (Statement *stmt : *while_block.stmts)
{
stmt->accept(*this);
if (dynamic_cast<Jump *>(stmt))
goto exit;
}
b_.CreateBr(while_cond);

exit:
b_.SetInsertPoint(while_end);
loops_.pop_back();
}

void CodegenLLVM::visit(Predicate &pred)
Expand Down
2 changes: 2 additions & 0 deletions src/ast/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class CodegenLLVM : public Visitor {
int cat_id_ = 0;
uint64_t join_id_ = 0;
int system_id_ = 0;

std::vector<std::tuple<BasicBlock *, BasicBlock *>> loops_;
};

} // namespace ast
Expand Down
43 changes: 41 additions & 2 deletions src/ast/semantic_analyser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,9 @@ void SemanticAnalyser::visit(Call &call)
else if (call.func == "print") {
check_assignment(call, false, false, false);
if (check_varargs(call, 1, 3)) {
if (in_loop > 0) {
error("print() is asynchronous and cannot be used in loops", call.loc);
}
auto &arg = *call.vargs->at(0);
if (!arg.is_map)
error("print() expects a map to be provided", call.loc);
Expand Down Expand Up @@ -1244,11 +1247,13 @@ void SemanticAnalyser::visit(If &if_block)

for (Statement *stmt : *if_block.stmts) {
stmt->accept(*this);
// TODO: Warn about extra instructions after a break?
}

if (if_block.else_stmts) {
for (Statement *stmt : *if_block.else_stmts) {
stmt->accept(*this);
// TODO: Warn about extra instructions after a break?
}
}
}
Expand All @@ -1268,18 +1273,52 @@ void SemanticAnalyser::visit(Unroll &unroll)
for (Statement *stmt : *unroll.stmts)
{
stmt->accept(*this);
// TODO: Jumps should be blocked?
}
}
}

void SemanticAnalyser::visit(Jump &jump)
{
error(jump.ident + " has not yet been implemented", jump.loc);
if (jump.ident == "return")
{
error("return is a reserved keyword", jump.loc);
}
else if (jump.ident == "continue" || jump.ident == "break")
{
if (in_loop == 0)
{
error(jump.ident + " used outside of a loop", jump.loc);
}
}
else
{
error("Unknown jump: '" + jump.ident + "'", jump.loc);
}
}

void SemanticAnalyser::visit(While &while_block)
{
error("While has not yet been implemented", while_block.loc);
if (is_final_pass() && !feature_.has_loop()) {
warning("Kernel does not support bounded loops. Depending"
" on LLVMs loop unroll to generate loadable code."
, while_block.loc);
}

while_block.cond->accept(*this);

in_loop++;
for (Statement *stmt : *while_block.stmts)
{
stmt->accept(*this);
if (auto *s = dynamic_cast<Jump *>(stmt))
{
// TODO: Should we warn here?
// Having a break/continue here is pointless
// Extra instructions after those will be skipped
}
}
in_loop--;
}

void SemanticAnalyser::visit(FieldAccess &acc)
Expand Down
1 change: 1 addition & 0 deletions src/ast/semantic_analyser.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SemanticAnalyser : public Visitor {
std::map<std::string, MapKey> map_key_;
std::map<std::string, ExpressionList> map_args_;
std::unordered_set<StackType> needs_stackid_maps_;
uint32_t in_loop = 0;
bool needs_join_map_ = false;
bool needs_elapsed_map_ = false;
bool has_begin_probe_ = false;
Expand Down
16 changes: 8 additions & 8 deletions tests/codegen/llvm/if_else_printf.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ entry:
%printf_args = alloca %printf_t, align 8
%get_pid_tgid = tail call i64 inttoptr (i64 14 to i64 ()*)()
%1 = icmp ugt i64 %get_pid_tgid, 47244640255
br i1 %1, label %if_stmt, label %else_stmt
br i1 %1, label %if_body, label %else_body

if_stmt: ; preds = %entry
if_body: ; preds = %entry
%2 = bitcast %printf_t* %printf_args to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %2)
%3 = getelementptr inbounds %printf_t, %printf_t* %printf_args, i64 0, i32 0
Expand All @@ -29,9 +29,12 @@ if_stmt: ; preds = %entry
%get_cpu_id = tail call i64 inttoptr (i64 8 to i64 ()*)()
%perf_event_output = call i64 inttoptr (i64 25 to i64 (i8*, i64, i64, %printf_t*, i64)*)(i8* %0, i64 %pseudo, i64 %get_cpu_id, %printf_t* nonnull %printf_args, i64 8)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %2)
br label %done
br label %if_end

else_stmt: ; preds = %entry
if_end: ; preds = %else_body, %if_body
ret i64 0

else_body: ; preds = %entry
%4 = bitcast %printf_t.0* %printf_args1 to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %4)
%5 = getelementptr inbounds %printf_t.0, %printf_t.0* %printf_args1, i64 0, i32 0
Expand All @@ -40,10 +43,7 @@ else_stmt: ; preds = %entry
%get_cpu_id3 = tail call i64 inttoptr (i64 8 to i64 ()*)()
%perf_event_output4 = call i64 inttoptr (i64 25 to i64 (i8*, i64, i64, %printf_t.0*, i64)*)(i8* %0, i64 %pseudo2, i64 %get_cpu_id3, %printf_t.0* nonnull %printf_args1, i64 8)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %4)
br label %done

done: ; preds = %else_stmt, %if_stmt
ret i64 0
br label %if_end
}

; Function Attrs: argmemonly nounwind
Expand Down
12 changes: 6 additions & 6 deletions tests/codegen/llvm/if_nested_printf.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ entry:
%printf_args = alloca %printf_t, align 8
%get_pid_tgid = tail call i64 inttoptr (i64 14 to i64 ()*)()
%1 = icmp ugt i64 %get_pid_tgid, 42953967927295
br i1 %1, label %if_stmt, label %else_stmt
br i1 %1, label %if_body, label %if_end

if_stmt: ; preds = %entry
if_body: ; preds = %entry
%get_pid_tgid3 = tail call i64 inttoptr (i64 14 to i64 ()*)()
%.lobit = and i64 %get_pid_tgid3, 4294967296
%true_cond4 = icmp eq i64 %.lobit, 0
br i1 %true_cond4, label %if_stmt1, label %else_stmt
br i1 %true_cond4, label %if_body1, label %if_end

else_stmt: ; preds = %if_stmt, %if_stmt1, %entry
if_end: ; preds = %if_body, %if_body1, %entry
ret i64 0

if_stmt1: ; preds = %if_stmt
if_body1: ; preds = %if_body
%2 = bitcast %printf_t* %printf_args to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %2)
%3 = getelementptr inbounds %printf_t, %printf_t* %printf_args, i64 0, i32 0
Expand All @@ -36,7 +36,7 @@ if_stmt1: ; preds = %if_stmt
%get_cpu_id = tail call i64 inttoptr (i64 8 to i64 ()*)()
%perf_event_output = call i64 inttoptr (i64 25 to i64 (i8*, i64, i64, %printf_t*, i64)*)(i8* %0, i64 %pseudo, i64 %get_cpu_id, %printf_t* nonnull %printf_args, i64 8)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %2)
br label %else_stmt
br label %if_end
}

; Function Attrs: argmemonly nounwind
Expand Down
8 changes: 4 additions & 4 deletions tests/codegen/llvm/if_printf.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ entry:
%printf_args = alloca %printf_t, align 8
%get_pid_tgid = tail call i64 inttoptr (i64 14 to i64 ()*)()
%1 = icmp ugt i64 %get_pid_tgid, 42953967927295
br i1 %1, label %if_stmt, label %else_stmt
br i1 %1, label %if_body, label %if_end

if_stmt: ; preds = %entry
if_body: ; preds = %entry
%2 = bitcast %printf_t* %printf_args to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %2)
%3 = getelementptr inbounds %printf_t, %printf_t* %printf_args, i64 0, i32 0
Expand All @@ -31,9 +31,9 @@ if_stmt: ; preds = %entry
%get_cpu_id = tail call i64 inttoptr (i64 8 to i64 ()*)()
%perf_event_output = call i64 inttoptr (i64 25 to i64 (i8*, i64, i64, %printf_t*, i64)*)(i8* %0, i64 %pseudo, i64 %get_cpu_id, %printf_t* nonnull %printf_args, i64 16)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %2)
br label %else_stmt
br label %if_end

else_stmt: ; preds = %if_stmt, %entry
if_end: ; preds = %if_body, %entry
ret i64 0
}

Expand Down
4 changes: 2 additions & 2 deletions tests/semantic_analyser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1490,8 +1490,8 @@ TEST(semantic_analyser, type_ctx)

TEST(semantic_analyser, while_loop)
{
test("i:s:1 { $a = 1; while ($a < 10) { $a++ }}", 0);
test("i:s:1 { $a = 1; while (1) { if($a > 50) { break } $a++ }}", 0);
// test("i:s:1 { $a = 1; while ($a < 10) { $a++ }}", 0);
// test("i:s:1 { $a = 1; while (1) { if($a > 50) { break } $a++ }}", 0);
}

} // namespace semantic_analyser
Expand Down

0 comments on commit cdb5a9b

Please sign in to comment.