Skip to content

Commit

Permalink
Update tests for unit support
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 16, 2020
1 parent 487fb78 commit fbd864b
Showing 1 changed file with 96 additions and 14 deletions.
110 changes: 96 additions & 14 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def fin():

return dt, tf

@pytest.fixture
def setup_no_units(request):
dt = 0.01 * ms
tf = TraceFitter(dt=dt,
model=model,
input_var='v',
output_var='I',
input=input_traces,
output=output_traces,
n_samples=2,
use_units=False)

def fin():
reinit_devices()
request.addfinalizer(fin)

return dt, tf

@pytest.fixture
def setup_constant(request):
Expand Down Expand Up @@ -161,6 +178,7 @@ def test_tracefitter_init(setup):
assert isinstance(tf.model, Equations)



def test_tracefitter_init_errors(setup):
dt, _ = setup
with pytest.raises(Exception):
Expand Down Expand Up @@ -202,10 +220,37 @@ def test_fitter_fit(setup):
assert isinstance(tf.simulator, Simulator)

assert isinstance(results, dict)
assert all(isinstance(v, Quantity) for v in results.values())
assert isinstance(errors, Quantity)
assert 'g' in results.keys()

assert_equal(results, tf.best_params)
assert_equal(errors, tf.best_error)


def test_fitter_fit_no_units(setup_no_units):
dt, tf = setup_no_units
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS],
callback=None)

attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(tf, attr)

assert isinstance(tf.metric, Metric)
assert isinstance(tf.optimizer, Optimizer)
assert isinstance(tf.simulator, Simulator)

assert isinstance(results, dict)
assert all(isinstance(v, float) for v in results.values())
assert isinstance(errors, float)
assert 'g' in results.keys()

assert_equal(results, tf.best_params)
assert_equal(errors, tf.best_error)


def test_fitter_fit_callback(setup):
Expand All @@ -217,7 +262,7 @@ def our_callback(params, errors, best_params, best_error, index):
assert all(isinstance(p, dict) for p in params)
assert isinstance(errors, np.ndarray)
assert isinstance(best_params, dict)
assert isinstance(best_error, float)
assert isinstance(best_error, Quantity)
assert isinstance(index, int)
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
Expand All @@ -234,7 +279,7 @@ def our_callback(params, errors, best_params, best_error, index):
assert all(isinstance(p, dict) for p in params)
assert isinstance(errors, np.ndarray)
assert isinstance(best_params, dict)
assert isinstance(best_error, float)
assert isinstance(best_error, Quantity)
assert isinstance(index, int)
return True # stop

Expand Down Expand Up @@ -269,7 +314,7 @@ def test_fitter_fit_tstart(setup_constant):
metric=MSEMetric(t_start=50*dt),
c=[0 * mV, 30 * mV])
# Fit should be close to 20mV
assert np.abs(params['c']*volt - 20*mV) < 1*mV
assert np.abs(params['c'] - 20*mV) < 1*mV

@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
def test_fitter_refine(setup):
Expand Down Expand Up @@ -376,7 +421,7 @@ def test_fitter_refine_tstart(setup_constant):
t_start=50*dt)

# Fit should be close to 20mV
assert np.abs(params['c']*volt - 20*mV) < 1*mV
assert np.abs(params['c'] - 20*mV) < 1*mV


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
Expand All @@ -391,7 +436,7 @@ def test_fitter_refine_reuse_tstart(setup_constant):
params, result = tf.refine({'c': 5 * mV}, c=[0 * mV, 30 * mV])

# Fit should be close to 20mV
assert np.abs(params['c'] * volt - 20 * mV) < 1 * mV
assert np.abs(params['c'] - 20 * mV) < 1 * mV


@pytest.mark.skipif(lmfit is None, reason="needs lmfit package")
Expand All @@ -415,7 +460,7 @@ def our_callback(params, errors, best_params, best_error, index):
assert isinstance(params, dict)
assert isinstance(errors, np.ndarray)
assert isinstance(best_params, dict)
assert isinstance(best_error, float)
assert isinstance(best_error, Quantity)
assert isinstance(index, int)

tf.refine({'g': 5 * nS}, g=[1 * nS, 30 * nS], callback=our_callback)
Expand Down Expand Up @@ -543,7 +588,7 @@ def test_fitter_generate_traces_standalone(setup_standalone):
assert_equal(np.shape(traces), np.shape(output_traces))


def test_fitter_results(setup):
def test_fitter_results(setup, caplog):
dt, tf = setup
best_params, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
Expand All @@ -554,27 +599,64 @@ def test_fitter_results(setup):
params_list = tf.results(format='list')
assert isinstance(params_list, list)
assert isinstance(params_list[0], dict)
print(params_list)
assert isinstance(params_list[0]['g'], Quantity)
assert 'g' in params_list[0].keys()
assert 'errors' in params_list[0].keys()
assert 'error' in params_list[0].keys()
assert_equal(np.shape(params_list), (4,))
assert_equal(len(params_list[0]), 2)
assert have_same_dimensions(params_list[0]['g'].dim, nS)

params_dic = tf.results(format='dict')
assert isinstance(params_dic, dict)
assert 'g' in params_dic.keys()
assert 'errors' in params_dic.keys()
assert 'error' in params_dic.keys()
assert isinstance(params_dic['g'], Quantity)
assert_equal(len(params_dic), 2)
assert_equal(np.shape(params_dic['g']), (4,))
assert_equal(np.shape(params_dic['errors']), (4,))
assert_equal(np.shape(params_dic['error']), (4,))

# Should raise a warning because dataframe cannot have units
assert len(caplog.records) == 0
params_df = tf.results(format='dataframe')
assert len(caplog.records) == 1
assert isinstance(params_df, pd.DataFrame)
assert_equal(params_df.shape, (4, 2))
assert 'g' in params_df.keys()
assert 'error' in params_df.keys()


def test_fitter_results_no_units(setup_no_units, caplog):
dt, tf = setup_no_units
tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS],
restart=False)

params_list = tf.results(format='list')
assert isinstance(params_list, list)
assert isinstance(params_list[0], dict)
assert isinstance(params_list[0]['g'], float)
assert 'g' in params_list[0].keys()
assert 'error' in params_list[0].keys()
assert_equal(np.shape(params_list), (4,))
assert_equal(len(params_list[0]), 2)

params_dic = tf.results(format='dict')
assert isinstance(params_dic, dict)
assert 'g' in params_dic.keys()
assert 'error' in params_dic.keys()
assert isinstance(params_dic['g'], np.ndarray)
assert_equal(len(params_dic), 2)
assert_equal(np.shape(params_dic['g']), (4,))
assert_equal(np.shape(params_dic['error']), (4,))

params_df = tf.results(format='dataframe')
assert isinstance(params_df, pd.DataFrame)
assert_equal(params_df.shape, (4, 2))
assert 'g' in params_df.keys()
assert 'errors' in params_df.keys()
assert 'error' in params_df.keys()


# OnlineTraceFitter
Expand Down Expand Up @@ -631,16 +713,16 @@ def test_onlinetracefitter_fit(setup_online):
optimizer=n_opt,
g=[1*nS, 30*nS],
restart=False,)

print(otf.best_params)
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(otf, attr)

assert otf.metric is None
assert isinstance(otf.metric, MSEMetric)
assert isinstance(otf.optimizer, Optimizer)

assert isinstance(results, dict)
assert isinstance(errors, float)
assert isinstance(errors, Quantity)
assert 'g' in results.keys()

assert_equal(results, otf.best_params)
Expand Down

0 comments on commit fbd864b

Please sign in to comment.