Skip to content

Commit

Permalink
Merge e613f1b into eb8521f
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Feb 9, 2021
2 parents eb8521f + e613f1b commit 238bfc6
Show file tree
Hide file tree
Showing 23 changed files with 916 additions and 269 deletions.
632 changes: 472 additions & 160 deletions brian2modelfitting/fitter.py

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,22 @@ def calc(self, model_results, data_results, dt):
"""
pass

def revert_normalization(self, error):
"""
Revert the normalization to recover the error before normalization.
Parameters
----------
error : Quantity or float
The normalized error.
Returns
-------
raw_error : Quantity or float
The error before normalization
"""
return error / self.normalization


class TraceMetric(Metric):
"""
Expand Down Expand Up @@ -457,13 +473,13 @@ def get_dimensions(self, output_dim):
def get_normalized_dimensions(self, output_dim):
return output_dim**2 * get_dimensions(self.normalization)**2

def revert_normalization(self, error):
return error / self.normalization**2


class FeatureMetric(TraceMetric):
def __init__(self, stim_times, feat_list, weights=None, combine=None,
t_start=0*second, normalization=1.):
if normalization != 1:
raise ValueError('Do not set the normalization factor when using '
'the FeatureMetric, use weights instead.')
super(FeatureMetric, self).__init__(t_start=t_start,
normalization=normalization)
self.stim_times = stim_times
Expand Down
4 changes: 0 additions & 4 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,6 @@ def test_get_features_feature_metric():
inp_times = [[99 * ms, 150 * ms], [49 * ms, 150 * ms]]

# Default comparison: absolute difference
# Check that FeatureMetric rejects the normalization argument
with pytest.raises(ValueError):
feature_metric = FeatureMetric(inp_times, ['voltage_base'],
normalization=2)
feature_metric = FeatureMetric(inp_times, ['voltage_base'])
results = feature_metric.get_features(voltage_model, voltage_target, dt=dt)
assert len(results) == 3
Expand Down
2 changes: 1 addition & 1 deletion brian2modelfitting/tests/test_modelfitting_spikefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_spikefitter_fit(setup):
for attr in attr_fit:
assert hasattr(sf, attr)

assert isinstance(sf.metric, Metric)
assert len(sf.metric) == 1 and isinstance(sf.metric[0], Metric)
assert isinstance(sf.optimizer, Optimizer)

assert isinstance(results, dict)
Expand Down
Loading

0 comments on commit 238bfc6

Please sign in to comment.