Skip to content

Commit

Permalink
Add callback to TraceFitter.refine
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Feb 19, 2020
1 parent 664abbf commit 0fb9779
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
28 changes: 25 additions & 3 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers

from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions
from numpy import ones, array, arange, concatenate, mean, nanmin, reshape
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape
from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity)
Expand Down Expand Up @@ -509,7 +509,8 @@ def generate_traces(self, params=None, param_init=None, level=0):
param_init=param_init, level=level+1)
return fits

def refine(self, params=None, t_start=None, normalization=None, level=0, **kwds):
def refine(self, params=None, t_start=None, normalization=None,
callback='text', level=0, **kwds):
"""
Refine the fitting results with a sequentially operating minimization
algorithm. Uses the `lmfit <https://lmfit.github.io/lmfit-py/>`_
Expand All @@ -535,6 +536,12 @@ def refine(self, params=None, t_start=None, normalization=None, level=0, **kwds)
the size of steps in the parameter space depends on the absolute
value of the error. If not set, will reuse the `normalization` value
from the previously used metric.
callback: `str` or `~typing.Callable`
Either the name of a provided callback function (``text`` or
``progressbar``), or a custom feedback function
``func(parameters, errors, best_parameters, best_error, index)``.
If this function returns ``True`` the fitting execution is
interrupted.
level : int, optional
How much farther to go down in the stack to find the namespace.
kwds
Expand Down Expand Up @@ -577,6 +584,8 @@ def refine(self, params=None, t_start=None, normalization=None, level=0, **kwds)
if normalization is None:
normalization = getattr(self.metric, 'normalization', 1.)

callback = callback_setup(callback, None)

# Set up Parameter objects
parameters = lmfit.Parameters()
for param_name in self.parameter_names:
Expand Down Expand Up @@ -621,7 +630,20 @@ def _calc_error(params):
residual = trace[:, t_start_steps:] - self.output[:, t_start_steps:]
return residual.flatten() * normalization

result = lmfit.minimize(_calc_error, parameters, **kwds)
tested_parameters = []
errors = []
def _callback_wrapper(params, iter, resid, *args, **kwds):
error = mean(resid**2)
params = {p: float(val) for p, val in params.items()}
tested_parameters.append(params)
errors.append(error)
best_idx = argmin(errors)
best_error = errors[best_idx]
best_params = tested_parameters[best_idx]
return callback(params, errors, best_params, best_error, iter)

result = lmfit.minimize(_calc_error, parameters,
iter_cb=_callback_wrapper, **kwds)

if needs_device_reset:
reset_device()
Expand Down
3 changes: 1 addition & 2 deletions brian2modelfitting/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def test_callback_none():


def test_ProgressBar():
pb = ProgressBar(toolbar_width=10)
assert_equal(pb.toolbar_width, 10)
pb = ProgressBar(total=10)
assert isinstance(pb.t, tqdm.tqdm)
pb([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2)

Expand Down

0 comments on commit 0fb9779

Please sign in to comment.