Skip to content

Commit

Permalink
implement loops
Browse files Browse the repository at this point in the history
  • Loading branch information
fbs committed Mar 26, 2020
1 parent a62f94f commit ca2342b
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 41 deletions.
8 changes: 8 additions & 0 deletions src/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ class While : public Statement {
location loc;

void accept(Visitor &v) override;

~While() {
delete cond;
if (stmts)
for (Statement *s : *stmts)
delete s;
delete stmts;
}
};

class AttachPoint : public Node {
Expand Down
107 changes: 86 additions & 21 deletions src/ast/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,38 +1394,51 @@ void CodegenLLVM::visit(AssignVarStatement &assignment)
void CodegenLLVM::visit(If &if_block)
{
Function *parent = b_.GetInsertBlock()->getParent();
BasicBlock *if_true = BasicBlock::Create(module_->getContext(), "if_stmt", parent);
BasicBlock *if_false = BasicBlock::Create(module_->getContext(), "else_stmt", parent);
BasicBlock *if_true = BasicBlock::Create(module_->getContext(), "if_body", parent);
BasicBlock *if_end = BasicBlock::Create(module_->getContext(), "if_end", parent);
BasicBlock *if_else = nullptr;

if_block.cond->accept(*this);
Value *cond = expr_;
Value *cond = b_.CreateICmpNE(expr_, b_.getInt64(0), "true_cond");

b_.CreateCondBr(b_.CreateICmpNE(cond, b_.getInt64(0), "true_cond"), if_true, if_false);
// 3 possible flows:
//
// if condition is true
// parent -> if_body -> if_end
//
// if condition is false, no else
// parent -> if_end
//
// if condition is false, with else
// parent -> if_else -> if_end
//
if (if_block.else_stmts)
{
// LLVM doesn't accept empty basic block, only create when needed
if_else = BasicBlock::Create(module_->getContext(), "else_body", parent);
b_.CreateCondBr(cond, if_true, if_else);
}
else
{
b_.CreateCondBr(cond, if_true, if_end);
}

b_.SetInsertPoint(if_true);
for (Statement *stmt : *if_block.stmts)
{
stmt->accept(*this);
}

b_.CreateBr(if_end);

b_.SetInsertPoint(if_end);

if (if_block.else_stmts)
{
BasicBlock *done = BasicBlock::Create(module_->getContext(), "done", parent);
b_.CreateBr(done);

b_.SetInsertPoint(if_false);
b_.SetInsertPoint(if_else);
for (Statement *stmt : *if_block.else_stmts)
{
stmt->accept(*this);
}
b_.CreateBr(done);

b_.SetInsertPoint(done);
}
else
{
b_.CreateBr(if_false);
b_.SetInsertPoint(if_false);
b_.CreateBr(if_end);
b_.SetInsertPoint(if_end);
}
}

Expand All @@ -1441,12 +1454,64 @@ void CodegenLLVM::visit(Unroll &unroll)

void CodegenLLVM::visit(Jump &jump)
{
return;
if (jump.ident == "continue") {
b_.CreateBr(std::get<0>(loops_.back()));
}
else if (jump.ident == "break") {
b_.CreateBr(std::get<1>(loops_.back()));
}
else {
throw new std::runtime_error("Unknown jump: " + jump.ident);
}

// LLVM doesn't like having instructions after an unconditional branch (segv)
// This can be avoided by putting all instructions in a unreachable basicblock
// which will be optimize out.
//
// e.g. in the case of `while (..) { $i++; break; $i++ }` the ir will be:
//
// while_body:
// ...
// br label %while_end
//
// while_end:
// ...
//
// unreach:
// $i++
// br label %while_cond
//

Function *parent = b_.GetInsertBlock()->getParent();
BasicBlock *unreach = BasicBlock::Create(module_->getContext(), "unreach", parent);
b_.SetInsertPoint(unreach);
}

void CodegenLLVM::visit(While &while_block)
{
return;
Function *parent = b_.GetInsertBlock()->getParent();
BasicBlock *while_cond = BasicBlock::Create(module_->getContext(), "while_cond", parent);
BasicBlock *while_body = BasicBlock::Create(module_->getContext(), "while_body", parent);
BasicBlock *while_end = BasicBlock::Create(module_->getContext(), "while_end", parent);

loops_.push_back(std::make_tuple(while_cond, while_end));

b_.CreateBr(while_cond);

b_.SetInsertPoint(while_cond);
while_block.cond->accept(*this);
auto *cond = b_.CreateICmpNE(expr_, b_.getInt64(0), "true_cond");
b_.CreateCondBr(cond, while_body, while_end);

b_.SetInsertPoint(while_body);
for (Statement *stmt : *while_block.stmts)
{
stmt->accept(*this);
}
b_.CreateBr(while_cond);

b_.SetInsertPoint(while_end);
loops_.pop_back();
}

void CodegenLLVM::visit(Predicate &pred)
Expand Down
2 changes: 2 additions & 0 deletions src/ast/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class CodegenLLVM : public Visitor {
int cat_id_ = 0;
uint64_t join_id_ = 0;
int system_id_ = 0;

std::vector<std::tuple<BasicBlock *, BasicBlock *>> loops_;
};

} // namespace ast
Expand Down
48 changes: 46 additions & 2 deletions src/ast/semantic_analyser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,9 @@ void SemanticAnalyser::visit(Call &call)
else if (call.func == "print") {
check_assignment(call, false, false, false);
if (check_varargs(call, 1, 3)) {
if (in_loop > 0) {
error("print() is asynchronous and cannot be used in loops", call.loc);
}
auto &arg = *call.vargs->at(0);
if (!arg.is_map)
error("print() expects a map to be provided", call.loc);
Expand Down Expand Up @@ -1244,11 +1247,13 @@ void SemanticAnalyser::visit(If &if_block)

for (Statement *stmt : *if_block.stmts) {
stmt->accept(*this);
// TODO: Warn about extra instructions after a break?
}

if (if_block.else_stmts) {
for (Statement *stmt : *if_block.else_stmts) {
stmt->accept(*this);
// TODO: Warn about extra instructions after a break?
}
}
}
Expand All @@ -1268,18 +1273,57 @@ void SemanticAnalyser::visit(Unroll &unroll)
for (Statement *stmt : *unroll.stmts)
{
stmt->accept(*this);
// TODO: Jumps should be blocked?
}
}
}

void SemanticAnalyser::visit(Jump &jump)
{
error(jump.ident + " has not yet been implemented", jump.loc);
if (jump.ident == "return")
{
error("return is a reserved keyword", jump.loc);
}
else if (jump.ident == "continue" || jump.ident == "break")
{
if (in_loop == 0)
{
error(jump.ident + " used outside of a loop", jump.loc);
}
}
else
{
error("Unknown jump: '" + jump.ident + "'", jump.loc);
}
}

void SemanticAnalyser::visit(While &while_block)
{
error("While has not yet been implemented", while_block.loc);
if (is_final_pass() && !feature_.has_loop()) {
warning("Kernel does not support bounded loops. Depending"
" on LLVMs loop unroll to generate loadable code."
, while_block.loc);
}

while_block.cond->accept(*this);

in_loop++;
size_t size = while_block.stmts->size();
for (size_t i = 0; i < size; i++)
{
auto stmt = while_block.stmts->at(i);
stmt->accept(*this);
if (is_final_pass())
{
auto *jump = dynamic_cast<Jump *>(stmt);
if (jump && i < (size - 1))
{
warning("All code after a '" + jump->ident + "' is unreachable.",
jump->loc);
}
}
}
in_loop--;
}

void SemanticAnalyser::visit(FieldAccess &acc)
Expand Down
1 change: 1 addition & 0 deletions src/ast/semantic_analyser.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SemanticAnalyser : public Visitor {
std::map<std::string, MapKey> map_key_;
std::map<std::string, ExpressionList> map_args_;
std::unordered_set<StackType> needs_stackid_maps_;
uint32_t in_loop = 0;
bool needs_join_map_ = false;
bool needs_elapsed_map_ = false;
bool has_begin_probe_ = false;
Expand Down
14 changes: 14 additions & 0 deletions tests/codegen/basic_while_loop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "common.h"

namespace bpftrace {
namespace test {
namespace codegen {

TEST(codegen, basic_while_loop)
{
test("i:s:1 { $a = 1; while ($a <= 150) { @=$a++; }}", NAME);
}

} // namespace codegen
} // namespace test
} // namespace bpftrace
42 changes: 42 additions & 0 deletions tests/codegen/llvm/basic_while_loop.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; ModuleID = 'bpftrace'
source_filename = "bpftrace"
target datalayout = "e-m:e-p:64:64-i64:64-n32:64-S128"
target triple = "bpf-pc-linux"

; Function Attrs: nounwind
declare i64 @llvm.bpf.pseudo(i64, i64) #0

; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1

define i64 @"interval:s:1"(i8* nocapture readnone) local_unnamed_addr section "s_interval:s:1_1" {
entry:
%"@_val" = alloca i64, align 8
%"@_key" = alloca i64, align 8
%1 = bitcast i64* %"@_key" to i8*
%2 = bitcast i64* %"@_val" to i8*
br label %while_body

while_body: ; preds = %while_body, %entry
%"$a.01" = phi i64 [ 1, %entry ], [ %3, %while_body ]
%3 = add nuw nsw i64 %"$a.01", 1
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %1)
store i64 0, i64* %"@_key", align 8
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %2)
store i64 %"$a.01", i64* %"@_val", align 8
%pseudo = call i64 @llvm.bpf.pseudo(i64 1, i64 1)
%update_elem = call i64 inttoptr (i64 2 to i64 (i64, i64*, i64*, i64)*)(i64 %pseudo, i64* nonnull %"@_key", i64* nonnull %"@_val", i64 0)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %2)
%exitcond = icmp eq i64 %3, 151
br i1 %exitcond, label %while_end, label %while_body

while_end: ; preds = %while_body
ret i64 0
}

; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1

attributes #0 = { nounwind }
attributes #1 = { argmemonly nounwind }
16 changes: 8 additions & 8 deletions tests/codegen/llvm/if_else_printf.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ entry:
%printf_args = alloca %printf_t, align 8
%get_pid_tgid = tail call i64 inttoptr (i64 14 to i64 ()*)()
%1 = icmp ugt i64 %get_pid_tgid, 47244640255
br i1 %1, label %if_stmt, label %else_stmt
br i1 %1, label %if_body, label %else_body

if_stmt: ; preds = %entry
if_body: ; preds = %entry
%2 = bitcast %printf_t* %printf_args to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %2)
%3 = getelementptr inbounds %printf_t, %printf_t* %printf_args, i64 0, i32 0
Expand All @@ -29,9 +29,12 @@ if_stmt: ; preds = %entry
%get_cpu_id = tail call i64 inttoptr (i64 8 to i64 ()*)()
%perf_event_output = call i64 inttoptr (i64 25 to i64 (i8*, i64, i64, %printf_t*, i64)*)(i8* %0, i64 %pseudo, i64 %get_cpu_id, %printf_t* nonnull %printf_args, i64 8)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %2)
br label %done
br label %if_end

else_stmt: ; preds = %entry
if_end: ; preds = %else_body, %if_body
ret i64 0

else_body: ; preds = %entry
%4 = bitcast %printf_t.0* %printf_args1 to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %4)
%5 = getelementptr inbounds %printf_t.0, %printf_t.0* %printf_args1, i64 0, i32 0
Expand All @@ -40,10 +43,7 @@ else_stmt: ; preds = %entry
%get_cpu_id3 = tail call i64 inttoptr (i64 8 to i64 ()*)()
%perf_event_output4 = call i64 inttoptr (i64 25 to i64 (i8*, i64, i64, %printf_t.0*, i64)*)(i8* %0, i64 %pseudo2, i64 %get_cpu_id3, %printf_t.0* nonnull %printf_args1, i64 8)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %4)
br label %done

done: ; preds = %else_stmt, %if_stmt
ret i64 0
br label %if_end
}

; Function Attrs: argmemonly nounwind
Expand Down
12 changes: 6 additions & 6 deletions tests/codegen/llvm/if_nested_printf.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ entry:
%printf_args = alloca %printf_t, align 8
%get_pid_tgid = tail call i64 inttoptr (i64 14 to i64 ()*)()
%1 = icmp ugt i64 %get_pid_tgid, 42953967927295
br i1 %1, label %if_stmt, label %else_stmt
br i1 %1, label %if_body, label %if_end

if_stmt: ; preds = %entry
if_body: ; preds = %entry
%get_pid_tgid3 = tail call i64 inttoptr (i64 14 to i64 ()*)()
%.lobit = and i64 %get_pid_tgid3, 4294967296
%true_cond4 = icmp eq i64 %.lobit, 0
br i1 %true_cond4, label %if_stmt1, label %else_stmt
br i1 %true_cond4, label %if_body1, label %if_end

else_stmt: ; preds = %if_stmt, %if_stmt1, %entry
if_end: ; preds = %if_body, %if_body1, %entry
ret i64 0

if_stmt1: ; preds = %if_stmt
if_body1: ; preds = %if_body
%2 = bitcast %printf_t* %printf_args to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %2)
%3 = getelementptr inbounds %printf_t, %printf_t* %printf_args, i64 0, i32 0
Expand All @@ -36,7 +36,7 @@ if_stmt1: ; preds = %if_stmt
%get_cpu_id = tail call i64 inttoptr (i64 8 to i64 ()*)()
%perf_event_output = call i64 inttoptr (i64 25 to i64 (i8*, i64, i64, %printf_t*, i64)*)(i8* %0, i64 %pseudo, i64 %get_cpu_id, %printf_t* nonnull %printf_args, i64 8)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %2)
br label %else_stmt
br label %if_end
}

; Function Attrs: argmemonly nounwind
Expand Down
Loading

0 comments on commit ca2342b

Please sign in to comment.