Skip to content

Commit

Permalink
Merge pull request #15113 from bangerth/nvector
Browse files Browse the repository at this point in the history
Address the pesky 'MPI_Comm as void*' problem with the SUNDIALS interfaces
  • Loading branch information
tjhei committed Apr 20, 2023
2 parents 5d0c939 + d717b6c commit 70d13b7
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 62 deletions.
165 changes: 112 additions & 53 deletions include/deal.II/sundials/n_vector.templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ namespace SUNDIALS
*
* @note This constructor is intended for the N_VClone() call of SUNDIALS.
*/
NVectorContent();
NVectorContent(const MPI_Comm comm);

/**
* Non-const access to the stored vector. Only allowed if a constructor
Expand All @@ -92,6 +92,21 @@ namespace SUNDIALS
const VectorType *
get() const;

/**
* Return a reference to a copy of the communicator the vector uses.
* This function exists because the N_Vector
* interface requires a function that returns a `void*` pointing
* to the communicator object -- so somewhere, we need to have an
* address to point to. The issue is that our vectors typically
* return a *copy* of the communicator, rather than a reference to
* the communicator they use, and so there is only a temporary
* object and no address we can point to. To work around this
* requirement, this class stores a copy of the communicator,
* and this function here returns a reference to this copy.
*/
const MPI_Comm &
get_mpi_communicator() const;

private:
using PointerType =
std::unique_ptr<VectorType, std::function<void(VectorType *)>>;
Expand All @@ -107,6 +122,18 @@ namespace SUNDIALS
*/
PointerType vector;

/**
* A copy of the communicator the vector uses, initialized in the
* constructor of this class. We store this because the N_Vector
* interface requires a function that returns a `void*` pointing
* to the communicator object -- so somewhere, we need to have an
* address to point to. The issue is that our vectors typically
* return a *copy* of the communicator, rather than a reference to
* the communicator they use, and so there is only a temporary
* object and no address we can point to.
*/
MPI_Comm comm;

/**
* Flag storing whether the stored pointer is to be treated as const. If
* the pointer passed in the constructor was indeed const, it is cast away
Expand Down Expand Up @@ -261,30 +288,17 @@ namespace SUNDIALS
void
add_constant(N_Vector x, realtype b, N_Vector z);

template <typename VectorType,
std::enable_if_t<!IsBlockVector<VectorType>::value, int> = 0>
const MPI_Comm &
get_communicator(N_Vector v);

template <typename VectorType,
std::enable_if_t<IsBlockVector<VectorType>::value, int> = 0>
template <typename VectorType>
const MPI_Comm &
get_communicator(N_Vector v);

/**
* Sundials likes a void* but we want to use the above functions
* internally with a safe type.
*/
template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int> = 0>
inline void *
get_communicator_as_void_ptr(N_Vector v);

template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value, int> = 0>
template <typename VectorType>
inline void *
get_communicator_as_void_ptr(N_Vector v);

} // namespace NVectorOperations
} // namespace internal
} // namespace SUNDIALS
Expand All @@ -297,18 +311,68 @@ namespace SUNDIALS
{
namespace internal
{
namespace
{
template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &)
{
return MPI_COMM_SELF;
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value &&
!IsBlockVector<VectorType>::value,
int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return MPI_COMM_SELF;
# else
return v.get_mpi_communicator();
# endif
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value &&
IsBlockVector<VectorType>::value,
int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return MPI_COMM_SELF;
# else
Assert(v.n_blocks() > 0,
ExcMessage("You cannot ask a block vector without blocks "
"for its MPI communicator."));
return v.block(0).get_mpi_communicator();
# endif
}
} // namespace


template <typename VectorType>
NVectorContent<VectorType>::NVectorContent()
NVectorContent<VectorType>::NVectorContent(const MPI_Comm comm)
: vector(typename VectorMemory<VectorType>::Pointer(mem))
, comm(comm)
, is_const(false)
{}



template <typename VectorType>
NVectorContent<VectorType>::NVectorContent(VectorType *vector)
: vector(vector,
[](VectorType *) { /* not owning memory -> don't free*/ })
, comm(get_mpi_communicator_from_vector(*vector))
, is_const(false)
{}

Expand All @@ -318,6 +382,7 @@ namespace SUNDIALS
NVectorContent<VectorType>::NVectorContent(const VectorType *vector)
: vector(const_cast<VectorType *>(vector),
[](VectorType *) { /* not owning memory -> don't free*/ })
, comm(get_mpi_communicator_from_vector(*vector))
, is_const(true)
{}

Expand Down Expand Up @@ -349,6 +414,14 @@ namespace SUNDIALS



template <typename VectorType>
const MPI_Comm &
NVectorContent<VectorType>::get_mpi_communicator() const
{
return comm;
}


# if DEAL_II_SUNDIALS_VERSION_GTE(6, 0, 0)
template <typename VectorType>
NVectorView<VectorType>
Expand Down Expand Up @@ -492,14 +565,17 @@ namespace SUNDIALS
{
N_Vector v = clone_empty(w);

// the corresponding delete is called in destroy()
auto cloned = new NVectorContent<VectorType>();
auto *w_dealii = unwrap_nvector_const<VectorType>(w);

// reinit the cloned vector based on the layout of the source vector
cloned->get()->reinit(*w_dealii);
v->content = cloned;
// Create the vector; the corresponding delete is called in destroy()
auto cloned = new NVectorContent<VectorType>(
get_mpi_communicator_from_vector(*w_dealii));

// Then also copy the structure and values:
*cloned->get() = *w_dealii;

// Finally set the cloned object in 'v':
v->content = cloned;
return v;
}

Expand Down Expand Up @@ -528,50 +604,33 @@ namespace SUNDIALS



template <typename VectorType,
std::enable_if_t<IsBlockVector<VectorType>::value, int>>
const MPI_Comm &
get_communicator(N_Vector v)
{
return unwrap_nvector_const<VectorType>(v)
->block(0)
.get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<!IsBlockVector<VectorType>::value, int>>
template <typename VectorType>
const MPI_Comm &
get_communicator(N_Vector v)
{
return unwrap_nvector_const<VectorType>(v)->get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int>>
void *get_communicator_as_void_ptr(N_Vector)
{
// required by SUNDIALS: MPI-unaware vectors should return the nullptr
// as comm
return nullptr;
Assert(v != nullptr, ExcInternalError());
Assert(v->content != nullptr, ExcInternalError());
auto *pContent =
reinterpret_cast<NVectorContent<VectorType> *>(v->content);
return pContent->get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value, int>>
template <typename VectorType>
void *
get_communicator_as_void_ptr(N_Vector v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return nullptr;
# else
// We need to cast away const here, as SUNDIALS demands a pure `void *`.
return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
if (is_serial_vector<VectorType>::value == false)
// We need to cast away const here, as SUNDIALS demands a pure
// `void*`.
return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
else
return nullptr;
# endif
}

Expand Down
18 changes: 9 additions & 9 deletions tests/sundials/n_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ namespace
*/
template <typename VectorType>
VectorType
create_test_vector(double value = 0.0);
create_test_vector(const double value = 0.0);



template <>
Vector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
Vector<double> vector(3 /*size*/);
vector = value;
Expand All @@ -94,7 +94,7 @@ namespace

template <>
BlockVector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
const int num_blocks = 2;
const int size_block = 3;
Expand All @@ -105,7 +105,7 @@ namespace

template <>
LinearAlgebra::distributed::Vector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
IndexSet local_dofs = create_parallel_index_set();
LinearAlgebra::distributed::Vector<double> vector(local_dofs,
Expand All @@ -116,7 +116,7 @@ namespace

template <>
LinearAlgebra::distributed::BlockVector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
const unsigned n_processes =
Utilities::MPI::n_mpi_processes(MPI_COMM_WORLD);
Expand All @@ -133,7 +133,7 @@ namespace

template <>
TrilinosWrappers::MPI::Vector
create_test_vector(double value)
create_test_vector(const double value)
{
IndexSet local_dofs = create_parallel_index_set();

Expand All @@ -144,7 +144,7 @@ namespace

template <>
TrilinosWrappers::MPI::BlockVector
create_test_vector(double value)
create_test_vector(const double value)
{
const unsigned n_processes =
Utilities::MPI::n_mpi_processes(MPI_COMM_WORLD);
Expand All @@ -159,7 +159,7 @@ namespace

template <>
PETScWrappers::MPI::Vector
create_test_vector(double value)
create_test_vector(const double value)
{
IndexSet local_dofs = create_parallel_index_set();

Expand All @@ -170,7 +170,7 @@ namespace

template <>
PETScWrappers::MPI::BlockVector
create_test_vector(double value)
create_test_vector(const double value)
{
const unsigned n_processes =
Utilities::MPI::n_mpi_processes(MPI_COMM_WORLD);
Expand Down

0 comments on commit 70d13b7

Please sign in to comment.