Skip to content

Commit

Permalink
script_interface: MPI-safe exception mechanism
Browse files Browse the repository at this point in the history
Make C++ exceptions from core classes safe in a MPI-parallel context
via a parallel exception handler. When an exception occurs, re-throw
the exception on the head node and throw a ScriptInterface::Exception
on worker nodes.
  • Loading branch information
jngrad committed May 17, 2022
1 parent 5c2435b commit 0cf245a
Show file tree
Hide file tree
Showing 21 changed files with 407 additions and 56 deletions.
14 changes: 10 additions & 4 deletions src/core/RuntimeErrorCollector.cpp
Expand Up @@ -36,11 +36,10 @@ RuntimeErrorCollector::RuntimeErrorCollector(boost::mpi::communicator comm)
: m_comm(std::move(comm)) {}

RuntimeErrorCollector::~RuntimeErrorCollector() {
if (!m_errors.empty())
if (!m_errors.empty()) {
/* Print remaining error messages on destruction */
std::cerr << "There were unhandled errors.\n";
/* Print remaining error messages on destruction */
for (auto const &e : m_errors) {
std::cerr << e.format() << std::endl;
flush();
}
}

Expand Down Expand Up @@ -108,6 +107,13 @@ int RuntimeErrorCollector::count(RuntimeError::ErrorLevel level) {

void RuntimeErrorCollector::clear() { m_errors.clear(); }

void RuntimeErrorCollector::flush() {
for (auto const &e : m_errors) {
std::cerr << e.format() << std::endl;
}
this->clear();
}

std::vector<RuntimeError> RuntimeErrorCollector::gather() {
std::vector<RuntimeError> all_errors{};
std::swap(all_errors, m_errors);
Expand Down
5 changes: 5 additions & 0 deletions src/core/RuntimeErrorCollector.hpp
Expand Up @@ -74,6 +74,11 @@ class RuntimeErrorCollector {
*/
void clear();

/**
* @brief Flush error messages to standard error.
*/
void flush();

std::vector<RuntimeError> gather();
void gather_local();

Expand Down
6 changes: 5 additions & 1 deletion src/core/errorhandling.cpp
Expand Up @@ -62,7 +62,7 @@ RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level,
return {*runtimeErrorCollector, level, file, line, function};
}

void mpi_gather_runtime_errors_local() {
static void mpi_gather_runtime_errors_local() {
runtimeErrorCollector->gather_local();
}

Expand All @@ -89,3 +89,7 @@ int check_runtime_errors(boost::mpi::communicator const &comm) {
return boost::mpi::all_reduce(comm, check_runtime_errors_local(),
std::plus<int>());
}

void flush_runtime_errors_local() {
ErrorHandling::runtimeErrorCollector->flush();
}
8 changes: 8 additions & 0 deletions src/core/errorhandling.hpp
Expand Up @@ -72,6 +72,14 @@ int check_runtime_errors(boost::mpi::communicator const &comm);
*/
int check_runtime_errors_local();

/**
* @brief Flush runtime errors to standard error on the local node.
* This is used to clear pending runtime error messages when the
* call site is handling an exception that needs to be re-thrown
* instead of being queued as an additional runtime error message.
*/
void flush_runtime_errors_local();

namespace ErrorHandling {
/**
* @brief Initialize the error collection system.
Expand Down
2 changes: 1 addition & 1 deletion src/script_interface/CMakeLists.txt
@@ -1,7 +1,7 @@
add_library(
ScriptInterface SHARED
initialize.cpp ObjectHandle.cpp object_container_mpi_guard.cpp
GlobalContext.cpp ContextManager.cpp)
GlobalContext.cpp ContextManager.cpp ParallelExceptionHandler.cpp)

add_subdirectory(accumulators)
add_subdirectory(bond_breakage)
Expand Down
1 change: 1 addition & 0 deletions src/script_interface/Context.hpp
Expand Up @@ -98,6 +98,7 @@ class Context : public std::enable_shared_from_this<Context> {
virtual boost::string_ref name(const ObjectHandle *o) const = 0;

virtual bool is_head_node() const = 0;
virtual void parallel_try_catch(std::function<void()> const &cb) const = 0;

virtual ~Context() = default;
};
Expand Down
4 changes: 2 additions & 2 deletions src/script_interface/ContextManager.cpp
Expand Up @@ -55,8 +55,8 @@ std::string ContextManager::serialize(const ObjectHandle *o) const {

ContextManager::ContextManager(Communication::MpiCallbacks &callbacks,
const Utils::Factory<ObjectHandle> &factory) {
auto const mpi_rank = callbacks.comm().rank();
auto local_context = std::make_shared<LocalContext>(factory, mpi_rank);
auto local_context =
std::make_shared<LocalContext>(factory, callbacks.comm());

/* If there is only one node, we can treat all objects as local, and thus
* never invoke any callback. */
Expand Down
2 changes: 1 addition & 1 deletion src/script_interface/ContextManager.hpp
Expand Up @@ -36,7 +36,7 @@
#include "Context.hpp"
#include "Variant.hpp"

#include "MpiCallbacks.hpp"
#include "core/MpiCallbacks.hpp"

#include <utils/Factory.hpp>

Expand Down
13 changes: 10 additions & 3 deletions src/script_interface/GlobalContext.hpp
Expand Up @@ -29,16 +29,19 @@

#include "Context.hpp"
#include "LocalContext.hpp"
#include "MpiCallbacks.hpp"
#include "ObjectHandle.hpp"
#include "ParallelExceptionHandler.hpp"
#include "packed_variant.hpp"

#include "core/MpiCallbacks.hpp"

#include <utils/Factory.hpp>

#include <boost/serialization/utility.hpp>

#include <cstddef>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
Expand All @@ -61,7 +64,6 @@ namespace ScriptInterface {
class GlobalContext : public Context {
using ObjectId = std::size_t;

private:
/* Instances on this node that are managed by the
* head node. */
std::unordered_map<ObjectId, ObjectRef> m_local_objects;
Expand All @@ -70,7 +72,8 @@ class GlobalContext : public Context {

bool m_is_head_node;

private:
ParallelExceptionHandler m_parallel_exception_handler;

Communication::CallbackHandle<ObjectId, const std::string &,
const PackedMap &>
cb_make_handle;
Expand All @@ -87,6 +90,7 @@ class GlobalContext : public Context {
std::shared_ptr<LocalContext> node_local_context)
: m_local_objects(), m_node_local_context(std::move(node_local_context)),
m_is_head_node(callbacks.comm().rank() == 0),
m_parallel_exception_handler(callbacks.comm()),
cb_make_handle(&callbacks,
[this](ObjectId id, const std::string &name,
const PackedMap &parameters) {
Expand Down Expand Up @@ -162,6 +166,9 @@ class GlobalContext : public Context {
boost::string_ref name(const ObjectHandle *o) const override;

bool is_head_node() const override { return m_is_head_node; }
void parallel_try_catch(std::function<void()> const &cb) const override {
m_parallel_exception_handler.parallel_try_catch<std::exception>(cb);
}
};
} // namespace ScriptInterface

Expand Down
14 changes: 12 additions & 2 deletions src/script_interface/LocalContext.hpp
Expand Up @@ -21,11 +21,15 @@

#include "Context.hpp"
#include "ObjectHandle.hpp"
#include "ParallelExceptionHandler.hpp"

#include <utils/Factory.hpp>

#include <boost/mpi/communicator.hpp>

#include <cassert>
#include <memory>
#include <stdexcept>
#include <string>
#include <utility>

Expand All @@ -39,10 +43,13 @@ namespace ScriptInterface {
class LocalContext : public Context {
Utils::Factory<ObjectHandle> m_factory;
bool m_is_head_node;
ParallelExceptionHandler m_parallel_exception_handler;

public:
LocalContext(Utils::Factory<ObjectHandle> factory, int mpi_rank)
: m_factory(std::move(factory)), m_is_head_node(mpi_rank == 0) {}
LocalContext(Utils::Factory<ObjectHandle> factory,
boost::mpi::communicator const &comm)
: m_factory(std::move(factory)), m_is_head_node(comm.rank() == 0),
m_parallel_exception_handler(comm) {}

const Utils::Factory<ObjectHandle> &factory() const { return m_factory; }

Expand All @@ -68,6 +75,9 @@ class LocalContext : public Context {
}

bool is_head_node() const override { return m_is_head_node; }
void parallel_try_catch(std::function<void()> const &cb) const override {
m_parallel_exception_handler.parallel_try_catch<std::exception>(cb);
}
};
} // namespace ScriptInterface

Expand Down
85 changes: 85 additions & 0 deletions src/script_interface/ParallelExceptionHandler.cpp
@@ -0,0 +1,85 @@
/*
* Copyright (C) 2022 The ESPResSo project
*
* This file is part of ESPResSo.
*
* ESPResSo is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ESPResSo is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#include "ParallelExceptionHandler.hpp"

#include "Exception.hpp"

#include "core/MpiCallbacks.hpp"
#include "core/RuntimeError.hpp"
#include "core/communication.hpp"
#include "core/errorhandling.hpp"

#include <boost/mpi/collectives.hpp>
#include <boost/serialization/string.hpp>

#include <cassert>
#include <cstddef>
#include <functional>
#include <stdexcept>
#include <string>
#include <vector>

namespace ScriptInterface {

void ParallelExceptionHandler::handle_impl(std::exception const *error) const {
auto const head_node = 0;
auto const this_node = m_comm.rank();

enum : unsigned char {
NO_RANK_FAILED = 0u,
SOME_RANK_FAILED = 1u,
THIS_RANK_SUCCESS = 0u,
THIS_RANK_FAILED = 1u,
MAIN_RANK_FAILED = 2u,
};
auto const this_fail_flag =
((error)
? ((this_node == head_node) ? MAIN_RANK_FAILED : THIS_RANK_FAILED)
: THIS_RANK_SUCCESS);
auto const fail_flag = boost::mpi::all_reduce(
m_comm, static_cast<unsigned char>(this_fail_flag), std::bit_or<>());
auto const main_rank_failed = fail_flag & MAIN_RANK_FAILED;
auto const some_rank_failed = fail_flag & SOME_RANK_FAILED;

if (main_rank_failed) {
flush_runtime_errors_local();
if (this_node == head_node) {
throw;
}
throw Exception("");
}

if (some_rank_failed) {
flush_runtime_errors_local();
std::vector<std::string> messages;
std::string this_message{(error) ? error->what() : ""};
boost::mpi::gather(m_comm, this_message, messages, head_node);
if (this_node == head_node) {
std::string error_message{"an error occurred on one or more MPI ranks:"};
for (std::size_t i = 0; i < messages.size(); ++i) {
error_message += "\n rank " + std::to_string(i) + ": " + messages[i];
}
throw std::runtime_error(error_message.c_str());
}
throw Exception("");
}
}

} // namespace ScriptInterface
89 changes: 89 additions & 0 deletions src/script_interface/ParallelExceptionHandler.hpp
@@ -0,0 +1,89 @@
/*
* Copyright (C) 2022 The ESPResSo project
*
* This file is part of ESPResSo.
*
* ESPResSo is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ESPResSo is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef ESPRESSO_SCRIPT_INTERFACE_PARALLEL_EXCEPTION_HANDLER_HPP
#define ESPRESSO_SCRIPT_INTERFACE_PARALLEL_EXCEPTION_HANDLER_HPP

#include "core/errorhandling.hpp"

#include <boost/mpi/communicator.hpp>

#include <stdexcept>
#include <string>
#include <utility>

namespace ScriptInterface {
/**
* Handle exceptions thrown in MPI parallel code.
*
* Instantiate this class inside the catch block and after the catch block,
* like so:
* @code{.cpp}
* boost::mpi::communicator world;
* auto handler = ScriptInterface::ParallelExceptionHandler{world};
* std::shared_ptr<MyClass> obj;
* context()->parallel_try_catch([&obj]() {
* obj = std::make_shared<MyClass>(2., true);
* });
* @endcode
*
* Exceptions are handled as follows:
* * the main rank throws: re-throw on main rank and throw @ref Exception
* on all other ranks
* * one or more of the worker nodes throw: collect error messages from
* worker nodes and throw them on the main rank as a @c std::runtime_error,
* throw @ref Exception on all other ranks
*
* Throwing a @ref Exception guarantees that the partially initialized script
* interface object won't be registered in the @ref GlobalContext dictionary;
* this is the only side-effect on worker nodes, since the exception itself
* is otherwise silently ignored. On the main rank, the thrown exception is
* converted to a Python exception.
*/
class ParallelExceptionHandler {
public:
ParallelExceptionHandler(boost::mpi::communicator comm)
: m_comm(std::move(comm)) {}

/**
* @brief Handle exceptions in synchronous code.
* Error messages queued in the runtime error collector are flushed
* to standard error if the code throws on any rank.
* @pre Must be called on all ranks.
* @pre The @p callback cannot invoke remote functions from the
* @ref Communication::MpiCallbacks framework due to blocking
* communication (risk of MPI deadlock on worker nodes).
* @param[in] callback Callback to execute synchronously on all ranks.
*/
template <typename T>
void parallel_try_catch(std::function<void()> const &callback) const {
try {
callback();
} catch (T const &error) {
handle_impl(&error);
}
handle_impl(nullptr);
}

private:
void handle_impl(std::exception const *error) const;
boost::mpi::communicator m_comm;
};
} // namespace ScriptInterface

#endif

0 comments on commit 0cf245a

Please sign in to comment.