diff --git a/brian2modelfitting/fitter.py b/brian2modelfitting/fitter.py index fe66bbd5..76d13164 100644 --- a/brian2modelfitting/fitter.py +++ b/brian2modelfitting/fitter.py @@ -850,7 +850,7 @@ def _callback_wrapper(params, iter, resid, *args, **kwds): class SpikeFitter(Fitter): def __init__(self, model, input, output, dt, reset, threshold, input_var='I', refractory=False, n_samples=30, - method=None, level=0, param_init=None): + method=None, param_init=None): """Initialize the fitter.""" if method is None: method = 'exponential_euler' diff --git a/brian2modelfitting/tests/test_modelfitting_spikefitter.py b/brian2modelfitting/tests/test_modelfitting_spikefitter.py index 68538c09..dc3a6ca5 100644 --- a/brian2modelfitting/tests/test_modelfitting_spikefitter.py +++ b/brian2modelfitting/tests/test_modelfitting_spikefitter.py @@ -7,7 +7,7 @@ from brian2 import (Equations, NeuronGroup, SpikeMonitor, TimedArray, nS, nF, mV, ms, nA, amp, run) from brian2modelfitting import (NevergradOptimizer, SpikeFitter, GammaFactor, - Simulator, Metric, Optimizer) + Simulator, Metric, Optimizer, MSEMetric) from brian2modelfitting.fitter import get_spikes from brian2.devices.device import reinit_devices @@ -101,14 +101,22 @@ def test_spikefitter_init(setup): assert isinstance(sf.model, Equations) -def test_tracefitter_init_errors(setup): +def test_spikefitter_init_errors(setup): dt, _ = setup with pytest.raises(Exception): SpikeFitter(model=model, input_var='Exception', dt=dt, input=inp_trace*amp, output=output, n_samples=2, threshold='v > -50*mV', - reset='v = -70*mV',) + reset='v = -70*mV') + + with pytest.raises(ValueError): + sf = SpikeFitter(model=model, input_var='I', dt=dt, + input=inp_trace * amp, output=output, + n_samples=2, + threshold='v > -50*mV', + reset='v = -70*mV', + param_init={'V': -70*mV}) # name is "v" not "V" def test_spikefitter_fit(setup): @@ -134,6 +142,23 @@ def test_spikefitter_fit(setup): assert_equal(results, sf.best_params) +def test_spikefitter_fit_errors(setup): + dt, sf = setup + with pytest.raises(TypeError): + results, errors = sf.fit(n_rounds=2, + optimizer=n_opt, + metric=MSEMetric(), + gL=[20*nS, 40*nS], + C=[0.5*nF, 1.5*nF]) + with pytest.raises(TypeError): + results, errors = sf.fit(n_rounds=2, + optimizer=None, + metric=MSEMetric(), + gL=[20*nS, 40*nS], + C=[0.5*nF, 1.5*nF]) + + + def test_spikefitter_param_init(setup): dt, _ = setup SpikeFitter(model=model, input_var='I', dt=dt, diff --git a/brian2modelfitting/tests/test_modelfitting_tracefitter.py b/brian2modelfitting/tests/test_modelfitting_tracefitter.py index 16f32d9d..9a3f07c5 100644 --- a/brian2modelfitting/tests/test_modelfitting_tracefitter.py +++ b/brian2modelfitting/tests/test_modelfitting_tracefitter.py @@ -208,6 +208,44 @@ def test_fitter_fit(setup): assert_equal(results, tf.best_params) +def test_fitter_fit_callback(setup): + dt, tf = setup + + calls = [] + def our_callback(params, errors, best_params, best_error, index): + calls.append(index) + assert all(isinstance(p, dict) for p in params) + assert isinstance(errors, np.ndarray) + assert isinstance(best_params, dict) + assert isinstance(best_error, float) + assert isinstance(index, int) + results, errors = tf.fit(n_rounds=2, + optimizer=n_opt, + metric=metric, + g=[1*nS, 30*nS], + callback=our_callback) + assert len(calls) == 2 + + # Stop a fit via the callback + + calls = [] + def our_callback(params, errors, best_params, best_error, index): + calls.append(index) + assert all(isinstance(p, dict) for p in params) + assert isinstance(errors, np.ndarray) + assert isinstance(best_params, dict) + assert isinstance(best_error, float) + assert isinstance(index, int) + return True # stop + + results, errors = tf.fit(n_rounds=2, + optimizer=n_opt, + metric=metric, + g=[1*nS, 30*nS], + callback=our_callback) + assert len(calls) == 1 + + def test_fitter_fit_errors(setup): dt, tf = setup with pytest.raises(TypeError): @@ -219,7 +257,7 @@ def test_fitter_fit_errors(setup): with pytest.raises(TypeError): tf.fit(n_rounds=2, optimizer=n_opt, - metric=1, + metric=GammaFactor(3*ms, 60*ms), # spike metric g=[1*nS, 30*nS]) @@ -570,14 +608,14 @@ def test_onlinetracefitter_init_errors(setup_online): n_samples=10, output=output_traces, output_var='I', - input_var='Exception',) + input_var='Exception') with pytest.raises(Exception): OnlineTraceFitter(dt=0.1*ms, model=model, input=input_traces, n_samples=10, output=output_traces, input_var='v', - output_var='Exception',) + output_var='Exception') with pytest.raises(Exception): OnlineTraceFitter(dt=0.1*ms, model=model, input=input_traces, @@ -606,3 +644,14 @@ def test_onlinetracefitter_fit(setup_online): assert 'g' in results.keys() assert_equal(results, otf.best_params) + + +def test_onlinetracefitter_generate_traces(setup_online): + dt, otf = setup_online + results, errors = otf.fit(n_rounds=2, + optimizer=n_opt, + g=[1 * nS, 30 * nS], + restart=False, ) + traces = otf.generate_traces() + assert isinstance(traces, np.ndarray) + assert_equal(np.shape(traces), np.shape(output_traces))