forked from AmbaPant/mantid
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Collectives.h
96 lines (83 loc) · 2.92 KB
/
Collectives.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Mantid Repository : https://github.com/mantidproject/mantid
//
// Copyright © 2017 ISIS Rutherford Appleton Laboratory UKRI,
// NScD Oak Ridge National Laboratory, European Spallation Source,
// Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
// SPDX - License - Identifier: GPL - 3.0 +
#pragma once
#include "MantidParallel/Communicator.h"
#include "MantidParallel/DllConfig.h"
#include "MantidParallel/Nonblocking.h"
#ifdef MPI_EXPERIMENTAL
#include <boost/mpi/collectives.hpp>
#endif
namespace Mantid {
namespace Parallel {
/** Wrapper for boost::mpi::gather and other collective communication. For
non-MPI builds an equivalent implementation with reduced functionality is
provided.
@author Simon Heybrock
@date 2017
*/
namespace detail {
template <typename T> void gather(const Communicator &comm, const T &in_value, std::vector<T> &out_values, int root) {
int tag{0};
if (comm.rank() != root) {
comm.send(root, tag, in_value);
} else {
out_values.resize(comm.size());
out_values[root] = in_value;
for (int rank = 0; rank < comm.size(); ++rank) {
if (rank == root)
continue;
comm.recv(rank, tag, out_values[rank]);
}
}
}
template <typename T> void gather(const Communicator &comm, const T &in_value, int root) {
if (comm.rank() != root) {
int tag{0};
comm.send(root, tag, in_value);
} else {
throw std::logic_error("Parallel::gather on root rank without output argument.");
}
}
template <typename... T> void all_gather(const Communicator &comm, T &&...args) {
for (int root = 0; root < comm.size(); ++root)
gather(comm, std::forward<T>(args)..., root);
}
template <typename T>
void all_to_all(const Communicator &comm, const std::vector<T> &in_values, std::vector<T> &out_values) {
int tag{0};
out_values.resize(comm.size());
std::vector<Request> requests;
for (int rank = 0; rank < comm.size(); ++rank)
requests.emplace_back(comm.irecv(rank, tag, out_values[rank]));
for (int rank = 0; rank < comm.size(); ++rank)
comm.send(rank, tag, in_values[rank]);
wait_all(requests.begin(), requests.end());
}
} // namespace detail
template <typename... T> void gather(const Communicator &comm, T &&...args) {
#ifdef MPI_EXPERIMENTAL
if (!comm.hasBackend())
return boost::mpi::gather(comm, std::forward<T>(args)...);
#endif
detail::gather(comm, std::forward<T>(args)...);
}
template <typename... T> void all_gather(const Communicator &comm, T &&...args) {
#ifdef MPI_EXPERIMENTAL
if (!comm.hasBackend())
return boost::mpi::all_gather(comm, std::forward<T>(args)...);
#endif
detail::all_gather(comm, std::forward<T>(args)...);
}
template <typename... T> void all_to_all(const Communicator &comm, T &&...args) {
#ifdef MPI_EXPERIMENTAL
if (!comm.hasBackend())
return boost::mpi::all_to_all(comm, std::forward<T>(args)...);
#endif
detail::all_to_all(comm, std::forward<T>(args)...);
}
} // namespace Parallel
} // namespace Mantid