Skip to content

Commit

Permalink
Merge pull request #418 from brian-team/state_updater_choice
Browse files Browse the repository at this point in the history
Change the system that decides which state updater algorithm is used
  • Loading branch information
mstimberg committed Mar 2, 2015
2 parents fb4b3a2 + 81ac05c commit ee02896
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 94 deletions.
3 changes: 2 additions & 1 deletion brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ class NeuronGroup(Group, SpikeSource):
'''
add_to_magic_network = True

def __init__(self, N, model, method=None,
def __init__(self, N, model,
method=('linear', 'euler', 'milstein'),
threshold=None,
reset=None,
refractory=False,
Expand Down
2 changes: 1 addition & 1 deletion brian2/spatialneuron/spatialneuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, morphology=None, model=None, threshold=None,
threshold_location=None,
dt=None, clock=None, order=0, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
name='spatialneuron*', dtype=None, namespace=None,
method=None):
method=('linear', 'exponential_euler', 'rk2', 'milstein')):

# #### Prepare and validate equations
if isinstance(model, basestring):
Expand Down
94 changes: 36 additions & 58 deletions brian2/stateupdaters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
for example in `NeuronGroup` when no state updater is given explicitly.
'''
from abc import abstractmethod, ABCMeta
import collections

from brian2.utils.logger import get_logger

Expand All @@ -15,8 +16,8 @@
class StateUpdateMethod(object):
__metaclass__ = ABCMeta

#: A list of registered (name, stateupdater) pairs (in the order of priority)
stateupdaters = []
#: A dictionary mapping state updater names to `StateUpdateMethod` objects
stateupdaters = dict()

@abstractmethod
def can_integrate(self, equations, variables):
Expand Down Expand Up @@ -66,57 +67,43 @@ def __call__(self, equations, variables=None):
pass

@staticmethod
def register(name, stateupdater, index=None):
def register(name, stateupdater):
'''
Register a state updater. Registered state updaters will be considered
when no state updater is explicitly given (e.g. in `NeuronGroup`) and
can be referred to via their name.
Register a state updater. Registered state updaters can be referred to
via their name.
Parameters
----------
name : str
A short name for the state updater (e.g. `'euler'`)
stateupdater : `StateUpdaterMethod`
The state updater object, e.g. an `ExplicitStateUpdater`.
index : int, optional
Where in the list of state updaters the given state updater should
be inserted. State updaters have a higher priority of being chosen
automatically if they appear earlier in the list. If no `index` is
given, the state updater will be inserted at the end of the list.
'''

# only deal with lower case names -- we don't want to have 'Euler' and
# 'euler', for example
name = name.lower()
for registered_name, _ in StateUpdateMethod.stateupdaters:
if registered_name == name:
raise ValueError(('A stateupdater with the name "%s" '
'has already been registered') % name)
name = name.lower()
if name in StateUpdateMethod.stateupdaters:
raise ValueError(('A stateupdater with the name "%s" '
'has already been registered') % name)

if not isinstance(stateupdater, StateUpdateMethod):
raise ValueError(('Given stateupdater of type %s does not seem to '
'be a valid stateupdater.' % str(type(stateupdater))))

if not index is None:
try:
index = int(index)
except (TypeError, ValueError):
raise TypeError(('Index argument should be an integer, is '
'of type %s instead.') % type(index))
StateUpdateMethod.stateupdaters.insert(index, (name, stateupdater))
else:
StateUpdateMethod.stateupdaters.append((name, stateupdater))
StateUpdateMethod.stateupdaters[name] = stateupdater

@staticmethod
def determine_stateupdater(equations, variables, method=None):
def determine_stateupdater(equations, variables, method):
'''
Determine a suitable state updater. If a `method` is given, the
state updater with the given name is used. In case it is a callable, it
will be used even if it is a state updater that claims it is not
applicable. If it is a string, the state updater registered with that
name will be used, but in this case an error will be raised if it
claims not to be applicable. If no `method` is given explicitly, the
suitable state updater with the highest priority is used.
claims not to be applicable. If a `method` is a list of names, all the
methods will be tried until one that can integrate the equations is
found.
Parameters
----------
Expand All @@ -125,17 +112,16 @@ def determine_stateupdater(equations, variables, method=None):
variables : `dict`
The dictionary of `Variable` objects, describing the internal
model variables.
method : {callable, str, ``None``}, optional
method : {callable, str, list of str}
A callable usable as a state updater, the name of a registered
state updater or ``None`` (the default)
state updater or a list of names of state updaters.
'''
if hasattr(method, '__call__'):
# if this is a standard state updater, i.e. if it has a
# can_integrate method, check this method and raise a warning if it
# claims not to be applicable.
try:
priority = method.can_integrate(equations, variables)
if priority == 0:
if not method.can_integrate(equations, variables):
logger.warn(('The manually specified state updater '
'claims that it does not support the given '
'equations.'))
Expand All @@ -145,40 +131,32 @@ def determine_stateupdater(equations, variables, method=None):

logger.info('Using manually specified state updater: %r' % method)
return method

if method is not None:
elif isinstance(method, basestring):
method = method.lower() # normalize name to lower case
stateupdater = None
for name, registered_stateupdater in StateUpdateMethod.stateupdaters:
if name == method:
stateupdater = registered_stateupdater
break
stateupdater = StateUpdateMethod.stateupdaters.get(method, None)
if stateupdater is None:
raise ValueError('No state updater with the name "%s" '
'is known' % method)
if not stateupdater.can_integrate(equations, variables):
raise ValueError(('The state updater "%s" cannot be used for '
'the given equations' % method))
return stateupdater
elif isinstance(method, collections.Iterable):
for name in method:
if name not in StateUpdateMethod.stateupdaters:
logger.warn('No state updater with the name "%s" '
'is known' % name, 'unkown_stateupdater')
else:
stateupdater = StateUpdateMethod.stateupdaters[name]
try:
if stateupdater.can_integrate(equations, variables):
logger.info('Using stateupdater "%s"' % name)
return stateupdater
except KeyError:
logger.debug(('It could not be determined whether state '
'updater "%s" is able to integrate the equations, '
'it appears the namespace is not yet complete.'
% name))

# determine the best suitable state updater
best_stateupdater = None
for name, stateupdater in StateUpdateMethod.stateupdaters:
try:
if stateupdater.can_integrate(equations, variables):
best_stateupdater = (name, stateupdater)
break
except KeyError:
logger.debug(('It could not be determined whether state '
'updater "%s" is able to integrate the equations, '
'it appears the namespace is not yet complete.'
% name))

# No suitable state updater has been found
if best_stateupdater is None:
raise ValueError(('No stateupdater that is suitable for the given '
'equations has been found'))

name, stateupdater = best_stateupdater
logger.info('Using stateupdater "%s"' % name)
return stateupdater
4 changes: 3 additions & 1 deletion brian2/synapses/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,9 @@ class Synapses(Group):
def __init__(self, source, target=None, model=None, pre=None, post=None,
connect=False, delay=None, namespace=None, dtype=None,
codeobj_class=None,
dt=None, clock=None, order=0, method=None, name='synapses*'):
dt=None, clock=None, order=0,
method=('linear', 'euler', 'milstein'),
name='synapses*'):
self._N = 0
Group.__init__(self, dt=dt, clock=clock, when='start', order=order,
name=name)
Expand Down
30 changes: 12 additions & 18 deletions brian2/tests/test_stateupdaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def test_registration():
Test state updater registration.
'''
# Save state before tests
before = list(StateUpdateMethod.stateupdaters)
before = dict(StateUpdateMethod.stateupdaters)

lazy_updater = ExplicitStateUpdater('x_new = x')
StateUpdateMethod.register('lazy', lazy_updater)
Expand Down Expand Up @@ -399,9 +399,6 @@ def test_determination():
# To save some typing
determine_stateupdater = StateUpdateMethod.determine_stateupdater

# Save state before tests
before = list(StateUpdateMethod.stateupdaters)

eqs = Equations('dv/dt = -v / (10*ms) : 1')
# Just make sure that state updaters know about the two state variables
variables = {'v': Variable(name='v', unit=None),
Expand Down Expand Up @@ -481,32 +478,29 @@ def test_determination():
# Automatic state updater choice should return linear for linear equations,
# euler for non-linear, non-stochastic equations and equations with
# additive noise, milstein for equations with multiplicative noise
# Because it is somewhat fragile, the "independent" state updater is not
# included in this list
all_methods = ['linear', 'exponential_euler', 'euler', 'milstein']
eqs = Equations('dv/dt = -v / (10*ms) : 1')
assert determine_stateupdater(eqs, variables) is linear
assert determine_stateupdater(eqs, variables, all_methods) is linear

# This is conditionally linear
eqs = Equations('''dv/dt = -(v + w**2)/ (10*ms) : 1
dw/dt = -w/ (10*ms) : 1''')
assert determine_stateupdater(eqs, variables) is exponential_euler
assert determine_stateupdater(eqs, variables, all_methods) is exponential_euler

eqs = Equations('dv/dt = sin(t) / (10*ms) : 1')
assert determine_stateupdater(eqs, variables) is independent
# # Do not test for now
# eqs = Equations('dv/dt = sin(t) / (10*ms) : 1')
# assert determine_stateupdater(eqs, variables) is independent

eqs = Equations('dv/dt = -sqrt(v) / (10*ms) : 1')
assert determine_stateupdater(eqs, variables) is euler
assert determine_stateupdater(eqs, variables, all_methods) is euler

eqs = Equations('dv/dt = -v / (10*ms) + 0.1*second**-.5*xi: 1')
assert determine_stateupdater(eqs, variables) is euler
assert determine_stateupdater(eqs, variables, all_methods) is euler

eqs = Equations('dv/dt = -v / (10*ms) + v*0.1*second**-.5*xi: 1')
assert determine_stateupdater(eqs, variables) is milstein

# remove all registered state updaters --> automatic choice no longer works
StateUpdateMethod.stateupdaters = {}
assert_raises(ValueError, lambda: determine_stateupdater(eqs, variables))

# reset to state before the test
StateUpdateMethod.stateupdaters = before
assert determine_stateupdater(eqs, variables, all_methods) is milstein

@attr('standalone-compatible')
@with_setup(teardown=restore_device)
Expand Down
7 changes: 0 additions & 7 deletions docs_sphinx/advanced/state_update.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,6 @@ After creating the state updater, it has to be registered with
new_state_updater = ExplicitStateUpdater('...', stochastic='additive')
StateUpdateMethod.register('mymethod', new_state_updater)

The `StateUpdateMethod.register` method also takes an optional ``index``
argument, allowing you to insert the new state updater at an arbitrary
location in the list of state updaters (by default it gets added at the end).
The position in the list determines which state updater is chosen if more than
one is able to integrate the equations: If more than one choice is possible,
the state updater that comes first in the list is chosen.

The preferred way to do write new general state updaters (i.e. state updaters
that cannot be described using the explicit syntax described above) is to
extend the `StateUpdateMethod` class (but this is not strictly necessary, all
Expand Down
20 changes: 12 additions & 8 deletions docs_sphinx/user/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,18 @@ Numerical integration
---------------------
Differential equations are converted into a sequence of statements that
integrate the equations numerically over a single time step. By default, Brian
choses an integration method automatically, trying to solve the equations
exactly first (which is possible for example for linear equations) and then
resorting to numerical algorithms (see below). It will also take care of integrating
stochastic differential equations appropriately. If you prefer to chose an
integration algorithm yourself, you can do so using the ``method`` keyword for
`NeuronGroup` or `Synapses`. The list of available methods is the following,
if no method is chosen explicitly Brian will try methods starting at the top
until it finds a method than can integrate the given equations:
chooses an integration method automatically, trying to solve the equations
exactly first (for linear equations) and then resorting to numerical algorithms.
It will also take care of integrating stochastic differential equations
appropriately. Each class defines its own list of algorithms it tries to
apply, `NeuronGroup` and `Synapses` will use the first suitable method out of
the methods ``'linear'``, ``'euler'``, and ``'milstein'`` while `SpatialNeuron`
objects will use ``'linear'``, ``'exponential_euler'``, ``'rk2'``, or
``'milstein'``.

If you prefer to chose an integration algorithm yourself, you can do so using
the ``method`` keyword for `NeuronGroup`, `Synapses`, or `SpatialNeuron`.
The complete list of available methods is the following:

* ``'linear'``: exact integration for linear equations
* ``'independent'``: exact integration for a system of independent equations,
Expand Down

0 comments on commit ee02896

Please sign in to comment.