forked from AmbaPant/mantid
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ParallelRunner.h
64 lines (52 loc) · 2.07 KB
/
ParallelRunner.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// Mantid Repository : https://github.com/mantidproject/mantid
//
// Copyright © 2017 ISIS Rutherford Appleton Laboratory UKRI,
// NScD Oak Ridge National Laboratory, European Spallation Source,
// Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
// SPDX - License - Identifier: GPL - 3.0 +
#pragma once
#include "MantidParallel/Communicator.h"
#include "MantidParallel/DllConfig.h"
#include <memory>
#include <functional>
#include <thread>
namespace ParallelTestHelpers {
/** Runs a callable in parallel. This is mainly a helper for testing code with
MPI calls. ParallelRunner passes a Communicator as first argument to the
callable. In runs with only a single MPI rank the callable is executed in
threads to mimic MPI ranks.
*/
class ParallelRunner {
public:
ParallelRunner();
ParallelRunner(const int threads);
int size() const;
template <class Function, class... Args> void runSerial(Function &&f, Args &&...args);
template <class Function, class... Args> void runParallel(Function &&f, Args &&...args);
private:
std::shared_ptr<Mantid::Parallel::detail::ThreadingBackend> m_backend;
std::shared_ptr<Mantid::Parallel::detail::ThreadingBackend> m_serialBackend;
};
template <class Function, class... Args> void ParallelRunner::runSerial(Function &&f, Args &&...args) {
f(Mantid::Parallel::Communicator(m_serialBackend, 0), std::forward<Args>(args)...);
}
template <class Function, class... Args> void ParallelRunner::runParallel(Function &&f, Args &&...args) {
if (!m_backend) {
Mantid::Parallel::Communicator comm;
f(comm, std::forward<Args>(args)...);
} else {
std::vector<std::thread> threads;
for (int t = 0; t < m_backend->size(); ++t) {
Mantid::Parallel::Communicator comm(m_backend, t);
threads.emplace_back(std::forward<Function>(f), comm, std::forward<Args>(args)...);
}
for (auto &t : threads) {
t.join();
}
}
}
template <class... Args> void runParallel(Args &&...args) {
ParallelRunner runner;
runner.runParallel(std::forward<Args>(args)...);
}
} // namespace ParallelTestHelpers