Skip to content

Commit

Permalink
Issue #3682 Refactored Integrator work vectors
Browse files Browse the repository at this point in the history
rx -> adj_x in IntegratorMemory
  • Loading branch information
jaeandersson committed May 8, 2024
1 parent f2c49cb commit fa72bf2
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 33 deletions.
38 changes: 19 additions & 19 deletions casadi/core/integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ int Integrator::eval(const double** arg, double** res,
double* z = res[INTEGRATOR_ZF];
double* q = res[INTEGRATOR_QF];
double* adj_x = res[INTEGRATOR_ADJ_X0];
double* rq = res[INTEGRATOR_ADJ_P];
double* adj_p = res[INTEGRATOR_ADJ_P];
double* adj_u = res[INTEGRATOR_ADJ_U];
res += INTEGRATOR_NUM_OUT;

Expand Down Expand Up @@ -429,15 +429,15 @@ int Integrator::eval(const double** arg, double** res,
if (m->k > 0) {
retreat(m, u, 0, 0, adj_u);
} else {
retreat(m, u, adj_x, rq, adj_u);
retreat(m, u, adj_x, adj_p, adj_u);
}
} else {
if (verbose_) casadi_message("No adjoint seeds from output time " + str(m->k)
+ ": t_next = " + str(m->t_next) + ", t_stop = " + str(m->t_stop));
casadi_clear(adj_u, nuq_);
if (m->k == 0) {
casadi_clear(adj_x, nrx_);
casadi_clear(rq, nrq_);
casadi_clear(adj_p, nrq_);
}
}
}
Expand Down Expand Up @@ -691,7 +691,7 @@ void Integrator::init(const Dict& opts) {
alloc_w(nx_, true); // ode
alloc_w(nrx_ + nrz_, true); // adj_x, rz
alloc_w(nrx_, true); // adj_ode
alloc_w(nrq_, true); // rq
alloc_w(nrq_, true); // adj_p
alloc_w(nx_+nz_); // Sparsity::sp_solve
alloc_w(nrx_+nrz_); // Sparsity::sp_solve
alloc_w(np_, true); // p
Expand All @@ -713,7 +713,7 @@ void Integrator::set_work(void* mem, const double**& arg, double**& res,
m->adj_x = w; w += nrx_; // doubles as adj_xz
m->adj_z = w; w += nrz_;
m->adj_ode = w; w += nrx_;
m->rq = w; w += nrq_;
m->adj_p = w; w += nrq_;
m->p = w; w += np_;
m->u = w; w += nu_;
m->adj_q = w; w += nrp_;
Expand Down Expand Up @@ -978,7 +978,7 @@ int Integrator::sp_forward(const bvec_t** arg, bvec_t** res,
bvec_t *adj_x = w; w += nrx_;
bvec_t *rz = w; w += nrz_;
bvec_t *adj_ode = w; w += nrx_;
bvec_t *rq = w; w += nrq_;
bvec_t *adj_p = w; w += nrq_;

// Memory struct for function calls below
SpForwardMem m = {arg, res, iw, w};
Expand Down Expand Up @@ -1049,10 +1049,10 @@ int Integrator::sp_forward(const bvec_t** arg, bvec_t** res,

// Propagate to quadratures
if ((nrq_ > 0 && adj_p0) || (nuq_ > 0 && adj_u)) {
if (bquad_sp_forward(&m, ode, alg, p, u, adj_x, rz, adj_qf, rq, adj_u)) return 1;
if (bquad_sp_forward(&m, ode, alg, p, u, adj_x, rz, adj_qf, adj_p, adj_u)) return 1;
// Sum contributions to adj_p0
if (adj_p0) {
for (casadi_int i = 0; i < nrq_; ++i) adj_p0[i] |= rq[i];
for (casadi_int i = 0; i < nrq_; ++i) adj_p0[i] |= adj_p[i];
}
}

Expand Down Expand Up @@ -1242,7 +1242,7 @@ int Integrator::sp_reverse(bvec_t** arg, bvec_t** res,
bvec_t *adj_x = w; w += nrx_;
bvec_t *rz = w; w += nrz_;
bvec_t *adj_ode = w; w += nrx_;
bvec_t *rq = w; w += nrq_;
bvec_t *adj_p = w; w += nrq_;

// Memory struct for function calls below
SpReverseMem m = {arg, res, iw, w};
Expand All @@ -1263,12 +1263,12 @@ int Integrator::sp_reverse(bvec_t** arg, bvec_t** res,
std::fill_n(rz, nrz_, 0);

// Save adj_p0: See note below
if (adj_p0) std::copy_n(adj_p0, nrq_, rq);
if (adj_p0) std::copy_n(adj_p0, nrq_, adj_p);

// Step backwards through backward problem
for (casadi_int k = 0; k < nt(); ++k) {
// Restore adj_p0: See note below
if (adj_p0) std::copy_n(rq, nrq_, adj_p0);
if (adj_p0) std::copy_n(adj_p, nrq_, adj_p0);

// Add impulse from adj_xf
if (adj_xf) {
Expand Down Expand Up @@ -1798,7 +1798,7 @@ void FixedStepIntegrator::init(const Dict& opts) {
// Work vectors, backward problem
alloc_w(nrv_, true); // rv
alloc_w(nuq_, true); // adj_u
alloc_w(nrq_, true); // rq_prev
alloc_w(nrq_, true); // adj_p_prev
alloc_w(nuq_, true); // adj_u_prev

// Allocate tape if backward states are present
Expand All @@ -1824,7 +1824,7 @@ void FixedStepIntegrator::set_work(void* mem, const double**& arg, double**& res
// Work vectors, backward problem
m->rv = w; w += nrv_;
m->adj_u = w; w += nuq_;
m->rq_prev = w; w += nrq_;
m->adj_p_prev = w; w += nrq_;
m->adj_u_prev = w; w += nuq_;

// Allocate tape if backward states are present
Expand Down Expand Up @@ -1881,7 +1881,7 @@ void FixedStepIntegrator::advance(IntegratorMemory* mem,
}

void FixedStepIntegrator::retreat(IntegratorMemory* mem, const double* u,
double* adj_x, double* rq, double* adj_u) const {
double* adj_x, double* adj_p, double* adj_u) const {
auto m = static_cast<FixedStepMemory*>(mem);

// Set controls
Expand All @@ -1898,23 +1898,23 @@ void FixedStepIntegrator::retreat(IntegratorMemory* mem, const double* u,

// Update the previous step
casadi_copy(m->adj_x, nrx_, m->adj_ode);
casadi_copy(m->rq, nrq_, m->rq_prev);
casadi_copy(m->adj_p, nrq_, m->adj_p_prev);
casadi_copy(m->adj_u, nuq_, m->adj_u_prev);

// Take step
casadi_int tapeind = disc_[m->k] + j;
stepB(m, t, h,
m->x_tape + nx_ * tapeind, m->x_tape + nx_ * (tapeind + 1),
m->v_tape + nv_ * tapeind,
m->adj_ode, m->rv, m->adj_x, m->rq, m->adj_u);
m->adj_ode, m->rv, m->adj_x, m->adj_p, m->adj_u);
casadi_clear(m->rv, nrv_);
casadi_axpy(nrq_, 1., m->rq_prev, m->rq);
casadi_axpy(nrq_, 1., m->adj_p_prev, m->adj_p);
casadi_axpy(nuq_, 1., m->adj_u_prev, m->adj_u);
}

// Return to user
casadi_copy(m->adj_x, nrx_, adj_x);
casadi_copy(m->rq, nrq_, rq);
casadi_copy(m->adj_p, nrq_, adj_p);
casadi_copy(m->adj_u, nuq_, adj_u);
}

Expand Down Expand Up @@ -2034,7 +2034,7 @@ void FixedStepIntegrator::resetB(IntegratorMemory* mem) const {
casadi_clear(m->adj_x, nrx_);

// Reset summation states
casadi_clear(m->rq, nrq_);
casadi_clear(m->adj_p, nrq_);
casadi_clear(m->adj_u, nuq_);

// Update backwards dependent variables
Expand Down
8 changes: 4 additions & 4 deletions casadi/core/integrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace casadi {
\identifier{1lp} */
struct CASADI_EXPORT IntegratorMemory : public OracleMemory {
// Work vectors
double *x, *z, *adj_x, *adj_z, *rq, *ode, *adj_ode, *p, *u, *adj_q;
double *x, *z, *adj_x, *adj_z, *adj_p, *ode, *adj_ode, *p, *u, *adj_q;
// Current control interval
casadi_int k;
// Current time
Expand Down Expand Up @@ -177,7 +177,7 @@ Integrator : public OracleFunction, public PluginInterface<Integrator> {
\identifier{25g} */
virtual void retreat(IntegratorMemory* mem, const double* u,
double* adj_x, double* rq, double* adj_u) const = 0;
double* adj_x, double* adj_p, double* adj_u) const = 0;

/** \brief evaluate
Expand Down Expand Up @@ -473,7 +473,7 @@ struct CASADI_EXPORT FixedStepMemory : public IntegratorMemory {
double *v, *q, *v_prev, *q_prev;

/// Work vectors, backward problem
double *rv, *adj_u, *rq_prev, *adj_u_prev;
double *rv, *adj_u, *adj_p_prev, *adj_u_prev;

/// State and dependent variables at all times
double *x_tape, *v_tape;
Expand Down Expand Up @@ -550,7 +550,7 @@ class CASADI_EXPORT FixedStepIntegrator : public Integrator {
\identifier{25k} */
void retreat(IntegratorMemory* mem, const double* u,
double* adj_x, double* rq, double* adj_u) const override;
double* adj_x, double* adj_p, double* adj_u) const override;

/// Take integrator step forward
void stepF(FixedStepMemory* m, double t, double h,
Expand Down
8 changes: 4 additions & 4 deletions casadi/interfaces/sundials/cvodes_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ void CvodesInterface::impulseB(IntegratorMemory* mem,
}

void CvodesInterface::retreat(IntegratorMemory* mem, const double* u,
double* rx, double* rq, double* uq) const {
double* adj_x, double* adj_p, double* adj_u) const {
auto m = to_mem(mem);

// Set controls
Expand All @@ -365,9 +365,9 @@ void CvodesInterface::retreat(IntegratorMemory* mem, const double* u,
}

// Save outputs
casadi_copy(NV_DATA_S(m->rxz), nrx_, rx);
casadi_copy(NV_DATA_S(m->ruq), nrq_, rq);
casadi_copy(NV_DATA_S(m->ruq) + nrq_, nuq_, uq);
casadi_copy(NV_DATA_S(m->rxz), nrx_, adj_x);
casadi_copy(NV_DATA_S(m->ruq), nrq_, adj_p);
casadi_copy(NV_DATA_S(m->ruq) + nrq_, nuq_, adj_u);

// Get stats
CVodeMem cv_mem = static_cast<CVodeMem>(m->mem);
Expand Down
2 changes: 1 addition & 1 deletion casadi/interfaces/sundials/cvodes_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class CASADI_INTEGRATOR_CVODES_EXPORT CvodesInterface : public SundialsInterface

/** \brief Retreat solution in time */
void retreat(IntegratorMemory* mem, const double* u,
double* rx, double* rq, double* uq) const override;
double* adj_x, double* adj_p, double* adj_u) const override;

/** \brief Cast to memory object */
static CvodesMemory* to_mem(void *mem) {
Expand Down
8 changes: 4 additions & 4 deletions casadi/interfaces/sundials/idas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ void IdasInterface::impulseB(IntegratorMemory* mem,
}

void IdasInterface::retreat(IntegratorMemory* mem, const double* u,
double* rx, double* rq, double* uq) const {
double* adj_x, double* adj_p, double* adj_u) const {
auto m = to_mem(mem);

// Set controls
Expand All @@ -569,9 +569,9 @@ void IdasInterface::retreat(IntegratorMemory* mem, const double* u,
}

// Save outputs
casadi_copy(NV_DATA_S(m->rxz), nrx_, rx);
casadi_copy(NV_DATA_S(m->ruq), nrq_, rq);
casadi_copy(NV_DATA_S(m->ruq) + nrq_, nuq_, uq);
casadi_copy(NV_DATA_S(m->rxz), nrx_, adj_x);
casadi_copy(NV_DATA_S(m->ruq), nrq_, adj_p);
casadi_copy(NV_DATA_S(m->ruq) + nrq_, nuq_, adj_u);

// Get stats
IDAMem IDA_mem = IDAMem(m->mem);
Expand Down
2 changes: 1 addition & 1 deletion casadi/interfaces/sundials/idas_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class CASADI_INTEGRATOR_IDAS_EXPORT IdasInterface : public SundialsInterface {

/** \brief Retreat solution in time */
void retreat(IntegratorMemory* mem, const double* u,
double* rx, double* rq, double* uq) const override;
double* adj_x, double* adj_p, double* adj_u) const override;

/** \brief Cast to memory object */
static IdasMemory* to_mem(void *mem) {
Expand Down

0 comments on commit fa72bf2

Please sign in to comment.