Skip to content

Commit

Permalink
For loops: Allow sharing variables with main program
Browse files Browse the repository at this point in the history
1. Determine which variables need to be shared with the loop callback
2. Pack pointers to them into a context struct
3. Pass pointer to the context struct to the callback function
4. In the callback, override the shared variables so that they read and
   write through the context pointers instead of directly from their
   original addresses

See the comment in semantic_analyser.cpp for pseudo code of this
transformation.
  • Loading branch information
ajor committed Mar 22, 2024
1 parent 96afe1e commit f00dedd
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 93 deletions.
2 changes: 2 additions & 0 deletions src/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,8 @@ class For : public Statement {
Expression *expr = nullptr;
StatementList *stmts = nullptr;

SizedType ctx_type;

private:
For(const For &other);
};
Expand Down
91 changes: 70 additions & 21 deletions src/ast/passes/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1284,10 +1284,10 @@ void CodegenLLVM::visit(Variable &var)
// Arrays and structs are not memcopied for local variables
if (needMemcpy(var.type) &&
!(var.type.IsArrayTy() || var.type.IsRecordTy())) {
expr_ = variables_[var.ident];
expr_ = variables_[var.ident].value;
} else {
auto *var_alloca = variables_[var.ident];
expr_ = b_.CreateLoad(var_alloca->getAllocatedType(), var_alloca);
auto &var_llvm = variables_[var.ident];
expr_ = b_.CreateLoad(var_llvm.type, var_llvm.value);
}
}

Expand Down Expand Up @@ -2159,20 +2159,20 @@ void CodegenLLVM::visit(AssignVarStatement &assignment)
}

AllocaInst *val = b_.CreateAllocaBPFInit(alloca_type, var.ident);
variables_[var.ident] = val;
variables_[var.ident] = VariableLLVM{ val, val->getAllocatedType() };
}

if (var.type.IsArrayTy() || var.type.IsRecordTy()) {
// For arrays and structs, only the pointer is stored
b_.CreateStore(b_.CreatePtrToInt(expr_, b_.getInt64Ty()),
variables_[var.ident]);
variables_[var.ident].value);
// Extend lifetime of RHS up to the end of probe
scoped_del.disarm();
} else if (needMemcpy(var.type)) {
b_.CREATE_MEMCPY(
variables_[var.ident], expr_, assignment.expr->type.GetSize(), 1);
variables_[var.ident].value, expr_, assignment.expr->type.GetSize(), 1);
} else {
b_.CreateStore(expr_, variables_[var.ident]);
b_.CreateStore(expr_, variables_[var.ident].value);
}
}

Expand Down Expand Up @@ -2332,8 +2332,32 @@ void CodegenLLVM::visit(For &f)
auto &map = static_cast<Map &>(*f.expr);

Value *ctx = b_.getInt64(0);
llvm::Type *ctx_t = nullptr;

const auto &ctx_fields = f.ctx_type.GetFields();
if (!ctx_fields.empty()) {
// Pack pointers to variables into context struct for use in the callback
std::vector<llvm::Type *> ctx_field_types(ctx_fields.size(),
b_.GET_PTR_TY());
ctx_t = b_.GetStructType("ctx_t", ctx_field_types);
ctx = b_.CreateAllocaBPF(ctx_t, "ctx");

for (size_t i = 0; i < ctx_fields.size(); i++) {
const auto &field = ctx_fields[i];
auto *field_expr = variables_[field.name].value;
auto *ctx_field_ptr = b_.CreateGEP(
ctx_t, ctx, { b_.getInt64(0), b_.getInt32(i) }, "ctx." + field.name);
#if LLVM_VERSION_MAJOR < 15
// An extra cast is required for older LLVM versions, pre-opaque-pointers
ctx_field_ptr = b_.CreatePointerCast(
ctx_field_ptr, field_expr->getType()->getPointerTo());
#endif
b_.CreateStore(field_expr, ctx_field_ptr);
}
}

b_.CreateForEachMapElem(
ctx_, map, createForEachMapCallback(*f.decl, *f.stmts), ctx, f.loc);
ctx_, map, createForEachMapCallback(f, ctx_t), ctx, f.loc);
}

void CodegenLLVM::visit(Predicate &pred)
Expand Down Expand Up @@ -2447,7 +2471,8 @@ void CodegenLLVM::visit(Subprog &subprog)
for (SubprogArg *arg : *subprog.args) {
auto alloca = b_.CreateAllocaBPF(b_.GetType(arg->type), arg->name());
b_.CreateStore(func->getArg(arg_index + 1), alloca);
variables_.insert({ arg->name(), alloca });
variables_[arg->name()] = VariableLLVM{ alloca,
alloca->getAllocatedType() };
++arg_index;
}

Expand Down Expand Up @@ -3844,14 +3869,14 @@ void CodegenLLVM::createIncDec(Unop &unop)
b_.CreateLifetimeEnd(newval);
} else if (unop.expr->is_variable) {
Variable &var = static_cast<Variable &>(*unop.expr);
Value *oldval = b_.CreateLoad(variables_[var.ident]->getAllocatedType(),
variables_[var.ident]);
Value *oldval = b_.CreateLoad(variables_[var.ident].type,
variables_[var.ident].value);
Value *newval;
if (is_increment)
newval = b_.CreateAdd(oldval, b_.GetIntSameSize(step, oldval));
else
newval = b_.CreateSub(oldval, b_.GetIntSameSize(step, oldval));
b_.CreateStore(newval, variables_[var.ident]);
b_.CreateStore(newval, variables_[var.ident].value);

if (unop.is_post_op)
expr_ = oldval;
Expand Down Expand Up @@ -3905,9 +3930,7 @@ Function *CodegenLLVM::createMapLenCallback()
return callback;
}

Function *CodegenLLVM::createForEachMapCallback(
const Variable &decl,
const std::vector<Statement *> &stmts)
Function *CodegenLLVM::createForEachMapCallback(const For &f, llvm::Type *ctx_t)
{
/*
* Create a callback function suitable for passing to bpf_for_each_map_elem,
Expand Down Expand Up @@ -3938,29 +3961,55 @@ Function *CodegenLLVM::createForEachMapCallback(
auto *bb = BasicBlock::Create(module_->getContext(), "", callback);
b_.SetInsertPoint(bb);

auto &key_type = decl.type.GetField(0).type;
auto &key_type = f.decl->type.GetField(0).type;
Value *key = callback->getArg(1);
if (!onStack(key_type)) {
key = b_.CreateLoad(b_.GetType(key_type), key, "key");
}

auto &val_type = decl.type.GetField(1).type;
auto &val_type = f.decl->type.GetField(1).type;
Value *val = callback->getArg(2);
if (!onStack(val_type)) {
val = b_.CreateLoad(b_.GetType(val_type), val, "val");
}

// Create decl variable for use in this iteration of the loop
variables_[decl.ident] = createTuple(
decl.type, { { key, &decl.loc }, { val, &decl.loc } }, decl.ident);
AllocaInst *tuple = createTuple(f.decl->type,
{ { key, &f.decl->loc },
{ val, &f.decl->loc } },
f.decl->ident);
variables_[f.decl->ident] = VariableLLVM{ tuple, tuple->getAllocatedType() };

// 1. Save original locations of variables which will form part of the
// callback context
// 2. Replace variable expressions with those from the context
Value *ctx = callback->getArg(3);
const auto &ctx_fields = f.ctx_type.GetFields();
std::unordered_map<std::string, Value *> orig_ctx_vars;
for (size_t i = 0; i < ctx_fields.size(); i++) {
const auto &field = ctx_fields[i];
orig_ctx_vars[field.name] = variables_[field.name].value;

for (Statement *stmt : stmts) {
auto *ctx_field_ptr = b_.CreateGEP(
ctx_t, ctx, { b_.getInt64(0), b_.getInt32(i) }, "ctx." + field.name);
variables_[field.name].value = b_.CreateLoad(b_.GET_PTR_TY(),
ctx_field_ptr,
field.name);
}

// Generate code for the loop body
for (Statement *stmt : *f.stmts) {
auto scoped_del = accept(stmt);
}
b_.CreateRet(b_.getInt64(0));

// Restore original non-context variables
for (const auto &[ident, expr] : orig_ctx_vars) {
variables_[ident].value = expr;
}

// Decl variable is not valid beyond this for loop
variables_.erase(decl.ident);
variables_.erase(f.decl->ident);

b_.restoreIP(saved_ip);
return callback;
Expand Down
9 changes: 6 additions & 3 deletions src/ast/passes/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ class CodegenLLVM : public Visitor {
void createIncDec(Unop &unop);

Function *createMapLenCallback();
Function *createForEachMapCallback(const Variable &decl,
const std::vector<Statement *> &stmts);
Function *createForEachMapCallback(const For &f, llvm::Type *ctx_t);

// Return a lambda that has captured-by-value CodegenLLVM's async id state
// (ie `printf_id_`, `mapped_printf_id_`, etc.). Running the returned lambda
Expand Down Expand Up @@ -259,7 +258,11 @@ class CodegenLLVM : public Visitor {
int current_usdt_location_index_{ 0 };
bool inside_subprog_ = false;

std::map<std::string, AllocaInst *> variables_;
struct VariableLLVM {
llvm::Value *value;
llvm::Type *type;
};
std::map<std::string, VariableLLVM> variables_;
int printf_id_ = 0;
int mapped_printf_id_ = 0;
int time_id_ = 0;
Expand Down
17 changes: 13 additions & 4 deletions src/ast/passes/printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <sstream>

#include "ast/ast.h"
#include "struct.h"

namespace bpftrace {
namespace ast {
Expand Down Expand Up @@ -351,19 +352,27 @@ void Printer::visit(While &while_block)
}
}

void Printer::visit(For &for_loop)
void Printer::visit(For &f)
{
std::string indent(depth_, ' ');
out_ << indent << "for" << std::endl;

++depth_;
if (f.ctx_type.IsRecordTy() && !f.ctx_type.GetFields().empty()) {
out_ << indent << " ctx\n";
for (const auto &field : f.ctx_type.GetFields()) {
out_ << indent << " " << field.name << type(field.type) << "\n";
}
}

out_ << indent << " decl\n";
print(for_loop.decl);
print(f.decl);

out_ << indent << " expr\n";
print(for_loop.expr);
print(f.expr);

out_ << indent << " stmts\n";
for (Statement *stmt : *for_loop.stmts) {
for (Statement *stmt : *f.stmts) {
print(stmt);
}
--depth_;
Expand Down
2 changes: 1 addition & 1 deletion src/ast/passes/printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Printer : public Visitor {
void visit(If &if_block) override;
void visit(Unroll &unroll) override;
void visit(While &while_block) override;
void visit(For &for_loop) override;
void visit(For &f) override;
void visit(Config &config) override;
void visit(Jump &jump) override;
void visit(Predicate &pred) override;
Expand Down

0 comments on commit f00dedd

Please sign in to comment.