diff --git a/casadi/core/integrator.cpp b/casadi/core/integrator.cpp index fc7a7412b3..10331fc641 100644 --- a/casadi/core/integrator.cpp +++ b/casadi/core/integrator.cpp @@ -689,14 +689,16 @@ void Integrator::init(const Dict& opts) { // Work vectors for sparsity pattern propagation: Can be reused in derived classes alloc_w(nx_ + nz_, true); // x, z alloc_w(nx_, true); // ode - alloc_w(nrx_ + nrz_, true); // adj_x, rz - alloc_w(nrx_, true); // adj_ode - alloc_w(nrq_, true); // adj_p - alloc_w(nx_+nz_); // Sparsity::sp_solve - alloc_w(nrx_+nrz_); // Sparsity::sp_solve alloc_w(np_, true); // p alloc_w(nu_, true); // u - alloc_w(nrp_, true); // rp + + alloc_w(nrx_ + nrz_, true); // adj_x, adj_z + alloc_w(nrx_, true); // adj_ode + alloc_w(nrq_, true); // adj_p + alloc_w(nrp_, true); // adj_q + + alloc_w(nx_ + nz_); // Sparsity::sp_solve + alloc_w(nrx_ + nrz_); // Sparsity::sp_solve } void Integrator::set_work(void* mem, const double**& arg, double**& res, @@ -710,12 +712,13 @@ void Integrator::set_work(void* mem, const double**& arg, double**& res, m->x = w; w += nx_; // doubles as xz m->z = w; w += nz_; m->ode = w; w += nx_; + m->p = w; w += np_; + m->u = w; w += nu_; + m->adj_x = w; w += nrx_; // doubles as adj_xz m->adj_z = w; w += nrz_; m->adj_ode = w; w += nrx_; m->adj_p = w; w += nrq_; - m->p = w; w += np_; - m->u = w; w += nu_; m->adj_q = w; w += nrp_; } @@ -1120,7 +1123,7 @@ int Integrator::fquad_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z, } int Integrator::bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z, - bvec_t* p, bvec_t* u, bvec_t* adj_ode, bvec_t* rp, + bvec_t* p, bvec_t* u, bvec_t* adj_ode, bvec_t* adj_quad, bvec_t* adj_x, bvec_t* adj_z) const { // Nondifferentiated inputs m->arg[BDYN_T] = nullptr; // t @@ -1134,7 +1137,7 @@ int Integrator::bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z, m->arg[BDYN_OUT_ZERO] = nullptr; // out_zero m->arg[BDYN_ADJ_ODE] = adj_ode; // adj_ode m->arg[BDYN_ADJ_ALG] = nullptr; // adj_alg - m->arg[BDYN_ADJ_QUAD] = rp; // adj_quad + m->arg[BDYN_ADJ_QUAD] = adj_quad; // adj_quad m->arg[BDYN_ADJ_ZERO] = nullptr; // adj_zero // Propagate through sensitivities for (casadi_int i = 0; i < nfwd_; ++i) { @@ -1155,7 +1158,7 @@ int Integrator::bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z, adj_ode + (i + 1) * nrx1_ * nadj_; // fwd:adj_ode m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_ADJ_ALG] = nullptr; // fwd:adj_alg m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_ADJ_QUAD] = - rp + (i + 1) * nrz1_ * nadj_; // fwd:adj_quad + adj_quad + (i + 1) * nrz1_ * nadj_; // fwd:adj_quad m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_ADJ_ZERO] = nullptr; // fwd:adj_zero if (calc_sp_reverse(forward_name("daeB", 1), m->arg, m->res, m->iw, m->w)) return 1; } @@ -1223,7 +1226,7 @@ int Integrator::sp_reverse(bvec_t** arg, bvec_t** res, bvec_t* p = arg[INTEGRATOR_P]; bvec_t* u = arg[INTEGRATOR_U]; bvec_t* adj_xf = arg[INTEGRATOR_ADJ_XF]; - bvec_t* rp = arg[INTEGRATOR_ADJ_QF]; + bvec_t* adj_qf = arg[INTEGRATOR_ADJ_QF]; arg += n_in_; // Outputs @@ -1278,19 +1281,19 @@ int Integrator::sp_reverse(bvec_t** arg, bvec_t** res, // Get dependencies from backward quadratures if ((nrq_ > 0 && adj_p0) || (nuq_ > 0 && adj_u)) { - if (bquad_sp_reverse(&m, ode, alg, p, u, adj_x, adj_z, rp, adj_p0, adj_u)) return 1; + if (bquad_sp_reverse(&m, ode, alg, p, u, adj_x, adj_z, adj_qf, adj_p0, adj_u)) return 1; } // Propagate interdependencies std::fill_n(w, nrx_ + nrz_, 0); sp_jac_rdae_.spsolve(w, adj_x, true); - std::copy_n(w, nrx_+nrz_, adj_x); + std::copy_n(w, nrx_ + nrz_, adj_x); // Direct dependency adj_ode -> adj_x std::copy_n(adj_x, nrx_, adj_ode); // Indirect dependency via g - if (bdae_sp_reverse(&m, ode, alg, p, u, adj_ode, rp, adj_x, adj_z)) return 1; + if (bdae_sp_reverse(&m, ode, alg, p, u, adj_ode, adj_qf, adj_x, adj_z)) return 1; // Update adj_x, adj_z std::copy_n(adj_ode, nrx_, adj_x); @@ -1298,7 +1301,7 @@ int Integrator::sp_reverse(bvec_t** arg, bvec_t** res, // Shift time if (adj_xf) adj_xf += nrx_; - if (rp) rp += nrp_; + if (adj_qf) adj_qf += nrp_; if (adj_u) adj_u += nuq_; if (u) u += nu_; }