Skip to content

Commit

Permalink
Turn the normalization factor back into divisive
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 5, 2020
1 parent 5110c68 commit 62e0fcb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
15 changes: 9 additions & 6 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,13 @@ def refine(self, params=None, t_start=None, normalization=None,
calculation. If not set, will reuse the `t_start` value from the
previously used metric.
normalization : float, optional
A normalization factor that will be multiplied with the total error
before handing it to the optimization algorithm. Can be useful if
the algorithm makes assumptions about the scale of errors, e.g. if
the size of steps in the parameter space depends on the absolute
value of the error. If not set, will reuse the `normalization` value
from the previously used metric.
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
algorithm makes assumptions about the scale of errors, e.g. if the
size of steps in the parameter space depends on the absolute value
of the error. The difference between simulated and target traces
will be divided by this value. If not set, will reuse the
`normalization` value from the previously used metric.
callback: `str` or `~typing.Callable`
Either the name of a provided callback function (``text`` or
``progressbar``), or a custom feedback function
Expand Down Expand Up @@ -635,6 +636,8 @@ def refine(self, params=None, t_start=None, normalization=None,
t_start = getattr(self.metric, 't_start', 0*second)
if normalization is None:
normalization = getattr(self.metric, 'normalization', 1.)
else:
normalization = 1/normalization

callback_func = callback_setup(callback, None)

Expand Down
16 changes: 8 additions & 8 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ def __init__(self, t_start=0*second, normalization=1., **kwds):
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Start of time window considered for calculating the fit error.
normalization : float, optional
A normalization factor that will be used before handing results to
the optimization algorithm. Can be useful if the algorithm makes
assumptions about the scale of errors, e.g. if the size of steps in
the parameter space depends on the absolute value of the error.
Trace-based metrics multiply the factor with the traces itself,
other metrics use it to scale the total error. Not used by default,
i.e. defaults to 1.
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
algorithm makes assumptions about the scale of errors, e.g. if the
size of steps in the parameter space depends on the absolute value
of the error. Trace-based metrics divide the traces itself by the
value, other metrics use it to scale the total error. Not used by
default, i.e. defaults to 1.
"""
self.t_start = t_start
self.normalization = normalization
self.normalization = 1/normalization

@abc.abstractmethod
def get_features(self, model_results, target_results, dt):
Expand Down
2 changes: 1 addition & 1 deletion brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_calc_mse():
out = np.ones((3, 10))
errors = mse.calc(inp, out, 0.01*ms)
assert_equal(errors, [0, 1])
mse = MSEMetric(normalization=2)
mse = MSEMetric(normalization=1/2)
errors = mse.calc(inp, out, 0.01 * ms)
# The normalization factor scales the traces, so the squared error scales
# with the square of the normalization factor
Expand Down

0 comments on commit 62e0fcb

Please sign in to comment.