Skip to content

Commit

Permalink
Address the pesky 'MPI_Comm as void*' problem with the SUNDIALS inter…
Browse files Browse the repository at this point in the history
…faces.
  • Loading branch information
bangerth committed Apr 19, 2023
1 parent 6d2f3e5 commit d7f759e
Showing 1 changed file with 99 additions and 46 deletions.
145 changes: 99 additions & 46 deletions include/deal.II/sundials/n_vector.templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ namespace SUNDIALS
const VectorType *
get() const;

/**
* Return a reference to a copy of the communicator the vector uses.
* This function exists because the N_Vector
* interface requires a function that returns a `void*` pointing
* to the communicator object -- so somewhere, we need to have an
* address to point to. The issue is that our vectors typically
* return a *copy* of the communicator, rather than a reference to
* the communicator they use, and so there is only a temporary
* object and no address we can point to. To work around this
* requirement, this class stores a copy of the communicator,
* and this function here returns a reference to this copy.
*/
const MPI_Comm &
get_mpi_communicator() const;

private:
using PointerType =
std::unique_ptr<VectorType, std::function<void(VectorType *)>>;
Expand All @@ -107,6 +122,18 @@ namespace SUNDIALS
*/
PointerType vector;

/**
* A copy of the communicator the vector uses, initialized in the
* constructor of this class. We store this because the N_Vector
* interface requires a function that returns a `void*` pointing
* to the communicator object -- so somewhere, we need to have an
* address to point to. The issue is that our vectors typically
* return a *copy* of the communicator, rather than a reference to
* the communicator they use, and so there is only a temporary
* object and no address we can point to.
*/
MPI_Comm comm;

/**
* Flag storing whether the stored pointer is to be treated as const. If
* the pointer passed in the constructor was indeed const, it is cast away
Expand Down Expand Up @@ -261,30 +288,17 @@ namespace SUNDIALS
void
add_constant(N_Vector x, realtype b, N_Vector z);

template <typename VectorType,
std::enable_if_t<!IsBlockVector<VectorType>::value, int> = 0>
const MPI_Comm &
get_communicator(N_Vector v);

template <typename VectorType,
std::enable_if_t<IsBlockVector<VectorType>::value, int> = 0>
template <typename VectorType>
const MPI_Comm &
get_communicator(N_Vector v);

/**
* Sundials likes a void* but we want to use the above functions
* internally with a safe type.
*/
template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int> = 0>
inline void *
get_communicator_as_void_ptr(N_Vector v);

template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value, int> = 0>
template <typename VectorType>
inline void *
get_communicator_as_void_ptr(N_Vector v);

} // namespace NVectorOperations
} // namespace internal
} // namespace SUNDIALS
Expand All @@ -297,18 +311,65 @@ namespace SUNDIALS
{
namespace internal
{
namespace
{
template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &)
{
return MPI_COMM_SELF;
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value &&
!IsBlockVector<VectorType>::value,
int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return MPI_COMM_SELF;
# else
return v.get_mpi_communicator();
# endif
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value &&
IsBlockVector<VectorType>::value,
int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return MPI_COMM_SELF;
# else
return v.block(0).get_mpi_communicator();
# endif
}
} // namespace


template <typename VectorType>
NVectorContent<VectorType>::NVectorContent()
: vector(typename VectorMemory<VectorType>::Pointer(mem))
, comm(get_mpi_communicator_from_vector(*vector))
, is_const(false)
{}



template <typename VectorType>
NVectorContent<VectorType>::NVectorContent(VectorType *vector)
: vector(vector,
[](VectorType *) { /* not owning memory -> don't free*/ })
, comm(get_mpi_communicator_from_vector(*vector))
, is_const(false)
{}

Expand All @@ -318,6 +379,7 @@ namespace SUNDIALS
NVectorContent<VectorType>::NVectorContent(const VectorType *vector)
: vector(const_cast<VectorType *>(vector),
[](VectorType *) { /* not owning memory -> don't free*/ })
, comm(get_mpi_communicator_from_vector(*vector))
, is_const(true)
{}

Expand Down Expand Up @@ -349,6 +411,14 @@ namespace SUNDIALS



template <typename VectorType>
const MPI_Comm &
NVectorContent<VectorType>::get_mpi_communicator() const
{
return comm;
}


# if DEAL_II_SUNDIALS_VERSION_GTE(6, 0, 0)
template <typename VectorType>
NVectorView<VectorType>
Expand Down Expand Up @@ -528,50 +598,33 @@ namespace SUNDIALS



template <typename VectorType,
std::enable_if_t<IsBlockVector<VectorType>::value, int>>
const MPI_Comm &
get_communicator(N_Vector v)
{
return unwrap_nvector_const<VectorType>(v)
->block(0)
.get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<!IsBlockVector<VectorType>::value, int>>
template <typename VectorType>
const MPI_Comm &
get_communicator(N_Vector v)
{
return unwrap_nvector_const<VectorType>(v)->get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int>>
void *get_communicator_as_void_ptr(N_Vector)
{
// required by SUNDIALS: MPI-unaware vectors should return the nullptr
// as comm
return nullptr;
Assert(v != nullptr, ExcInternalError());
Assert(v->content != nullptr, ExcInternalError());
auto *pContent =
reinterpret_cast<NVectorContent<VectorType> *>(v->content);
return pContent->get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value, int>>
template <typename VectorType>
void *
get_communicator_as_void_ptr(N_Vector v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return nullptr;
# else
// We need to cast away const here, as SUNDIALS demands a pure `void *`.
return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
if (is_serial_vector<VectorType>::value == false)
// We need to cast away const here, as SUNDIALS demands a pure
// `void*`.
return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
else
return nullptr;
# endif
}

Expand Down

0 comments on commit d7f759e

Please sign in to comment.