diff --git a/src/solvers/petsc_linear_solver.C b/src/solvers/petsc_linear_solver.C index 81bea18e840..65a51b2b75b 100644 --- a/src/solvers/petsc_linear_solver.C +++ b/src/solvers/petsc_linear_solver.C @@ -36,6 +36,10 @@ #include "libmesh/enum_solver_type.h" #include "libmesh/enum_convergence_flags.h" +#ifdef LIBMESH_HAVE_PETSC_HYPRE +#include +#endif + // C++ includes #include #include @@ -571,6 +575,19 @@ PetscLinearSolver::solve_base (SparseMatrix * matrix, // Allow command line options to override anything set programmatically. LibmeshPetscCall(KSPSetFromOptions(_ksp)); +#if defined(LIBMESH_HAVE_PETSC_HYPRE) && !PETSC_VERSION_LESS_THAN(3,12,0) && defined(PETSC_HAVE_HYPRE_DEVICE) + { + // Make sure hypre has been initialized + LibmeshPetscCallExternal(HYPRE_Initialize); + PetscScalar * dummyarray; + PetscMemType mtype; + LibmeshPetscCall(VecGetArrayAndMemType(solution->vec(), &dummyarray, &mtype)); + LibmeshPetscCall(VecRestoreArrayAndMemType(solution->vec(), &dummyarray)); + if (PetscMemTypeHost(mtype)) + LibmeshPetscCallExternal(HYPRE_SetMemoryLocation, HYPRE_MEMORY_HOST); + } +#endif + // If the SolverConfiguration object is provided, use it to override // solver options. if (this->_solver_configuration) diff --git a/src/solvers/petsc_nonlinear_solver.C b/src/solvers/petsc_nonlinear_solver.C index 79e3a980707..37a8578661f 100644 --- a/src/solvers/petsc_nonlinear_solver.C +++ b/src/solvers/petsc_nonlinear_solver.C @@ -1073,22 +1073,17 @@ PetscNonlinearSolver::solve (SparseMatrix & pre_in, // System Preconditi #endif LibmeshPetscCall(SNESSetFromOptions(_snes)); -#if defined(LIBMESH_HAVE_PETSC_HYPRE) && !PETSC_VERSION_LESS_THAN(3,12,0) - // The above call set our PC type. If we're a hypre type we have to ensure that hypre is deployed - // in the same memory space as our vector types - PC pc; - LibmeshPetscCall(KSPGetPC(ksp, &pc)); - PetscBool is_hypre; - LibmeshPetscCall(PetscObjectTypeCompare((PetscObject)pc, PCHYPRE, &is_hypre)); - if (is_hypre == PETSC_TRUE) - { - PetscScalar * dummyarray; - PetscMemType mtype; - LibmeshPetscCall(VecGetArrayAndMemType(x->vec(), &dummyarray, &mtype)); - LibmeshPetscCall(VecRestoreArrayAndMemType(x->vec(), &dummyarray)); - if (PetscMemTypeHost(mtype)) - LibmeshPetscCallExternal(HYPRE_SetMemoryLocation, HYPRE_MEMORY_HOST); - } +#if defined(LIBMESH_HAVE_PETSC_HYPRE) && !PETSC_VERSION_LESS_THAN(3,12,0) && defined(PETSC_HAVE_HYPRE_DEVICE) + { + // Make sure hypre has been initialized + LibmeshPetscCallExternal(HYPRE_Initialize); + PetscScalar * dummyarray; + PetscMemType mtype; + LibmeshPetscCall(VecGetArrayAndMemType(x->vec(), &dummyarray, &mtype)); + LibmeshPetscCall(VecRestoreArrayAndMemType(x->vec(), &dummyarray)); + if (PetscMemTypeHost(mtype)) + LibmeshPetscCallExternal(HYPRE_SetMemoryLocation, HYPRE_MEMORY_HOST); + } #endif if (this->user_presolve)