Skip to content

Commit

Permalink
fix inhomogeneous ODE solver
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Jul 24, 2023
1 parent 89ab59d commit 5371e1c
Show file tree
Hide file tree
Showing 20 changed files with 292 additions and 163 deletions.
Binary file modified doc/fig/eq_analysis_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/fig/eq_analysis_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/fig/eq_analysis_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 3 additions & 5 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,7 @@ Analytic solver generation
If an ODE is constant-coefficient and linear, an analytic solution can be computed. Analytically solvable ODEs can also contain dependencies on other analytically solvable ODEs, but an otherwise analytically tractable ODE cannot depend on an ODE that can only be solved numerically. In the latter case, no analytic solution will be computed.
For example, consider an integrate-and-fire neuron with two alpha-shaped kernels (``I_shape_in`` and ``I_shape_gap``), and one nonlinear kernel (``I_shape_ex``). Each of these kernels can be expressed as a system of ODEs containing two variables. ``I_shape_in`` is specified as a second-order equation, whereas ``I_shape_gap`` is explicitly given as a system of two coupled first-order equations, i.e. as two separate ``dynamics`` entries with names ``I_shape_gap1`` and ``I_shape_gap2``.
Both formulations are mathematically equivalent, and ODE-toolbox treats them the same following input processing.
For example, consider an integrate-and-fire neuron with two alpha-shaped kernels (``I_shape_in`` and ``I_shape_gap``), and one nonlinear kernel (``I_shape_ex``). Each of these kernels can be expressed as a system of ODEs containing two variables. ``I_shape_in`` is specified as a second-order equation, whereas ``I_shape_gap`` is explicitly given as a system of two coupled first-order equations, i.e. as two separate ``dynamics`` entries with names ``I_shape_gap1`` and ``I_shape_gap2``. Both formulations are mathematically equivalent, and ODE-toolbox treats them the same following input processing. The membrane potential ``V_rel`` is expressed relative to zero, making it a homogeneous equation and one that could be analytically solved, if it were not for its dedependence on the quantity ``I_shape_ex`` which itself requires a numeric solver due to its nonlinear dynamics.
During processing, a dependency graph is generated, where each node corresponds to one dynamical variable, and an arrow from node *a* to *b* indicates that *a* depends on the value of *b*. Boxes enclosing nodes mark input shapes that were specified as either a direct function of time or a higher-order differential equation, and were expanded to a system of first-order ODEs.
Expand All @@ -399,7 +397,7 @@ Each variable is subsequently marked according to whether it can, by itself, be
<img src="https://raw.githubusercontent.com/nest/ode-toolbox/master/doc/fig/eq_analysis_1.png" alt="Dependency graph with membrane potential and excitatory and gap junction kernels marked green" width="720" height="383">
In the next step, variables are unmarked as analytically solvable if they depend on other variables that are themselves not analytically solvable. In this example, ``V_abs`` is unmarked as it depends on the nonlinear excitatory kernel.
In the next step, variables are unmarked as analytically solvable if they depend on other variables that are themselves not analytically solvable. In this example, ``V_rel`` is unmarked as it depends on the nonlinear excitatory kernel.
.. raw:: html
Expand Down Expand Up @@ -562,7 +560,7 @@ The file `test_analytic_solver_integration.py <https://github.com/nest/ode-toolb
.. raw:: html
<img src="https://raw.githubusercontent.com/nest/ode-toolbox/master/doc/fig/test_analytic_solver_integration.png" alt="V_abs, i_ex and i_ex' timeseries plots" width="620" height="465">
<img src="https://raw.githubusercontent.com/nest/ode-toolbox/master/doc/fig/test_analytic_solver_integration.png" alt="V_rel, i_ex and i_ex' timeseries plots" width="620" height="465">
The file `test_mixed_integrator_numeric.py <https://github.com/nest/ode-toolbox/blob/master/tests/test_mixed_integrator_numeric.py>`_ contains an integration test, that uses :py:class:`~odetoolbox.mixed_integrator.MixedIntegrator` and the results dictionary from ODE-toolbox to simulate the same integrate-and-fire neuron with alpha-shaped postsynaptic response, but purely numerically (without the use of propagators). In contrast to the :py:class:`~odetoolbox.analytic_integrator.AnalyticIntegrator`, enforcement of upper- and lower bounds is supported, as can be seen in the behaviour of :math:`V_m` in the plot that is generated:
Expand Down
74 changes: 58 additions & 16 deletions odetoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8

from .config import Config
from .sympy_helpers import SympyPrinter, _is_zero, _is_sympy_type
from .sympy_helpers import _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter
from .system_of_shapes import SystemOfShapes
from .shapes import MalformedInputException, Shape

Expand Down Expand Up @@ -55,19 +55,39 @@
sympy.Basic.__str__ = lambda self: SympyPrinter().doprint(self)


def _dependency_analysis(shape_sys, shapes, parameters=None):
def _find_analytically_solvable_equations(shape_sys, shapes, parameters=None):
r"""
Find which equations can be solved analytically (and, conversely, which cannot).
Perform dependency analysis and plot dependency graph.
"""
logging.info("Dependency analysis...")
logging.info("Finding analytically solvable equations...")
dependency_edges = shape_sys.get_dependency_edges()
node_is_lin = shape_sys.get_lin_cc_symbols(dependency_edges, parameters=parameters)

if PLOT_DEPENDENCY_GRAPH:
node_is_analytically_solvable = {sym: False for sym in list(shape_sys.x_)}
DependencyGraphPlotter.plot_graph(shapes, dependency_edges, node_is_analytically_solvable, fn="/tmp/ode_dependency_graph.dot")

node_is_analytically_solvable = shape_sys.get_lin_cc_symbols(dependency_edges, parameters=parameters)

if PLOT_DEPENDENCY_GRAPH:
DependencyGraphPlotter.plot_graph(shapes, dependency_edges, node_is_lin, fn="/tmp/ode_dependency_graph_before.dot")
node_is_lin = shape_sys.propagate_lin_cc_judgements(node_is_lin, dependency_edges)
DependencyGraphPlotter.plot_graph(shapes, dependency_edges, node_is_analytically_solvable, fn="/tmp/ode_dependency_graph_analytically_solvable_before_propagated.dot")

# remove inhomogeneous and order > 1 shapes from ``analytic_syms``
for i in range(len(shape_sys.x_)):
if not _is_zero(shape_sys.b_[i]) and shape_sys.shape_order_from_system_matrix(i) > 1:
analytic_syms = [sym for sym in analytic_syms if not sym in shape_sys.get_connected_symbols(i)]

for j in range(len(shape_sys.x_)):
if not i == j and not _is_zero(shape_sys.A_[i, j]) and not _is_zero(shape_sys.b_[_find_in_matrix(shape_sys.x_, shape_sys.x_[j])]) and shape_sys.x_[i] in analytic_syms:
# this shape depends on another ODE that is inhomogeneous -- can't be solved analytically by this version of ODE-toolbox
node_is_analytically_solvable[shape_sys.x_[i]] = False

node_is_analytically_solvable = shape_sys.propagate_lin_cc_judgements(node_is_analytically_solvable, dependency_edges)
if PLOT_DEPENDENCY_GRAPH:
DependencyGraphPlotter.plot_graph(shapes, dependency_edges, node_is_lin, fn="/tmp/ode_dependency_graph.dot")
return dependency_edges, node_is_lin
DependencyGraphPlotter.plot_graph(shapes, dependency_edges, node_is_analytically_solvable, fn="/tmp/ode_dependency_graph_analytically_solvable.dot")

return dependency_edges, node_is_analytically_solvable


def _read_global_config(indict):
Expand Down Expand Up @@ -202,25 +222,20 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
sys.exit(1)

shape_sys = SystemOfShapes.from_shapes(shapes, parameters=parameters)
dependency_edges, node_is_lin = _dependency_analysis(shape_sys, shapes, parameters=parameters)
_, node_is_analytically_solvable = _find_analytically_solvable_equations(shape_sys, shapes, parameters=parameters)


#
# generate analytical solutions (propagators) where possible
#

solvers_json = []
analytic_solver_json = None
if disable_analytic_solver:
analytic_syms = []
else:
analytic_syms = [node_sym for node_sym, _node_is_lin in node_is_lin.items() if _node_is_lin]

# remove inhomogeneous and order > 1 shapes from ``analytic_syms``
for i in range(shape_sys.A_.shape[0]):
if not _is_zero(shape_sys.b_[i]) and shape_sys.shape_order_from_system_matrix(i) > 1:
analytic_syms = [sym for sym in analytic_syms if not sym in shape_sys.get_connected_symbols(i)]
analytic_syms = [node_sym for node_sym, _node_is_analytically_solvable in node_is_analytically_solvable.items() if _node_is_analytically_solvable]

analytic_solver_json = None
if analytic_syms:
logging.info("Generating propagators for the following symbols: " + ", ".join([str(k) for k in analytic_syms]))
sub_sys = shape_sys.get_sub_system(analytic_syms)
Expand Down Expand Up @@ -346,6 +361,33 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so

logging.info("In ode-toolbox: returning outdict = ")
logging.info(json.dumps(solvers_json, indent=4, sort_keys=True))
# solvers_json = [
# {
# "initial_values": {
# "I_1": "0",
# "I_2": "0",
# "V_m": "0.0"
# },
# "propagators": {
# "__P__I_1__I_1": "exp(-__h/tau_1)",
# "__P__I_2__I_2": "exp(-__h/tau_2)",
# "__P__V_m__I_1": "tau_1*tau_m*(-exp(__h/tau_1) + exp(__h/tau_m))*exp(-__h*(tau_1 + tau_m)/(tau_1*tau_m))/(C_m*(tau_1 - tau_m))",
# "__P__V_m__I_2": "tau_2*tau_m*(-exp(__h/tau_2) + exp(__h/tau_m))*exp(-__h*(tau_2 + tau_m)/(tau_2*tau_m))/(C_m*(tau_2 - tau_m))",
# "__P__V_m__V_m": "exp(-__h/tau_m)"
# },
# "solver": "analytical",
# "state_variables": [
# "I_1",
# "I_2",
# "V_m"
# ],
# "update_expressions": {
# "I_1": "I_inj1*tau_1 + __P__I_1__I_1*(I_1 - I_inj1*tau_1)",
# "I_2": "I_inj2*tau_2 + __P__I_2__I_2*(I_2 - I_inj2*tau_2)",
# "V_m": "E_L + __P__V_m__I_1*(I_1 + I_inj1*tau_1) + __P__V_m__I_2*(I_2 + I_inj2*tau_2) - __P__V_m__V_m*(E_L - V_m)"
# }
# }
# ]

return solvers_json, shape_sys, shapes

Expand Down
9 changes: 5 additions & 4 deletions odetoolbox/analytic_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,12 @@ def __init__(self, solver_dict, spike_times: Optional[Dict[str, List[float]]] =


#
# perform substtitution in update expressions ahead of time to save time later
# perform substitution in update expressions ahead of time to save time later
#

for k, v in self.update_expressions.items():
self.update_expressions[k] = self.update_expressions[k].subs(self.subs_dict).subs(self.subs_dict)


#
# autowrap
#
Expand Down Expand Up @@ -212,7 +211,6 @@ def get_value(self, t):

all_spike_times, all_spike_times_sym = self.get_sorted_spike_times()


#
# process spikes between ⟨t_curr, t]
#
Expand Down Expand Up @@ -250,7 +248,6 @@ def get_value(self, t):
self.t_curr = t_curr
self.state_at_t_curr = state_at_t_curr


#
# apply propagator to update the state from `t_curr` to `t`
#
Expand All @@ -260,4 +257,8 @@ def get_value(self, t):
state_at_t_curr = self._update_step(delta_t, state_at_t_curr)
t_curr = t

#
# add inhomogeneous contribution
#

return state_at_t_curr
13 changes: 13 additions & 0 deletions odetoolbox/sympy_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ def _custom_simplify_expr(expr: str):
sys.exit(1)


def _find_in_matrix(A, el):
num_rows = A.rows
num_cols = A.cols

# Iterate over the elements of the matrix
for i in range(num_rows):
for j in range(num_cols):
if A[i, j] == el:
return (i, j)

return None


class SympyPrinter(sympy.printing.StrPrinter):

def _print_Exp1(self, expr):
Expand Down
36 changes: 23 additions & 13 deletions odetoolbox/system_of_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def generate_propagator_solver(self):
r"""
Generate the propagator matrix and symbolic expressions for propagator-based updates; return as JSON.
"""

#
# generate the propagator matrix
#
Expand All @@ -201,6 +202,12 @@ def generate_propagator_solver(self):
for cond in condition:
logging.warning("\t" + r" ∧ ".join([str(k) + " = " + str(v) for k, v in cond.items()]))

logging.info("System of equations:")
logging.info("x = " + str(self.x_))
logging.info("A = " + repr(self.A_))
logging.info("b = " + str(self.b_))
logging.info("c = " + str(self.c_))


#
# generate symbols for each nonzero entry of the propagator matrix
Expand All @@ -223,28 +230,31 @@ def generate_propagator_solver(self):
sym_str = "__P__{}__{}".format(str(self.x_[row]), str(self.x_[col]))
P_sym[row, col] = sympy.parsing.sympy_parser.parse_expr(sym_str, global_dict=Shape._sympy_globals)
P_expr[sym_str] = P[row, col]
if _is_zero(self.b_[col]):
# homogeneous ODE
update_expr_terms.append(sym_str + " * " + str(self.x_[col]))
else:
# inhomogeneous ODE
if _is_zero(self.A_[col, col]):
# of the form x' = const
update_expr_terms.append(sym_str + " * " + str(self.x_[col]) + " + " + Config().output_timestep_symbol + " * " + str(self.b_[col]))
else:
particular_solution = -self.b_[col] / self.A_[col, col]
update_expr_terms.append(sym_str + " * (" + str(self.x_[col]) + " - (" + str(particular_solution) + "))" + " + (" + str(particular_solution) + ")")
# if row != col and not _is_zero(self.b_[col]):
# # the ODE for x_[row] depends on the inhomogeneous ODE of x_[col]. We can't solve this analytically in the general case (even though some specific cases might admit a solution)
# raise PropagatorGenerationException("the ODE for " + str(self.x_[row]) + " depends on the inhomogeneous ODE of " + str(self.x_[col]) + ". We can't solve this analytically in the general case (even though some specific cases might admit a solution)")

update_expr_terms.append(sym_str + " * " + str(self.x_[col]))

if not _is_zero(self.b_[row]):
# this is an inhomogeneous ODE
if _is_zero(self.A_[row, row]):
# of the form x' = const
update_expr_terms.append(Config().output_timestep_symbol + " * " + str(self.b_[col]))
else:
particular_solution = -self.b_[row] / self.A_[row, row]
update_expr_terms.append("-" + sym_str + " * " + str(self.x_[col])) # remove the term (add its inverse) that would have corresponded to a homogeneous solution and that was added in the ``for col...`` loop above
update_expr_terms.append(sym_str + " * (" + str(self.x_[row]) + " - (" + str(particular_solution) + "))" + " + (" + str(particular_solution) + ")")

update_expr[str(self.x_[row])] = " + ".join(update_expr_terms)
update_expr[str(self.x_[row])] = sympy.parsing.sympy_parser.parse_expr(update_expr[str(self.x_[row])], global_dict=Shape._sympy_globals)
if not _is_zero(self.b_[row]):
# only simplify in case an inhomogeneous term is present
update_expr[str(self.x_[row])] = _custom_simplify_expr(update_expr[str(self.x_[row])])
logging.info("update_expr[" + str(self.x_[row]) + "] = " + str(update_expr[str(self.x_[row])]))

all_state_symbols = [str(sym) for sym in self.x_]

initial_values = {sym: str(self.get_initial_value(sym)) for sym in all_state_symbols}

solver_dict = {"propagators": P_expr,
"update_expressions": update_expr,
"state_variables": all_state_symbols,
Expand Down
16 changes: 16 additions & 0 deletions tests/iaf_psc_exp.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"dynamics": [
{
"expression": "I_1' = -I_1 / tau_1 + I_inj1",
"initial_value": "0"
},
{
"expression": "I_2' = -I_2 / tau_2 + I_inj2",
"initial_value": "0"
},
{
"expression": "V_m' = -(V_m - E_L) / tau_m + (I_1 + I_2) / C_m",
"initial_value": "0."
}
]
}
2 changes: 1 addition & 1 deletion tests/mixed_analytic_numerical_no_stiffness.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"initial_value": "e / Tau_syn_gap"
},
{
"expression": "V_abs' = -1/Tau*V_abs**2 + 1/C_m*(I_shape_in + I_shape_ex + I_shape_gap1 + I_e + currents)",
"expression": "V_rel' = -1/Tau*V_rel**2 + 1/C_m*(I_shape_in + I_shape_ex + I_shape_gap1 + I_e + currents)",
"initial_value": "0."
}
]
Expand Down
2 changes: 1 addition & 1 deletion tests/mixed_analytic_numerical_with_stiffness.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"initial_value": "e / Tau_syn_gap"
},
{
"expression": "V_abs' = -1/Tau*V_abs**2 + 1/C_m*(I_shape_in + I_shape_ex + I_shape_gap1 + I_e + currents)",
"expression": "V_rel' = -1/Tau*V_rel**2 + 1/C_m*(I_shape_in + I_shape_ex + I_shape_gap1 + I_e + currents)",
"initial_value": "0."
}
],
Expand Down
File renamed without changes.
22 changes: 5 additions & 17 deletions tests/test_analysis_mixed_analytic_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#

import json
import os
import unittest
import pytest

from tests.test_utils import _open_json

try:
import pygsl
PYGSL_AVAILABLE = True
Expand All @@ -33,30 +32,19 @@
from .context import odetoolbox


def open_json(fname):
absfname = os.path.join(os.path.abspath(os.path.dirname(__file__)), fname)
with open(absfname) as infile:
indict = json.load(infile)
return indict


class TestAnalysisMixedAnalyticNumerical(unittest.TestCase):
class TestAnalysisMixedAnalyticNumerical:

def test_mixed_analytic_numerical_no_stiffness(self):
indict = open_json("mixed_analytic_numerical_no_stiffness.json")
indict = _open_json("mixed_analytic_numerical_no_stiffness.json")
solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True)
assert len(solver_dict) == 2
assert (solver_dict[0]["solver"] == "analytical" and solver_dict[1]["solver"][:7] == "numeric") \
or (solver_dict[1]["solver"] == "analytical" and solver_dict[0]["solver"][:7] == "numeric")

@pytest.mark.skipif(not PYGSL_AVAILABLE, reason="Cannot run stiffness test if GSL is not installed.")
def test_mixed_analytic_numerical_with_stiffness(self):
indict = open_json("mixed_analytic_numerical_with_stiffness.json")
indict = _open_json("mixed_analytic_numerical_with_stiffness.json")
solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=False)
assert len(solver_dict) == 2
assert (solver_dict[0]["solver"] == "analytical" and solver_dict[1]["solver"][:7] == "numeric") \
or (solver_dict[1]["solver"] == "analytical" and solver_dict[0]["solver"][:7] == "numeric")


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 5371e1c

Please sign in to comment.