Skip to content

Commit

Permalink
Issue #3682 Adding zero crossing function to DaeBuilder DAE function
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeandersson committed May 6, 2024
1 parent 6b8d7c6 commit 1f743d0
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 3 deletions.
25 changes: 25 additions & 0 deletions casadi/core/dae_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,18 @@ std::vector<std::string> DaeBuilder::q() const {
return name((*this)->q_);
}

std::vector<std::string> DaeBuilder::e() const {
return name((*this)->e_);
}

std::vector<MX> DaeBuilder::quad() const {
return (*this)->quad();
}

std::vector<MX> DaeBuilder::zero() const {
return (*this)->zero();
}

std::vector<std::string> DaeBuilder::y() const {
return name((*this)->y_);
}
Expand Down Expand Up @@ -191,6 +199,10 @@ casadi_int DaeBuilder::nq() const {
return (*this)->q_.size();
}

casadi_int DaeBuilder::ne() const {
return (*this)->e_.size();
}

casadi_int DaeBuilder::ny() const {
return (*this)->y_.size();
}
Expand Down Expand Up @@ -388,6 +400,10 @@ void DaeBuilder::register_y(const std::string& name) {
(*this)->y_.push_back(find(name));
}

void DaeBuilder::register_e(const std::string& name) {
(*this)->e_.push_back(find(name));
}

void DaeBuilder::clear_all(const std::string& v) {
try {
(*this)->clear_all(v);
Expand Down Expand Up @@ -492,6 +508,15 @@ MX DaeBuilder::add_y(const std::string& name, const MX& new_ydef) {
}
}

MX DaeBuilder::add_e(const std::string& name, const MX& new_edef) {
try {
return (*this)->add_e(name, new_edef);
} catch (std::exception& e) {
THROW_ERROR("add_e", e.what());
return MX();
}
}

void DaeBuilder::add_when(const MX& cond, const MX& lhs, const MX& rhs) {
(*this)->when_cond_.push_back(cond);
(*this)->when_lhs_.push_back(lhs);
Expand Down
13 changes: 13 additions & 0 deletions casadi/core/dae_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,17 @@ class CASADI_EXPORT DaeBuilder
\identifier{5j} */
std::vector<std::string> q() const;

/** \brief Event indicators */
std::vector<std::string> e() const;

/** \brief Quadrature equations
\identifier{5k} */
std::vector<MX> quad() const;

/** \brief Zero-crossing functions */
std::vector<MX> zero() const;

/** \brief Output variables
\identifier{5l} */
Expand Down Expand Up @@ -236,6 +242,9 @@ class CASADI_EXPORT DaeBuilder
\identifier{67} */
casadi_int nq() const;

/** \brief Event indicators */
casadi_int ne() const;

/** \brief Output variables
\identifier{68} */
Expand Down Expand Up @@ -301,6 +310,9 @@ class CASADI_EXPORT DaeBuilder
/// Add a new output
MX add_y(const std::string& name, const MX& new_ydef);

/// Add a new event indicator
MX add_e(const std::string& name, const MX& new_edef);

/// Specify the ordinary differential equation for a state
void set_ode(const std::string& name, const MX& ode_rhs);

Expand Down Expand Up @@ -345,6 +357,7 @@ class CASADI_EXPORT DaeBuilder
void register_d(const std::string& name);
void register_w(const std::string& name);
void register_y(const std::string& name);
void register_e(const std::string& name);
///@}

/** @name Manipulation
Expand Down
27 changes: 25 additions & 2 deletions casadi/core/dae_builder_internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,7 @@ std::string to_string(DaeBuilderInternal::DaeBuilderInternalIn v) {
case DaeBuilderInternal::DAE_BUILDER_D: return "d";
case DaeBuilderInternal::DAE_BUILDER_W: return "w";
case DaeBuilderInternal::DAE_BUILDER_Y: return "y";
case DaeBuilderInternal::DAE_BUILDER_E: return "e";
default: break;
}
return "";
Expand All @@ -1574,6 +1575,7 @@ std::string to_string(DaeBuilderInternal::DaeBuilderInternalOut v) {
case DaeBuilderInternal::DAE_BUILDER_ODE: return "ode";
case DaeBuilderInternal::DAE_BUILDER_ALG: return "alg";
case DaeBuilderInternal::DAE_BUILDER_QUAD: return "quad";
case DaeBuilderInternal::DAE_BUILDER_ZERO: return "zero";
case DaeBuilderInternal::DAE_BUILDER_DDEF: return "ddef";
case DaeBuilderInternal::DAE_BUILDER_WDEF: return "wdef";
case DaeBuilderInternal::DAE_BUILDER_YDEF: return "ydef";
Expand All @@ -1594,6 +1596,7 @@ std::vector<MX> DaeBuilderInternal::input(DaeBuilderInternalIn ind) const {
case DAE_BUILDER_Z: return var(z_);
case DAE_BUILDER_Q: return var(q_);
case DAE_BUILDER_Y: return var(y_);
case DAE_BUILDER_E: return var(e_);
default: return std::vector<MX>{};
}
}
Expand All @@ -1611,6 +1614,7 @@ std::vector<MX> DaeBuilderInternal::output(DaeBuilderInternalOut ind) const {
case DAE_BUILDER_ODE: return ode();
case DAE_BUILDER_ALG: return alg();
case DAE_BUILDER_QUAD: return quad();
case DAE_BUILDER_ZERO: return zero();
case DAE_BUILDER_DDEF: return ddef();
case DAE_BUILDER_WDEF: return wdef();
case DAE_BUILDER_YDEF: return ydef();
Expand Down Expand Up @@ -2395,6 +2399,14 @@ std::vector<MX> DaeBuilderInternal::quad() const {
return ret;
}


std::vector<MX> DaeBuilderInternal::zero() const {
std::vector<MX> ret;
ret.reserve(e_.size());
for (size_t v : e_) ret.push_back(variable(v).beq);
return ret;
}

std::vector<MX> DaeBuilderInternal::init_lhs() const {
std::vector<MX> ret;
ret.reserve(init_.size());
Expand Down Expand Up @@ -2504,6 +2516,15 @@ MX DaeBuilderInternal::add_y(const std::string& name, const MX& new_ydef) {
return v.v;
}

MX DaeBuilderInternal::add_e(const std::string& name, const MX& new_edef) {
Variable& v = new_variable(name);
v.v = MX::sym(name);
v.causality = Causality::OUTPUT;
v.beq = new_edef;
e_.push_back(v.index);
return v.v;
}

void DaeBuilderInternal::set_ode(const std::string& name, const MX& ode_rhs) {
// Find the state variable
const Variable& x = variable(name);
Expand Down Expand Up @@ -2809,22 +2830,24 @@ void DaeBuilderInternal::import_dynamic_equations(const XmlNode& eqs) {
// Not implemented
casadi_error(n_equ[0].name + " in when equation not supported");
}
// Hack: Turn "when" boolean expression into a zero-crossing expression
// Turn non-snooth zero-crossing expression into a smooth zero-crossing expression
MX zc;
switch (cond.beq.op()) {
case OP_LT:
// x1 < x2 <=> x2 - x1 > 0
zc = cond.beq.dep(1) - cond.beq.dep(0);
break;
default:
casadi_error("Cannot turn " + str(cond.beq) + " into a zero-crossing expression");
casadi_error("Cannot turn " + str(cond.beq) + " into a smooth expression");
}
set_init(cond.name, MX()); // remove initial conditions, if any
auto w_it = std::find(w_.begin(), w_.end(), cond.index);
if (w_it != w_.end()) w_.erase(w_it); // remove from dependent equations
auto x_it = std::find(x_.begin(), x_.end(), cond.index);
if (x_it != x_.end()) x_.erase(x_it); // remove from states

// Create event indicator
add_e(cond.name + "_smooth", zc);
// Add to list of when equations
when_cond_.push_back(zc);
when_lhs_.push_back(lhs);
Expand Down
8 changes: 7 additions & 1 deletion casadi/core/dae_builder_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ class CASADI_EXPORT DaeBuilderInternal : public SharedObjectInternal {
DAE_BUILDER_Z,
DAE_BUILDER_Q,
DAE_BUILDER_Y,
DAE_BUILDER_E,
DAE_BUILDER_NUM_IN
};

Expand All @@ -293,6 +294,7 @@ class CASADI_EXPORT DaeBuilderInternal : public SharedObjectInternal {
DAE_BUILDER_ODE,
DAE_BUILDER_ALG,
DAE_BUILDER_QUAD,
DAE_BUILDER_ZERO,
DAE_BUILDER_DDEF,
DAE_BUILDER_WDEF,
DAE_BUILDER_YDEF,
Expand Down Expand Up @@ -454,7 +456,7 @@ class CASADI_EXPORT DaeBuilderInternal : public SharedObjectInternal {
std::unordered_map<std::string, size_t> varind_;

/// Ordered variables
std::vector<size_t> t_, p_, u_, x_, z_, q_, c_, d_, w_, y_;
std::vector<size_t> t_, p_, u_, x_, z_, q_, c_, d_, w_, y_, e_;

// Initial equations
std::vector<size_t> init_;
Expand Down Expand Up @@ -499,6 +501,9 @@ class CASADI_EXPORT DaeBuilderInternal : public SharedObjectInternal {
\identifier{10} */
std::vector<MX> quad() const;

/** \brief Zero crossing functions */
std::vector<MX> zero() const;

/** \brief Initial conditions, left-hand-side */
std::vector<MX> init_lhs() const;

Expand All @@ -517,6 +522,7 @@ class CASADI_EXPORT DaeBuilderInternal : public SharedObjectInternal {
MX add_d(const std::string& name, const MX& new_ddef);
MX add_w(const std::string& name, const MX& new_wdef);
MX add_y(const std::string& name, const MX& new_ydef);
MX add_e(const std::string& name, const MX& new_edef);
///@}

///@{
Expand Down
8 changes: 8 additions & 0 deletions casadi/core/integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,12 @@ void Integrator::init(const Dict& opts) {
nq1_ = oracle_.numel_out(DYN_QUAD);
np1_ = oracle_.numel_in(DYN_P);
nu1_ = oracle_.numel_in(DYN_U);
ne_ = oracle_.numel_out(DYN_ZERO);

// Event support not implemented
if (ne_ > 0) {
casadi_warning("Event support has not yet been implemented");
}

// Consistency checks
casadi_assert(nx1_ > 0, "Ill-posed ODE - no state");
Expand Down Expand Up @@ -650,6 +656,8 @@ void Integrator::init(const Dict& opts) {
create_forward("daeF", 1);
if (nq_ > 0) create_forward("quadF", 1);
}
// Zero-crossing function
if (ne_ > 0) create_function("zero", dyn_in(), zero_out());

// Create problem functions, backward problem
if (nadj_ > 0) {
Expand Down
5 changes: 5 additions & 0 deletions casadi/core/integrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ Integrator : public OracleFunction, public PluginInterface<Integrator> {
static std::vector<std::string> dae_out() { return {"ode", "alg"}; }
enum QuadOut { QUAD_QUAD, QUAD_NUM_OUT};
static std::vector<std::string> quad_out() { return {"quad"}; }
enum ZeroOut { ZERO_ZERO, ZERO_NUM_OUT};
static std::vector<std::string> zero_out() { return {"zero"}; }
enum BDynIn { BDYN_T, BDYN_X, BDYN_Z, BDYN_P, BDYN_U,
BDYN_OUT_ODE, BDYN_OUT_ALG, BDYN_OUT_QUAD, BDYN_OUT_ZERO,
BDYN_ADJ_ODE, BDYN_ADJ_ALG, BDYN_ADJ_QUAD, BDYN_ADJ_ZERO, BDYN_NUM_IN};
Expand Down Expand Up @@ -335,6 +337,9 @@ Integrator : public OracleFunction, public PluginInterface<Integrator> {
/// Number of controls
casadi_int nu_, nu1_;

/// Number of of zero-crossing functions
casadi_int ne_;

// Nominal values for states
std::vector<double> nom_x_, nom_z_;

Expand Down

0 comments on commit 1f743d0

Please sign in to comment.