diff --git a/brian2modelfitting/fitter.py b/brian2modelfitting/fitter.py index ca841cd0..89b41de7 100644 --- a/brian2modelfitting/fitter.py +++ b/brian2modelfitting/fitter.py @@ -582,12 +582,13 @@ def refine(self, params=None, t_start=None, normalization=None, calculation. If not set, will reuse the `t_start` value from the previously used metric. normalization : float, optional - A normalization factor that will be multiplied with the total error - before handing it to the optimization algorithm. Can be useful if - the algorithm makes assumptions about the scale of errors, e.g. if - the size of steps in the parameter space depends on the absolute - value of the error. If not set, will reuse the `normalization` value - from the previously used metric. + A normalization term that will be used rescale results before + handing them to the optimization algorithm. Can be useful if the + algorithm makes assumptions about the scale of errors, e.g. if the + size of steps in the parameter space depends on the absolute value + of the error. The difference between simulated and target traces + will be divided by this value. If not set, will reuse the + `normalization` value from the previously used metric. callback: `str` or `~typing.Callable` Either the name of a provided callback function (``text`` or ``progressbar``), or a custom feedback function @@ -635,6 +636,8 @@ def refine(self, params=None, t_start=None, normalization=None, t_start = getattr(self.metric, 't_start', 0*second) if normalization is None: normalization = getattr(self.metric, 'normalization', 1.) + else: + normalization = 1/normalization callback_func = callback_setup(callback, None) diff --git a/brian2modelfitting/metric.py b/brian2modelfitting/metric.py index b317c263..d27b720b 100644 --- a/brian2modelfitting/metric.py +++ b/brian2modelfitting/metric.py @@ -125,16 +125,16 @@ def __init__(self, t_start=0*second, normalization=1., **kwds): t_start : `~brian2.units.fundamentalunits.Quantity`, optional Start of time window considered for calculating the fit error. normalization : float, optional - A normalization factor that will be used before handing results to - the optimization algorithm. Can be useful if the algorithm makes - assumptions about the scale of errors, e.g. if the size of steps in - the parameter space depends on the absolute value of the error. - Trace-based metrics multiply the factor with the traces itself, - other metrics use it to scale the total error. Not used by default, - i.e. defaults to 1. + A normalization term that will be used rescale results before + handing them to the optimization algorithm. Can be useful if the + algorithm makes assumptions about the scale of errors, e.g. if the + size of steps in the parameter space depends on the absolute value + of the error. Trace-based metrics divide the traces itself by the + value, other metrics use it to scale the total error. Not used by + default, i.e. defaults to 1. """ self.t_start = t_start - self.normalization = normalization + self.normalization = 1/normalization @abc.abstractmethod def get_features(self, model_results, target_results, dt): diff --git a/brian2modelfitting/tests/test_metric.py b/brian2modelfitting/tests/test_metric.py index afe313b5..33bcc69e 100644 --- a/brian2modelfitting/tests/test_metric.py +++ b/brian2modelfitting/tests/test_metric.py @@ -62,7 +62,7 @@ def test_calc_mse(): out = np.ones((3, 10)) errors = mse.calc(inp, out, 0.01*ms) assert_equal(errors, [0, 1]) - mse = MSEMetric(normalization=2) + mse = MSEMetric(normalization=1/2) errors = mse.calc(inp, out, 0.01 * ms) # The normalization factor scales the traces, so the squared error scales # with the square of the normalization factor