Skip to content

Commit

Permalink
Merge 46105c0 into ed6e426
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Feb 17, 2020
2 parents ed6e426 + 46105c0 commit e8945d5
Show file tree
Hide file tree
Showing 14 changed files with 322 additions and 21 deletions.
5 changes: 2 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ dist: xenial
language: python
python:
- "3.6"
- "3.7"
- "3.8"

# command to install dependencies
install:
- pip install pytest-coverage
- pip install coveralls
- pip install pytest-coverage coveralls lmfit
- pip install -r requirements.txt
- pip install .

Expand Down
115 changes: 115 additions & 0 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,9 @@ def __init__(self, model, input_var, input, output_var, output, dt,
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
param_init)
# We store the bounds set in TraceFitter.fit, so that Tracefitter.refine
# can
self.bounds = None

if output_var not in self.model.names:
raise NameError("%s is not a model variable" % output_var)
Expand Down Expand Up @@ -491,6 +494,7 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
self.bounds = dict(params)
self.best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart, **params)
return self.best_params, error
Expand All @@ -501,6 +505,117 @@ def generate_traces(self, params=None, param_init=None, level=0):
param_init=param_init, level=level+1)
return fits

def refine(self, params=None, t_start=None, level=0, **kwds):
"""
Refine the fitting results with a sequentially operating minimization
algorithm. Uses the `lmfit <https://lmfit.github.io/lmfit-py/>`_
package which itself makes use of
`scipy.optimize <https://docs.scipy.org/doc/scipy/reference/optimize.html>`_.
Has to be called after `~.TraceFitter.fit`, but a call with
``n_rounds=0`` is enough.
Parameters
----------
params : dict, optional
A dictionary with the parameters to use as a starting point for the
refinement. If not given, the best parameters found so far by
`~.TraceFitter.fit` will be used.
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Initial simulation/model time that should be ignored for the error
calculation. If the metric used in previous fits has a `t_start`
option (this is the case for `MSEMetric`), this will be reused if
not otherwise set. If no such metric has been used previously,
the default is 0s.
level : int, optional
How much farther to go down in the stack to find the namespace.
kwds
Additional arguments can overwrite the bounds for individual
parameters (if not given, the bounds previously specified in the
call to `~.TraceFitter.fit` will be used). All other arguments will
be passed on to `.lmfit.minimize` and can be used to e.g. change the
method, or to specify method-specific arguments.
Returns
-------
parameters : dict
The parameters at the end of the optimization process as a
dictionary.
result : `.lmfit.MinimizerResult`
The result of the optimization process.
Notes
-----
The default method used by `lmfit` is least-squares minimization using
a Levenberg-Marquardt method. Note that there is no support for
specifying a `Metric`, the given output trace(s) will be subtracted
from the simulated trace(s) and passed on to the minimization algorithm.
This method always uses the runtime mode, independent of the selection
of the current device.
"""
try:
import lmfit
except ImportError:
raise ImportError('Refinement needs the "lmfit" package.')
if params is None:
if self.best_params is None:
raise TypeError('You need to either specify parameters or run '
'the fit function first.')
params = self.best_params

if t_start is None:
t_start = getattr(self.metric, 't_start', 0*second)

# Set up Parameter objects
parameters = lmfit.Parameters()
for param_name in self.parameter_names:
if param_name not in kwds:
if self.bounds is None:
raise TypeError('You need to either specify bounds for all '
'parameters or run the fit function first.')
min_bound, max_bound = self.bounds[param_name]
else:
min_bound, max_bound = kwds.pop(param_name)
parameters.add(param_name, value=array(params[param_name]),
min=array(min_bound), max=array(max_bound))

needs_device_reset = False
if isinstance(get_device(), CPPStandaloneDevice):
set_device('runtime')
simulator = RuntimeSimulator()
needs_device_reset = True
else:
simulator = self.simulator

namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': self.n_traces,
'output_var': self.output_var},
level=level+1)
neurons = self.setup_neuron_group(self.n_traces, namespace,
name='neurons')
monitor = StateMonitor(neurons, self.output_var, record=True,
name='monitor')
network = Network(neurons, monitor)

simulator.initialize(network, self.param_init, name='refine')

t_start_steps = int(round(t_start / self.dt))

def _calc_error(params):
simulator.run(self.duration, {p: float(val)
for p, val in params.items()},
self.parameter_names, name='refine')
trace = getattr(simulator.networks['refine']['monitor'],
self.output_var+'_')
return (trace[:, t_start_steps:] - self.output[:, t_start_steps:]).flatten()

result = lmfit.minimize(_calc_error, parameters, **kwds)

if needs_device_reset:
reset_device()

return {p: float(val) for p, val in result.params.items()}, result


class SpikeFitter(Fitter):
def __init__(self, model, input, output, dt, reset, threshold,
Expand Down
136 changes: 132 additions & 4 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import pytest
import numpy as np
import pandas as pd
try:
import lmfit
except ImportError:
lmfit = None
from numpy.testing.utils import assert_equal
from brian2 import (zeros, Equations, NeuronGroup, StateMonitor, TimedArray,
nS, mV, volt, ms, Quantity, set_device, get_device, Network)
Expand Down Expand Up @@ -32,11 +36,15 @@
g : siemens (constant)
'''

constant_model = Equations('''
v = c + x: volt
c : volt (constant)''')

n_opt = NevergradOptimizer()
metric = MSEMetric()


@pytest.fixture()
@pytest.fixture
def setup(request):
dt = 0.01 * ms
tf = TraceFitter(dt=dt,
Expand All @@ -54,7 +62,27 @@ def fin():
return dt, tf


@pytest.fixture()
@pytest.fixture
def setup_constant(request):
dt = 0.1 * ms
# Membrane potential is constant at 10mV for first 50 steps, then at 20mV
out_trace = np.hstack([np.ones(50) * 10 * mV, np.ones(50) * 20 * mV])
tf = TraceFitter(dt=dt,
model=constant_model,
input_var='x',
output_var='v',
input=(np.zeros(100)*mV)[None, :],
output=out_trace[None, :],
n_samples=100,)

def fin():
reinit_devices()
request.addfinalizer(fin)

return dt, tf


@pytest.fixture
def setup_online(request):
dt = 0.01 * ms

Expand All @@ -73,8 +101,10 @@ def fin():
return dt, otf


@pytest.fixture()
@pytest.fixture
def setup_standalone(request):
# Workaround to avoid issues with Network instances still around
Network.__instances__().clear()
set_device('cpp_standalone', directory=None)
dt = 0.01 * ms
tf = TraceFitter(dt=dt,
Expand Down Expand Up @@ -159,7 +189,8 @@ def test_fitter_fit(setup):
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS])
g=[1*nS, 30*nS],
callback=None)

attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
Expand Down Expand Up @@ -191,6 +222,103 @@ def test_fitter_fit_errors(setup):
g=[1*nS, 30*nS])


def test_fitter_fit_tstart(setup_constant):
dt, tf = setup_constant

# Ignore the first 50 steps at 10mV
params, result = tf.fit(n_rounds=10, optimizer=n_opt,
metric=MSEMetric(t_start=50*dt),
c=[0 * mV, 30 * mV])
# Fit should be close to 20mV
assert np.abs(params['c']*volt - 20*mV) < 1*mV

@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine(setup):
dt, tf = setup
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS],
callback=None)
# Run refine after running fit
params, result = tf.refine()
assert result.method == 'leastsq'
assert isinstance(params, dict)
assert isinstance(result, lmfit.minimizer.MinimizerResult)

# Pass options to lmfit.minimize
params, result = tf.refine(method='least_squares')
assert result.method == 'least_squares'


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine_standalone(setup_standalone):
dt, tf = setup_standalone
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS],
callback=None)
# Run refine after running fit
params, result = tf.refine()
assert result.method == 'leastsq'
assert isinstance(params, dict)
assert isinstance(result, lmfit.minimizer.MinimizerResult)

# Pass options to lmfit.minimize
params, result = tf.refine(method='least_squares')
assert result.method == 'least_squares'


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine_direct(setup):
dt, tf = setup
# Run refine without running fit before
params, result = tf.refine({'g': 5*nS}, g=[1*nS, 30*nS])
assert isinstance(params, dict)
assert isinstance(result, lmfit.minimizer.MinimizerResult)


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine_tstart(setup_constant):
dt, tf = setup_constant

# Ignore the first 50 steps at 10mV
params, result = tf.refine({'c': 5*mV}, c=[0 * mV, 30 * mV],
t_start=50*dt)

# Fit should be close to 20mV
print(params['c'])
assert np.abs(params['c']*volt - 20*mV) < 1*mV


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine_reuse_tstart(setup_constant):
dt, tf = setup_constant

# Ignore the first 50 steps at 10mV but do not actually fit (0 rounds)
params, result = tf.fit(n_rounds=0, optimizer=n_opt,
metric=MSEMetric(t_start=50*dt),
c=[0 * mV, 30 * mV])
# t_start should be reused
params, result = tf.refine({'c': 5 * mV}, c=[0 * mV, 30 * mV])

# Fit should be close to 20mV
assert np.abs(params['c'] * volt - 20 * mV) < 1 * mV


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine_errors(setup):
dt, tf = setup
with pytest.raises(TypeError):
# Missing start parameter
tf.refine(g=[1*nS, 30*nS])

with pytest.raises(TypeError):
# Missing bounds
tf.refine({'g': 5*nS})


def test_fit_restart(setup):
dt, tf = setup
results, errors = tf.fit(n_rounds=2,
Expand Down
2 changes: 2 additions & 0 deletions brian2modelfitting/tests/test_simulation_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def fin():

@pytest.fixture()
def setup_standalone(request):
# Workaround to avoid issues with Network instances still around
Network.__instances__().clear()
set_device('cpp_standalone', directory=None)
dt = 0.1 * ms
duration = 10 * ms
Expand Down
Binary file modified docs_sphinx/_static/hh_best_fit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs_sphinx/_static/hh_best_fit_refined.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs_sphinx/_static/hh_best_fit_refined_zoom.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs_sphinx/_static/hh_best_fit_zoom.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs_sphinx/_static/hh_tutorial_input.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs_sphinx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@
'matplotlib': ('http://matplotlib.org/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None)}
'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None),
'lmfit': ('https://lmfit.github.io/lmfit-py/', None)}
11 changes: 7 additions & 4 deletions docs_sphinx/introduction/howitworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,15 @@ Example of `~brian2modelfitting.fitter.TraceFitter` with all of the necessary ar
Remarks
-------
- After performing first fitting, user can continue the optimization
with another `~brian2modelfitting.fitter.Fitter.fit()` run.
- After performing first fitting, user can continue the optimization
with another `~brian2modelfitting.fitter.Fitter.fit()` run.

- Number of samples can not be changed between rounds or `~brian2modelfitting.fitter.Fitter.fit()`
calls, due to parallelization of the simulations.
- Number of samples can not be changed between rounds or `~brian2modelfitting.fitter.Fitter.fit()`
calls, due to parallelization of the simulations.

.. warning::
User is not allowed to change the optimizer or metric between `~brian2modelfitting.fitter.Fitter.fit()`
calls.

- When using the `~brian2modelfitting.fitter.TraceFitter`, users can use a standard
curve fitting algorithm for refinement by calling `~brian2modelfitting.fitter.TraceFitter.refine`.
Loading

0 comments on commit e8945d5

Please sign in to comment.