diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 33ac27ff5ca00..676df8c6516be 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -511,7 +511,7 @@ struct test_result { }; // Printer classes for different output formats -enum class test_status_t { NOT_SUPPORTED, OK, FAIL }; +enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED }; struct test_operation_info { std::string op_name; @@ -687,6 +687,8 @@ struct printer { virtual void print_backend_status(const backend_status_info & info) { (void) info; } virtual void print_overall_summary(const overall_summary_info & info) { (void) info; } + + virtual void print_failed_tests(const std::vector & failed_tests) { (void) failed_tests; } }; struct console_printer : public printer { @@ -804,6 +806,17 @@ struct console_printer : public printer { } } + void print_failed_tests(const std::vector & failed_tests) override { + if (failed_tests.empty()) { + return; + } + + printf("\nFailing tests:\n"); + for (const auto & test_name : failed_tests) { + printf(" %s\n", test_name.c_str()); + } + } + private: void print_test_console(const test_result & result) { printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); @@ -1056,6 +1069,8 @@ struct test_case { std::vector sentinels; + std::string current_op_name; + void add_sentinel(ggml_context * ctx) { if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { return; @@ -1127,7 +1142,10 @@ struct test_case { } } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) { + test_status_t eval(ggml_backend_t backend1, + ggml_backend_t backend2, + const char * op_names_filter, + printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -1144,11 +1162,12 @@ struct test_case { add_sentinel(ctx); ggml_tensor * out = build_graph(ctx); - std::string current_op_name = op_desc(out); + current_op_name = op_desc(out); + if (!matches_filter(out, op_names_filter)) { //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); - return true; + return test_status_t::SKIPPED; } // check if the backends support the ops @@ -1172,7 +1191,7 @@ struct test_case { } ggml_free(ctx); - return true; + return test_status_t::NOT_SUPPORTED; } // post-graph sentinel @@ -1184,7 +1203,7 @@ struct test_case { if (buf == NULL) { printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); ggml_free(ctx); - return false; + return test_status_t::FAIL; } // build graph @@ -1289,7 +1308,7 @@ struct test_case { output_printer->print_test_result(result); } - return test_passed; + return test_passed ? test_status_t::OK : test_status_t::FAIL; } bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) { @@ -1306,7 +1325,7 @@ struct test_case { GGML_ASSERT(ctx); ggml_tensor * out = build_graph(ctx.get()); - std::string current_op_name = op_desc(out); + current_op_name = op_desc(out); if (!matches_filter(out, op_names_filter)) { //printf(" %s: skipping\n", op_desc(out).c_str()); return true; @@ -1435,8 +1454,9 @@ struct test_case { ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); - ggml_tensor * out = build_graph(ctx.get()); - std::string current_op_name = op_desc(out); + ggml_tensor * out = build_graph(ctx.get()); + current_op_name = op_desc(out); + if (!matches_filter(out, op_names_filter)) { return true; } @@ -7356,16 +7376,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } size_t n_ok = 0; + size_t tests_run = 0; + std::vector failed_tests; for (auto & test : test_cases) { - if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) { + test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer); + if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) { + continue; + } + tests_run++; + if (status == test_status_t::OK) { n_ok++; + } else if (status == test_status_t::FAIL) { + failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")"); } } - output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false)); + output_printer->print_summary(test_summary_info(n_ok, tests_run, false)); + output_printer->print_failed_tests(failed_tests); ggml_backend_free(backend_cpu); - return n_ok == test_cases.size(); + return n_ok == tests_run; } if (mode == MODE_GRAD) {