diff --git a/src/ast/codegen_llvm.cpp b/src/ast/codegen_llvm.cpp index b1da28fe39f..04f5ac7883f 100644 --- a/src/ast/codegen_llvm.cpp +++ b/src/ast/codegen_llvm.cpp @@ -1298,40 +1298,11 @@ void CodegenLLVM::visit(Probe &probe) int starting_time_id_ = time_id_; int starting_join_id_ = join_id_; - for (auto &attach_point : *probe.attach_points) { + for (auto attach_point : *probe.attach_points) { current_attach_point_ = attach_point; - std::set matches; - switch (probetype(attach_point->provider)) - { - case ProbeType::kprobe: - case ProbeType::kretprobe: - matches = bpftrace_.find_wildcard_matches(attach_point->target, - attach_point->func, - "/sys/kernel/debug/tracing/available_filter_functions"); - break; - case ProbeType::uprobe: - case ProbeType::uretprobe: - { - auto symbol_stream = std::istringstream(bpftrace_.extract_func_symbols_from_path(attach_point->target)); - matches = bpftrace_.find_wildcard_matches("", attach_point->func, symbol_stream); - break; - } - case ProbeType::tracepoint: - matches = bpftrace_.find_wildcard_matches(attach_point->target, - attach_point->func, - "/sys/kernel/debug/tracing/available_events"); - break; - case ProbeType::usdt: - { - auto usdt_symbol_stream = USDTHelper::probe_stream(bpftrace_.pid_, attach_point->target); - matches = bpftrace_.find_usdt_wildcard_matches(attach_point->ns, attach_point->func, usdt_symbol_stream); - break; - } - default: - std::cerr << "Wildcard matches aren't available on probe type '" - << attach_point->provider << "'" << std::endl; - return; - } + + auto matches = bpftrace_.find_wildcard_matches(*attach_point); + tracepoint_struct_ = ""; for (auto &match_ : matches) { printf_id_ = starting_printf_id_; diff --git a/src/ast/semantic_analyser.cpp b/src/ast/semantic_analyser.cpp index ea69a59d0cb..6ee3d0933ae 100644 --- a/src/ast/semantic_analyser.cpp +++ b/src/ast/semantic_analyser.cpp @@ -168,12 +168,14 @@ void SemanticAnalyser::visit(Builtin &builtin) * 2. sets is_tparg so that codegen does the real type setting after * expansion. */ - std::set matches; - matches = bpftrace_.find_wildcard_matches(attach_point->target, - attach_point->func, - "/sys/kernel/debug/tracing/available_events"); + auto symbol_stream = bpftrace_.get_symbols_from_file( + "/sys/kernel/debug/tracing/available_events"); + auto matches = bpftrace_.find_wildcard_matches(attach_point->target, + attach_point->func, + *symbol_stream); for (auto &match : matches) { - std::string tracepoint_struct = TracepointFormatParser::get_struct_name(attach_point->target, match); + std::string tracepoint_struct = TracepointFormatParser::get_struct_name( + attach_point->target, match); Struct &cstruct = bpftrace_.structs_[tracepoint_struct]; builtin.type = SizedType(Type::cast, cstruct.size, tracepoint_struct); builtin.type.is_pointer = true; @@ -786,9 +788,11 @@ void SemanticAnalyser::visit(FieldAccess &acc) for (AttachPoint *attach_point : *probe_->attach_points) { assert(probetype(attach_point->provider) == ProbeType::tracepoint); - std::set matches = bpftrace_.find_wildcard_matches( - attach_point->target, attach_point->func, + auto symbol_stream = bpftrace_.get_symbols_from_file( "/sys/kernel/debug/tracing/available_events"); + auto matches = bpftrace_.find_wildcard_matches(attach_point->target, + attach_point->func, + *symbol_stream); for (auto &match : matches) { std::string tracepoint_struct = TracepointFormatParser::get_struct_name(attach_point->target, diff --git a/src/bpftrace.cpp b/src/bpftrace.cpp index d10b65f765a..075d0fbff02 100644 --- a/src/bpftrace.cpp +++ b/src/bpftrace.cpp @@ -98,38 +98,15 @@ int BPFtrace::add_probe(ast::Probe &p) if (attach_point->need_expansion && (has_wildcard(attach_point->func) || underspecified_usdt_probe)) { std::set matches; - switch (probetype(attach_point->provider)) + try { - case ProbeType::kprobe: - case ProbeType::kretprobe: - matches = find_wildcard_matches(attach_point->target, - attach_point->func, - "/sys/kernel/debug/tracing/available_filter_functions"); - break; - case ProbeType::uprobe: - case ProbeType::uretprobe: - { - auto symbol_stream = std::istringstream(extract_func_symbols_from_path(attach_point->target)); - matches = find_wildcard_matches("", attach_point->func, symbol_stream); - break; - } - case ProbeType::tracepoint: - matches = find_wildcard_matches(attach_point->target, - attach_point->func, - "/sys/kernel/debug/tracing/available_events"); - break; - case ProbeType::usdt: - { - auto usdt_symbol_stream = USDTHelper::probe_stream(pid_, attach_point->target); - matches = find_usdt_wildcard_matches(attach_point->ns, attach_point->func, usdt_symbol_stream); - break; - } - default: - std::cerr << "Wildcard matches aren't available on probe type '" - << attach_point->provider << "'" << std::endl; - return 1; + matches = find_wildcard_matches(*attach_point); + } + catch (const WildcardException &e) + { + std::cerr << e.what() << std::endl; + return 1; } - attach_funcs.insert(attach_funcs.end(), matches.begin(), matches.end()); } else @@ -175,39 +152,71 @@ int BPFtrace::add_probe(ast::Probe &p) return 0; } -// FIXME should this really be a separate function? -std::set BPFtrace::find_usdt_wildcard_matches(const std::string &prefix, const std::string &func, std::istream &symbol_name_stream) +std::set BPFtrace::find_wildcard_matches( + const ast::AttachPoint &attach_point) const { - // Turn glob into a regex - std::string search_str = func; - if (prefix == "") - search_str = "*:" + func; - else - search_str = prefix + ":" + func; - auto regex_str = "(" + std::regex_replace(search_str, std::regex("\\*"), "[^\\s]*") + ")"; - regex_str = "^" + regex_str + "$"; - - std::regex func_regex(regex_str); - std::smatch match; + std::unique_ptr symbol_stream; + std::string prefix, func; - std::string line; - std::set matches; - while (std::getline(symbol_name_stream, line)) + switch (probetype(attach_point.provider)) { - if (std::regex_search(line, match, func_regex)) + case ProbeType::kprobe: + case ProbeType::kretprobe: { - assert(match.size() == 2); - // skip the ".part.N" kprobe variants, as they can't be traced: - if (std::strstr(match.str(1).c_str(), ".part.") == NULL) - { - matches.insert(match[1]); - } + symbol_stream = get_symbols_from_file( + "/sys/kernel/debug/tracing/available_filter_functions"); + prefix = ""; + func = attach_point.func; + break; + } + case ProbeType::uprobe: + case ProbeType::uretprobe: + { + symbol_stream = std::make_unique( + extract_func_symbols_from_path(attach_point.target)); + prefix = ""; + func = attach_point.func; + break; + } + case ProbeType::tracepoint: + { + symbol_stream = get_symbols_from_file( + "/sys/kernel/debug/tracing/available_events"); + prefix = attach_point.target; + func = attach_point.func; + break; + } + case ProbeType::usdt: + { + symbol_stream = get_symbols_from_usdt(pid_, attach_point.target); + prefix = ""; + if (attach_point.ns == "") + func = "*:" + attach_point.func; + else + func = attach_point.ns + ":" + attach_point.func; + break; + } + default: + { + throw WildcardException("Wildcard matches aren't available on probe type '" + + attach_point.provider + "'"); } } - return matches; + + return find_wildcard_matches(prefix, func, *symbol_stream); } -std::set BPFtrace::find_wildcard_matches(const std::string &prefix, const std::string &func, std::istream &symbol_name_stream) +/* + * Finds all matches of func in the provided input stream. + * + * If an optional prefix is provided, lines must start with it to count as a + * match, but the prefix is stripped from entries in the result set. + * Wildcard tokens ("*") are accepted in func. + */ +std::set BPFtrace::find_wildcard_matches( + const std::string &prefix, + const std::string &func, + std::istream &symbol_stream) const { if (!has_wildcard(func)) return std::set({func}); @@ -220,7 +229,7 @@ std::set BPFtrace::find_wildcard_matches(const std::string &prefix, std::string line; std::set matches; std::string full_prefix = prefix.empty() ? "" : (prefix + ":"); - while (std::getline(symbol_name_stream, line)) + while (std::getline(symbol_stream, line)) { if (!full_prefix.empty()) { if (line.find(full_prefix, 0) != 0) @@ -240,26 +249,39 @@ std::set BPFtrace::find_wildcard_matches(const std::string &prefix, return matches; } -std::set BPFtrace::find_wildcard_matches(const std::string &prefix, const std::string &func, const std::string &file_name) +std::unique_ptr BPFtrace::get_symbols_from_file(const std::string &path) const { - if (!has_wildcard(func)) - return std::set({func}); - std::ifstream file(file_name); - if (file.fail()) + auto file = std::make_unique(path); + if (file->fail()) { - throw std::runtime_error("Could not read symbols from \"" + file_name + "\", err=" + std::to_string(errno)); + throw std::runtime_error("Could not read symbols from " + path + + ": " + strerror(errno)); } - std::stringstream symbol_name_stream; - std::string line; - while (file >> line) + return file; +} + +std::unique_ptr BPFtrace::get_symbols_from_usdt( + int pid, + const std::string &target) const +{ + std::string probes; + usdt_probe_list usdt_probes; + + if (pid > 0) + usdt_probes = USDTHelper::probes_for_pid(pid); + else + usdt_probes = USDTHelper::probes_for_path(target); + + for (auto const& usdt_probe : usdt_probes) { - symbol_name_stream << line << std::endl; + std::string path = std::get(usdt_probe); + std::string provider = std::get(usdt_probe); + std::string fname = std::get(usdt_probe); + probes += provider + ":" + fname + "\n"; } - file.close(); - - return find_wildcard_matches(prefix, func, symbol_name_stream); + return std::make_unique(probes); } int BPFtrace::num_probes() const diff --git a/src/bpftrace.h b/src/bpftrace.h index 5f40bc3050e..4d1655d7548 100644 --- a/src/bpftrace.h +++ b/src/bpftrace.h @@ -48,6 +48,20 @@ inline DebugLevel operator++(DebugLevel& level, int) return level; } +class WildcardException : public std::exception +{ +public: + WildcardException(const std::string &msg) : msg_(msg) {} + + const char *what() const noexcept override + { + return msg_.c_str(); + } + +private: + std::string msg_; +}; + class BPFtrace { public: @@ -70,7 +84,7 @@ class BPFtrace std::string resolve_uid(uintptr_t addr) const; uint64_t resolve_kname(const std::string &name) const; uint64_t resolve_uname(const std::string &name, const std::string &path) const; - std::string extract_func_symbols_from_path(const std::string &path) const; + virtual std::string extract_func_symbols_from_path(const std::string &path) const; std::string resolve_probe(uint64_t probe_id) const; uint64_t resolve_cgroupid(const std::string &path) const; std::vector> get_arg_values(const std::vector &args, uint8_t* arg_data); @@ -101,11 +115,20 @@ class BPFtrace bool demangle_cpp_symbols = true; bool safe_mode = true; - static void sort_by_key(std::vector key_args, - std::vector, std::vector>> &values_by_key); - virtual std::set find_usdt_wildcard_matches(const std::string &prefix, const std::string &func, std::istream &symbol_name_stream); - virtual std::set find_wildcard_matches(const std::string &prefix, const std::string &func, std::istream &symbol_name_stream); - virtual std::set find_wildcard_matches(const std::string &prefix, const std::string &attach_point, const std::string &file_name); + static void sort_by_key( + std::vector key_args, + std::vector, + std::vector>> &values_by_key); + std::set find_wildcard_matches( + const ast::AttachPoint &attach_point) const; + std::set find_wildcard_matches( + const std::string &prefix, + const std::string &func, + std::istream &symbol_stream) const; + virtual std::unique_ptr get_symbols_from_file(const std::string &path) const; + virtual std::unique_ptr get_symbols_from_usdt( + int pid, + const std::string &target) const; protected: std::vector probes_; diff --git a/src/utils.cpp b/src/utils.cpp index f7260fc8154..dadb9e89bdd 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -50,7 +50,11 @@ static void usdt_probe_each(struct bcc_usdt *usdt_probe) usdt_provider_cache[usdt_probe->provider].push_back(std::make_tuple(usdt_probe->bin_path, usdt_probe->provider, usdt_probe->name)); } -usdt_probe_entry USDTHelper::find(int pid, std::string target, std::string provider, std::string name) +usdt_probe_entry USDTHelper::find( + int pid, + const std::string &target, + const std::string &provider, + const std::string &name) { if (pid > 0) @@ -68,7 +72,7 @@ usdt_probe_entry USDTHelper::find(int pid, std::string target, std::string provi } } -usdt_probe_list USDTHelper::probes_for_provider(std::string provider) +usdt_probe_list USDTHelper::probes_for_provider(const std::string &provider) { usdt_probe_list probes; @@ -93,7 +97,7 @@ usdt_probe_list USDTHelper::probes_for_pid(int pid) return probes; } -usdt_probe_list USDTHelper::probes_for_path(std::string path) +usdt_probe_list USDTHelper::probes_for_path(const std::string &path) { read_probes_for_path(path); @@ -105,27 +109,6 @@ usdt_probe_list USDTHelper::probes_for_path(std::string path) return probes; } -std::istringstream USDTHelper::probe_stream(int pid, std::string target) -{ - std::string probes; - usdt_probe_list usdt_probes; - - if (pid > 0) - usdt_probes = probes_for_pid(pid); - else - usdt_probes = probes_for_path(target); - - for (auto const& usdt_probe : usdt_probes) - { - std::string path = std::get(usdt_probe); - std::string provider = std::get(usdt_probe); - std::string fname = std::get(usdt_probe); - probes += provider + ":" + fname + "\n"; - } - - return std::istringstream(probes); -} - void USDTHelper::read_probes_for_pid(int pid) { if(provider_cache_loaded) @@ -149,7 +132,7 @@ void USDTHelper::read_probes_for_pid(int pid) } } -void USDTHelper::read_probes_for_path(std::string path) +void USDTHelper::read_probes_for_path(const std::string &path) { if(provider_cache_loaded) return; diff --git a/src/utils.h b/src/utils.h index 9be19546591..e5fa1d29c58 100644 --- a/src/utils.h +++ b/src/utils.h @@ -17,14 +17,16 @@ typedef std::vector usdt_probe_list; class USDTHelper { public: - static usdt_probe_entry find(int pid, std::string target, std::string provider, std::string name); - static usdt_probe_list probes_for_provider(std::string provider); + static usdt_probe_entry find( + int pid, + const std::string &target, + const std::string &provider, + const std::string &name); + static usdt_probe_list probes_for_provider(const std::string &provider); static usdt_probe_list probes_for_pid(int pid); - static usdt_probe_list probes_for_path(std::string path); - static std::istringstream probe_stream(int pid, std::string target); -private: + static usdt_probe_list probes_for_path(const std::string &path); static void read_probes_for_pid(int pid); - static void read_probes_for_path(std::string path); + static void read_probes_for_path(const std::string &path); }; // Hack used to suppress build warning related to #474 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e27878a14af..884eba027b9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,6 +8,7 @@ add_executable(bpftrace_test clang_parser.cpp codegen.cpp main.cpp + mocks.cpp parser.cpp probe.cpp semantic_analyser.cpp diff --git a/tests/bpftrace.cpp b/tests/bpftrace.cpp index 1278b7f4e29..ff8df599445 100644 --- a/tests/bpftrace.cpp +++ b/tests/bpftrace.cpp @@ -1,30 +1,13 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "bpftrace.h" +#include "mocks.h" namespace bpftrace { namespace test { namespace bpftrace { -class MockBPFtrace : public BPFtrace { -public: - MOCK_METHOD3(find_wildcard_matches, std::set( - const std::string &prefix, - const std::string &func, - const std::string &file_name)); - std::vector get_probes() - { - return probes_; - } - std::vector get_special_probes() - { - return special_probes_; - } -}; - -using ::testing::_; using ::testing::ContainerEq; -using ::testing::Return; using ::testing::StrictMock; void check_kprobe(Probe &p, const std::string &attach_point, const std::string &orig_name) @@ -106,9 +89,9 @@ TEST(bpftrace, add_begin_probe) ast::Probe probe(&attach_points, nullptr, nullptr); StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(0U, bpftrace.get_probes().size()); - EXPECT_EQ(1U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(0U, bpftrace.get_probes().size()); + ASSERT_EQ(1U, bpftrace.get_special_probes().size()); check_special_probe(bpftrace.get_special_probes().at(0), "BEGIN_trigger", "BEGIN"); } @@ -120,9 +103,9 @@ TEST(bpftrace, add_end_probe) ast::Probe probe(&attach_points, nullptr, nullptr); StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(0U, bpftrace.get_probes().size()); - EXPECT_EQ(1U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(0U, bpftrace.get_probes().size()); + ASSERT_EQ(1U, bpftrace.get_special_probes().size()); check_special_probe(bpftrace.get_special_probes().at(0), "END_trigger", "END"); } @@ -134,9 +117,9 @@ TEST(bpftrace, add_probes_single) ast::Probe probe(&attach_points, nullptr, nullptr); StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); check_kprobe(bpftrace.get_probes().at(0), "sys_read", "kprobe:sys_read"); } @@ -149,41 +132,15 @@ TEST(bpftrace, add_probes_multiple) ast::Probe probe(&attach_points, nullptr, nullptr); StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(2U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(2U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); std::string probe_orig_name = "kprobe:sys_read,kprobe:sys_write"; check_kprobe(bpftrace.get_probes().at(0), "sys_read", probe_orig_name); check_kprobe(bpftrace.get_probes().at(1), "sys_write", probe_orig_name); } -TEST(bpftrace, add_probes_character_class) -{ - ast::AttachPoint a1("kprobe", "[Ss]y[Ss]_read"); - ast::AttachPoint a2("kprobe", "sys_write"); - ast::AttachPointList attach_points = { &a1, &a2 }; - ast::Probe probe(&attach_points, nullptr, nullptr); - - StrictMock bpftrace; - std::set matches = { "SyS_read", "sys_read" }; - ON_CALL(bpftrace, find_wildcard_matches(_, _, _)) - .WillByDefault(Return(matches)); - EXPECT_CALL(bpftrace, - find_wildcard_matches("", "[Ss]y[Ss]_read", - "/sys/kernel/debug/tracing/available_filter_functions")) - .Times(1); - - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(3U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); - - std::string probe_orig_name = "kprobe:[Ss]y[Ss]_read,kprobe:sys_write"; - check_kprobe(bpftrace.get_probes().at(0), "SyS_read", probe_orig_name); - check_kprobe(bpftrace.get_probes().at(1), "sys_read", probe_orig_name); - check_kprobe(bpftrace.get_probes().at(2), "sys_write", probe_orig_name); -} - TEST(bpftrace, add_probes_wildcard) { ast::AttachPoint a1("kprobe", "sys_read"); @@ -192,50 +149,44 @@ TEST(bpftrace, add_probes_wildcard) ast::AttachPointList attach_points = { &a1, &a2, &a3 }; ast::Probe probe(&attach_points, nullptr, nullptr); - StrictMock bpftrace; - std::set matches = { "my_one", "my_two" }; - ON_CALL(bpftrace, find_wildcard_matches(_, _, _)) - .WillByDefault(Return(matches)); - EXPECT_CALL(bpftrace, - find_wildcard_matches("", "my_*", + auto bpftrace = get_strict_mock_bpftrace(); + EXPECT_CALL(*bpftrace, + get_symbols_from_file( "/sys/kernel/debug/tracing/available_filter_functions")) .Times(1); - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(4U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(4U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); std::string probe_orig_name = "kprobe:sys_read,kprobe:my_*,kprobe:sys_write"; - check_kprobe(bpftrace.get_probes().at(0), "sys_read", probe_orig_name); - check_kprobe(bpftrace.get_probes().at(1), "my_one", probe_orig_name); - check_kprobe(bpftrace.get_probes().at(2), "my_two", probe_orig_name); - check_kprobe(bpftrace.get_probes().at(3), "sys_write", probe_orig_name); + check_kprobe(bpftrace->get_probes().at(0), "sys_read", probe_orig_name); + check_kprobe(bpftrace->get_probes().at(1), "my_one", probe_orig_name); + check_kprobe(bpftrace->get_probes().at(2), "my_two", probe_orig_name); + check_kprobe(bpftrace->get_probes().at(3), "sys_write", probe_orig_name); } TEST(bpftrace, add_probes_wildcard_no_matches) { ast::AttachPoint a1("kprobe", "sys_read"); - ast::AttachPoint a2("kprobe", "my_*"); + ast::AttachPoint a2("kprobe", "not_here_*"); ast::AttachPoint a3("kprobe", "sys_write"); ast::AttachPointList attach_points = { &a1, &a2, &a3 }; ast::Probe probe(&attach_points, nullptr, nullptr); - StrictMock bpftrace; - std::set matches; - ON_CALL(bpftrace, find_wildcard_matches(_, _, _)) - .WillByDefault(Return(matches)); - EXPECT_CALL(bpftrace, - find_wildcard_matches("", "my_*", + auto bpftrace = get_strict_mock_bpftrace(); + EXPECT_CALL(*bpftrace, + get_symbols_from_file( "/sys/kernel/debug/tracing/available_filter_functions")) .Times(1); - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(2U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(2U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); - std::string probe_orig_name = "kprobe:sys_read,kprobe:my_*,kprobe:sys_write"; - check_kprobe(bpftrace.get_probes().at(0), "sys_read", probe_orig_name); - check_kprobe(bpftrace.get_probes().at(1), "sys_write", probe_orig_name); + std::string probe_orig_name = "kprobe:sys_read,kprobe:not_here_*,kprobe:sys_write"; + check_kprobe(bpftrace->get_probes().at(0), "sys_read", probe_orig_name); + check_kprobe(bpftrace->get_probes().at(1), "sys_write", probe_orig_name); } TEST(bpftrace, add_probes_uprobe) @@ -246,64 +197,115 @@ TEST(bpftrace, add_probes_uprobe) StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); check_uprobe(bpftrace.get_probes().at(0), "/bin/sh", "foo", "uprobe:/bin/sh:foo"); } -TEST(bpftrace, add_probes_usdt) +TEST(bpftrace, add_probes_uprobe_wildcard) { - ast::AttachPoint a("usdt", "/bin/sh", "foo", "bar", false); + ast::AttachPoint a("uprobe", "/bin/sh", "*open", true); ast::AttachPointList attach_points = { &a }; ast::Probe probe(&attach_points, nullptr, nullptr); - StrictMock bpftrace; + auto bpftrace = get_strict_mock_bpftrace(); + EXPECT_CALL(*bpftrace, extract_func_symbols_from_path("/bin/sh")).Times(1); + + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(2U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); - check_usdt(bpftrace.get_probes().at(0), "/bin/sh", "foo", "bar", "usdt:/bin/sh:foo:bar"); + std::string probe_orig_name = "uprobe:/bin/sh:*open"; + check_uprobe(bpftrace->get_probes().at(0), "/bin/sh", "first_open", probe_orig_name); + check_uprobe(bpftrace->get_probes().at(1), "/bin/sh", "second_open", probe_orig_name); } -TEST(bpftrace, add_probes_uprobe_wildcard) +TEST(bpftrace, add_probes_uprobe_wildcard_no_matches) { - ast::AttachPoint a("uprobe", "/bin/grep", "*open", true); + ast::AttachPoint a("uprobe", "/bin/sh", "foo*", true); ast::AttachPointList attach_points = { &a }; ast::Probe probe(&attach_points, nullptr, nullptr); - StrictMock bpftrace; + auto bpftrace = get_strict_mock_bpftrace(); + EXPECT_CALL(*bpftrace, extract_func_symbols_from_path("/bin/sh")).Times(1); - EXPECT_EQ(bpftrace.add_probe(probe), 0); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(0U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); } -TEST(bpftrace, add_probes_uprobe_wildcard_no_matches) +TEST(bpftrace, add_probes_uprobe_string_literal) { - ast::AttachPoint a("uprobe", "/bin/sh", "foo*", true); + ast::AttachPoint a("uprobe", "/bin/sh", "foo*", false); ast::AttachPointList attach_points = { &a }; ast::Probe probe(&attach_points, nullptr, nullptr); StrictMock bpftrace; - EXPECT_EQ(bpftrace.add_probe(probe), 0); - EXPECT_EQ(0U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); + check_uprobe(bpftrace.get_probes().at(0), "/bin/sh", "foo*", "uprobe:/bin/sh:foo*"); } -TEST(bpftrace, add_probes_uprobe_string_literal) +TEST(bpftrace, add_probes_usdt) { - ast::AttachPoint a("uprobe", "/bin/sh", "foo*", false); + ast::AttachPoint a("usdt", "/bin/sh", "prov1", "mytp", false); ast::AttachPointList attach_points = { &a }; ast::Probe probe(&attach_points, nullptr, nullptr); StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); - check_uprobe(bpftrace.get_probes().at(0), "/bin/sh", "foo*", "uprobe:/bin/sh:foo*"); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); + check_usdt(bpftrace.get_probes().at(0), + "/bin/sh", "prov1", "mytp", + "usdt:/bin/sh:prov1:mytp"); +} + +TEST(bpftrace, add_probes_usdt_wildcard) +{ + ast::AttachPoint a("usdt", "/bin/sh", "prov*", "tp*", true); + ast::AttachPointList attach_points = { &a }; + ast::Probe probe(&attach_points, nullptr, nullptr); + + auto bpftrace = get_strict_mock_bpftrace(); + EXPECT_CALL(*bpftrace, get_symbols_from_usdt(0, "/bin/sh")).Times(1); + + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(3U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); + check_usdt(bpftrace->get_probes().at(0), + "/bin/sh", "prov1", "tp1", + "usdt:/bin/sh:prov1:tp1"); + check_usdt(bpftrace->get_probes().at(1), + "/bin/sh", "prov1", "tp2", + "usdt:/bin/sh:prov1:tp2"); + check_usdt(bpftrace->get_probes().at(2), + "/bin/sh", "prov2", "tp", + "usdt:/bin/sh:prov2:tp"); +} + +TEST(bpftrace, add_probes_usdt_empty_namespace) +{ + ast::AttachPoint a("usdt", "/bin/sh", "", "tp", true); + ast::AttachPointList attach_points = { &a }; + ast::Probe probe(&attach_points, nullptr, nullptr); + + auto bpftrace = get_strict_mock_bpftrace(); + EXPECT_CALL(*bpftrace, get_symbols_from_usdt(0, "/bin/sh")).Times(1); + + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(2U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); + check_usdt(bpftrace->get_probes().at(0), + "/bin/sh", "nahprov", "tp", + "usdt:/bin/sh:nahprov:tp"); + check_usdt(bpftrace->get_probes().at(1), + "/bin/sh", "prov2", "tp", + "usdt:/bin/sh:prov2:tp"); } TEST(bpftrace, add_probes_tracepoint) @@ -314,9 +316,9 @@ TEST(bpftrace, add_probes_tracepoint) StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); std::string probe_orig_name = "tracepoint:sched:sched_switch"; check_tracepoint(bpftrace.get_probes().at(0), "sched", "sched_switch", probe_orig_name); @@ -328,22 +330,19 @@ TEST(bpftrace, add_probes_tracepoint_wildcard) ast::AttachPointList attach_points = { &a }; ast::Probe probe(&attach_points, nullptr, nullptr); - StrictMock bpftrace; + auto bpftrace = get_strict_mock_bpftrace(); std::set matches = { "sched_one", "sched_two" }; - ON_CALL(bpftrace, find_wildcard_matches(_, _, _)) - .WillByDefault(Return(matches)); - EXPECT_CALL(bpftrace, - find_wildcard_matches("sched", "sched_*", - "/sys/kernel/debug/tracing/available_events")) + EXPECT_CALL(*bpftrace, + get_symbols_from_file("/sys/kernel/debug/tracing/available_events")) .Times(1); - EXPECT_EQ(bpftrace.add_probe(probe), 0); - EXPECT_EQ(2U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(2U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); std::string probe_orig_name = "tracepoint:sched:sched_*"; - check_tracepoint(bpftrace.get_probes().at(0), "sched", "sched_one", probe_orig_name); - check_tracepoint(bpftrace.get_probes().at(1), "sched", "sched_two", probe_orig_name); + check_tracepoint(bpftrace->get_probes().at(0), "sched", "sched_one", probe_orig_name); + check_tracepoint(bpftrace->get_probes().at(1), "sched", "sched_two", probe_orig_name); } TEST(bpftrace, add_probes_tracepoint_wildcard_no_matches) @@ -352,18 +351,14 @@ TEST(bpftrace, add_probes_tracepoint_wildcard_no_matches) ast::AttachPointList attach_points = { &a }; ast::Probe probe(&attach_points, nullptr, nullptr); - StrictMock bpftrace; - std::set matches; - ON_CALL(bpftrace, find_wildcard_matches(_, _, _)) - .WillByDefault(Return(matches)); - EXPECT_CALL(bpftrace, - find_wildcard_matches("typo", "typo_*", - "/sys/kernel/debug/tracing/available_events")) + auto bpftrace = get_strict_mock_bpftrace(); + EXPECT_CALL(*bpftrace, + get_symbols_from_file("/sys/kernel/debug/tracing/available_events")) .Times(1); - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(0U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace->add_probe(probe)); + ASSERT_EQ(0U, bpftrace->get_probes().size()); + ASSERT_EQ(0U, bpftrace->get_special_probes().size()); } TEST(bpftrace, add_probes_profile) @@ -374,9 +369,9 @@ TEST(bpftrace, add_probes_profile) StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); std::string probe_orig_name = "profile:ms:997"; check_profile(bpftrace.get_probes().at(0), "ms", 997, probe_orig_name); @@ -390,9 +385,9 @@ TEST(bpftrace, add_probes_interval) StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); std::string probe_orig_name = "interval:s:1"; check_interval(bpftrace.get_probes().at(0), "s", 1, probe_orig_name); @@ -406,9 +401,9 @@ TEST(bpftrace, add_probes_software) StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); std::string probe_orig_name = "software:faults:1000"; check_software(bpftrace.get_probes().at(0), "faults", 1000, probe_orig_name); @@ -422,9 +417,9 @@ TEST(bpftrace, add_probes_hardware) StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); - EXPECT_EQ(1U, bpftrace.get_probes().size()); - EXPECT_EQ(0U, bpftrace.get_special_probes().size()); + ASSERT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(1U, bpftrace.get_probes().size()); + ASSERT_EQ(0U, bpftrace.get_special_probes().size()); std::string probe_orig_name = "hardware:cache-references:1000000"; check_hardware(bpftrace.get_probes().at(0), "cache-references", 1000000, probe_orig_name); @@ -438,7 +433,7 @@ TEST(bpftrace, invalid_provider) StrictMock bpftrace; - EXPECT_EQ(0, bpftrace.add_probe(probe)); + ASSERT_EQ(0, bpftrace.add_probe(probe)); } std::pair, std::vector> key_value_pair_int(std::vector key, int val) diff --git a/tests/mocks.cpp b/tests/mocks.cpp new file mode 100644 index 00000000000..725adc488e7 --- /dev/null +++ b/tests/mocks.cpp @@ -0,0 +1,72 @@ +#include "mocks.h" + +namespace bpftrace { +namespace test { + +using ::testing::_; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::StrictMock; + +void setup_mock_bpftrace(MockBPFtrace &bpftrace) +{ + ON_CALL(bpftrace, + get_symbols_from_file("/sys/kernel/debug/tracing/available_filter_functions")) + .WillByDefault([](const std::string &) + { + std::string ksyms = "SyS_read\n" + "sys_read\n" + "sys_write\n" + "my_one\n" + "my_two\n"; + auto myval = std::unique_ptr(new std::istringstream(ksyms)); + printf("doing ok\n"); + return myval; + }); + + ON_CALL(bpftrace, + get_symbols_from_file("/sys/kernel/debug/tracing/available_events")) + .WillByDefault([](const std::string &) + { + std::string tracepoints = "sched:sched_one\n" + "sched:sched_two\n" + "sched:foo\n" + "notsched:bar\n"; + return std::unique_ptr(new std::istringstream(tracepoints)); + }); + + std::string usyms = "first_open\n" + "second_open\n" + "open_as_well\n" + "something_else\n"; + ON_CALL(bpftrace, extract_func_symbols_from_path(_)) + .WillByDefault(Return(usyms)); + + ON_CALL(bpftrace, get_symbols_from_usdt(_, _)) + .WillByDefault([](int, const std::string &) + { + std::string usdt_syms = "prov1:tp1\n" + "prov1:tp2\n" + "prov2:tp\n" + "prov2:notatp\n" + "nahprov:tp\n"; + return std::unique_ptr(new std::istringstream(usdt_syms)); + }); +} + +std::unique_ptr get_mock_bpftrace() +{ + auto bpftrace = std::make_unique>(); + setup_mock_bpftrace(*bpftrace); + return bpftrace; +} + +std::unique_ptr get_strict_mock_bpftrace() +{ + auto bpftrace = std::make_unique>(); + setup_mock_bpftrace(*bpftrace); + return bpftrace; +} + +} // namespace test +} // namespace bpftrace diff --git a/tests/mocks.h b/tests/mocks.h new file mode 100644 index 00000000000..f8c6a34805e --- /dev/null +++ b/tests/mocks.h @@ -0,0 +1,29 @@ +#include "gmock/gmock.h" +#include "bpftrace.h" + +namespace bpftrace { +namespace test { + +class MockBPFtrace : public BPFtrace { +public: + MOCK_CONST_METHOD1(get_symbols_from_file, + std::unique_ptr(const std::string &path)); + MOCK_CONST_METHOD2(get_symbols_from_usdt, + std::unique_ptr(int pid, const std::string &target)); + MOCK_CONST_METHOD1(extract_func_symbols_from_path, + std::string(const std::string &path)); + std::vector get_probes() + { + return probes_; + } + std::vector get_special_probes() + { + return special_probes_; + } +}; + +std::unique_ptr get_mock_bpftrace(); +std::unique_ptr get_strict_mock_bpftrace(); + +} // namespace test +} // namespace bpftrace diff --git a/tests/probe.cpp b/tests/probe.cpp index 7303c69e290..c37c654d521 100644 --- a/tests/probe.cpp +++ b/tests/probe.cpp @@ -6,14 +6,12 @@ #include "codegen_llvm.h" #include "driver.h" #include "fake_map.h" +#include "mocks.h" #include "semantic_analyser.h" -namespace bpftrace -{ -namespace test -{ -namespace probe -{ +namespace bpftrace { +namespace test { +namespace probe { using bpftrace::ast::AttachPoint; using bpftrace::ast::AttachPointList; @@ -21,20 +19,20 @@ using bpftrace::ast::Probe; void gen_bytecode(const std::string &input, std::stringstream &out) { - BPFtrace bpftrace; - Driver driver(bpftrace); + auto bpftrace = get_mock_bpftrace(); + Driver driver(*bpftrace); FakeMap::next_mapfd_ = 1; ASSERT_EQ(driver.parse_str(input), 0); ClangParser clang; - clang.parse(driver.root_, bpftrace); + clang.parse(driver.root_, *bpftrace); - ast::SemanticAnalyser semantics(driver.root_, bpftrace); + ast::SemanticAnalyser semantics(driver.root_, *bpftrace); ASSERT_EQ(semantics.analyse(), 0); ASSERT_EQ(semantics.create_maps(true), 0); - ast::CodegenLLVM codegen(driver.root_, bpftrace); + ast::CodegenLLVM codegen(driver.root_, *bpftrace); codegen.compile(DebugLevel::kDebug, out); } diff --git a/tests/semantic_analyser.cpp b/tests/semantic_analyser.cpp index 2f582b4906e..7613fc04aaa 100644 --- a/tests/semantic_analyser.cpp +++ b/tests/semantic_analyser.cpp @@ -3,6 +3,7 @@ #include "bpftrace.h" #include "clang_parser.h" #include "driver.h" +#include "mocks.h" #include "semantic_analyser.h" namespace bpftrace { @@ -46,17 +47,17 @@ void test(Driver &driver, int expected_result=0, bool safe_mode = true) { - BPFtrace bpftrace; - test(bpftrace, driver, input, expected_result, safe_mode); + auto bpftrace = get_mock_bpftrace(); + test(*bpftrace, driver, input, expected_result, safe_mode); } void test(const std::string &input, int expected_result=0, bool safe_mode = true) { - BPFtrace bpftrace; - Driver driver(bpftrace); - test(bpftrace, driver, input, expected_result, safe_mode); + auto bpftrace = get_mock_bpftrace(); + Driver driver(*bpftrace); + test(*bpftrace, driver, input, expected_result, safe_mode); } TEST(semantic_analyser, builtin_variables)