diff --git a/src/ast/ast.h b/src/ast/ast.h index 17e55e568e26..c0eccd893cb4 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -571,6 +571,8 @@ class For : public Statement { Expression *expr = nullptr; StatementList *stmts = nullptr; + SizedType ctx_type; + private: For(const For &other); }; diff --git a/src/ast/passes/codegen_llvm.cpp b/src/ast/passes/codegen_llvm.cpp index b3950ae338b5..8c205cc19bb9 100644 --- a/src/ast/passes/codegen_llvm.cpp +++ b/src/ast/passes/codegen_llvm.cpp @@ -1284,10 +1284,10 @@ void CodegenLLVM::visit(Variable &var) // Arrays and structs are not memcopied for local variables if (needMemcpy(var.type) && !(var.type.IsArrayTy() || var.type.IsRecordTy())) { - expr_ = variables_[var.ident]; + expr_ = variables_[var.ident].value; } else { - auto *var_alloca = variables_[var.ident]; - expr_ = b_.CreateLoad(var_alloca->getAllocatedType(), var_alloca); + auto &var_llvm = variables_[var.ident]; + expr_ = b_.CreateLoad(var_llvm.type, var_llvm.value); } } @@ -2159,20 +2159,20 @@ void CodegenLLVM::visit(AssignVarStatement &assignment) } AllocaInst *val = b_.CreateAllocaBPFInit(alloca_type, var.ident); - variables_[var.ident] = val; + variables_[var.ident] = VariableLLVM{ val, val->getAllocatedType() }; } if (var.type.IsArrayTy() || var.type.IsRecordTy()) { // For arrays and structs, only the pointer is stored b_.CreateStore(b_.CreatePtrToInt(expr_, b_.getInt64Ty()), - variables_[var.ident]); + variables_[var.ident].value); // Extend lifetime of RHS up to the end of probe scoped_del.disarm(); } else if (needMemcpy(var.type)) { b_.CREATE_MEMCPY( - variables_[var.ident], expr_, assignment.expr->type.GetSize(), 1); + variables_[var.ident].value, expr_, assignment.expr->type.GetSize(), 1); } else { - b_.CreateStore(expr_, variables_[var.ident]); + b_.CreateStore(expr_, variables_[var.ident].value); } } @@ -2332,8 +2332,32 @@ void CodegenLLVM::visit(For &f) auto &map = static_cast(*f.expr); Value *ctx = b_.getInt64(0); + llvm::Type *ctx_t = nullptr; + + const auto &ctx_fields = f.ctx_type.GetFields(); + if (!ctx_fields.empty()) { + // Pack pointers to variables into context struct for use in the callback + std::vector ctx_field_types(ctx_fields.size(), + b_.GET_PTR_TY()); + ctx_t = b_.GetStructType("ctx_t", ctx_field_types); + ctx = b_.CreateAllocaBPF(ctx_t, "ctx"); + + for (size_t i = 0; i < ctx_fields.size(); i++) { + const auto &field = ctx_fields[i]; + auto *field_expr = variables_[field.name].value; + auto *ctx_field_ptr = b_.CreateGEP( + ctx_t, ctx, { b_.getInt64(0), b_.getInt32(i) }, "ctx." + field.name); +#if LLVM_VERSION_MAJOR < 15 + // An extra cast is required for older LLVM versions, pre-opaque-pointers + ctx_field_ptr = b_.CreatePointerCast( + ctx_field_ptr, field_expr->getType()->getPointerTo()); +#endif + b_.CreateStore(field_expr, ctx_field_ptr); + } + } + b_.CreateForEachMapElem( - ctx_, map, createForEachMapCallback(*f.decl, *f.stmts), ctx, f.loc); + ctx_, map, createForEachMapCallback(f, ctx_t), ctx, f.loc); } void CodegenLLVM::visit(Predicate &pred) @@ -2447,7 +2471,8 @@ void CodegenLLVM::visit(Subprog &subprog) for (SubprogArg *arg : *subprog.args) { auto alloca = b_.CreateAllocaBPF(b_.GetType(arg->type), arg->name()); b_.CreateStore(func->getArg(arg_index + 1), alloca); - variables_.insert({ arg->name(), alloca }); + variables_[arg->name()] = VariableLLVM{ alloca, + alloca->getAllocatedType() }; ++arg_index; } @@ -3844,14 +3869,14 @@ void CodegenLLVM::createIncDec(Unop &unop) b_.CreateLifetimeEnd(newval); } else if (unop.expr->is_variable) { Variable &var = static_cast(*unop.expr); - Value *oldval = b_.CreateLoad(variables_[var.ident]->getAllocatedType(), - variables_[var.ident]); + Value *oldval = b_.CreateLoad(variables_[var.ident].type, + variables_[var.ident].value); Value *newval; if (is_increment) newval = b_.CreateAdd(oldval, b_.GetIntSameSize(step, oldval)); else newval = b_.CreateSub(oldval, b_.GetIntSameSize(step, oldval)); - b_.CreateStore(newval, variables_[var.ident]); + b_.CreateStore(newval, variables_[var.ident].value); if (unop.is_post_op) expr_ = oldval; @@ -3905,9 +3930,7 @@ Function *CodegenLLVM::createMapLenCallback() return callback; } -Function *CodegenLLVM::createForEachMapCallback( - const Variable &decl, - const std::vector &stmts) +Function *CodegenLLVM::createForEachMapCallback(const For &f, llvm::Type *ctx_t) { /* * Create a callback function suitable for passing to bpf_for_each_map_elem, @@ -3938,29 +3961,55 @@ Function *CodegenLLVM::createForEachMapCallback( auto *bb = BasicBlock::Create(module_->getContext(), "", callback); b_.SetInsertPoint(bb); - auto &key_type = decl.type.GetField(0).type; + auto &key_type = f.decl->type.GetField(0).type; Value *key = callback->getArg(1); if (!onStack(key_type)) { key = b_.CreateLoad(b_.GetType(key_type), key, "key"); } - auto &val_type = decl.type.GetField(1).type; + auto &val_type = f.decl->type.GetField(1).type; Value *val = callback->getArg(2); if (!onStack(val_type)) { val = b_.CreateLoad(b_.GetType(val_type), val, "val"); } // Create decl variable for use in this iteration of the loop - variables_[decl.ident] = createTuple( - decl.type, { { key, &decl.loc }, { val, &decl.loc } }, decl.ident); + AllocaInst *tuple = createTuple(f.decl->type, + { { key, &f.decl->loc }, + { val, &f.decl->loc } }, + f.decl->ident); + variables_[f.decl->ident] = VariableLLVM{ tuple, tuple->getAllocatedType() }; + + // 1. Save original locations of variables which will form part of the + // callback context + // 2. Replace variable expressions with those from the context + Value *ctx = callback->getArg(3); + const auto &ctx_fields = f.ctx_type.GetFields(); + std::unordered_map orig_ctx_vars; + for (size_t i = 0; i < ctx_fields.size(); i++) { + const auto &field = ctx_fields[i]; + orig_ctx_vars[field.name] = variables_[field.name].value; - for (Statement *stmt : stmts) { + auto *ctx_field_ptr = b_.CreateGEP( + ctx_t, ctx, { b_.getInt64(0), b_.getInt32(i) }, "ctx." + field.name); + variables_[field.name].value = b_.CreateLoad(b_.GET_PTR_TY(), + ctx_field_ptr, + field.name); + } + + // Generate code for the loop body + for (Statement *stmt : *f.stmts) { auto scoped_del = accept(stmt); } b_.CreateRet(b_.getInt64(0)); + // Restore original non-context variables + for (const auto &[ident, expr] : orig_ctx_vars) { + variables_[ident].value = expr; + } + // Decl variable is not valid beyond this for loop - variables_.erase(decl.ident); + variables_.erase(f.decl->ident); b_.restoreIP(saved_ip); return callback; diff --git a/src/ast/passes/codegen_llvm.h b/src/ast/passes/codegen_llvm.h index 84a8af32ed7f..effb639d71b3 100644 --- a/src/ast/passes/codegen_llvm.h +++ b/src/ast/passes/codegen_llvm.h @@ -221,8 +221,7 @@ class CodegenLLVM : public Visitor { void createIncDec(Unop &unop); Function *createMapLenCallback(); - Function *createForEachMapCallback(const Variable &decl, - const std::vector &stmts); + Function *createForEachMapCallback(const For &f, llvm::Type *ctx_t); // Return a lambda that has captured-by-value CodegenLLVM's async id state // (ie `printf_id_`, `mapped_printf_id_`, etc.). Running the returned lambda @@ -259,7 +258,11 @@ class CodegenLLVM : public Visitor { int current_usdt_location_index_{ 0 }; bool inside_subprog_ = false; - std::map variables_; + struct VariableLLVM { + llvm::Value *value; + llvm::Type *type; + }; + std::map variables_; int printf_id_ = 0; int mapped_printf_id_ = 0; int time_id_ = 0; diff --git a/src/ast/passes/printer.cpp b/src/ast/passes/printer.cpp index d93dd4b5f8c5..70b125804585 100644 --- a/src/ast/passes/printer.cpp +++ b/src/ast/passes/printer.cpp @@ -6,6 +6,7 @@ #include #include "ast/ast.h" +#include "struct.h" namespace bpftrace { namespace ast { @@ -351,19 +352,27 @@ void Printer::visit(While &while_block) } } -void Printer::visit(For &for_loop) +void Printer::visit(For &f) { std::string indent(depth_, ' '); out_ << indent << "for" << std::endl; ++depth_; + if (f.ctx_type.IsRecordTy() && !f.ctx_type.GetFields().empty()) { + out_ << indent << " ctx\n"; + for (const auto &field : f.ctx_type.GetFields()) { + out_ << indent << " " << field.name << type(field.type) << "\n"; + } + } + out_ << indent << " decl\n"; - print(for_loop.decl); + print(f.decl); + out_ << indent << " expr\n"; - print(for_loop.expr); + print(f.expr); out_ << indent << " stmts\n"; - for (Statement *stmt : *for_loop.stmts) { + for (Statement *stmt : *f.stmts) { print(stmt); } --depth_; diff --git a/src/ast/passes/printer.h b/src/ast/passes/printer.h index 5ea6ed1a92aa..5c2736c5ba84 100644 --- a/src/ast/passes/printer.h +++ b/src/ast/passes/printer.h @@ -40,7 +40,7 @@ class Printer : public Visitor { void visit(If &if_block) override; void visit(Unroll &unroll) override; void visit(While &while_block) override; - void visit(For &for_loop) override; + void visit(For &f) override; void visit(Config &config) override; void visit(Jump &jump) override; void visit(Predicate &pred) override; diff --git a/src/ast/passes/semantic_analyser.cpp b/src/ast/passes/semantic_analyser.cpp index 856bc82d5618..49db564a5a22 100644 --- a/src/ast/passes/semantic_analyser.cpp +++ b/src/ast/passes/semantic_analyser.cpp @@ -1999,10 +1999,6 @@ void SemanticAnalyser::visit(For &f) * For-loops are implemented using the bpf_for_each_map_elem helper function, * which requires them to be rewritten into a callback style. * - * In this first implementation, we do not allow any variables to be shared - * between the main probe and the loop's body. Maps are global so are not a - * problem. - * * Pseudo code for the transformation we apply: * * Before: @@ -2025,6 +2021,62 @@ void SemanticAnalyser::visit(For &f) * $kv = ((uint64)key, (uint64)value); * [LOOP BODY] * } + * + * + * To allow variables to be shared between the loop callback and the main + * program, some extra steps are taken: + * + * 1. Determine which variables need to be shared with the loop callback + * 2. Pack pointers to them into a context struct + * 3. Pass pointer to the context struct to the callback function + * 4. In the callback, override the shared variables so that they read and + * write through the context pointers instead of directly from their + * original addresses + * + * Example transformation with context: + * + * Before: + * PROBE { + * $str = "hello"; + * $not_shared = 2; + * $len = 0; + * @map[11, 12] = "c"; + * for ($kv : @map) { + * print($str); + * $len++; + * } + * print($len); + * print($not_shared); + * } + * + * After: + * struct ctx_t { + * string *str; + * uint64 *len; + * }; + * PROBE { + * $str = "hello"; + * $not_shared = 2; + * $len = 0; + * @map[11, 12] = "c"; + * + * ctx_t ctx { .str = &$str, .len = &$len }; + * bpf_for_each_map_elem(@map, &map_for_each_cb, &ctx, 0); + * + * print($len); + * print($not_shared); + * } + * long map_for_each_cb(bpf_map *map, + * const void *key, + * void *value, + * void *ctx) { + * $kv = (((uint64, uint64))key, (string)value); + * $str = ((ctx_t*)ctx)->str; + * $len = ((ctx_t*)ctx)->len; + * + * print($str); + * $len++; + * } */ // Validate decl @@ -2042,17 +2094,6 @@ void SemanticAnalyser::visit(For &f) Map &map = static_cast(*f.expr); // Validate body - CollectNodes vars_referenced; - for (auto *stmt : *f.stmts) { - vars_referenced.run(*stmt); - } - for (const Variable &var : vars_referenced.nodes()) { - if (variable_val_[scope_].find(var.ident) != variable_val_[scope_].end()) { - LOG(ERROR, var.loc, err_) << "Variables defined outside of a for-loop " - "can not be accessed in the loop's scope"; - } - } - // This could be relaxed in the future: CollectNodes jumps; for (auto *stmt : *f.stmts) { @@ -2069,6 +2110,35 @@ void SemanticAnalyser::visit(For &f) if (has_error()) return; + // Collect a list of unique variables which are referenced in the loop's body + // and declared before the loop. These will be passed into the loop callback + // function as the context parameter. + CollectNodes vars_referenced; + std::unordered_set var_set; + for (auto *stmt : *f.stmts) { + const auto &live_vars = variable_val_[scope_]; + vars_referenced.run(*stmt, [&live_vars, &var_set](const auto &var) { + if (live_vars.find(var.ident) == live_vars.end()) + return false; + if (var_set.find(var.ident) != var_set.end()) + return false; + var_set.insert(var.ident); + return true; + }); + } + + // Collect a list of variables which are used in the loop without having been + // used before. This is a hack to simulate block scoping in the absence of the + // real thing (#3017). + CollectNodes new_vars; + for (auto *stmt : *f.stmts) { + const auto &live_vars = variable_val_[scope_]; + new_vars.run(*stmt, [&live_vars](const auto &var) { + return live_vars.find(var.ident) == live_vars.end(); + }); + } + + // Create type for the loop's decl // Iterating over a map provides a tuple: (map_key, map_val) auto *mapkey = get_map_key_type(map); auto *mapval = get_map_type(map); @@ -2098,9 +2168,21 @@ void SemanticAnalyser::visit(For &f) variable_val_[scope_].erase(decl_name); // Variables declared in a for-loop are not valid beyond it - for (const Variable &var : vars_referenced.nodes()) { + for (const Variable &var : new_vars.nodes()) { variable_val_[scope_].erase(var.ident); } + + // Finally, create the context tuple now that all variables inside the loop + // have been visited. + std::vector ctx_types; + std::vector ctx_idents; + for (const Variable &var : vars_referenced.nodes()) { + ctx_types.push_back(CreatePointer(var.type, AddrSpace::none)); // addr space + // is bpf? + ctx_idents.push_back(var.ident); + } + f.ctx_type = CreateRecord( + "", bpftrace_.structs.AddAnonStruct(ctx_types, ctx_idents)); } void SemanticAnalyser::visit(FieldAccess &acc) diff --git a/src/struct.cpp b/src/struct.cpp index 6c2648c86a42..44149dfd14ff 100644 --- a/src/struct.cpp +++ b/src/struct.cpp @@ -56,25 +56,32 @@ bool Bitfield::operator!=(const Bitfield &other) const return !(*this == other); } -std::unique_ptr Struct::CreateTuple(std::vector fields) +// Creates a struct or tuple with the given field types. +// If field_names is empty then all fields with be created without names. +std::unique_ptr Struct::CreateRecord( + const std::vector &fields, + const std::vector &field_names) { + assert(field_names.empty() || field_names.size() == fields.size()); + // See llvm::StructLayout::StructLayout source - std::unique_ptr tuple(new Struct(0)); + auto record = std::make_unique(0); ssize_t offset = 0; ssize_t struct_align = 1; - for (auto &field : fields) { + for (size_t i = 0; i < fields.size(); i++) { + const auto &field = fields[i]; auto align = field.GetInTupleAlignment(); struct_align = std::max(align, struct_align); auto size = field.GetSize(); auto padding = (align - (offset % align)) % align; if (padding) - tuple->padded = true; + record->padded = true; offset += padding; - tuple->fields.push_back(Field{ - .name = "", + record->fields.push_back(Field{ + .name = field_names.empty() ? "" : std::string{ field_names[i] }, .type = field, .offset = offset, .bitfield = std::nullopt, @@ -85,10 +92,16 @@ std::unique_ptr Struct::CreateTuple(std::vector fields) auto padding = (struct_align - (offset % struct_align)) % struct_align; - tuple->size = offset + padding; - tuple->align = struct_align; + record->size = offset + padding; + record->align = struct_align; - return tuple; + return record; +} + +std::unique_ptr Struct::CreateTuple( + const std::vector &fields) +{ + return CreateRecord(fields, {}); } void Struct::Dump(std::ostream &os) @@ -193,15 +206,24 @@ bool StructManager::Has(const std::string &name) const return struct_map_.find(name) != struct_map_.end(); } -std::weak_ptr StructManager::AddTuple(std::vector fields) +std::weak_ptr StructManager::AddAnonStruct( + const std::vector &fields, + const std::vector &field_names) +{ + auto t = anonymous_types_.insert(Struct::CreateRecord(fields, field_names)); + return *t.first; +} + +std::weak_ptr StructManager::AddTuple( + const std::vector &fields) { - auto t = tuples_.insert(Struct::CreateTuple(std::move(fields))); + auto t = anonymous_types_.insert(Struct::CreateTuple(fields)); return *t.first; } size_t StructManager::GetTuplesCnt() const { - return tuples_.size(); + return anonymous_types_.size(); } const Field *StructManager::GetProbeArg(const ast::Probe &probe, diff --git a/src/struct.h b/src/struct.h index bcb427c9e029..b5083ee58fee 100644 --- a/src/struct.h +++ b/src/struct.h @@ -97,7 +97,11 @@ struct Struct { bool HasFields() const; void ClearFields(); - static std::unique_ptr CreateTuple(std::vector fields); + static std::unique_ptr CreateRecord( + const std::vector &fields, + const std::vector &field_names); + static std::unique_ptr CreateTuple( + const std::vector &fields); void Dump(std::ostream &os); bool operator==(const Struct &rhs) const @@ -166,8 +170,10 @@ class StructManager { bool allow_override = true); bool Has(const std::string &name) const; - // tuples set manipulation - std::weak_ptr AddTuple(std::vector fields); + std::weak_ptr AddAnonStruct( + const std::vector &fields, + const std::vector &field_names); + std::weak_ptr AddTuple(const std::vector &fields); size_t GetTuplesCnt() const; // probe args lookup @@ -176,7 +182,7 @@ class StructManager { private: std::map> struct_map_; - std::unordered_set> tuples_; + std::unordered_set> anonymous_types_; }; } // namespace bpftrace diff --git a/tests/codegen/for_map_variables.cpp b/tests/codegen/for_map_variables.cpp new file mode 100644 index 000000000000..279aee6ad985 --- /dev/null +++ b/tests/codegen/for_map_variables.cpp @@ -0,0 +1,23 @@ +#include "common.h" + +namespace bpftrace::test::codegen { + +TEST(codegen, for_map_variables) +{ + test(R"( + BEGIN + { + @map[16] = 32; + $var1 = 123; + $var2 = "abc"; + $var3 = "def"; + for ($kv : @map) { + $var1++; + print($var3); + } + @len = $var1; + })", + NAME); +} + +} // namespace bpftrace::test::codegen diff --git a/tests/codegen/llvm/for_map_variables.ll b/tests/codegen/llvm/for_map_variables.ll new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/runtime/for b/tests/runtime/for index df9524a91502..7bc51fb23334 100644 --- a/tests/runtime/for +++ b/tests/runtime/for @@ -69,3 +69,23 @@ EXPECT @[0, 0, 1]: 1 EXPECT @[1, 0, 0]: 1 EXPECT_NONE @[0, 1, 0]: 1 TIMEOUT 5 + +NAME variable context read only +PROG BEGIN { @map[0] = 0; $var = 123; for ($kv : @map) { print($var); } exit(); } +EXPECT 123 +TIMEOUT 5 + +NAME variable context update +PROG BEGIN { @map[0] = 0; @map[1] = 1; $var = 123; for ($kv : @map) { print($var); $var *= 2; } print($var); exit(); } +EXPECT_REGEX ^123\n246\n492$ +TIMEOUT 5 + +NAME variable context string +PROG BEGIN { @map[0] = 0; @map[1] = 1; $var = "abc"; for ($kv : @map) { print($var); $var = "def"; } print($var); exit(); } +EXPECT_REGEX ^abc\ndef\ndef$ +TIMEOUT 5 + +NAME variable context multiple +PROG BEGIN { @map[0] = 0; $var1 = 123; $var2 = "abc"; for ($kv : @map) { print(($var1, $var2)); } exit(); } +EXPECT (123, abc) +TIMEOUT 5 diff --git a/tests/semantic_analyser.cpp b/tests/semantic_analyser.cpp index f2255d6de832..5fd8df040dcc 100644 --- a/tests/semantic_analyser.cpp +++ b/tests/semantic_analyser.cpp @@ -197,12 +197,21 @@ void test(std::string_view input, std::string_view expected_ast) Driver driver(*bpftrace); test(*bpftrace, true, driver, input, 0, {}, true, false); - if (!expected_ast.empty() && expected_ast[0] == '\n') + if (expected_ast[0] == '\n') expected_ast.remove_prefix(1); // Remove initial '\n' std::ostringstream out; ast::Printer printer(out); printer.print(driver.root.get()); + + if (expected_ast[0] == '*' && expected_ast[expected_ast.size() - 1] == '*') { + // Remove globs from beginning and end + expected_ast.remove_prefix(1); + expected_ast.remove_suffix(1); + EXPECT_THAT(out.str(), HasSubstr(expected_ast)); + return; + } + EXPECT_EQ(expected_ast, out.str()); } @@ -3431,10 +3440,9 @@ stdin:4:11-15: ERROR: Loop declaration shadows existing variable: $kv )"); } -TEST(semantic_analyser, for_loop_variables) +TEST(semantic_analyser, for_loop_variables_read_only) { - // Read-only - test_error(R"( + test(R"( BEGIN { $var = 0; @map[0] = 1; @@ -3443,54 +3451,70 @@ TEST(semantic_analyser, for_loop_variables) } print($var); })", - R"( -stdin:5:9-19: ERROR: Variables defined outside of a for-loop can not be accessed in the loop's scope - print($var); - ~~~~~~~~~~ -)"); + R"(* + for + ctx + $var :: [int64 *] + decl +*)"); +} - // Modified after loop - test_error(R"( +TEST(semantic_analyser, for_loop_variables_modified_during_loop) +{ + test(R"( BEGIN { $var = 0; @map[0] = 1; for ($kv : @map) { - print($var); + $var++; } - $var = 1; print($var); })", - R"( -stdin:5:9-19: ERROR: Variables defined outside of a for-loop can not be accessed in the loop's scope - print($var); - ~~~~~~~~~~ -)"); + R"(* + for + ctx + $var :: [int64 *] + decl +*)"); +} - // Modified during loop - test_error(R"( +TEST(semantic_analyser, for_loop_variables_created_in_loop) +{ + // $var should not appear in ctx + test(R"( BEGIN { - $var = 0; @map[0] = 1; for ($kv : @map) { - $var++; + $var = 2; + print($var); } - print($var); })", - R"( -stdin:5:9-13: ERROR: Variables defined outside of a for-loop can not be accessed in the loop's scope - $var++; - ~~~~ -)"); + R"(* + for + decl +*)"); +} - // Created in loop +TEST(semantic_analyser, for_loop_variables_multiple) +{ test(R"( BEGIN { @map[0] = 1; + $var1 = 123; + $var2 = "abc"; + $var3 = "def"; for ($kv : @map) { - $var = 2; - print($var); + $var1 = 456; + print($var3); } - })"); + })", + R"(* + for + ctx + $var1 :: [int64 *] + $var3 :: [string[4] *] + decl +*)"); } TEST(semantic_analyser, for_loop_variables_created_in_loop_used_after)