Skip to content

Commit

Permalink
Add test for t_start in OnlineTraceFitter
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed May 18, 2020
1 parent 813ec73 commit a1a1202
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
4 changes: 2 additions & 2 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,10 +1080,10 @@ def generate_spikes(self, params=None, param_init=None, level=0):


class OnlineTraceFitter(Fitter):
"""Input nad output have to have the same dimensions."""
def __init__(self, model, input_var, input, output_var, output, dt,
n_samples=30, method=None, reset=None, refractory=False,
threshold=None, level=0, param_init=None, t_start=0*second):
threshold=None, param_init=None,
t_start=0*second):
"""Initialize the fitter."""
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
Expand Down
43 changes: 41 additions & 2 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def setup_constant(request):
output_var='v',
input=(np.zeros(100)*mV)[None, :],
output=out_trace[None, :],
n_samples=100,)
n_samples=100)

def fin():
reinit_devices()
Expand All @@ -111,7 +111,26 @@ def setup_online(request):
output_var='I',
input=input_traces,
output=output_traces,
n_samples=10,)
n_samples=10)

def fin():
reinit_devices()
request.addfinalizer(fin)

return dt, otf

@pytest.fixture
def setup_online_constant(request):
dt = 0.1 * ms
# Membrane potential is constant at 10mV for first 50 steps, then at 20mV
out_trace = np.hstack([np.ones(50) * 10, np.ones(50) * 20])*mV
otf = OnlineTraceFitter(dt=dt,
model=constant_model,
input_var='x',
output_var='v',
input=(np.zeros(100) * mV)[None, :],
output=out_trace[None, :],
n_samples=100)

def fin():
reinit_devices()
Expand Down Expand Up @@ -868,3 +887,23 @@ def test_onlinetracefitter_generate_traces(setup_online):
traces = otf.generate_traces()
assert isinstance(traces, np.ndarray)
assert_equal(np.shape(traces), np.shape(output_traces))


def test_onlinetracefitter_fit_tstart():
dt = 0.1 * ms
# Membrane potential is constant at 10mV for first 50 steps, then at 20mV
out_trace = np.hstack([np.ones(50) * 10, np.ones(50) * 20]) * mV
otf = OnlineTraceFitter(dt=dt,
model=constant_model,
input_var='x',
output_var='v',
input=(np.zeros(100) * mV)[None, :],
output=out_trace[None, :],
n_samples=100,
t_start=50*dt)

# Ignore the first 50 steps at 10mV
params, result = otf.fit(n_rounds=10, optimizer=n_opt,
c=[0 * mV, 30 * mV])
# Fit should be close to 20mV
assert np.abs(params['c'] - 20*mV) < 1*mV

0 comments on commit a1a1202

Please sign in to comment.