Skip to content

Commit

Permalink
Issue #3682 Separating getting x,z,q from Integrator::advance
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeandersson committed May 9, 2024
1 parent 9c12782 commit 2c77ba1
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 22 deletions.
26 changes: 19 additions & 7 deletions casadi/core/integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,10 @@ int Integrator::eval(const double** arg, double** res,
// Advance solution
if (verbose_) casadi_message("Integrating forward to output time " + str(m->k) + ": t_next = "
+ str(m->t_next) + ", t_stop = " + str(m->t_stop));
advance(m, x, z, q);
advance(m);
get_x(m, x);
get_z(m, z);
get_q(m, q);
if (x) x += nx_;
if (z) z += nz_;
if (q) q += nq_;
Expand Down Expand Up @@ -1875,8 +1878,7 @@ int FixedStepIntegrator::init_mem(void* mem) const {
return 0;
}

void FixedStepIntegrator::advance(IntegratorMemory* mem,
double* x, double* z, double* q) const {
void FixedStepIntegrator::advance(IntegratorMemory* mem) const {
auto m = static_cast<FixedStepMemory*>(mem);

// State at previous step
Expand Down Expand Up @@ -1908,10 +1910,8 @@ void FixedStepIntegrator::advance(IntegratorMemory* mem,
}
}

// Return to user
casadi_copy(m->x, nx_, x);
casadi_copy(m->v + nv_ - nz_, nz_, z);
casadi_copy(m->q, nq_, q);
// Save algebraic variables
casadi_copy(m->v + nv_ - nz_, nz_, m->z);
}

void FixedStepIntegrator::retreat(IntegratorMemory* mem, const double* u,
Expand Down Expand Up @@ -2347,6 +2347,18 @@ void Integrator::set_u(IntegratorMemory* m, const double* u) const {
casadi_copy(u, nu_, m->u);
}

void Integrator::get_q(IntegratorMemory* m, double* q) const {
casadi_copy(m->q, nq_, q);
}

void Integrator::get_x(IntegratorMemory* m, double* x) const {
casadi_copy(m->x, nx_, x);
}

void Integrator::get_z(IntegratorMemory* m, double* z) const {
casadi_copy(m->z, nz_, z);
}

void Integrator::reset(IntegratorMemory* m) const {
}

Expand Down
13 changes: 11 additions & 2 deletions casadi/core/integrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ Integrator : public OracleFunction, public PluginInterface<Integrator> {
// Set the controls
void set_u(IntegratorMemory* m, const double* u) const;

// Get the quadrature states
void get_q(IntegratorMemory* m, double* q) const;

// Get the differential states
void get_x(IntegratorMemory* m, double* x) const;

// Get the algebraic variables
void get_z(IntegratorMemory* m, double* z) const;

/** \brief Reset the forward problem
\identifier{25a} */
Expand All @@ -178,7 +187,7 @@ Integrator : public OracleFunction, public PluginInterface<Integrator> {
/** \brief Advance solution in time
\identifier{25c} */
virtual void advance(IntegratorMemory* mem, double* x, double* z, double* q) const = 0;
virtual void advance(IntegratorMemory* mem) const = 0;

/** \brief Reset the backward problem
Expand Down Expand Up @@ -559,7 +568,7 @@ class CASADI_EXPORT FixedStepIntegrator : public Integrator {
/** \brief Advance solution in time
\identifier{25j} */
void advance(IntegratorMemory* mem, double* x, double* z, double* q) const override;
void advance(IntegratorMemory* mem) const override;

/// Reset the backward problem and take time to tf
void resetB(IntegratorMemory* mem) const override;
Expand Down
6 changes: 2 additions & 4 deletions casadi/interfaces/sundials/cvodes_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,7 @@ void CvodesInterface::reset(IntegratorMemory* mem) const {
}
}

void CvodesInterface::advance(IntegratorMemory* mem,
double* x, double* z, double* q) const {
void CvodesInterface::advance(IntegratorMemory* mem) const {
auto m = to_mem(mem);

// Do not integrate past change in input signals or past the end
Expand All @@ -278,8 +277,7 @@ void CvodesInterface::advance(IntegratorMemory* mem,

// Set function outputs
casadi_copy(NV_DATA_S(m->v_xz), nx_, m->x);
casadi_copy(m->x, nx_, x);
casadi_copy(NV_DATA_S(m->v_q), nq_, q);
casadi_copy(NV_DATA_S(m->v_q), nq_, m->q);

// Get stats
THROWING(CVodeGetIntegratorStats, m->mem, &m->nsteps, &m->nfevals, &m->nlinsetups,
Expand Down
3 changes: 1 addition & 2 deletions casadi/interfaces/sundials/cvodes_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ class CASADI_INTEGRATOR_CVODES_EXPORT CvodesInterface : public SundialsInterface
void reset(IntegratorMemory* mem) const override;

/** \brief Advance solution in time */
void advance(IntegratorMemory* mem,
double* x, double* z, double* q) const override;
void advance(IntegratorMemory* mem) const override;

/** \brief Introduce an impulse into the backwards integration at the current time */
void impulseB(IntegratorMemory* mem,
Expand Down
8 changes: 2 additions & 6 deletions casadi/interfaces/sundials/idas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,7 @@ void IdasInterface::reset(IntegratorMemory* mem) const {
if (nadj_ > 0) THROWING(IDAAdjReInit, m->mem);
}

void IdasInterface::advance(IntegratorMemory* mem,
double* x, double* z, double* q) const {
void IdasInterface::advance(IntegratorMemory* mem) const {
auto m = to_mem(mem);

// Do not integrate past change in input signals or past the end
Expand All @@ -409,15 +408,12 @@ void IdasInterface::advance(IntegratorMemory* mem,

// Set function outputs
casadi_copy(NV_DATA_S(m->v_xz), nx_ + nz_, m->x);
casadi_copy(m->x, nx_, x);
casadi_copy(m->x + nx_, nz_, z);
casadi_copy(NV_DATA_S(m->v_q), nq_, q);
casadi_copy(NV_DATA_S(m->v_q), nq_, m->q);

// Get stats
THROWING(IDAGetIntegratorStats, m->mem, &m->nsteps, &m->nfevals, &m->nlinsetups,
&m->netfails, &m->qlast, &m->qcur, &m->hinused, &m->hlast, &m->hcur, &m->tcur);
THROWING(IDAGetNonlinSolvStats, m->mem, &m->nniters, &m->nncfails);

}

void IdasInterface::resetB(IntegratorMemory* mem) const {
Expand Down
2 changes: 1 addition & 1 deletion casadi/interfaces/sundials/idas_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class CASADI_INTEGRATOR_IDAS_EXPORT IdasInterface : public SundialsInterface {
void reset(IntegratorMemory* mem) const override;

/** \brief Advance solution in time */
void advance(IntegratorMemory* mem, double* x, double* z, double* q) const override;
void advance(IntegratorMemory* mem) const override;

/** \brief Reset the backward problem and take time to tf */
void resetB(IntegratorMemory* mem) const override;
Expand Down

0 comments on commit 2c77ba1

Please sign in to comment.