Skip to content

Commit

Permalink
IDA solver modifications to be compatible with the latest
Browse files Browse the repository at this point in the history
Sundials versions

Co-authored-by: Nick Curtis <arghdos@gmail.com>
  • Loading branch information
2 people authored and ischoegl committed Jun 14, 2023
1 parent 568cf48 commit f1922af
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
1 change: 1 addition & 0 deletions include/cantera/numerics/IDAIntegrator.h
Expand Up @@ -88,6 +88,7 @@ class IDAIntegrator : public Integrator
void* m_ida_mem; //!< Pointer to the IDA memory for the problem
void* m_linsol; //!< Sundials linear solver object
void* m_linsol_matrix; //!< matrix used by Sundials
void * m_ctx; //!< contex object used by Sundials
FuncEval* m_func;
double m_t0;
double m_time; //!< The current integrator time
Expand Down
32 changes: 16 additions & 16 deletions src/numerics/IDAIntegrator.cpp
Expand Up @@ -115,7 +115,7 @@ void IDAIntegrator::setTolerances(double reltol, size_t n, double* abstol)
if (m_abstol) {
N_VDestroy_Serial(m_abstol);
}
m_abstol = N_VNew_Serial(static_cast<sd_size_t>(n));
m_abstol = N_VNew_Serial(static_cast<sd_size_t>(n), (SUNContext) m_ctx);
}
for (size_t i=0; i<n; i++) {
NV_Ith_S(m_abstol, i) = abstol[i];
Expand Down Expand Up @@ -232,14 +232,14 @@ void IDAIntegrator::initialize(double t0, FuncEval& func)
if (m_y) {
N_VDestroy_Serial(m_y); // free solution vector if already allocated
}
m_y = N_VNew_Serial(static_cast<sd_size_t>(m_neq)); // allocate solution vector
m_y = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx); // allocate solution vector
N_VConst(0.0, m_y);

if (m_ydot)
{
N_VDestroy_Serial(m_ydot); // free derivative vector if already allocated
}
m_ydot = N_VNew_Serial(m_neq);
m_ydot = N_VNew_Serial(m_neq, (SUNContext) m_ctx);
N_VConst(0.0, m_ydot);

// check abs tolerance array size
Expand All @@ -251,7 +251,7 @@ void IDAIntegrator::initialize(double t0, FuncEval& func)
if (m_constraints) {
N_VDestroy_Serial(m_constraints);
}
m_constraints = N_VNew_Serial(static_cast<sd_size_t>(m_neq));
m_constraints = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx);
// set the constraints
func.getConstraints(NV_DATA_S(m_constraints));

Expand All @@ -264,7 +264,7 @@ void IDAIntegrator::initialize(double t0, FuncEval& func)
}

//! Create the IDA solver
m_ida_mem = IDACreate();
m_ida_mem = IDACreate((SUNContext) m_ctx);
if (!m_ida_mem) {
throw CanteraError("IDAIntegrator::initialize",
"IDACreate failed.");
Expand Down Expand Up @@ -351,13 +351,13 @@ void IDAIntegrator::applyOptions()
#if CT_SUNDIALS_VERSION >= 30
SUNLinSolFree((SUNLinearSolver) m_linsol);
SUNMatDestroy((SUNMatrix) m_linsol_matrix);
m_linsol_matrix = SUNDenseMatrix(N, N);
m_linsol_matrix = SUNDenseMatrix(N, N, (SUNContext) m_ctx);
#if CT_SUNDIALS_USE_LAPACK
m_linsol = SUNLapackDense(m_y, (SUNMatrix) m_linsol_matrix);
#else
m_linsol = SUNDenseLinearSolver(m_y, (SUNMatrix) m_linsol_matrix);
m_linsol = SUNLinSol_Dense(m_y, (SUNMatrix) m_linsol_matrix, (SUNContext) m_ctx);
#endif
IDADlsSetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol,
IDASetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol,
(SUNMatrix) m_linsol_matrix);
#else
#if CT_SUNDIALS_USE_LAPACK
Expand All @@ -371,7 +371,7 @@ void IDAIntegrator::applyOptions()
"Cannot use a diagonal matrix with IDA.");
} else if (m_type == GMRES) {
#if CT_SUNDIALS_VERSION >= 30
m_linsol = SUNSPGMR(m_y, PREC_NONE, 0);
m_linsol = SUNLinSol_SPGMR(m_y, PREC_NONE, 0, (SUNContext) m_ctx);
IDASpilsSetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol);
#else
IDASpgmr(m_ida_mem, PREC_NONE, 0);
Expand All @@ -383,13 +383,13 @@ void IDAIntegrator::applyOptions()
#if CT_SUNDIALS_VERSION >= 30
SUNLinSolFree((SUNLinearSolver) m_linsol);
SUNMatDestroy((SUNMatrix) m_linsol_matrix);
m_linsol_matrix = SUNBandMatrix(N, nu, nl, nu+nl);
m_linsol_matrix = SUNBandMatrix(N, nu, nl, (SUNContext) m_ctx);
#if CT_SUNDIALS_USE_LAPACK
m_linsol = SUNLapackBand(m_y, (SUNMatrix) m_linsol_matrix);
#else
m_linsol = SUNBandLinearSolver(m_y, (SUNMatrix) m_linsol_matrix);
m_linsol = SUNLinSol_Band(m_y, (SUNMatrix) m_linsol_matrix, (SUNContext) m_ctx);
#endif
IDADlsSetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol,
IDASetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol,
(SUNMatrix) m_linsol_matrix);
#else
#if CT_SUNDIALS_USE_LAPACK
Expand Down Expand Up @@ -443,13 +443,13 @@ void IDAIntegrator::sensInit(double t0, FuncEval& func)
m_np = func.nparams();
m_sens_ok = false;

N_Vector y = N_VNew_Serial(static_cast<sd_size_t>(func.neq()));
N_Vector y = N_VNew_Serial(static_cast<sd_size_t>(func.neq()), (SUNContext) m_ctx);
m_yS = N_VCloneVectorArray_Serial(static_cast<sd_size_t>(m_np), y);
for (size_t n = 0; n < m_np; n++) {
N_VConst(0.0, m_yS[n]);
}
N_VDestroy_Serial(y);
N_Vector ydot = N_VNew_Serial(static_cast<sd_size_t>(func.neq()));
N_Vector ydot = N_VNew_Serial(static_cast<sd_size_t>(func.neq()), (SUNContext) m_ctx);
m_ySdot = N_VCloneVectorArray_Serial(static_cast<sd_size_t>(m_np), ydot);
for (size_t n = 0; n < m_np; n++) {
N_VConst(0.0, m_ySdot[n]);
Expand Down Expand Up @@ -546,8 +546,8 @@ double IDAIntegrator::sensitivity(size_t k, size_t p)

string IDAIntegrator::getErrorInfo(int N)
{
N_Vector errs = N_VNew_Serial(static_cast<sd_size_t>(m_neq));
N_Vector errw = N_VNew_Serial(static_cast<sd_size_t>(m_neq));
N_Vector errs = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx);
N_Vector errw = N_VNew_Serial(static_cast<sd_size_t>(m_neq), (SUNContext) m_ctx);
IDAGetErrWeights(m_ida_mem, errw);
IDAGetEstLocalErrors(m_ida_mem, errs);

Expand Down

0 comments on commit f1922af

Please sign in to comment.