Skip to content

Commit

Permalink
Merge pull request #16371 from bangerth/exscan
Browse files Browse the repository at this point in the history
Introduce a function Utilities::MPI::partial_and_total_sum().
  • Loading branch information
drwells committed Dec 21, 2023
2 parents 02bb9b2 + 4451782 commit 37527b0
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 122 deletions.
3 changes: 3 additions & 0 deletions doc/news/changes/minor/20231219Bangerth
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
New: There is now a function Utilities::MPI::partial_and_total_sum().
<br>
(Wolfgang Bangerth, 2023/12/19)
55 changes: 55 additions & 0 deletions include/deal.II/base/mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,23 @@ namespace Utilities
const std::function<T(const T &, const T &)> &combiner,
const unsigned int root_process = 0);


/**
* For each process $p$ on a communicator with $P$ processes, compute both
* the (exclusive) partial sum $\sum_{i=0}^{p-1} v_i$ and the total
* sum $\sum_{i=0}^{P-1} v_i$, and return these two values as a pair.
* The former is computed via the `MPI_Exscan` function where the partial
* sum is typically called "(exclusive) scan" of the values $v_p$ provided
* by the individual processes. The term "prefix sum" is also used.
*
* This function is only available if `T` is a type natively supported
* by MPI.
*/
template <typename T, typename = std::enable_if_t<is_mpi_type<T> == true>>
std::pair<T, T>
partial_and_total_sum(const T &value, const MPI_Comm comm);


/**
* A function that combines values @p local_value from all processes
* via a user-specified binary operation @p combiner and distributes the
Expand Down Expand Up @@ -1971,6 +1988,44 @@ namespace Utilities



template <typename T, typename>
std::pair<T, T>
partial_and_total_sum(const T &value, const MPI_Comm comm)
{
# ifndef DEAL_II_WITH_MPI
(void)comm;
return {0, value};
# else
if (Utilities::MPI::n_mpi_processes(comm) == 1)
return {0, value};
else
{
T prefix = {};

// First obtain every process's prefix sum:
int ierr =
MPI_Exscan(&value,
&prefix,
1,
Utilities::MPI::mpi_type_id_for_type<decltype(value)>,
MPI_SUM,
comm);
AssertThrowMPI(ierr);

// Then we also need the total sum. We could obtain it by
// calling Utilities::MPI::sum(), but it is cheaper if we
// broadcast it from the last process, which can compute it
// from its own prefix sum plus its own value.
const T sum = Utilities::MPI::broadcast(
comm, prefix + value, Utilities::MPI::n_mpi_processes(comm) - 1);

return {prefix, sum};
}
# endif
}



template <typename T>
std::vector<T>
all_gather(const MPI_Comm comm, const T &object)
Expand Down
43 changes: 4 additions & 39 deletions include/deal.II/lac/sparse_matrix_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,42 +143,6 @@ namespace SparseMatrixTools

namespace internal
{
template <typename T>
std::tuple<T, T>
compute_prefix_sum(const T &value, const MPI_Comm comm)
{
# ifndef DEAL_II_WITH_MPI
(void)comm;
return {0, value};
# else
if (Utilities::MPI::n_mpi_processes(comm) == 1)
return {0, value};
else
{
T prefix = {};

// First obtain every process's prefix sum:
int ierr =
MPI_Exscan(&value,
&prefix,
1,
Utilities::MPI::mpi_type_id_for_type<decltype(value)>,
MPI_SUM,
comm);
AssertThrowMPI(ierr);

// Then we also need the total sum. We could obtain it by
// calling Utilities::MPI::sum(), but it is cheaper if we
// broadcast it from the last process, which can compute it
// from its own prefix sum plus its own value.
T sum = Utilities::MPI::broadcast(
comm, prefix + value, Utilities::MPI::n_mpi_processes(comm) - 1);

return {prefix, sum};
}
# endif
}

template <typename T>
using get_mpi_communicator_t =
decltype(std::declval<const T>().get_mpi_communicator());
Expand Down Expand Up @@ -246,8 +210,9 @@ namespace SparseMatrixTools
{
std::vector<unsigned int> dummy(locally_active_dofs.n_elements());

const auto local_size = get_local_size(system_matrix);
const auto [prefix_sum, total_sum] = compute_prefix_sum(local_size, comm);
const auto local_size = get_local_size(system_matrix);
const auto [prefix_sum, total_sum] =
Utilities::MPI::partial_and_total_sum(local_size, comm);
IndexSet locally_owned_dofs(total_sum);
locally_owned_dofs.add_range(prefix_sum, prefix_sum + local_size);

Expand Down Expand Up @@ -495,7 +460,7 @@ namespace SparseMatrixTools
{
// 0) determine which rows are locally owned and which ones are remote
const auto local_size = internal::get_local_size(system_matrix);
const auto prefix_sum = internal::compute_prefix_sum(
const auto prefix_sum = Utilities::MPI::partial_and_total_sum(
local_size, internal::get_mpi_communicator(system_matrix));
IndexSet locally_owned_dofs(std::get<1>(prefix_sum));
locally_owned_dofs.add_range(std::get<0>(prefix_sum),
Expand Down
28 changes: 5 additions & 23 deletions source/distributed/repartitioning_policy_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,29 +304,11 @@ namespace RepartitioningPolicyTools
for (const auto &weight : weights)
process_local_weight += weight;

// determine partial sum of weights of this process
std::uint64_t process_local_weight_offset = 0;

int ierr = MPI_Exscan(
&process_local_weight,
&process_local_weight_offset,
1,
Utilities::MPI::mpi_type_id_for_type<decltype(process_local_weight)>,
MPI_SUM,
tria->get_communicator());
AssertThrowMPI(ierr);

// total weight of all processes
std::uint64_t total_weight =
process_local_weight_offset + process_local_weight;

ierr =
MPI_Bcast(&total_weight,
1,
Utilities::MPI::mpi_type_id_for_type<decltype(total_weight)>,
n_subdomains - 1,
mpi_communicator);
AssertThrowMPI(ierr);
// determine partial sum of weights of this process, as well as the total
// weight
const auto [process_local_weight_offset, total_weight] =
Utilities::MPI::partial_and_total_sum(process_local_weight,
tria->get_communicator());

// set up partition
LinearAlgebra::distributed::Vector<double> partition(partitioner);
Expand Down
50 changes: 13 additions & 37 deletions source/dofs/dof_handler_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3694,14 +3694,10 @@ namespace internal

// --------- Phase 4: shift indices so that each processor has a unique
// range of indices
dealii::types::global_dof_index my_shift = 0;
const int ierr = MPI_Exscan(&n_locally_owned_dofs,
&my_shift,
1,
DEAL_II_DOF_INDEX_MPI_TYPE,
MPI_SUM,
triangulation->get_communicator());
AssertThrowMPI(ierr);
const auto [my_shift, n_global_dofs] =
Utilities::MPI::partial_and_total_sum(
n_locally_owned_dofs, triangulation->get_communicator());


// make dof indices globally consecutive
Implementation::enumerate_dof_indices_for_renumbering(
Expand All @@ -3715,11 +3711,6 @@ namespace internal
*dof_handler,
/*check_validity=*/false);

// now a little bit of housekeeping
const dealii::types::global_dof_index n_global_dofs =
Utilities::MPI::sum(n_locally_owned_dofs,
triangulation->get_communicator());

NumberCache number_cache;
number_cache.n_global_dofs = n_global_dofs;
number_cache.n_locally_owned_dofs = n_locally_owned_dofs;
Expand Down Expand Up @@ -3899,33 +3890,17 @@ namespace internal

//* 3. communicate local dofcount and shift ids to make
// them unique
dealii::types::global_dof_index my_shift = 0;
int ierr = MPI_Exscan(&level_number_cache.n_locally_owned_dofs,
&my_shift,
1,
DEAL_II_DOF_INDEX_MPI_TYPE,
MPI_SUM,
triangulation->get_communicator());
AssertThrowMPI(ierr);

// The last processor knows about the total number of dofs, so we
// can use a cheaper broadcast rather than an MPI_Allreduce via
// MPI::sum().
level_number_cache.n_global_dofs =
my_shift + level_number_cache.n_locally_owned_dofs;
ierr = MPI_Bcast(&level_number_cache.n_global_dofs,
1,
DEAL_II_DOF_INDEX_MPI_TYPE,
Utilities::MPI::n_mpi_processes(
triangulation->get_communicator()) -
1,
triangulation->get_communicator());
AssertThrowMPI(ierr);
const auto [my_shift, n_global_dofs] =
Utilities::MPI::partial_and_total_sum(
level_number_cache.n_locally_owned_dofs,
triangulation->get_communicator());
level_number_cache.n_global_dofs = n_global_dofs;

// assign appropriate indices
types::global_dof_index next_free_index = my_shift;
for (types::global_dof_index &index : renumbering)
if (index == enumeration_dof_index)
index = my_shift++;
index = next_free_index++;

// now re-enumerate all dofs to this shifted and condensed
// numbering form. we renumber some dofs as invalid, so
Expand All @@ -3943,7 +3918,8 @@ namespace internal
level_number_cache.locally_owned_dofs =
IndexSet(level_number_cache.n_global_dofs);
level_number_cache.locally_owned_dofs.add_range(
my_shift - level_number_cache.n_locally_owned_dofs, my_shift);
next_free_index - level_number_cache.n_locally_owned_dofs,
next_free_index);
level_number_cache.locally_owned_dofs.compress();

number_caches.emplace_back(level_number_cache);
Expand Down
30 changes: 7 additions & 23 deletions source/particles/generators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,21 @@ namespace Particles
cumulative_cell_weights.back() :
0.0;

double global_weight_integral;

double local_start_weight = numbers::signaling_nan<double>();
double global_weight_integral = numbers::signaling_nan<double>();

if (const auto tria =
dynamic_cast<const parallel::TriangulationBase<dim, spacedim> *>(
&triangulation))
{
global_weight_integral =
Utilities::MPI::sum(local_weight_integral,
tria->get_communicator());
std::tie(local_start_weight, global_weight_integral) =
Utilities::MPI::partial_and_total_sum(local_weight_integral,
tria->get_communicator());
}
else
{
local_start_weight = 0;
global_weight_integral = local_weight_integral;
}

Expand All @@ -337,25 +340,6 @@ namespace Particles
"part of the domain; also check the syntax of "
"the function."));

// Determine the starting weight of this process, which is the sum of
// the weights of all processes with a lower rank
double local_start_weight = 0.0;

#ifdef DEAL_II_WITH_MPI
if (const auto tria =
dynamic_cast<const parallel::TriangulationBase<dim, spacedim> *>(
&triangulation))
{
const int ierr = MPI_Exscan(&local_weight_integral,
&local_start_weight,
1,
MPI_DOUBLE,
MPI_SUM,
tria->get_communicator());
AssertThrowMPI(ierr);
}
#endif

// Calculate start id
start_particle_id =
std::llround(static_cast<double>(n_particles_to_create) *
Expand Down

0 comments on commit 37527b0

Please sign in to comment.