Skip to content

Commit

Permalink
Merge branch 'master' into remove_lmfit
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Oct 27, 2021
2 parents c253eda + cfb1f61 commit 16a738f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
except ImportError:
warnings.warn('eFEL package not found.')
from itertools import repeat
from brian2 import Hz, second, Quantity, ms, us, get_dimensions
from brian2 import Hz, second, Quantity, ms, us, get_dimensions, mV
from brian2.units.fundamentalunits import check_units, in_unit, DIMENSIONLESS
from numpy import (array, sum, abs, amin, digitize, rint, arange, inf, NaN,
clip, mean)
Expand Down Expand Up @@ -101,7 +101,7 @@ def calc_eFEL(traces, inp_times, feat_list, dt):
time = arange(0, len(trace))*dt/ms
temp_trace = {}
temp_trace['T'] = time
temp_trace['V'] = array(trace, copy=False)
temp_trace['V'] = trace/mV
temp_trace['stim_start'] = [inp_times[i][0]]
temp_trace['stim_end'] = [inp_times[i][1]]
out_traces.append(temp_trace)
Expand Down
12 changes: 6 additions & 6 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def test_calc_EFL():
results = calc_eFEL(voltage, inp_times, ['voltage_base'], dt=dt)
assert len(results) == 2
assert all(res.keys() == {'voltage_base'} for res in results)
assert_almost_equal(results[0]['voltage_base'], float(-60*mV))
assert_almost_equal(results[1]['voltage_base'], float(-70*mV))
assert_almost_equal(results[0]['voltage_base'], -60)
assert_almost_equal(results[1]['voltage_base'], -70)


def test_get_features_feature_metric():
Expand All @@ -251,9 +251,9 @@ def test_get_features_feature_metric():
assert len(results) == 3
assert all(isinstance(r, dict) for r in results)
assert all(r.keys() == {'voltage_base'} for r in results)
assert_almost_equal(results[0]['voltage_base'], np.array([2.5*mV, 5*mV]))
assert_almost_equal(results[0]['voltage_base'], [2.5, 5])
assert_almost_equal(results[1]['voltage_base'], [0, 0])
assert_almost_equal(results[2]['voltage_base'], np.array([2.5*mV, 5*mV]))
assert_almost_equal(results[2]['voltage_base'], [2.5, 5])

# Custom comparison: squared difference
feature_metric = FeatureMetric(inp_times, ['voltage_base'],
Expand All @@ -262,9 +262,9 @@ def test_get_features_feature_metric():
assert len(results) == 3
assert all(isinstance(r, dict) for r in results)
assert all(r.keys() == {'voltage_base'} for r in results)
assert_almost_equal(results[0]['voltage_base'], np.array([(2.5*mV)**2, (5*mV)**2]))
assert_almost_equal(results[0]['voltage_base'], [2.5**2, 5**2])
assert_almost_equal(results[1]['voltage_base'], [0, 0])
assert_almost_equal(results[2]['voltage_base'], np.array([(2.5*mV)**2, (5*mV)**2]))
assert_almost_equal(results[2]['voltage_base'], [2.5**2, 5**2])


def test_get_errors_feature_metric():
Expand Down

0 comments on commit 16a738f

Please sign in to comment.