Skip to content

Commit

Permalink
Merge ab5ef01 into 8ccb68e
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 23, 2020
2 parents 8ccb68e + ab5ef01 commit 685573e
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 4 deletions.
6 changes: 6 additions & 0 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,12 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
if metric.t_weights is not None:
if not metric.t_weights.shape == (self.output.shape[1], ):
raise ValueError("The 't_weights' argument of the metric has "
"to be a one-dimensional array of length "
f"{self.output.shape[1]} but has shape "
f"{metric.t_weights.shape}")
self.bounds = dict(params)
best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart,
Expand Down
49 changes: 45 additions & 4 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,39 @@ class TraceMetric(Metric):
Input traces have to be shaped into 2D array.
"""

@check_units(t_start=second)
def __init__(self, t_start=0*second, t_weights=None, normalization=1.,
**kwds):
"""
Initialize the metric.
Parameters
----------
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Start of time window considered for calculating the fit error.
t_weights : `~.ndarray`, optional
A 1-dimensional array of weights for each time point. This array
has to have the same size as the input/output traces that are used
for fitting. A value of 0 means that data points are ignored, and
a value of 1 means that a data point enters into the calculation
as usual. Values of > 1 will emphasize the error by the respective
factor at that point, a value of < 1 will de-emphasize it. Cannot
be combined with ``t_start``.
normalization : float, optional
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
algorithm makes assumptions about the scale of errors, e.g. if the
size of steps in the parameter space depends on the absolute value
of the error. Trace-based metrics divide the traces itself by the
value, other metrics use it to scale the total error. Not used by
default, i.e. defaults to 1.
"""
if t_weights is not None and t_start != 0*second:
raise ValueError("Cannot use both 't_weigths' and 't_start'.")
super(TraceMetric, self).__init__(t_start=t_start,
normalization=normalization)
self.t_weights = t_weights

def calc(self, model_traces, data_traces, dt):
"""
Perform the error calculation across all parameters,
Expand Down Expand Up @@ -266,9 +299,14 @@ def calc(self, model_traces, data_traces, dt):
``(n_samples, )``.
"""
start_steps = int(round(self.t_start/dt))
features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization),
data_traces[:, start_steps:] * float(self.normalization),
dt)
if self.t_weights is not None:
features = self.get_features(model_traces * float(self.normalization),
data_traces * float(self.normalization),
dt)
else:
features = self.get_features(model_traces[:, :, start_steps:] * float(self.normalization),
data_traces[:, start_steps:] * float(self.normalization),
dt)
errors = self.get_errors(features)

return errors
Expand Down Expand Up @@ -390,7 +428,10 @@ class MSEMetric(TraceMetric):
def get_features(self, model_traces, data_traces, dt):
# Note that the traces have already beeen normalized in
# TraceMetric.calc
return ((model_traces - data_traces)**2).mean(axis=2)
error = (model_traces - data_traces)**2
if self.t_weights is not None:
error *= self.t_weights
return error.mean(axis=2)

def get_errors(self, features):
return features.mean(axis=1)
Expand Down
31 changes: 31 additions & 0 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,37 @@ def test_calc_mse_t_start():
assert_equal(mse.calc(inp, out, 0.1*ms), np.zeros(5))


def test_calc_mse_t_weights():
with pytest.raises(ValueError):
MSEMetric(t_start=1*ms, t_weights=np.ones(20))
weights = np.ones(20)
weights[:10] = 0
mse = MSEMetric(t_weights=weights)
out = np.random.rand(2, 20)
inp = np.random.rand(5, 2, 20)

errors = mse.calc(inp, out, 0.1*ms)
assert_equal(np.shape(errors), (5,))
assert(np.all(errors > 0))
# Everything before 1ms should be ignored, so having the same values for
# the rest should give an error of 0
inp[:, :, 10:] = out[None, :, 10:]
assert_equal(mse.calc(inp, out, 0.1*ms), np.zeros(5))


def test_calc_mse_t_weights2():
mse = MSEMetric()
mse_weighted = MSEMetric(t_weights=2*np.ones(20))
out = np.random.rand(2, 20)
inp = np.random.rand(5, 2, 20)

errors = mse.calc(inp, out, 0.1*ms)
assert_equal(np.shape(errors), (5,))
assert(np.all(errors > 0))
errors_weighted = mse_weighted.calc(inp, out, 0.1*ms)
assert_almost_equal(errors_weighted, 2*errors)


def test_calc_gf():
assert_raises(TypeError, GammaFactor)
assert_raises(DimensionMismatchError, GammaFactor, delta=10*mV)
Expand Down
20 changes: 20 additions & 0 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,26 @@ def test_fitter_fit_tstart(setup_constant):
# Fit should be close to 20mV
assert np.abs(params['c'] - 20*mV) < 1*mV


def test_fitter_fit_tsteps(setup_constant):
dt, tf = setup_constant

with pytest.raises(ValueError):
# Incorrect weight size
tf.fit(n_rounds=10, optimizer=n_opt,
metric=MSEMetric(t_weights=np.ones(101)),
c=[0 * mV, 30 * mV])

# Ignore the first 50 steps at 10mV
weights = np.ones(100)
weights[:50] = 0
params, result = tf.fit(n_rounds=10, optimizer=n_opt,
metric=MSEMetric(t_weights=weights),
c=[0 * mV, 30 * mV])
# Fit should be close to 20mV
assert np.abs(params['c'] - 20*mV) < 1*mV


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine(setup):
dt, tf = setup
Expand Down
14 changes: 14 additions & 0 deletions docs_sphinx/metric/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ calculation.
metric = MSEMetric(t_start=5*ms)
Alternatively, the user can specify a weight vector emphasizing/de-emphasizing
certain parts of the trace. For example, to ignore the first 5ms and to weigh
the error between 10 and 15ms twice as high as the rest:

.. code:: python
# total trace length = 50ms
weights = np.ones(int(50*ms/dt))
weights[:int(5*ms/dt)] = 0
weights[int(10*ms/dt):int(15*ms/dt)] = 2
metric = MSEMetric(t_weights=weights)
Note that the ``t_weights`` argument cannot be combined with ``t_start``.

In `~brian2modelfitting.fitter.OnlineTraceFitter`,
the mean square error gets calculated in online manner, with no need of
specifying a metric object.
Expand Down

0 comments on commit 685573e

Please sign in to comment.