Skip to content

Commit

Permalink
Allow specifying iter_cb for refine (instead of normal callback)
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 5, 2020
1 parent 59ed8fa commit 5110c68
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
25 changes: 20 additions & 5 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from brian2.parsing.sympytools import sympy_to_str, str_to_sympy
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros
from brian2 import (NeuronGroup, defaultclock, get_device, Network,
from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity)
Quantity, get_logger)
from brian2.input import TimedArray
from brian2.equations.equations import Equations
from brian2.devices import set_device, reset_device, device
Expand All @@ -18,6 +18,8 @@
from .utils import callback_setup, make_dic


logger = get_logger(__name__)

def get_param_dic(params, param_names, n_traces, n_samples):
"""Transform parameters into a dictionary of appropiate size"""
params = array(params)
Expand Down Expand Up @@ -634,7 +636,7 @@ def refine(self, params=None, t_start=None, normalization=None,
if normalization is None:
normalization = getattr(self.metric, 'normalization', 1.)

callback = callback_setup(callback, None)
callback_func = callback_setup(callback, None)

# Set up Parameter objects
parameters = lmfit.Parameters()
Expand Down Expand Up @@ -705,13 +707,26 @@ def _callback_wrapper(params, iter, resid, *args, **kwds):
best_idx = argmin(errors)
best_error = errors[best_idx]
best_params = tested_parameters[best_idx]
return callback(params, errors, best_params, best_error, iter)
return callback_func(params, errors, best_params, best_error, iter)

assert 'Dfun' not in kwds
if calc_gradient:
kwds.update({'Dfun': _calc_gradient})
if 'iter_cb' in kwds:
# Use the given callback but raise a warning if callback is not
# set to None
if callback is not None:
logger.warn('The iter_cb keyword has been specified together '
f'with callback={callback!r}. Only the iter_cb '
'callback will be used. Use the standard '
'callback mechanism or set callback=None to '
'remove this warning.',
name_suffix='iter_cb_callback')
iter_cb = kwds.pop('iter_cb')
else:
iter_cb = _callback_wrapper
result = lmfit.minimize(_calc_error, parameters,
iter_cb=_callback_wrapper,
iter_cb=iter_cb,
**kwds)

if needs_device_reset:
Expand Down
1 change: 0 additions & 1 deletion brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def test_fitter_refine_tstart(setup_constant):
t_start=50*dt)

# Fit should be close to 20mV
print(params['c'])
assert np.abs(params['c']*volt - 20*mV) < 1*mV


Expand Down

0 comments on commit 5110c68

Please sign in to comment.