Skip to content

Commit

Permalink
Issue #3682 Implemented prediction of zero crossings
Browse files Browse the repository at this point in the history
By linearizing zero crossing function in the time direction.
Not implemented for algebraic states, which would require the calculation of z_dot
using the implicit function theorem
  • Loading branch information
jaeandersson committed May 8, 2024
1 parent 8939d61 commit e9f983f
Showing 1 changed file with 78 additions and 10 deletions.
88 changes: 78 additions & 10 deletions casadi/core/integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,11 @@ int Integrator::eval(const double** arg, double** res,
for (m->k = 0; m->k < nt(); ++m->k) {
// Update stopping time, if needed
if (m->k > k_stop) k_stop = next_stop(m->k, u);
// Advance solution
m->t_next = tout_[m->k];
m->t_stop = tout_[k_stop];
// Events handling
if (next_event(m, p, u)) return 1;
// Advance solution
if (verbose_) casadi_message("Integrating forward to output time " + str(m->k) + ": t_next = "
+ str(m->t_next) + ", t_stop = " + str(m->t_stop));
advance(m, u, x, z, q);
Expand Down Expand Up @@ -645,9 +647,18 @@ void Integrator::init(const Dict& opts) {
nrq_ = nrq1_ * nadj_ * (1 + nfwd_);
nuq_ = nuq1_ * nadj_ * (1 + nfwd_);

// Length of tmp1, tmp2 vectors
ntmp_ = nx_ + nz_;
ntmp_ = std::max(ntmp_, nrx_ + nrz_);
ntmp_ = std::max(ntmp_, ne_);

// Call the base class method
OracleFunction::init(opts);

// Instantiate functions, forward and backward problem
set_function(oracle_, "dae");
if (nadj_ > 0) set_function(rdae_, "rdae");

// Create problem functions, forward problem
create_function("daeF", dyn_in(), dae_out());
if (nq_ > 0) create_function("quadF", dyn_in(), quad_out());
Expand All @@ -656,8 +667,10 @@ void Integrator::init(const Dict& opts) {
create_forward("daeF", 1);
if (nq_ > 0) create_forward("quadF", 1);
}
// Zero-crossing function
if (ne_ > 0) create_function("zero", dyn_in(), zero_out());
// Event detection requires linearization of the zero-crossing function in the time direction
if (ne_ > 0) {
create_forward("dae", 1);
}

// Create problem functions, backward problem
if (nadj_ > 0) {
Expand Down Expand Up @@ -689,12 +702,13 @@ void Integrator::init(const Dict& opts) {
alloc_w(nx_ + nz_, true); // x, z
alloc_w(np_, true); // p
alloc_w(nu_, true); // u
alloc_w(ne_, true); // e

alloc_w(nrx_ + nrz_, true); // adj_x, adj_z
alloc_w(nrq_, true); // adj_p
alloc_w(nrp_, true); // adj_q

alloc_w(2 * std::max(nx_ + nz_, nrx_ + nrz_), true); // tmp1, tmp2
alloc_w(2 * ntmp_, true); // tmp1, tmp2

alloc_w(nx_ + nz_); // Sparsity::sp_solve
alloc_w(nrx_ + nrz_); // Sparsity::sp_solve
Expand All @@ -712,14 +726,15 @@ void Integrator::set_work(void* mem, const double**& arg, double**& res,
m->z = w; w += nz_;
m->p = w; w += np_;
m->u = w; w += nu_;
m->e = w; w += ne_;

m->adj_x = w; w += nrx_; // doubles as adj_xz
m->adj_z = w; w += nrz_;
m->adj_p = w; w += nrq_;
m->adj_q = w; w += nrp_;

m->tmp1 = w; w += std::max(nx_ + nz_, nrx_ + nrz_);
m->tmp2 = w; w += std::max(nx_ + nz_, nrx_ + nrz_);
m->tmp1 = w; w += ntmp_;
m->tmp2 = w; w += ntmp_;
}


Expand Down Expand Up @@ -1760,10 +1775,6 @@ void FixedStepIntegrator::init(const Dict& opts) {
// Call the base class init
Integrator::init(opts);

// Instantiate functions, forward and backward problem
set_function(oracle_, "dae");
if (nadj_ > 0) set_function(rdae_, "rdae");

// Read options
for (auto&& op : opts) {
if (op.first=="number_of_finite_elements") {
Expand Down Expand Up @@ -2336,6 +2347,63 @@ casadi_int Integrator::next_stop(casadi_int k, const double* u) const {
return k;
}

int Integrator::next_event(IntegratorMemory* m, const double* p, const double* u) const {
// Event time same as stopping time, by default
m->t_event = m->t_stop;
m->event_index = -1;
// Quick return if no events
if (ne_ == 0) return 0;
// Evaluate the DAE and zero crossing function
m->arg[DYN_T] = &m->t; // t
m->arg[DYN_X] = m->x; // x
m->arg[DYN_Z] = m->z; // z
m->arg[DYN_P] = p; // p
m->arg[DYN_U] = u; // u
m->res[DYN_ODE] = m->tmp1; // ode
m->res[DYN_ALG] = m->tmp1 + nx_; // alg
m->res[DYN_QUAD] = nullptr; // quad
m->res[DYN_ZERO] = m->e; // quad
if (calc_function(m, "dae")) return 1;
// Calculate de_dt using by forward mode AD applied to zero crossing function
// Note: Currently ignoring dependency propagation via algebraic equations
double dt_dt = 1;
double *de_dt = m->tmp2;
m->arg[DYN_NUM_IN + DYN_ODE] = m->tmp1; // out:ode
m->arg[DYN_NUM_IN + DYN_ALG] = m->tmp1 + nx_; // out:alg
m->arg[DYN_NUM_IN + DYN_QUAD] = nullptr; // out:quad
m->arg[DYN_NUM_IN + DYN_ZERO] = m->e; // out:zero
m->arg[DYN_NUM_IN + DYN_NUM_OUT + DYN_T] = &dt_dt; // fwd:t
m->arg[DYN_NUM_IN + DYN_NUM_OUT + DYN_X] = m->tmp1; // fwd:x
m->arg[DYN_NUM_IN + DYN_NUM_OUT + DYN_Z] = nullptr; // fwd:z
m->arg[DYN_NUM_IN + DYN_NUM_OUT + DYN_P] = nullptr; // fwd:p
m->arg[DYN_NUM_IN + DYN_NUM_OUT + DYN_U] = nullptr; // fwd:u
m->res[DYN_ODE] = nullptr; // fwd:ode
m->res[DYN_ALG] = nullptr; // fwd:alg
m->res[DYN_QUAD] = nullptr; // fwd:quad
m->res[DYN_ZERO] = de_dt; // fwd:zero
if (calc_function(m, forward_name("dae", 1))) return 1;
// Find the next event, if any
for (casadi_int i = 0; i < ne_; ++i) {
// Check if zero crossing function is negative and moving in the positive direction
if (m->e[i] < 0 && de_dt[i] > 0) {
// Projected zero-crossing time
double t = m->t - m->e[i] / de_dt[i];
// Save if closer than current t_event
if (t < m->t_event) {
m->t_event = t;
m->event_index = i;
}
}
}
// Just print the results for now
if (m->event_index >= 0) {
casadi_warning("Projected zero crossing for index " + str(m->event_index)
+ " at t = " + str(m->t_event));
}

return 0;
}

casadi_int Integrator::next_stopB(casadi_int k, const double* u) const {
// Integrate till the beginning if no input signals
if (nu_ == 0 || u == 0) return -1;
Expand Down

0 comments on commit e9f983f

Please sign in to comment.