Skip to content

Commit

Permalink
Add a number of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 13, 2020
1 parent e0c6449 commit db10a53
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 7 deletions.
2 changes: 1 addition & 1 deletion brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def _callback_wrapper(params, iter, resid, *args, **kwds):
class SpikeFitter(Fitter):
def __init__(self, model, input, output, dt, reset, threshold,
input_var='I', refractory=False, n_samples=30,
method=None, level=0, param_init=None):
method=None, param_init=None):
"""Initialize the fitter."""
if method is None:
method = 'exponential_euler'
Expand Down
31 changes: 28 additions & 3 deletions brian2modelfitting/tests/test_modelfitting_spikefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from brian2 import (Equations, NeuronGroup, SpikeMonitor, TimedArray,
nS, nF, mV, ms, nA, amp, run)
from brian2modelfitting import (NevergradOptimizer, SpikeFitter, GammaFactor,
Simulator, Metric, Optimizer)
Simulator, Metric, Optimizer, MSEMetric)
from brian2modelfitting.fitter import get_spikes
from brian2.devices.device import reinit_devices

Expand Down Expand Up @@ -101,14 +101,22 @@ def test_spikefitter_init(setup):
assert isinstance(sf.model, Equations)


def test_tracefitter_init_errors(setup):
def test_spikefitter_init_errors(setup):
dt, _ = setup
with pytest.raises(Exception):
SpikeFitter(model=model, input_var='Exception', dt=dt,
input=inp_trace*amp, output=output,
n_samples=2,
threshold='v > -50*mV',
reset='v = -70*mV',)
reset='v = -70*mV')

with pytest.raises(ValueError):
sf = SpikeFitter(model=model, input_var='I', dt=dt,
input=inp_trace * amp, output=output,
n_samples=2,
threshold='v > -50*mV',
reset='v = -70*mV',
param_init={'V': -70*mV}) # name is "v" not "V"


def test_spikefitter_fit(setup):
Expand All @@ -134,6 +142,23 @@ def test_spikefitter_fit(setup):
assert_equal(results, sf.best_params)


def test_spikefitter_fit_errors(setup):
dt, sf = setup
with pytest.raises(TypeError):
results, errors = sf.fit(n_rounds=2,
optimizer=n_opt,
metric=MSEMetric(),
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
with pytest.raises(TypeError):
results, errors = sf.fit(n_rounds=2,
optimizer=None,
metric=MSEMetric(),
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])



def test_spikefitter_param_init(setup):
dt, _ = setup
SpikeFitter(model=model, input_var='I', dt=dt,
Expand Down
55 changes: 52 additions & 3 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,44 @@ def test_fitter_fit(setup):
assert_equal(results, tf.best_params)


def test_fitter_fit_callback(setup):
dt, tf = setup

calls = []
def our_callback(params, errors, best_params, best_error, index):
calls.append(index)
assert all(isinstance(p, dict) for p in params)
assert isinstance(errors, np.ndarray)
assert isinstance(best_params, dict)
assert isinstance(best_error, float)
assert isinstance(index, int)
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS],
callback=our_callback)
assert len(calls) == 2

# Stop a fit via the callback

calls = []
def our_callback(params, errors, best_params, best_error, index):
calls.append(index)
assert all(isinstance(p, dict) for p in params)
assert isinstance(errors, np.ndarray)
assert isinstance(best_params, dict)
assert isinstance(best_error, float)
assert isinstance(index, int)
return True # stop

results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS],
callback=our_callback)
assert len(calls) == 1


def test_fitter_fit_errors(setup):
dt, tf = setup
with pytest.raises(TypeError):
Expand All @@ -219,7 +257,7 @@ def test_fitter_fit_errors(setup):
with pytest.raises(TypeError):
tf.fit(n_rounds=2,
optimizer=n_opt,
metric=1,
metric=GammaFactor(3*ms, 60*ms), # spike metric
g=[1*nS, 30*nS])


Expand Down Expand Up @@ -570,14 +608,14 @@ def test_onlinetracefitter_init_errors(setup_online):
n_samples=10,
output=output_traces,
output_var='I',
input_var='Exception',)
input_var='Exception')

with pytest.raises(Exception):
OnlineTraceFitter(dt=0.1*ms, model=model, input=input_traces,
n_samples=10,
output=output_traces,
input_var='v',
output_var='Exception',)
output_var='Exception')

with pytest.raises(Exception):
OnlineTraceFitter(dt=0.1*ms, model=model, input=input_traces,
Expand Down Expand Up @@ -606,3 +644,14 @@ def test_onlinetracefitter_fit(setup_online):
assert 'g' in results.keys()

assert_equal(results, otf.best_params)


def test_onlinetracefitter_generate_traces(setup_online):
dt, otf = setup_online
results, errors = otf.fit(n_rounds=2,
optimizer=n_opt,
g=[1 * nS, 30 * nS],
restart=False, )
traces = otf.generate_traces()
assert isinstance(traces, np.ndarray)
assert_equal(np.shape(traces), np.shape(output_traces))

0 comments on commit db10a53

Please sign in to comment.