Skip to content

Commit

Permalink
Remove unnecessary sensitivity equations
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 6, 2020
1 parent 75df433 commit 4b8c446
Showing 1 changed file with 70 additions and 6 deletions.
76 changes: 70 additions & 6 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,33 @@ def setup_fit():
return simulators[get_device().__class__.__name__]


def get_sensitivity_equations(group, parameters, namespace=None, level=1):
def get_sensitivity_equations(group, parameters, namespace=None, level=1,
optimize=True):
"""
Get equations for sensitivity variables.
Parameters
----------
group : `NeuronGroup`
The group of neurons that will be simulated.
parameters : list of str
Names of the parameters that are fit.
namespace : dict, optional
The namespace to use.
level : `int`, optional
How much farther to go down in the stack to find the namespace.
optimize : bool, optional
Whether to remove sensitivity variables from the equations that do
not evolve if initialized to zero (e.g. ``dS_x_y/dt = -S_x_y/tau``
would be removed). This avoids unnecessary computation but will fail
in the rare case that such a sensitivity variable needs to be
initialized to a non-zero value. Defaults to ``True``.
Returns
-------
sensitivity_eqs : `Equations`
The equations for the sensitivity variables.
"""
if namespace is None:
namespace = get_local_namespace(level)
namespace.update(group.namespace)
Expand Down Expand Up @@ -113,15 +139,26 @@ def get_sensitivity_equations(group, parameters, namespace=None, level=1):
else:
raise AssertionError(f'Parameter {param} neither in namespace nor variables')
unit = repr(unit) if not unit.is_dimensionless else '1'
new_eqs.append('d{lhs}/dt = {rhs} : {unit}'.format(lhs=name,
if optimize:
# Check if the equation stays at zero if initialized at zero
zeroed = eq.subs(name, sympy.S.Zero)
if zeroed == sympy.S.Zero:
print(f'removing {sympy_to_str(name)} from equations')
# No need to include equation as differential equation
if unit == '1':
new_eqs.append(f'{sympy_to_str(name)} = 0 : {unit}')
else:
new_eqs.append(f'{sympy_to_str(name)} = 0*{unit} : {unit}')
continue
new_eqs.append('d{lhs}/dt = {rhs} : {unit}'.format(lhs=sympy_to_str(name),
rhs=sympy_to_str(eq),
unit=unit))
new_eqs = Equations('\n'.join(new_eqs))
return new_eqs


def get_sensitivity_init(group, parameters, param_init):
'''
"""
Calculate the initial values for the sensitivity parameters (necessary if
initial values are functions of parameters).
Expand All @@ -139,7 +176,7 @@ def get_sensitivity_init(group, parameters, param_init):
sensitivity_init : dict
Dictionary of expressions to initialize the sensitivity
parameters.
'''
"""
sensitivity_dict = {}
for var_name, expr in param_init.items():
if not isinstance(expr, str):
Expand All @@ -156,6 +193,15 @@ def get_sensitivity_init(group, parameters, param_init):
for parameter in parameters:
diffed = sympy_expr.diff(str_to_sympy(parameter))
if diffed != sympy.S.Zero:
if getattr(group.variables[parameter],
'type', None) == SUBEXPRESSION:
raise NotImplementedError('Sensitivity '
f'S_{var_name}_{parameter} '
'is initialized to a non-zero '
'value, but it has been '
'removed from the equations. '
'Set optimize=False to avoid '
'this.')
init_expr = sympy_to_str(diffed)
sensitivity_dict[f'S_{var_name}_{parameter}'] = init_expr
return sensitivity_dict
Expand Down Expand Up @@ -270,7 +316,7 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self.param_init = param_init

def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
name='neurons'):
optimize=True, name='neurons'):
"""
Setup neuron group, initialize required number of neurons, create
namespace and initialize the parameters.
Expand All @@ -295,6 +341,7 @@ def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
if calc_gradient:
sensitivity_eqs = get_sensitivity_equations(neurons,
parameters=self.parameter_names,
optimize=optimize,
namespace=namespace)
neurons = NeuronGroup(n_neurons, self.model + sensitivity_eqs,
method=self.method,
Expand Down Expand Up @@ -606,7 +653,8 @@ def generate_traces(self, params=None, param_init=None, level=0):
return fits

def refine(self, params=None, t_start=None, normalization=None,
callback='text', calc_gradient=False, level=0, **kwds):
callback='text', calc_gradient=False, optimize=True,
level=0, **kwds):
"""
Refine the fitting results with a sequentially operating minimization
algorithm. Uses the `lmfit <https://lmfit.github.io/lmfit-py/>`_
Expand Down Expand Up @@ -639,6 +687,21 @@ def refine(self, params=None, t_start=None, normalization=None,
``func(parameters, errors, best_parameters, best_error, index)``.
If this function returns ``True`` the fitting execution is
interrupted.
calc_gradient: bool, optional
Whether to add "sensitivity variables" to the equation that track
the sensitivity of the equation variables to the parameters. This
information will be used to pass the local gradient of the error
with respect to the parameters to the optimization function. This
can lead to much faster convergence than with an estimated gradient
but comes at the expense of additional computation. Defaults to
``False``.
optimize : bool, optional
Whether to remove sensitivity variables from the equations that do
not evolve if initialized to zero (e.g. ``dS_x_y/dt = -S_x_y/tau``
would be removed). This avoids unnecessary computation but will fail
in the rare case that such a sensitivity variable needs to be
initialized to a non-zero value. Only taken into account if
``calc_gradient`` is ``True``. Defaults to ``True``.
level : int, optional
How much farther to go down in the stack to find the namespace.
kwds
Expand Down Expand Up @@ -712,6 +775,7 @@ def refine(self, params=None, t_start=None, normalization=None,
level=level+1)
neurons = self.setup_neuron_group(self.n_traces, namespace,
calc_gradient=calc_gradient,
optimize=optimize,
name='neurons')
monitored_variables = [self.output_var]
param_init = dict(self.param_init)
Expand Down

0 comments on commit 4b8c446

Please sign in to comment.