Skip to content

Commit

Permalink
Issue #3682 Refactored Integrator work vectors
Browse files Browse the repository at this point in the history
rp -> adj_qf
  • Loading branch information
jaeandersson committed May 8, 2024
1 parent fbd3a92 commit 9993291
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions casadi/core/integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,14 +689,16 @@ void Integrator::init(const Dict& opts) {
// Work vectors for sparsity pattern propagation: Can be reused in derived classes
alloc_w(nx_ + nz_, true); // x, z
alloc_w(nx_, true); // ode
alloc_w(nrx_ + nrz_, true); // adj_x, rz
alloc_w(nrx_, true); // adj_ode
alloc_w(nrq_, true); // adj_p
alloc_w(nx_+nz_); // Sparsity::sp_solve
alloc_w(nrx_+nrz_); // Sparsity::sp_solve
alloc_w(np_, true); // p
alloc_w(nu_, true); // u
alloc_w(nrp_, true); // rp

alloc_w(nrx_ + nrz_, true); // adj_x, adj_z
alloc_w(nrx_, true); // adj_ode
alloc_w(nrq_, true); // adj_p
alloc_w(nrp_, true); // adj_q

alloc_w(nx_ + nz_); // Sparsity::sp_solve
alloc_w(nrx_ + nrz_); // Sparsity::sp_solve
}

void Integrator::set_work(void* mem, const double**& arg, double**& res,
Expand All @@ -710,12 +712,13 @@ void Integrator::set_work(void* mem, const double**& arg, double**& res,
m->x = w; w += nx_; // doubles as xz
m->z = w; w += nz_;
m->ode = w; w += nx_;
m->p = w; w += np_;
m->u = w; w += nu_;

m->adj_x = w; w += nrx_; // doubles as adj_xz
m->adj_z = w; w += nrz_;
m->adj_ode = w; w += nrx_;
m->adj_p = w; w += nrq_;
m->p = w; w += np_;
m->u = w; w += nu_;
m->adj_q = w; w += nrp_;
}

Expand Down Expand Up @@ -1120,7 +1123,7 @@ int Integrator::fquad_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
}

int Integrator::bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
bvec_t* p, bvec_t* u, bvec_t* adj_ode, bvec_t* rp,
bvec_t* p, bvec_t* u, bvec_t* adj_ode, bvec_t* adj_quad,
bvec_t* adj_x, bvec_t* adj_z) const {
// Nondifferentiated inputs
m->arg[BDYN_T] = nullptr; // t
Expand All @@ -1134,7 +1137,7 @@ int Integrator::bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
m->arg[BDYN_OUT_ZERO] = nullptr; // out_zero
m->arg[BDYN_ADJ_ODE] = adj_ode; // adj_ode
m->arg[BDYN_ADJ_ALG] = nullptr; // adj_alg
m->arg[BDYN_ADJ_QUAD] = rp; // adj_quad
m->arg[BDYN_ADJ_QUAD] = adj_quad; // adj_quad
m->arg[BDYN_ADJ_ZERO] = nullptr; // adj_zero
// Propagate through sensitivities
for (casadi_int i = 0; i < nfwd_; ++i) {
Expand All @@ -1155,7 +1158,7 @@ int Integrator::bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
adj_ode + (i + 1) * nrx1_ * nadj_; // fwd:adj_ode
m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_ADJ_ALG] = nullptr; // fwd:adj_alg
m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_ADJ_QUAD] =
rp + (i + 1) * nrz1_ * nadj_; // fwd:adj_quad
adj_quad + (i + 1) * nrz1_ * nadj_; // fwd:adj_quad
m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_ADJ_ZERO] = nullptr; // fwd:adj_zero
if (calc_sp_reverse(forward_name("daeB", 1), m->arg, m->res, m->iw, m->w)) return 1;
}
Expand Down Expand Up @@ -1223,7 +1226,7 @@ int Integrator::sp_reverse(bvec_t** arg, bvec_t** res,
bvec_t* p = arg[INTEGRATOR_P];
bvec_t* u = arg[INTEGRATOR_U];
bvec_t* adj_xf = arg[INTEGRATOR_ADJ_XF];
bvec_t* rp = arg[INTEGRATOR_ADJ_QF];
bvec_t* adj_qf = arg[INTEGRATOR_ADJ_QF];
arg += n_in_;

// Outputs
Expand Down Expand Up @@ -1278,27 +1281,27 @@ int Integrator::sp_reverse(bvec_t** arg, bvec_t** res,

// Get dependencies from backward quadratures
if ((nrq_ > 0 && adj_p0) || (nuq_ > 0 && adj_u)) {
if (bquad_sp_reverse(&m, ode, alg, p, u, adj_x, adj_z, rp, adj_p0, adj_u)) return 1;
if (bquad_sp_reverse(&m, ode, alg, p, u, adj_x, adj_z, adj_qf, adj_p0, adj_u)) return 1;
}

// Propagate interdependencies
std::fill_n(w, nrx_ + nrz_, 0);
sp_jac_rdae_.spsolve(w, adj_x, true);
std::copy_n(w, nrx_+nrz_, adj_x);
std::copy_n(w, nrx_ + nrz_, adj_x);

// Direct dependency adj_ode -> adj_x
std::copy_n(adj_x, nrx_, adj_ode);

// Indirect dependency via g
if (bdae_sp_reverse(&m, ode, alg, p, u, adj_ode, rp, adj_x, adj_z)) return 1;
if (bdae_sp_reverse(&m, ode, alg, p, u, adj_ode, adj_qf, adj_x, adj_z)) return 1;

// Update adj_x, adj_z
std::copy_n(adj_ode, nrx_, adj_x);
std::fill_n(adj_z, nrz_, 0);

// Shift time
if (adj_xf) adj_xf += nrx_;
if (rp) rp += nrp_;
if (adj_qf) adj_qf += nrp_;
if (adj_u) adj_u += nuq_;
if (u) u += nu_;
}
Expand Down

0 comments on commit 9993291

Please sign in to comment.