Skip to content

Commit

Permalink
Merge pull request #62 from clinssen/singularity_detection
Browse files Browse the repository at this point in the history
Add singularity detection feature
  • Loading branch information
clinssen committed Oct 10, 2022
2 parents 12cbf3b + 184a6c1 commit 34e2349
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 8 deletions.
2 changes: 1 addition & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ The propagator matrix :math:`\mathbf{P}` is derived from the system matrix by ma
If the imaginary unit :math:`i` is found in any of the entries in :math:`\mathbf{P}`, fail. This usually indicates an unstable (diverging) dynamical system. Double-check the dynamical equations.
In some cases, elements of :math:`\mathbf{P}` may contain fractions that have a factor of the form :python:`param1 - param2` in their denominator. If at a later stage, the numerical value of :python:`param1` is chosen equal to that of :python:`param2`, a numerical singularity (division by zero) occurs. To avoid this issue, it is necessary to eliminate either :python:`param1` or :python:`param2` in the input, before the propagator matrix is generated.
In some cases, elements of :math:`\mathbf{P}` may contain fractions that have a factor of the form :python:`param1 - param2` in their denominator. If at a later stage, the numerical value of :python:`param1` is chosen equal to that of :python:`param2`, a numerical singularity (division by zero) occurs. To avoid this issue, it is necessary to eliminate either :python:`param1` or :python:`param2` in the input, before the propagator matrix is generated. ODE-toolbox will detect conditions (in this example, :python:`param1 = param2`) under which these singularities can occur. If any conditions were found, log warning messages will be emitted during the computation of the propagator matrix. A condition is only reported if the system matrix :math:`A` is defined under that condition, ensuring that only those conditions are returned that are purely an artifact of the propagator computation.
Computing the update expressions
Expand Down
6 changes: 3 additions & 3 deletions odetoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _get_all_first_order_variables(indict) -> Iterable[str]:
return variable_names


def _analysis(indict, disable_stiffness_check: bool=False, disable_analytic_solver: bool=False, preserve_expressions: Union[bool, Iterable[str]]=False, simplify_expression: str="sympy.simplify(expr)", log_level: Union[str, int]=logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]:
def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, simplify_expression: str = "sympy.simplify(expr)", log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]:
r"""
Like analysis(), but additionally returns ``shape_sys`` and ``shapes``.
Expand Down Expand Up @@ -360,7 +360,7 @@ def _analysis(indict, disable_stiffness_check: bool=False, disable_analytic_solv
return solvers_json, shape_sys, shapes


def _init_logging(log_level: Union[str, int]=logging.WARNING):
def _init_logging(log_level: Union[str, int] = logging.WARNING):
"""
Initialise message logging.
Expand All @@ -371,7 +371,7 @@ def _init_logging(log_level: Union[str, int]=logging.WARNING):
logging.getLogger().setLevel(log_level)


def analysis(indict, disable_stiffness_check: bool=False, disable_analytic_solver: bool=False, preserve_expressions: Union[bool, Iterable[str]]=False, simplify_expression: str="sympy.simplify(expr)", log_level: Union[str, int]=logging.WARNING) -> List[Dict]:
def analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, simplify_expression: str = "sympy.simplify(expr)", log_level: Union[str, int] = logging.WARNING) -> List[Dict]:
r"""
The main entry point of the ODE-toolbox API.
Expand Down
3 changes: 2 additions & 1 deletion odetoolbox/analytic_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import sympy
import sympy.matrices
import sympy.utilities
import sympy.utilities.autowrap

from .shapes import Shape
from .integrator import Integrator
Expand Down Expand Up @@ -223,7 +225,6 @@ def get_value(self, t):
if spike_t > t:
break


#
# apply propagator to update the state from `t_curr` to `spike_t`
#
Expand Down
4 changes: 2 additions & 2 deletions odetoolbox/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MalformedInputException(Exception):
pass


def is_constant_term(term, parameters: Mapping[sympy.Symbol, str]=None):
def is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None):
r"""
:return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise.
"""
Expand Down Expand Up @@ -456,7 +456,7 @@ def from_function(cls, symbol: str, definition, max_t=100, max_order=4, all_vari
# `derivatives` is a list of all derivatives of `shape` up to the order we are checking, starting at 0.
derivatives = [definition, sympy.diff(definition, time_symbol)]

logging.info("\nProcessing function-of-time shape " + str(symbol) + " with defining expression = \"" + str(definition) + "\"")
logging.info("\nProcessing function-of-time shape \"" + str(symbol) + "\" with defining expression = \"" + str(definition) + "\"")


#
Expand Down
154 changes: 154 additions & 0 deletions odetoolbox/singularity_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#
# singularity_detection.py
#
# This file is part of the NEST ODE toolbox.
#
# Copyright (C) 2017 The NEST Initiative
#
# The NEST ODE toolbox is free software: you can redistribute it
# and/or modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation, either version 2 of
# the License, or (at your option) any later version.
#
# The NEST ODE toolbox is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#
from typing import Mapping

import sympy
import sympy.parsing.sympy_parser


class SingularityDetection:
r"""Singularity detection for generated propagator matrix.
Some ordinary differential equations (ODEs) can be solved analytically: an expression for the solution can be readily derived by algebraic manipulation. This allows us to formulate an "exact integrator", that yields the next state of the system given the current state and the timestep Δt, to floating point (machine) precision [1]_.
In some cases, an ODE is analytically tractable, but vulnerable to an edge case condition in the generated propagator matrices. Consider the following example: Let the system of ODEs be given by
.. math::
y' = A \cdot y
Then the propagator matrix for a timestep :math:`\Delta t` is
.. math::
P = \exp(A \cdot \Delta t)
which we can use to advance the system
.. math::
y(t + \Delta t) = P \cdot y(t)
If :math:`A` is of the form:
.. math::
\begin{bmatrix}
-a & 0 & 0\\
1 & -a & 0\\
0 & 1 & -b
\end{bmatrix}
Then the generated propagator matrix contains denominators that include the factor :math:`a - b`. When the parameters are chosen such that :math:`a = b`, a singularity (division by zero fault) occurs. However, the singularity is readily avoided if we assume that :math:`a = b` before generating the propagator, i.e. we start out with the matrix
.. math::
\begin{bmatrix}
-a & 0 & 0\\
1 & -a & 0\\
0 & 1 & -a
\end{bmatrix}
The resulting propagator contains no singularities.
This class detects the potential occurrence of such singularities (potential division by zero) in the generated propagator matrix, which occur under certain choices of parameter values. These choices are reported as "conditions" by the ``find_singularities()`` function.
References
----------
.. [1] Stefan Rotter, Markus Diesmann. Exact digital simulation of time-invariant linear systems with applications to neuronal modeling. Neurobiologie und Biophysik, Institut für Biologie III, Universität Freiburg, Freiburg, Germany Biol. Cybern. 81, 381-402 (1999)
"""

@staticmethod
def _is_matrix_defined_under_substitution(A: sympy.Matrix, cond: Mapping) -> bool:
r"""
Function to check if a matrix is defined (i.e. does not contain NaN or infinity) after we perform a given set of subsitutions.
Parameters
----------
A : sympy.Matrix
input matrix
cond : Mapping
mapping from expression that is to be subsituted, to expression to put in its place
"""
for val in sympy.flatten(A):
for expr, subs_expr in cond.items():
if sympy.simplify(val.subs(expr, subs_expr)) in [sympy.nan, sympy.zoo, sympy.oo]:
return False

return True

@staticmethod
def _flatten_conditions(cond):
r"""
Return a list with conditions in the form of dictionaries
"""
lst = []
for i in range(len(cond)):
if cond[i] not in lst:
lst.append(cond[i])

return lst

@staticmethod
def _filter_valid_conditions(cond, A: sympy.Matrix):
filt_cond = []
for i in range(len(cond)): # looping over conditions
if SingularityDetection._is_matrix_defined_under_substitution(A, cond[i]):
filt_cond.append(cond[i])

return filt_cond

@staticmethod
def _generate_singularity_conditions(A: sympy.Matrix):
r"""
The function solve returns a list where each element is a dictionary. And each dictionary entry (condition: expression) corresponds to a condition at which that expression goes to zero.
If the expression is quadratic, like let's say "x**2-1" then the function 'solve() returns two dictionaries in a list. each dictionary corresponds to one solution.
We are then collecting these lists in our own list called 'condition'.
"""
conditions = []
for expr in sympy.flatten(A):
for subexpr in sympy.preorder_traversal(expr): # traversing through the tree
if isinstance(subexpr, sympy.Pow) and subexpr.args[1] < 0: # find expressions of the form 1/x, which is encoded in sympy as x^-1
denom = subexpr.args[0] # extracting the denominator
cond = sympy.solve(denom, denom.free_symbols, dict=True) # ``cond`` here is a list of all those conditions at which the denominator goes to zero
if cond not in conditions:
conditions.extend(cond)

return conditions

@staticmethod
def find_singularities(P: sympy.Matrix, A: sympy.Matrix):
r"""Find singularities in the propagator matrix :math:`P` given the system matrix :math:`A`.
Parameters
----------
P : sympy.Matrix
propagator matrix to check for singularities
A : sympy.Matrix
system matrix
"""
conditions = SingularityDetection._generate_singularity_conditions(P)
conditions = SingularityDetection._flatten_conditions(conditions) # makes a list of conditions with each condition in the form of a dict
conditions = SingularityDetection._filter_valid_conditions(conditions, A) # filters out the invalid conditions (invalid means those for which A is not defined)

return conditions
2 changes: 1 addition & 1 deletion odetoolbox/spike_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def spike_times_from_json(cls, stimuli, sim_time, derivative_symbol="__d") -> Ma


@classmethod
def _generate_homogeneous_poisson_spikes(cls, T: float, rate: float, min_isi: float=1E-6):
def _generate_homogeneous_poisson_spikes(cls, T: float, rate: float, min_isi: float = 1E-6):
r"""
Generate spike trains for the given simulation length. Uses a Poisson distribution to create biologically realistic characteristics of the spike-trains.
Expand Down
9 changes: 9 additions & 0 deletions odetoolbox/system_of_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import sys

from .shapes import Shape
from .singularity_detection import SingularityDetection
from .sympy_printer import _is_zero


Expand Down Expand Up @@ -196,6 +197,14 @@ def generate_propagator_solver(self, output_timestep_symbol: str = "__h"):
if sympy.I in sympy.preorder_traversal(P):
raise PropagatorGenerationException("The imaginary unit was found in the propagator matrix. This can happen if the dynamical system that was passed to ode-toolbox is unstable, i.e. one or more state variables will diverge to minus or positive infinity.")

condition = SingularityDetection.find_singularities(P, self.A_)
if condition:
logging.warning("Under certain conditions, the propagator matrix is singular (contains infinities).")
logging.warning("List of all conditions that result in a singular propagator:")
for cond in condition:
logging.warning("\t" + r" ∧ ".join([str(k) + " = " + str(v) for k, v in cond.items()]))


#
# generate symbols for each nonzero entry of the propagator matrix
#
Expand Down
68 changes: 68 additions & 0 deletions tests/test_singularity_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# test_singularity_detection.py
#
# This file is part of the NEST ODE toolbox.
#
# Copyright (C) 2017 The NEST Initiative
#
# The NEST ODE toolbox is free software: you can redistribute it
# and/or modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation, either version 2 of
# the License, or (at your option) any later version.
#
# The NEST ODE toolbox is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#

import numpy as np
import sympy
import pytest

from odetoolbox.singularity_detection import SingularityDetection


class TestSingularityDetection:
r"""Test singularity detection"""

def test_is_matrix_defined_under_substitution(self):
tau_m, tau_r, C, h = sympy.symbols("tau_m, tau_r, C, h")
P = sympy.Matrix([[-1 / tau_r, 0, 0], [1, -1 / tau_r, 0], [0, 1 / C, -1 / tau_m]])
assert SingularityDetection._is_matrix_defined_under_substitution(P, {})
assert SingularityDetection._is_matrix_defined_under_substitution(P, {tau_r: 1})
assert not SingularityDetection._is_matrix_defined_under_substitution(P, {tau_r: 0})

@pytest.mark.parametrize("kernel_to_use", ["alpha", "beta"])
def test_alpha_beta_kernels(self, kernel_to_use: str):
r"""Test correctness of result for simple leaky integrate-and-fire neuron with biexponential postsynaptic kernel"""
if kernel_to_use == "alpha":
tau_m, tau_s, C, h = sympy.symbols("tau_m, tau_s, C, h")
A = sympy.Matrix([[-1 / tau_s, 0, 0], [1, -1 / tau_s, 0], [0, 1 / C, -1 / tau_m]])
elif kernel_to_use == "beta":
tau_m, tau_d, tau_r, C, h = sympy.symbols("tau_m, tau_d, tau_r, C, h")
A = sympy.Matrix([[-1 / tau_d, 0, 0], [1, -1 / tau_r, 0], [0, 1 / C, -1 / tau_m]])

P = sympy.simplify(sympy.exp(A * h)) # Propagator matrix

condition = SingularityDetection._generate_singularity_conditions(P)
condition = SingularityDetection._flatten_conditions(condition) # makes a list of conditions with each condition in the form of a dict
condition = SingularityDetection._filter_valid_conditions(condition, A) # filters out the invalid conditions (invalid means those for which A is not defined)

if kernel_to_use == "alpha":
assert len(condition) == 1
elif kernel_to_use == "beta":
assert len(condition) == 3

def test_more_than_one_solution(self):
r"""Test the case where there is more than one element returned in a solution to an equation; in this example, for a quadratic input equation"""
A = sympy.Matrix([[sympy.parsing.sympy_parser.parse_expr("-1/(tau_s**2 - 3*tau_s - 42)")]])
condition = SingularityDetection._generate_singularity_conditions(A)
assert len(condition) == 2
for cond in condition:
assert sympy.Symbol("tau_s") in cond.keys()
assert cond[sympy.Symbol("tau_s")] == sympy.parsing.sympy_parser.parse_expr("3/2 + sqrt(177)/2") \
or cond[sympy.Symbol("tau_s")] == sympy.parsing.sympy_parser.parse_expr("3/2 - sqrt(177)/2")

0 comments on commit 34e2349

Please sign in to comment.