Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address the pesky 'MPI_Comm as void*' problem with the SUNDIALS interfaces #15113

Merged
merged 4 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
165 changes: 112 additions & 53 deletions include/deal.II/sundials/n_vector.templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ namespace SUNDIALS
*
* @note This constructor is intended for the N_VClone() call of SUNDIALS.
*/
NVectorContent();
NVectorContent(const MPI_Comm comm);

/**
* Non-const access to the stored vector. Only allowed if a constructor
Expand All @@ -92,6 +92,21 @@ namespace SUNDIALS
const VectorType *
get() const;

/**
* Return a reference to a copy of the communicator the vector uses.
* This function exists because the N_Vector
* interface requires a function that returns a `void*` pointing
* to the communicator object -- so somewhere, we need to have an
* address to point to. The issue is that our vectors typically
* return a *copy* of the communicator, rather than a reference to
* the communicator they use, and so there is only a temporary
* object and no address we can point to. To work around this
* requirement, this class stores a copy of the communicator,
* and this function here returns a reference to this copy.
*/
const MPI_Comm &
get_mpi_communicator() const;

private:
using PointerType =
std::unique_ptr<VectorType, std::function<void(VectorType *)>>;
Expand All @@ -107,6 +122,18 @@ namespace SUNDIALS
*/
PointerType vector;

/**
* A copy of the communicator the vector uses, initialized in the
* constructor of this class. We store this because the N_Vector
* interface requires a function that returns a `void*` pointing
* to the communicator object -- so somewhere, we need to have an
* address to point to. The issue is that our vectors typically
* return a *copy* of the communicator, rather than a reference to
* the communicator they use, and so there is only a temporary
* object and no address we can point to.
*/
MPI_Comm comm;

/**
* Flag storing whether the stored pointer is to be treated as const. If
* the pointer passed in the constructor was indeed const, it is cast away
Expand Down Expand Up @@ -261,30 +288,17 @@ namespace SUNDIALS
void
add_constant(N_Vector x, realtype b, N_Vector z);

template <typename VectorType,
std::enable_if_t<!IsBlockVector<VectorType>::value, int> = 0>
const MPI_Comm &
get_communicator(N_Vector v);

template <typename VectorType,
std::enable_if_t<IsBlockVector<VectorType>::value, int> = 0>
template <typename VectorType>
const MPI_Comm &
get_communicator(N_Vector v);

/**
* Sundials likes a void* but we want to use the above functions
* internally with a safe type.
*/
template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int> = 0>
inline void *
get_communicator_as_void_ptr(N_Vector v);

template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value, int> = 0>
template <typename VectorType>
inline void *
get_communicator_as_void_ptr(N_Vector v);

} // namespace NVectorOperations
} // namespace internal
} // namespace SUNDIALS
Expand All @@ -297,18 +311,68 @@ namespace SUNDIALS
{
namespace internal
{
namespace
{
template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &)
{
return MPI_COMM_SELF;
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value &&
!IsBlockVector<VectorType>::value,
int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return MPI_COMM_SELF;
# else
return v.get_mpi_communicator();
# endif
}



template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value &&
IsBlockVector<VectorType>::value,
int> = 0>
MPI_Comm
get_mpi_communicator_from_vector(const VectorType &v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return MPI_COMM_SELF;
# else
Assert(v.n_blocks() > 0,
ExcMessage("You cannot ask a block vector without blocks "
"for its MPI communicator."));
return v.block(0).get_mpi_communicator();
# endif
}
} // namespace


template <typename VectorType>
NVectorContent<VectorType>::NVectorContent()
NVectorContent<VectorType>::NVectorContent(const MPI_Comm comm)
: vector(typename VectorMemory<VectorType>::Pointer(mem))
, comm(comm)
, is_const(false)
{}



template <typename VectorType>
NVectorContent<VectorType>::NVectorContent(VectorType *vector)
: vector(vector,
[](VectorType *) { /* not owning memory -> don't free*/ })
, comm(get_mpi_communicator_from_vector(*vector))
, is_const(false)
{}

Expand All @@ -318,6 +382,7 @@ namespace SUNDIALS
NVectorContent<VectorType>::NVectorContent(const VectorType *vector)
: vector(const_cast<VectorType *>(vector),
[](VectorType *) { /* not owning memory -> don't free*/ })
, comm(get_mpi_communicator_from_vector(*vector))
, is_const(true)
{}

Expand Down Expand Up @@ -349,6 +414,14 @@ namespace SUNDIALS



template <typename VectorType>
const MPI_Comm &
NVectorContent<VectorType>::get_mpi_communicator() const
{
return comm;
}


# if DEAL_II_SUNDIALS_VERSION_GTE(6, 0, 0)
template <typename VectorType>
NVectorView<VectorType>
Expand Down Expand Up @@ -492,14 +565,17 @@ namespace SUNDIALS
{
N_Vector v = clone_empty(w);

// the corresponding delete is called in destroy()
auto cloned = new NVectorContent<VectorType>();
auto *w_dealii = unwrap_nvector_const<VectorType>(w);

// reinit the cloned vector based on the layout of the source vector
cloned->get()->reinit(*w_dealii);
v->content = cloned;
// Create the vector; the corresponding delete is called in destroy()
auto cloned = new NVectorContent<VectorType>(
get_mpi_communicator_from_vector(*w_dealii));

// Then also copy the structure and values:
*cloned->get() = *w_dealii;

// Finally set the cloned object in 'v':
v->content = cloned;
return v;
}

Expand Down Expand Up @@ -528,50 +604,33 @@ namespace SUNDIALS



template <typename VectorType,
std::enable_if_t<IsBlockVector<VectorType>::value, int>>
const MPI_Comm &
get_communicator(N_Vector v)
{
return unwrap_nvector_const<VectorType>(v)
->block(0)
.get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<!IsBlockVector<VectorType>::value, int>>
template <typename VectorType>
const MPI_Comm &
get_communicator(N_Vector v)
{
return unwrap_nvector_const<VectorType>(v)->get_mpi_communicator();
}



template <typename VectorType,
std::enable_if_t<is_serial_vector<VectorType>::value, int>>
void *get_communicator_as_void_ptr(N_Vector)
{
// required by SUNDIALS: MPI-unaware vectors should return the nullptr
// as comm
return nullptr;
Assert(v != nullptr, ExcInternalError());
Assert(v->content != nullptr, ExcInternalError());
auto *pContent =
reinterpret_cast<NVectorContent<VectorType> *>(v->content);
return pContent->get_mpi_communicator();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did

      template <typename VectorType,
                std::enable_if_t<is_serial_vector<VectorType>::value, int>>
      void *get_communicator_as_void_ptr(N_Vector)
      {
        // required by SUNDIALS: MPI-unaware vectors should return the nullptr
        // as comm
        return nullptr;
      }

disappear? Is &MPI_COMM_SELF==nullptr?

edit: Clarified "this".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were two specializations of the function, one for serial and one for parallel vectors (enabled or disabled via std::enable_if) because the former do not have a get_communicator() member function.

I've outsourced that decision to the get_mpi_communicator() free function above now that returns MPI_COMM_SELF for serial vectors. Are you saying that we should just return nullptr if a vector is serial, regardless of whether we configured for MPI or not?

Copy link
Member Author

@bangerth bangerth Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like this in the current place?

diff --git a/include/deal.II/sundials/n_vector.templates.h b/include/deal.II/sundials/n_vector.templates.h
index e5e9db4bd2..e4a42d659a 100644
--- a/include/deal.II/sundials/n_vector.templates.h
+++ b/include/deal.II/sundials/n_vector.templates.h
@@ -619,8 +619,12 @@ namespace SUNDIALS
         (void)v;
         return nullptr;
 #  else
-        // We need to cast away const here, as SUNDIALS demands a pure `void *`.
-        return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
+        if (is_serial_vector<VectorType>::value == false)
+          // We need to cast away const here, as SUNDIALS demands a pure `void
+          // *`.
+          return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
+        else
+          return nullptr;
 #  endif
       }
 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I convinced myself that that is the right way to go and pushed an updated commit.




template <typename VectorType,
std::enable_if_t<!is_serial_vector<VectorType>::value, int>>
template <typename VectorType>
void *
get_communicator_as_void_ptr(N_Vector v)
{
# ifndef DEAL_II_WITH_MPI
(void)v;
return nullptr;
# else
// We need to cast away const here, as SUNDIALS demands a pure `void *`.
return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
if (is_serial_vector<VectorType>::value == false)
// We need to cast away const here, as SUNDIALS demands a pure
// `void*`.
return &(const_cast<MPI_Comm &>(get_communicator<VectorType>(v)));
else
return nullptr;
# endif
}

Expand Down
18 changes: 9 additions & 9 deletions tests/sundials/n_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ namespace
*/
template <typename VectorType>
VectorType
create_test_vector(double value = 0.0);
create_test_vector(const double value = 0.0);



template <>
Vector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
Vector<double> vector(3 /*size*/);
vector = value;
Expand All @@ -94,7 +94,7 @@ namespace

template <>
BlockVector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
const int num_blocks = 2;
const int size_block = 3;
Expand All @@ -105,7 +105,7 @@ namespace

template <>
LinearAlgebra::distributed::Vector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
IndexSet local_dofs = create_parallel_index_set();
LinearAlgebra::distributed::Vector<double> vector(local_dofs,
Expand All @@ -116,7 +116,7 @@ namespace

template <>
LinearAlgebra::distributed::BlockVector<double>
create_test_vector(double value)
create_test_vector(const double value)
{
const unsigned n_processes =
Utilities::MPI::n_mpi_processes(MPI_COMM_WORLD);
Expand All @@ -133,7 +133,7 @@ namespace

template <>
TrilinosWrappers::MPI::Vector
create_test_vector(double value)
create_test_vector(const double value)
{
IndexSet local_dofs = create_parallel_index_set();

Expand All @@ -144,7 +144,7 @@ namespace

template <>
TrilinosWrappers::MPI::BlockVector
create_test_vector(double value)
create_test_vector(const double value)
{
const unsigned n_processes =
Utilities::MPI::n_mpi_processes(MPI_COMM_WORLD);
Expand All @@ -159,7 +159,7 @@ namespace

template <>
PETScWrappers::MPI::Vector
create_test_vector(double value)
create_test_vector(const double value)
{
IndexSet local_dofs = create_parallel_index_set();

Expand All @@ -170,7 +170,7 @@ namespace

template <>
PETScWrappers::MPI::BlockVector
create_test_vector(double value)
create_test_vector(const double value)
{
const unsigned n_processes =
Utilities::MPI::n_mpi_processes(MPI_COMM_WORLD);
Expand Down