diff --git a/src/core/RuntimeErrorCollector.cpp b/src/core/RuntimeErrorCollector.cpp index b08a463c58c..55d6f469671 100644 --- a/src/core/RuntimeErrorCollector.cpp +++ b/src/core/RuntimeErrorCollector.cpp @@ -36,11 +36,10 @@ RuntimeErrorCollector::RuntimeErrorCollector(boost::mpi::communicator comm) : m_comm(std::move(comm)) {} RuntimeErrorCollector::~RuntimeErrorCollector() { - if (!m_errors.empty()) + if (!m_errors.empty()) { + /* Print remaining error messages on destruction */ std::cerr << "There were unhandled errors.\n"; - /* Print remaining error messages on destruction */ - for (auto const &e : m_errors) { - std::cerr << e.format() << std::endl; + flush(); } } @@ -108,6 +107,13 @@ int RuntimeErrorCollector::count(RuntimeError::ErrorLevel level) { void RuntimeErrorCollector::clear() { m_errors.clear(); } +void RuntimeErrorCollector::flush() { + for (auto const &e : m_errors) { + std::cerr << e.format() << std::endl; + } + this->clear(); +} + std::vector RuntimeErrorCollector::gather() { std::vector all_errors{}; std::swap(all_errors, m_errors); diff --git a/src/core/RuntimeErrorCollector.hpp b/src/core/RuntimeErrorCollector.hpp index 388a66eee24..e99a4b9bebb 100644 --- a/src/core/RuntimeErrorCollector.hpp +++ b/src/core/RuntimeErrorCollector.hpp @@ -74,6 +74,11 @@ class RuntimeErrorCollector { */ void clear(); + /** + * @brief Flush error messages to standard error. + */ + void flush(); + std::vector gather(); void gather_local(); diff --git a/src/core/errorhandling.cpp b/src/core/errorhandling.cpp index edd346882d5..2373237c958 100644 --- a/src/core/errorhandling.cpp +++ b/src/core/errorhandling.cpp @@ -62,7 +62,7 @@ RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level, return {*runtimeErrorCollector, level, file, line, function}; } -void mpi_gather_runtime_errors_local() { +static void mpi_gather_runtime_errors_local() { runtimeErrorCollector->gather_local(); } @@ -89,3 +89,7 @@ int check_runtime_errors(boost::mpi::communicator const &comm) { return boost::mpi::all_reduce(comm, check_runtime_errors_local(), std::plus()); } + +void flush_runtime_errors_local() { + ErrorHandling::runtimeErrorCollector->flush(); +} diff --git a/src/core/errorhandling.hpp b/src/core/errorhandling.hpp index 132783cc7e6..e73d9d1d336 100644 --- a/src/core/errorhandling.hpp +++ b/src/core/errorhandling.hpp @@ -72,6 +72,14 @@ int check_runtime_errors(boost::mpi::communicator const &comm); */ int check_runtime_errors_local(); +/** + * @brief Flush runtime errors to standard error on the local node. + * This is used to clear pending runtime error messages when the + * call site is handling an exception that needs to be re-thrown + * instead of being queued as an additional runtime error message. + */ +void flush_runtime_errors_local(); + namespace ErrorHandling { /** * @brief Initialize the error collection system. diff --git a/src/script_interface/CMakeLists.txt b/src/script_interface/CMakeLists.txt index 956f697fa62..b906335fb3d 100644 --- a/src/script_interface/CMakeLists.txt +++ b/src/script_interface/CMakeLists.txt @@ -1,7 +1,7 @@ add_library( ScriptInterface SHARED initialize.cpp ObjectHandle.cpp object_container_mpi_guard.cpp - GlobalContext.cpp ContextManager.cpp) + GlobalContext.cpp ContextManager.cpp ParallelExceptionHandler.cpp) add_subdirectory(accumulators) add_subdirectory(bond_breakage) diff --git a/src/script_interface/Context.hpp b/src/script_interface/Context.hpp index 3593cefe064..e4ac4becb9e 100644 --- a/src/script_interface/Context.hpp +++ b/src/script_interface/Context.hpp @@ -98,6 +98,7 @@ class Context : public std::enable_shared_from_this { virtual boost::string_ref name(const ObjectHandle *o) const = 0; virtual bool is_head_node() const = 0; + virtual void parallel_try_catch(std::function const &cb) const = 0; virtual ~Context() = default; }; diff --git a/src/script_interface/ContextManager.cpp b/src/script_interface/ContextManager.cpp index 691a39f9676..6eb37948526 100644 --- a/src/script_interface/ContextManager.cpp +++ b/src/script_interface/ContextManager.cpp @@ -55,8 +55,8 @@ std::string ContextManager::serialize(const ObjectHandle *o) const { ContextManager::ContextManager(Communication::MpiCallbacks &callbacks, const Utils::Factory &factory) { - auto const mpi_rank = callbacks.comm().rank(); - auto local_context = std::make_shared(factory, mpi_rank); + auto local_context = + std::make_shared(factory, callbacks.comm()); /* If there is only one node, we can treat all objects as local, and thus * never invoke any callback. */ diff --git a/src/script_interface/ContextManager.hpp b/src/script_interface/ContextManager.hpp index 40d0906ce3f..d2a07dc1f78 100644 --- a/src/script_interface/ContextManager.hpp +++ b/src/script_interface/ContextManager.hpp @@ -36,7 +36,7 @@ #include "Context.hpp" #include "Variant.hpp" -#include "MpiCallbacks.hpp" +#include "core/MpiCallbacks.hpp" #include diff --git a/src/script_interface/GlobalContext.hpp b/src/script_interface/GlobalContext.hpp index 5dd712f83e5..cf72beb5d8c 100644 --- a/src/script_interface/GlobalContext.hpp +++ b/src/script_interface/GlobalContext.hpp @@ -29,16 +29,19 @@ #include "Context.hpp" #include "LocalContext.hpp" -#include "MpiCallbacks.hpp" #include "ObjectHandle.hpp" +#include "ParallelExceptionHandler.hpp" #include "packed_variant.hpp" +#include "core/MpiCallbacks.hpp" + #include #include #include #include +#include #include #include #include @@ -61,7 +64,6 @@ namespace ScriptInterface { class GlobalContext : public Context { using ObjectId = std::size_t; -private: /* Instances on this node that are managed by the * head node. */ std::unordered_map m_local_objects; @@ -70,7 +72,8 @@ class GlobalContext : public Context { bool m_is_head_node; -private: + ParallelExceptionHandler m_parallel_exception_handler; + Communication::CallbackHandle cb_make_handle; @@ -87,6 +90,7 @@ class GlobalContext : public Context { std::shared_ptr node_local_context) : m_local_objects(), m_node_local_context(std::move(node_local_context)), m_is_head_node(callbacks.comm().rank() == 0), + m_parallel_exception_handler(callbacks.comm()), cb_make_handle(&callbacks, [this](ObjectId id, const std::string &name, const PackedMap ¶meters) { @@ -162,6 +166,9 @@ class GlobalContext : public Context { boost::string_ref name(const ObjectHandle *o) const override; bool is_head_node() const override { return m_is_head_node; } + void parallel_try_catch(std::function const &cb) const override { + m_parallel_exception_handler.parallel_try_catch(cb); + } }; } // namespace ScriptInterface diff --git a/src/script_interface/LocalContext.hpp b/src/script_interface/LocalContext.hpp index e057c26d0f2..617cd321d1b 100644 --- a/src/script_interface/LocalContext.hpp +++ b/src/script_interface/LocalContext.hpp @@ -21,11 +21,15 @@ #include "Context.hpp" #include "ObjectHandle.hpp" +#include "ParallelExceptionHandler.hpp" #include +#include + #include #include +#include #include #include @@ -39,10 +43,13 @@ namespace ScriptInterface { class LocalContext : public Context { Utils::Factory m_factory; bool m_is_head_node; + ParallelExceptionHandler m_parallel_exception_handler; public: - LocalContext(Utils::Factory factory, int mpi_rank) - : m_factory(std::move(factory)), m_is_head_node(mpi_rank == 0) {} + LocalContext(Utils::Factory factory, + boost::mpi::communicator const &comm) + : m_factory(std::move(factory)), m_is_head_node(comm.rank() == 0), + m_parallel_exception_handler(comm) {} const Utils::Factory &factory() const { return m_factory; } @@ -68,6 +75,9 @@ class LocalContext : public Context { } bool is_head_node() const override { return m_is_head_node; } + void parallel_try_catch(std::function const &cb) const override { + m_parallel_exception_handler.parallel_try_catch(cb); + } }; } // namespace ScriptInterface diff --git a/src/script_interface/ParallelExceptionHandler.cpp b/src/script_interface/ParallelExceptionHandler.cpp new file mode 100644 index 00000000000..87c290e8700 --- /dev/null +++ b/src/script_interface/ParallelExceptionHandler.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2022 The ESPResSo project + * + * This file is part of ESPResSo. + * + * ESPResSo is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * ESPResSo is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "ParallelExceptionHandler.hpp" + +#include "Exception.hpp" + +#include "core/MpiCallbacks.hpp" +#include "core/RuntimeError.hpp" +#include "core/communication.hpp" +#include "core/errorhandling.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace ScriptInterface { + +void ParallelExceptionHandler::handle_impl(std::exception const *error) const { + auto const head_node = 0; + auto const this_node = m_comm.rank(); + + enum : unsigned char { + NO_RANK_FAILED = 0u, + SOME_RANK_FAILED = 1u, + THIS_RANK_SUCCESS = 0u, + THIS_RANK_FAILED = 1u, + MAIN_RANK_FAILED = 2u, + }; + auto const this_fail_flag = + ((error) + ? ((this_node == head_node) ? MAIN_RANK_FAILED : THIS_RANK_FAILED) + : THIS_RANK_SUCCESS); + auto const fail_flag = boost::mpi::all_reduce( + m_comm, static_cast(this_fail_flag), std::bit_or<>()); + auto const main_rank_failed = fail_flag & MAIN_RANK_FAILED; + auto const some_rank_failed = fail_flag & SOME_RANK_FAILED; + + if (main_rank_failed) { + flush_runtime_errors_local(); + if (this_node == head_node) { + throw; + } + throw Exception(""); + } + + if (some_rank_failed) { + flush_runtime_errors_local(); + std::vector messages; + std::string this_message{(error) ? error->what() : ""}; + boost::mpi::gather(m_comm, this_message, messages, head_node); + if (this_node == head_node) { + std::string error_message{"an error occurred on one or more MPI ranks:"}; + for (std::size_t i = 0; i < messages.size(); ++i) { + error_message += "\n rank " + std::to_string(i) + ": " + messages[i]; + } + throw std::runtime_error(error_message.c_str()); + } + throw Exception(""); + } +} + +} // namespace ScriptInterface diff --git a/src/script_interface/ParallelExceptionHandler.hpp b/src/script_interface/ParallelExceptionHandler.hpp new file mode 100644 index 00000000000..2ed79a7c022 --- /dev/null +++ b/src/script_interface/ParallelExceptionHandler.hpp @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2022 The ESPResSo project + * + * This file is part of ESPResSo. + * + * ESPResSo is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * ESPResSo is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#ifndef ESPRESSO_SCRIPT_INTERFACE_PARALLEL_EXCEPTION_HANDLER_HPP +#define ESPRESSO_SCRIPT_INTERFACE_PARALLEL_EXCEPTION_HANDLER_HPP + +#include "core/errorhandling.hpp" + +#include + +#include +#include +#include + +namespace ScriptInterface { +/** + * Handle exceptions thrown in MPI parallel code. + * + * Instantiate this class inside the catch block and after the catch block, + * like so: + * @code{.cpp} + * boost::mpi::communicator world; + * auto handler = ScriptInterface::ParallelExceptionHandler{world}; + * std::shared_ptr obj; + * context()->parallel_try_catch([&obj]() { + * obj = std::make_shared(2., true); + * }); + * @endcode + * + * Exceptions are handled as follows: + * * the main rank throws: re-throw on main rank and throw @ref Exception + * on all other ranks + * * one or more of the worker nodes throw: collect error messages from + * worker nodes and throw them on the main rank as a @c std::runtime_error, + * throw @ref Exception on all other ranks + * + * Throwing a @ref Exception guarantees that the partially initialized script + * interface object won't be registered in the @ref GlobalContext dictionary; + * this is the only side-effect on worker nodes, since the exception itself + * is otherwise silently ignored. On the main rank, the thrown exception is + * converted to a Python exception. + */ +class ParallelExceptionHandler { +public: + ParallelExceptionHandler(boost::mpi::communicator comm) + : m_comm(std::move(comm)) {} + + /** + * @brief Handle exceptions in synchronous code. + * Error messages queued in the runtime error collector are flushed + * to standard error if the code throws on any rank. + * @pre Must be called on all ranks. + * @pre The @p callback cannot invoke remote functions from the + * @ref Communication::MpiCallbacks framework due to blocking + * communication (risk of MPI deadlock on worker nodes). + * @param[in] callback Callback to execute synchronously on all ranks. + */ + template + void parallel_try_catch(std::function const &callback) const { + try { + callback(); + } catch (T const &error) { + handle_impl(&error); + } + handle_impl(nullptr); + } + +private: + void handle_impl(std::exception const *error) const; + boost::mpi::communicator m_comm; +}; +} // namespace ScriptInterface + +#endif diff --git a/src/script_interface/cell_system/CellSystem.hpp b/src/script_interface/cell_system/CellSystem.hpp index 8237296ec61..eba6cb17e58 100644 --- a/src/script_interface/cell_system/CellSystem.hpp +++ b/src/script_interface/cell_system/CellSystem.hpp @@ -85,9 +85,9 @@ class CellSystem : public AutoParameters { {"use_verlet_lists", cell_structure.use_verlet_list}, {"node_grid", [this](Variant const &v) { - auto const error_msg = std::string("Parameter 'node_grid'"); - auto const vec = get_value>(v); - try { + context()->parallel_try_catch([&v]() { + auto const error_msg = std::string("Parameter 'node_grid'"); + auto const vec = get_value>(v); if (vec.size() != 3ul) { throw std::invalid_argument(error_msg + " must be 3 ints"); } @@ -102,12 +102,7 @@ class CellSystem : public AutoParameters { } ::node_grid = new_node_grid; on_nodegrid_change(); - } catch (...) { - if (context()->is_head_node()) { - throw; - } - throw Exception(""); - } + }); }, []() { return pack_vector(::node_grid); }}, {"skin", @@ -200,8 +195,8 @@ class CellSystem : public AutoParameters { } if (name == "get_pairs") { std::vector out; - std::vector> pair_list; - try { + context()->parallel_try_catch([¶ms, &out]() { + std::vector> pair_list; auto const distance = get_value(params, "distance"); if (boost::get(¶ms.at("types")) != nullptr) { auto const key = get_value(params, "types"); @@ -218,12 +213,7 @@ class CellSystem : public AutoParameters { [](std::pair const &pair) { return std::vector{pair.first, pair.second}; }); - } catch (std::exception const &err) { - if (context()->is_head_node()) { - throw; - } - throw Exception(""); - } + }); return out; } if (name == "non_bonded_loop_trace") { diff --git a/src/script_interface/collision_detection/CollisionDetection.hpp b/src/script_interface/collision_detection/CollisionDetection.hpp index 0322d1444c9..59843905073 100644 --- a/src/script_interface/collision_detection/CollisionDetection.hpp +++ b/src/script_interface/collision_detection/CollisionDetection.hpp @@ -102,26 +102,25 @@ class CollisionDetection : public AutoParameters { Variant do_call_method(const std::string &name, const VariantMap ¶ms) override { if (name == "instantiate") { - auto collision_params_backup = ::collision_params; - try { - // check provided parameters - check_input_parameters(params); - // set parameters - ::collision_params = Collision_parameters(); - for (auto const &kv : params) { - do_set_parameter(get_value(kv.first), kv.second); - } - // sanitize parameters and calculate derived parameters - ::collision_params.initialize(); - return none; - } catch (...) { - // restore original parameters and re-throw exception - ::collision_params = collision_params_backup; - if (context()->is_head_node()) { + context()->parallel_try_catch([this, ¶ms]() { + auto collision_params_backup = ::collision_params; + try { + // check provided parameters + check_input_parameters(params); + // set parameters + ::collision_params = Collision_parameters(); + for (auto const &kv : params) { + do_set_parameter(get_value(kv.first), kv.second); + } + // sanitize parameters and calculate derived parameters + ::collision_params.initialize(); + return none; + } catch (...) { + // restore original parameters and re-throw exception + ::collision_params = collision_params_backup; throw; } - throw Exception(""); - } + }); } if (name == "params_for_mode") { auto const name = get_value(params, "mode"); diff --git a/src/script_interface/tests/CMakeLists.txt b/src/script_interface/tests/CMakeLists.txt index 540f0bb0013..6c619ca271a 100644 --- a/src/script_interface/tests/CMakeLists.txt +++ b/src/script_interface/tests/CMakeLists.txt @@ -33,6 +33,9 @@ unit_test(NAME LocalContext_test SRC LocalContext_test.cpp DEPENDS unit_test(NAME GlobalContext_test SRC GlobalContext_test.cpp DEPENDS ScriptInterface Boost::mpi MPI::MPI_CXX NUM_PROC 2) unit_test(NAME Exception_test SRC Exception_test.cpp DEPENDS ScriptInterface) +unit_test(NAME ParallelExceptionHandler_test SRC + ParallelExceptionHandler_test.cpp DEPENDS ScriptInterface Boost::mpi + MPI::MPI_CXX NUM_PROC 2) unit_test(NAME packed_variant_test SRC packed_variant_test.cpp DEPENDS ScriptInterface) unit_test(NAME ObjectList_test SRC ObjectList_test.cpp DEPENDS ScriptInterface) diff --git a/src/script_interface/tests/GlobalContext_test.cpp b/src/script_interface/tests/GlobalContext_test.cpp index acba288c88f..e5d7759514a 100644 --- a/src/script_interface/tests/GlobalContext_test.cpp +++ b/src/script_interface/tests/GlobalContext_test.cpp @@ -22,10 +22,11 @@ #define BOOST_TEST_DYN_LINK #include -#include - #include "script_interface/GlobalContext.hpp" +#include +#include + #include #include #include @@ -55,9 +56,10 @@ struct Dummy : si::ObjectHandle { auto make_global_context(Communication::MpiCallbacks &cb) { Utils::Factory factory; factory.register_new("Dummy"); + boost::mpi::communicator comm; return std::make_shared( - cb, std::make_shared(factory, 0)); + cb, std::make_shared(factory, comm)); } BOOST_AUTO_TEST_CASE(GlobalContext_make_shared) { diff --git a/src/script_interface/tests/LocalContext_test.cpp b/src/script_interface/tests/LocalContext_test.cpp index 2a3c9007274..85bbd1708ab 100644 --- a/src/script_interface/tests/LocalContext_test.cpp +++ b/src/script_interface/tests/LocalContext_test.cpp @@ -17,12 +17,16 @@ * along with this program. If not, see . */ +#define BOOST_TEST_NO_MAIN #define BOOST_TEST_MODULE ScriptInterface::LocalContext test #define BOOST_TEST_DYN_LINK #include #include "script_interface/LocalContext.hpp" +#include +#include + #include #include #include @@ -57,7 +61,8 @@ auto factory = []() { }(); BOOST_AUTO_TEST_CASE(LocalContext_make_shared) { - auto ctx = std::make_shared(factory, 0); + boost::mpi::communicator comm; + auto ctx = std::make_shared(factory, comm); auto res = ctx->make_shared("Dummy", {}); BOOST_REQUIRE(res != nullptr); @@ -66,7 +71,8 @@ BOOST_AUTO_TEST_CASE(LocalContext_make_shared) { } BOOST_AUTO_TEST_CASE(LocalContext_serialization) { - auto ctx = std::make_shared(factory, 0); + boost::mpi::communicator comm; + auto ctx = std::make_shared(factory, comm); auto const serialized = [&]() { auto d1 = ctx->make_shared("Dummy", {}); @@ -94,3 +100,9 @@ BOOST_AUTO_TEST_CASE(LocalContext_serialization) { BOOST_CHECK_EQUAL(boost::get(d3->get_parameter("id")), 3); } } + +int main(int argc, char **argv) { + boost::mpi::environment mpi_env(argc, argv); + + return boost::unit_test::unit_test_main(init_unit_test, argc, argv); +} diff --git a/src/script_interface/tests/ObjectHandle_test.cpp b/src/script_interface/tests/ObjectHandle_test.cpp index 7ed7b1aa6b6..0e0581baabc 100644 --- a/src/script_interface/tests/ObjectHandle_test.cpp +++ b/src/script_interface/tests/ObjectHandle_test.cpp @@ -194,6 +194,7 @@ struct LogContext : public Context { } bool is_head_node() const override { return true; } + void parallel_try_catch(std::function const &) const override {} }; } // namespace Testing @@ -248,4 +249,5 @@ BOOST_AUTO_TEST_CASE(interface_) { auto o = log_ctx->make_shared({}, {}); BOOST_CHECK(log_ctx->is_head_node()); BOOST_CHECK_EQUAL(log_ctx->name(o.get()), "Dummy"); + static_cast(log_ctx->parallel_try_catch([]() {})); } diff --git a/src/script_interface/tests/ObjectList_test.cpp b/src/script_interface/tests/ObjectList_test.cpp index b755576bde3..e4fbb6f7f77 100644 --- a/src/script_interface/tests/ObjectList_test.cpp +++ b/src/script_interface/tests/ObjectList_test.cpp @@ -31,6 +31,7 @@ #include "core/communication.hpp" #include +#include #include #include @@ -93,7 +94,8 @@ BOOST_AUTO_TEST_CASE(serialization) { Utils::Factory f; f.register_new("ObjectHandle"); f.register_new("ObjectList"); - auto ctx = std::make_shared(f, 0); + boost::mpi::communicator comm; + auto ctx = std::make_shared(f, comm); // A list of some elements auto list = std::dynamic_pointer_cast( ctx->make_shared("ObjectList", {})); diff --git a/src/script_interface/tests/ObjectMap_test.cpp b/src/script_interface/tests/ObjectMap_test.cpp index 58a7363fe4f..5236fa6d358 100644 --- a/src/script_interface/tests/ObjectMap_test.cpp +++ b/src/script_interface/tests/ObjectMap_test.cpp @@ -30,6 +30,7 @@ #include "core/communication.hpp" #include +#include #include #include @@ -101,7 +102,8 @@ BOOST_AUTO_TEST_CASE(serialization) { Utils::Factory f; f.register_new("ObjectHandle"); f.register_new("ObjectMap"); - auto ctx = std::make_shared(f, 0); + boost::mpi::communicator comm; + auto ctx = std::make_shared(f, comm); // A list of some elements auto map = std::dynamic_pointer_cast( ctx->make_shared("ObjectMap", {})); diff --git a/src/script_interface/tests/ParallelExceptionHandler_test.cpp b/src/script_interface/tests/ParallelExceptionHandler_test.cpp new file mode 100644 index 00000000000..0d912efc398 --- /dev/null +++ b/src/script_interface/tests/ParallelExceptionHandler_test.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2022 The ESPResSo project + * + * This file is part of ESPResSo. + * + * ESPResSo is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * ESPResSo is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#define BOOST_TEST_NO_MAIN +#define BOOST_TEST_MODULE ScriptInterface::ParallelExceptionHandler test +#define BOOST_TEST_DYN_LINK +#include + +#include "script_interface/Exception.hpp" +#include "script_interface/ParallelExceptionHandler.hpp" + +#include "core/MpiCallbacks.hpp" +#include "core/errorhandling.hpp" + +#include +#include + +#include +#include +#include + +namespace utf = boost::unit_test; + +namespace Testing { +struct Error : public std::exception {}; +struct Warning : public std::exception {}; +} // namespace Testing + +/** Decorator to skip tests running on only 1 MPI rank. */ +struct if_parallel_test { + boost::test_tools::assertion_result operator()(utf::test_unit_id) { + boost::mpi::communicator world; + return world.size() >= 2; + } +}; + +BOOST_TEST_DECORATOR(*utf::precondition(if_parallel_test())) +BOOST_AUTO_TEST_CASE(parallel_exceptions) { + boost::mpi::communicator world; + Communication::MpiCallbacks callbacks{world}; + ErrorHandling::init_error_handling(callbacks); + auto handler = ScriptInterface::ParallelExceptionHandler{world}; + + { + // exception on main rank -> re-throw on main rank + bool rethrown = false; + bool converted = false; + try { + handler.parallel_try_catch( + []() { throw Testing::Error(); }); + } catch (Testing::Error const &err) { + rethrown = true; + } catch (ScriptInterface::Exception const &err) { + converted = true; + } + if (world.rank() == 0) { + BOOST_CHECK(rethrown); + } else { + BOOST_CHECK(converted); + } + } + { + // exception of an unknown type: not caught + bool unhandled = false; + try { + handler.parallel_try_catch( + []() { throw Testing::Warning(); }); + } catch (Testing::Warning const &err) { + unhandled = true; + } + BOOST_CHECK(unhandled); + } + { + // exception on worker rank -> communicate to main rank + bool communicated = false; + bool converted = false; + try { + handler.parallel_try_catch([&world]() { + runtimeErrorMsg() << "harmless message"; + if (world.rank() != 0) { + throw Testing::Error(); + } + }); + } catch (std::runtime_error const &err) { + communicated = true; + if (world.rank() == 0) { + BOOST_CHECK_EQUAL(err.what(), + "an error occurred on one or more MPI ranks:\n rank " + "0: \n rank 1: std::exception"); + } + } catch (ScriptInterface::Exception const &err) { + converted = true; + } + if (world.rank() == 0) { + BOOST_CHECK(communicated); + } else { + BOOST_CHECK(converted); + } + // runtime error messages are printed to stderr and cleared + BOOST_CHECK_EQUAL(check_runtime_errors_local(), 0); + } +} + +int main(int argc, char **argv) { + boost::mpi::environment mpi_env(argc, argv); + + return boost::unit_test::unit_test_main(init_unit_test, argc, argv); +}