Skip to content

Commit

Permalink
WIP: Optionally use gradients in TraceFitter.refine
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Feb 27, 2020
1 parent adad721 commit c52c168
Showing 1 changed file with 74 additions and 6 deletions.
80 changes: 74 additions & 6 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import abc
import numbers

from brian2.parsing.sympytools import sympy_to_str
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros
from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity)
Expand Down Expand Up @@ -74,6 +75,45 @@ def setup_fit():
return simulators[get_device().__class__.__name__]


def get_sensitivity_equations(group, parameters, namespace=None, level=1):
from sympy import Matrix
if namespace is None:
namespace = get_local_namespace(level)
namespace.update(group.namespace)

eqs = group.equations
diff_eqs = eqs.get_substituted_expressions(group.variables)
diff_eq_names = [name for name, _ in diff_eqs]

system = Matrix([diff_eq[1] for diff_eq in diff_eqs])
J = system.jacobian(diff_eq_names)

sensitivity = []
sensitivity_names = []
for parameter in parameters:
F = system.jacobian([parameter])
names = ['S_{}_{}'.format(diff_eq_name, parameter)
for diff_eq_name in diff_eq_names]
sensitivity.append(J * Matrix(names) + F)
sensitivity_names.append(names)

new_eqs = []
for names, sensitivity_eqs, param in zip(sensitivity_names, sensitivity, parameters):
for name, eq, orig_var in zip(names, sensitivity_eqs, diff_eq_names):
if param in namespace:
unit = eqs[orig_var].dim / namespace[param].dim
elif param in group.variables:
unit = eqs[orig_var].dim / group.variables[param].dim
else:
raise AssertionError(f'Parameter {param} neither in namespace nor variables')
unit = repr(unit) if not unit.is_dimensionless else '1'
new_eqs.append('d{lhs}/dt = {rhs} : {unit}'.format(lhs=name,
rhs=sympy_to_str(eq),
unit=unit))
new_eqs = Equations('\n'.join(new_eqs))
return new_eqs


class Fitter(metaclass=abc.ABCMeta):
"""
Base Fitter class for model fitting applications.
Expand Down Expand Up @@ -182,7 +222,8 @@ def __init__(self, dt, model, input, output, input_var, output_var,
"parameter in the model" % param)
self.param_init = param_init

def setup_neuron_group(self, n_neurons, namespace, name='neurons'):
def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
name='neurons'):
"""
Setup neuron group, initialize required number of neurons, create
namespace and initialize the parameters.
Expand All @@ -204,7 +245,15 @@ def setup_neuron_group(self, n_neurons, namespace, name='neurons'):
threshold=self.threshold, reset=self.reset,
refractory=self.refractory, name=name,
namespace=namespace)

if calc_gradient:
sensitivity_eqs = get_sensitivity_equations(neurons,
parameters=self.parameter_names,
namespace=namespace)
neurons = NeuronGroup(n_neurons, self.model + sensitivity_eqs,
method=self.method,
threshold=self.threshold, reset=self.reset,
refractory=self.refractory, name=name,
namespace=namespace)
return neurons

@abc.abstractmethod
Expand Down Expand Up @@ -510,7 +559,7 @@ def generate_traces(self, params=None, param_init=None, level=0):
return fits

def refine(self, params=None, t_start=None, normalization=None,
callback='text', level=0, **kwds):
callback='text', calc_gradient=False, level=0, **kwds):
"""
Refine the fitting results with a sequentially operating minimization
algorithm. Uses the `lmfit <https://lmfit.github.io/lmfit-py/>`_
Expand Down Expand Up @@ -612,8 +661,13 @@ def refine(self, params=None, t_start=None, normalization=None,
'output_var': self.output_var},
level=level+1)
neurons = self.setup_neuron_group(self.n_traces, namespace,
calc_gradient=calc_gradient,
name='neurons')
monitor = StateMonitor(neurons, self.output_var, record=True,
monitored_variables = [self.output_var]
if calc_gradient:
monitored_variables += [f'S_{self.output_var}_{p}'
for p in self.parameter_names]
monitor = StateMonitor(neurons, monitored_variables, record=True,
name='monitor')
network = Network(neurons, monitor)

Expand All @@ -630,6 +684,16 @@ def _calc_error(params):
residual = trace[:, t_start_steps:] - self.output[:, t_start_steps:]
return residual.flatten() * normalization

def _calc_gradient(params):
residuals = []
for name in self.parameter_names:
trace = getattr(simulator.networks['refine']['monitor'],
f'S_{self.output_var}_{name}_')
residual = trace[:, t_start_steps:] - self.output[:, t_start_steps:]
residuals.append(residual.flatten() * normalization)
gradient = array(residuals)
return gradient.T

tested_parameters = []
errors = []
def _callback_wrapper(params, iter, resid, *args, **kwds):
Expand All @@ -642,8 +706,12 @@ def _callback_wrapper(params, iter, resid, *args, **kwds):
best_params = tested_parameters[best_idx]
return callback(params, errors, best_params, best_error, iter)

assert 'Dfun' not in kwds
if calc_gradient:
kwds.update({'Dfun': _calc_gradient})
result = lmfit.minimize(_calc_error, parameters,
iter_cb=_callback_wrapper, **kwds)
iter_cb=_callback_wrapper,
**kwds)

if needs_device_reset:
reset_device()
Expand Down

0 comments on commit c52c168

Please sign in to comment.