Skip to content

Commit

Permalink
Merge 906088c into ed6e426
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Feb 19, 2020
2 parents ed6e426 + 906088c commit cce8a86
Show file tree
Hide file tree
Showing 18 changed files with 422 additions and 53 deletions.
5 changes: 2 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ dist: xenial
language: python
python:
- "3.6"
- "3.7"
- "3.8"

# command to install dependencies
install:
- pip install pytest-coverage
- pip install coveralls
- pip install pytest-coverage coveralls lmfit
- pip install -r requirements.txt
- pip install .

Expand Down
157 changes: 153 additions & 4 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers

from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions
from numpy import ones, array, arange, concatenate, mean, nanmin, reshape
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape
from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity)
Expand Down Expand Up @@ -267,8 +267,9 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
callback: `str` or `~typing.Callable`
Either the name of a provided callback function (``text`` or
``progressbar``), or a custom feedback function
``func(results, errors, parameters, index)``. If this function
returns ``True`` the fitting execution is interrupted.
``func(parameters, errors, best_parameters, best_error, index)``.
If this function returns ``True`` the fitting execution is
interrupted.
restart: bool
Flag that reinitializes the Fitter to reset the optimization.
With restart True user is allowed to change optimizer/metric.
Expand Down Expand Up @@ -315,8 +316,11 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
# create output variables
self.best_params = make_dic(self.parameter_names, best_params)
error = nanmin(self.optimizer.errors)
param_dicts = [{p: v for p, v in zip(self.parameter_names,
one_param_set)}
for one_param_set in parameters]

if callback(parameters, errors, best_params, error, index) is True:
if callback(param_dicts, errors, self.best_params, error, index) is True:
break

return self.best_params, error
Expand Down Expand Up @@ -451,6 +455,9 @@ def __init__(self, model, input_var, input, output_var, output, dt,
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
param_init)
# We store the bounds set in TraceFitter.fit, so that Tracefitter.refine
# can
self.bounds = None

if output_var not in self.model.names:
raise NameError("%s is not a model variable" % output_var)
Expand Down Expand Up @@ -491,6 +498,7 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
self.bounds = dict(params)
self.best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart, **params)
return self.best_params, error
Expand All @@ -501,6 +509,147 @@ def generate_traces(self, params=None, param_init=None, level=0):
param_init=param_init, level=level+1)
return fits

def refine(self, params=None, t_start=None, normalization=None,
callback='text', level=0, **kwds):
"""
Refine the fitting results with a sequentially operating minimization
algorithm. Uses the `lmfit <https://lmfit.github.io/lmfit-py/>`_
package which itself makes use of
`scipy.optimize <https://docs.scipy.org/doc/scipy/reference/optimize.html>`_.
Has to be called after `~.TraceFitter.fit`, but a call with
``n_rounds=0`` is enough.
Parameters
----------
params : dict, optional
A dictionary with the parameters to use as a starting point for the
refinement. If not given, the best parameters found so far by
`~.TraceFitter.fit` will be used.
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Initial simulation/model time that should be ignored for the error
calculation. If not set, will reuse the `t_start` value from the
previously used metric.
normalization : float, optional
A normalization factor that will be multiplied with the total error
before handing it to the optimization algorithm. Can be useful if
the algorithm makes assumptions about the scale of errors, e.g. if
the size of steps in the parameter space depends on the absolute
value of the error. If not set, will reuse the `normalization` value
from the previously used metric.
callback: `str` or `~typing.Callable`
Either the name of a provided callback function (``text`` or
``progressbar``), or a custom feedback function
``func(parameters, errors, best_parameters, best_error, index)``.
If this function returns ``True`` the fitting execution is
interrupted.
level : int, optional
How much farther to go down in the stack to find the namespace.
kwds
Additional arguments can overwrite the bounds for individual
parameters (if not given, the bounds previously specified in the
call to `~.TraceFitter.fit` will be used). All other arguments will
be passed on to `.lmfit.minimize` and can be used to e.g. change the
method, or to specify method-specific arguments.
Returns
-------
parameters : dict
The parameters at the end of the optimization process as a
dictionary.
result : `.lmfit.MinimizerResult`
The result of the optimization process.
Notes
-----
The default method used by `lmfit` is least-squares minimization using
a Levenberg-Marquardt method. Note that there is no support for
specifying a `Metric`, the given output trace(s) will be subtracted
from the simulated trace(s) and passed on to the minimization algorithm.
This method always uses the runtime mode, independent of the selection
of the current device.
"""
try:
import lmfit
except ImportError:
raise ImportError('Refinement needs the "lmfit" package.')
if params is None:
if self.best_params is None:
raise TypeError('You need to either specify parameters or run '
'the fit function first.')
params = self.best_params

if t_start is None:
t_start = getattr(self.metric, 't_start', 0*second)
if normalization is None:
normalization = getattr(self.metric, 'normalization', 1.)

callback = callback_setup(callback, None)

# Set up Parameter objects
parameters = lmfit.Parameters()
for param_name in self.parameter_names:
if param_name not in kwds:
if self.bounds is None:
raise TypeError('You need to either specify bounds for all '
'parameters or run the fit function first.')
min_bound, max_bound = self.bounds[param_name]
else:
min_bound, max_bound = kwds.pop(param_name)
parameters.add(param_name, value=array(params[param_name]),
min=array(min_bound), max=array(max_bound))

needs_device_reset = False
if isinstance(get_device(), CPPStandaloneDevice):
set_device('runtime')
simulator = RuntimeSimulator()
needs_device_reset = True
else:
simulator = self.simulator

namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': self.n_traces,
'output_var': self.output_var},
level=level+1)
neurons = self.setup_neuron_group(self.n_traces, namespace,
name='neurons')
monitor = StateMonitor(neurons, self.output_var, record=True,
name='monitor')
network = Network(neurons, monitor)

simulator.initialize(network, self.param_init, name='refine')

t_start_steps = int(round(t_start / self.dt))

def _calc_error(params):
simulator.run(self.duration, {p: float(val)
for p, val in params.items()},
self.parameter_names, name='refine')
trace = getattr(simulator.networks['refine']['monitor'],
self.output_var+'_')
residual = trace[:, t_start_steps:] - self.output[:, t_start_steps:]
return residual.flatten() * normalization

tested_parameters = []
errors = []
def _callback_wrapper(params, iter, resid, *args, **kwds):
error = mean(resid**2)
params = {p: float(val) for p, val in params.items()}
tested_parameters.append(params)
errors.append(error)
best_idx = argmin(errors)
best_error = errors[best_idx]
best_params = tested_parameters[best_idx]
return callback(params, errors, best_params, best_error, iter)

result = lmfit.minimize(_calc_error, parameters,
iter_cb=_callback_wrapper, **kwds)

if needs_device_reset:
reset_device()

return {p: float(val) for p, val in result.params.items()}, result


class SpikeFitter(Fitter):
def __init__(self, model, input, output, dt, reset, threshold,
Expand Down
50 changes: 31 additions & 19 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,25 @@ class Metric(metaclass=abc.ABCMeta):
"""

@check_units(t_start=second)
def __init__(self, t_start=0*second, **kwds):
def __init__(self, t_start=0*second, normalization=1., **kwds):
"""
Initialize the metric.
Parameters
----------
t_start: `~brian2.units.fundamentalunits.Quantity`, optional
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Start of time window considered for calculating the fit error.
normalization : float, optional
A normalization factor that will be used before handing results to
the optimization algorithm. Can be useful if the algorithm makes
assumptions about the scale of errors, e.g. if the size of steps in
the parameter space depends on the absolute value of the error.
Trace-based metrics multiply the factor with the traces itself,
other metrics use it to scale the total error. Not used by default,
i.e. defaults to 1.
"""
self.t_start = t_start
self.normalization = normalization

@abc.abstractmethod
def get_features(self, model_results, target_results, dt):
Expand Down Expand Up @@ -217,8 +226,8 @@ def calc(self, model_traces, data_traces, dt):
``(n_samples, )``.
"""
start_steps = int(round(self.t_start/dt))
features = self.get_features(model_traces[:, :, start_steps:],
data_traces[:, start_steps:],
features = self.get_features(model_traces[:, :, start_steps:] * self.normalization,
data_traces[:, start_steps:] * self.normalization,
dt)
errors = self.get_errors(features)

Expand Down Expand Up @@ -300,7 +309,7 @@ def calc(self, model_spikes, data_spikes, dt):
model_spikes = relevant_model_spikes
data_spikes = relevant_data_spikes
features = self.get_features(model_spikes, data_spikes, dt)
errors = self.get_errors(features)
errors = self.get_errors(features) * self.normalization

return errors

Expand Down Expand Up @@ -337,21 +346,24 @@ class MSEMetric(TraceMetric):
"""
Mean Square Error between goal and calculated output.
"""
def __init__(self, t_start=0*second, normalization=1.):
super(MSEMetric, self).__init__(t_start=t_start)
self.normalization_factor = 1./float(normalization) # A normalization factor for the traces

def get_features(self, model_traces, data_traces, dt):
return (((model_traces - data_traces)*self.normalization_factor)**2).mean(axis=2)
# Note that the traces have already beeen normalized in
# TraceMetric.calc
return ((model_traces - data_traces)**2).mean(axis=2)

def get_errors(self, features):
return features.mean(axis=1)


class FeatureMetric(TraceMetric):
def __init__(self, stim_times, feat_list, weights=None, combine=None,
t_start=0*second):
super(FeatureMetric, self).__init__(t_start=t_start)
t_start=0*second, normalization=1.):
if normalization != 1:
raise ValueError('Do not set the normalization factor when using '
'the FeatureMetric, use weights instead.')
super(FeatureMetric, self).__init__(t_start=t_start,
normalization=normalization)
self.stim_times = stim_times
if isinstance(self.stim_times[0][0], Quantity):
for n, trace in enumerate(self.stim_times):
Expand All @@ -368,10 +380,8 @@ def combine(x, y):
self.combine = combine

if weights is None:
weights = {}
for key in feat_list:
weights[key] = 1
if type(weights) is not dict:
weights = {key: 1 for key in feat_list}
if not isinstance(weights, dict):
raise TypeError("Weights has to be a dictionary!")

self.weights = weights
Expand Down Expand Up @@ -447,7 +457,7 @@ def get_errors(self, features):
sample_error += total
errors.append(sample_error)

return errors
return array(errors) * self.normalization


class GammaFactor(SpikeMetric):
Expand All @@ -467,7 +477,8 @@ class GammaFactor(SpikeMetric):
"""

@check_units(delta=second, time=second, t_start=0*second)
def __init__(self, delta, time, t_start=0*second, rate_correction=True):
def __init__(self, delta, time, t_start=0*second, normalization=1.,
rate_correction=True):
"""
Initialize the metric with time window delta and time step dt output
Expand All @@ -482,7 +493,8 @@ def __init__(self, delta, time, t_start=0*second, rate_correction=True):
rate, following `Clopath et al., Neurocomputing (2007)
<https://doi.org/10.1016/j.neucom.2006.10.047>`_.
"""
super(GammaFactor, self).__init__(t_start=t_start)
super(GammaFactor, self).__init__(t_start=t_start,
normalization=normalization)
self.delta = delta
self.time = time
self.rate_correction = rate_correction
Expand All @@ -497,7 +509,7 @@ def get_features(self, traces, output, dt):
rate_correction=self.rate_correction)
gf_for_sample.append(gf)
all_gf.append(gf_for_sample)
return array(all_gf)
return array(all_gf) * self.normalization

def get_errors(self, features):
errors = features.mean(axis=1)
Expand Down
20 changes: 20 additions & 0 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
'''
Test the metric class
'''
import pytest

import numpy as np
from numpy.testing.utils import assert_equal, assert_raises, assert_almost_equal
from brian2 import ms, mV
Expand Down Expand Up @@ -56,6 +58,16 @@ def test_calc_mse():
np.zeros(5))
assert(np.all(mse.calc(inp, out, 0.1*ms) > 0))

inp = np.vstack([np.ones((1, 3, 10)), np.zeros((1, 3, 10))])
out = np.ones((3, 10))
errors = mse.calc(inp, out, 0.01*ms)
assert_equal(errors, [0, 1])
mse = MSEMetric(normalization=2)
errors = mse.calc(inp, out, 0.01 * ms)
# The normalization factor scales the traces, so the squared error scales
# with the square of the normalization factor
assert_equal(errors, [0, 4])


def test_calc_mse_t_start():
mse = MSEMetric(t_start=1*ms)
Expand Down Expand Up @@ -128,6 +140,10 @@ def test_get_features_gamma():
features = gf.get_features(model_spikes, data_spikes, 0.1*ms)
assert_equal(np.shape(features), (2, 2))
assert(np.all(np.array(features) > -1))
normed_gf = GammaFactor(delta=0.5 * ms, time=10 * ms, normalization=2.)
normed_features = normed_gf.get_features(model_spikes, data_spikes,
0.1 * ms)
assert_equal(normed_features, 2*features)

features = gf.get_features([data_spikes]*3, data_spikes, 0.1*ms)
assert_equal(np.shape(features), (3, 2))
Expand Down Expand Up @@ -175,6 +191,10 @@ def test_get_features_feature_metric():
inp_times = [[99 * ms, 150 * ms], [49 * ms, 150 * ms]]

# Default comparison: absolute difference
# Check that FeatureMetric rejects the normalization argument
with pytest.raises(ValueError):
feature_metric = FeatureMetric(inp_times, ['voltage_base'],
normalization=2)
feature_metric = FeatureMetric(inp_times, ['voltage_base'])
results = feature_metric.get_features(voltage_model, voltage_target, dt=dt)
assert len(results) == 3
Expand Down
Loading

0 comments on commit cce8a86

Please sign in to comment.