Skip to content

Commit

Permalink
Compatibility with all Sundials versions
Browse files Browse the repository at this point in the history
  • Loading branch information
gkogekar authored and ischoegl committed Jun 14, 2023
1 parent 34e0996 commit 8678ec8
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
3 changes: 2 additions & 1 deletion include/cantera/numerics/IDAIntegrator.h
Expand Up @@ -12,6 +12,7 @@
#include "cantera/numerics/Integrator.h"
#include "cantera/base/ctexceptions.h"
#include "sundials/sundials_nvector.h"
#include "cantera/numerics/SundialsContext.h"

namespace Cantera
{
Expand Down Expand Up @@ -88,7 +89,7 @@ class IDAIntegrator : public Integrator
void* m_ida_mem; //!< Pointer to the IDA memory for the problem
void* m_linsol; //!< Sundials linear solver object
void* m_linsol_matrix; //!< matrix used by Sundials
void * m_ctx; //!< contex object used by Sundials
SundialsContext m_sundials_ctx; //!< SUNContext object for Sundials>=6.0
FuncEval* m_func;
double m_t0;
double m_time; //!< The current integrator time
Expand Down
70 changes: 56 additions & 14 deletions src/numerics/IDAIntegrator.cpp
Expand Up @@ -9,6 +9,19 @@
#include "cantera/numerics/sundials_headers.h"

using namespace std;

namespace {

N_Vector newNVector(size_t N, Cantera::SundialsContext& context)
{
#if CT_SUNDIALS_VERSION >= 60
return N_VNew_Serial(static_cast<sd_size_t>(N), context.get());
#else
return N_VNew_Serial(static_cast<sd_size_t>(N));
#endif
}

} // end anonymous namespace
namespace Cantera
{

Expand Down Expand Up @@ -115,7 +128,7 @@ void IDAIntegrator::setTolerances(double reltol, size_t n, double* abstol)
if (m_abstol) {
N_VDestroy_Serial(m_abstol);
}
m_abstol = N_VNew_Serial(static_cast<sd_size_t>(n), (SUNContext) m_ctx);
m_abstol = newNVector(static_cast<sd_size_t>(n), m_sundials_ctx);
}
for (size_t i=0; i<n; i++) {
NV_Ith_S(m_abstol, i) = abstol[i];
Expand Down Expand Up @@ -232,13 +245,13 @@ void IDAIntegrator::initialize(double t0, FuncEval& func)
if (m_y) {
N_VDestroy_Serial(m_y); // free solution vector if already allocated
}
m_y = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx); // allocate solution vector
m_y = newNVector(static_cast<sd_size_t>(m_neq), m_sundials_ctx);
N_VConst(0.0, m_y);

if (m_ydot) {
N_VDestroy_Serial(m_ydot); // free derivative vector if already allocated
}
m_ydot = N_VNew_Serial(m_neq, (SUNContext) m_ctx);
m_ydot = newNVector(m_neq, m_sundials_ctx);
N_VConst(0.0, m_ydot);

// check abs tolerance array size
Expand All @@ -250,7 +263,7 @@ void IDAIntegrator::initialize(double t0, FuncEval& func)
if (m_constraints) {
N_VDestroy_Serial(m_constraints);
}
m_constraints = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx);
m_constraints = newNVector(static_cast<sd_size_t>(m_neq), m_sundials_ctx);
// set the constraints
func.getConstraints(NV_DATA_S(m_constraints));

Expand All @@ -263,7 +276,11 @@ void IDAIntegrator::initialize(double t0, FuncEval& func)
}

//! Create the IDA solver
m_ida_mem = IDACreate((SUNContext) m_ctx);
#if CT_SUNDIALS_VERSION >= 60
m_ida_mem = IDACreate(m_sundials_ctx.get());
#else
m_ida_mem = IDACreate();
#endif
if (!m_ida_mem) {
throw CanteraError("IDAIntegrator::initialize",
"IDACreate failed.");
Expand Down Expand Up @@ -350,11 +367,24 @@ void IDAIntegrator::applyOptions()
#if CT_SUNDIALS_VERSION >= 30
SUNLinSolFree((SUNLinearSolver) m_linsol);
SUNMatDestroy((SUNMatrix) m_linsol_matrix);
m_linsol_matrix = SUNDenseMatrix(N, N, (SUNContext) m_ctx);
#if CT_SUNDIALS_VERSION >= 60
m_linsol_matrix = SUNDenseMatrix(N, N, m_sundials_ctx.get());
#else
m_linsol_matrix = SUNDenseMatrix(N, N);
#endif
#if CT_SUNDIALS_VERSION >= 60
m_linsol_matrix = SUNDenseMatrix(N, N, m_sundials_ctx.get());
#else
m_linsol_matrix = SUNDenseMatrix(N, N);
#endif
#if CT_SUNDIALS_USE_LAPACK
m_linsol = SUNLapackDense(m_y, (SUNMatrix) m_linsol_matrix);
#else
m_linsol = SUNLinSol_Dense(m_y, (SUNMatrix) m_linsol_matrix, (SUNContext) m_ctx);
#if CT_SUNDIALS_VERSION >= 60
m_linsol = SUNLinSol_Dense(m_y, (SUNMatrix) m_linsol_matrix, m_sundials_ctx.get());
#else
m_linsol = SUNLinSol_Dense(m_y, (SUNMatrix) m_linsol_matrix);
#endif
#endif
IDASetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol,
(SUNMatrix) m_linsol_matrix);
Expand All @@ -370,7 +400,11 @@ void IDAIntegrator::applyOptions()
"Cannot use a diagonal matrix with IDA.");
} else if (m_type == GMRES) {
#if CT_SUNDIALS_VERSION >= 30
m_linsol = SUNLinSol_SPGMR(m_y, PREC_NONE, 0, (SUNContext) m_ctx);
#if CT_SUNDIALS_VERSION >= 60
m_linsol = SUNLinSol_SPGMR(m_y, PREC_NONE, 0, m_sundials_ctx.get());
#else
m_linsol = SUNLinSol_SPGMR(m_y, PREC_NONE, 0);
#endif
IDASpilsSetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol);
#else
IDASpgmr(m_ida_mem, 0);
Expand All @@ -382,11 +416,19 @@ void IDAIntegrator::applyOptions()
#if CT_SUNDIALS_VERSION >= 30
SUNLinSolFree((SUNLinearSolver) m_linsol);
SUNMatDestroy((SUNMatrix) m_linsol_matrix);
m_linsol_matrix = SUNBandMatrix(N, nu, nl, (SUNContext) m_ctx);
#if CT_SUNDIALS_VERSION >= 60
m_linsol_matrix = SUNBandMatrix(N, nu, nl, m_sundials_ctx.get());
#else
m_linsol_matrix = SUNBandMatrix(N, nu, nl);
#endif
#if CT_SUNDIALS_USE_LAPACK
m_linsol = SUNLapackBand(m_y, (SUNMatrix) m_linsol_matrix);
#else
m_linsol = SUNLinSol_Band(m_y, (SUNMatrix) m_linsol_matrix, (SUNContext) m_ctx);
#if CT_SUNDIALS_VERSION >= 60
m_linsol = SUNLinSol_Band(m_y, (SUNMatrix) m_linsol_matrix, m_sundials_ctx.get());
#else
m_linsol = SUNLinSol_Band(m_y, (SUNMatrix) m_linsol_matrix);
#endif
#endif
IDASetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol,
(SUNMatrix) m_linsol_matrix);
Expand Down Expand Up @@ -442,13 +484,13 @@ void IDAIntegrator::sensInit(double t0, FuncEval& func)
m_np = func.nparams();
m_sens_ok = false;

N_Vector y = N_VNew_Serial(static_cast<sd_size_t>(func.neq()), (SUNContext) m_ctx);
N_Vector y = newNVector(static_cast<sd_size_t>(func.neq()), m_sundials_ctx);
m_yS = N_VCloneVectorArray_Serial(static_cast<sd_size_t>(m_np), y);
for (size_t n = 0; n < m_np; n++) {
N_VConst(0.0, m_yS[n]);
}
N_VDestroy_Serial(y);
N_Vector ydot = N_VNew_Serial(static_cast<sd_size_t>(func.neq()), (SUNContext) m_ctx);
N_Vector ydot = newNVector(static_cast<sd_size_t>(func.neq()), m_sundials_ctx);
m_ySdot = N_VCloneVectorArray_Serial(static_cast<sd_size_t>(m_np), ydot);
for (size_t n = 0; n < m_np; n++) {
N_VConst(0.0, m_ySdot[n]);
Expand Down Expand Up @@ -545,8 +587,8 @@ double IDAIntegrator::sensitivity(size_t k, size_t p)

string IDAIntegrator::getErrorInfo(int N)
{
N_Vector errs = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx);
N_Vector errw = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx);
N_Vector errs = newNVector(static_cast<sd_size_t>(m_neq), m_sundials_ctx);
N_Vector errw = newNVector(static_cast<sd_size_t>(m_neq), m_sundials_ctx);
IDAGetErrWeights(m_ida_mem, errw);
IDAGetEstLocalErrors(m_ida_mem, errs);

Expand Down

0 comments on commit 8678ec8

Please sign in to comment.