Skip to content

Commit

Permalink
Merge 5ceaafc into 19e407d
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Nov 16, 2020
2 parents 19e407d + 5ceaafc commit d63f7ef
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 146 deletions.
403 changes: 282 additions & 121 deletions brian2modelfitting/fitter.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion brian2modelfitting/tests/test_modelfitting_spikefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_spikefitter_fit(setup):
for attr in attr_fit:
assert hasattr(sf, attr)

assert isinstance(sf.metric, Metric)
assert len(sf.metric) == 1 and isinstance(sf.metric[0], Metric)
assert isinstance(sf.optimizer, Optimizer)

assert isinstance(results, dict)
Expand Down
31 changes: 19 additions & 12 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ def test_tracefitter_init(setup):
assert isinstance(tf.input_traces, TimedArray)
assert isinstance(tf.model, Equations)

target_var = '{}_target'.format(tf.output_var)
target_var = '{}_target'.format(tf.output_var[0])
assert target_var in tf.model
assert tf.model[target_var].dim is tf.output_dim
assert tf.model[target_var].dim is tf.output_dim[0]


def test_tracefitter_init_errors(setup):
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_fitter_fit(setup):
assert hasattr(tf, attr)
assert tf.simulator.neurons.iteration == 1

assert isinstance(tf.metric, Metric)
assert len(tf.metric) == 1 and isinstance(tf.metric[0], Metric)
assert isinstance(tf.optimizer, Optimizer)
assert isinstance(tf.simulator, Simulator)

Expand All @@ -328,7 +328,7 @@ def test_fitter_fit_no_units(setup_no_units):
for attr in attr_fit:
assert hasattr(tf, attr)

assert isinstance(tf.metric, Metric)
assert len(tf.metric) == 1 and isinstance(tf.metric[0], Metric)
assert isinstance(tf.optimizer, Optimizer)
assert isinstance(tf.simulator, Simulator)

Expand All @@ -345,13 +345,15 @@ def test_fitter_fit_callback(setup):
dt, tf = setup

calls = []
def our_callback(params, errors, best_params, best_error, index):
def our_callback(params, errors, best_params, best_error, index,
additional_info):
calls.append(index)
assert all(isinstance(p, dict) for p in params)
assert isinstance(errors, np.ndarray)
assert isinstance(best_params, dict)
assert isinstance(best_error, Quantity)
assert isinstance(index, int)
assert isinstance(additional_info, dict)
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
Expand All @@ -362,13 +364,15 @@ def our_callback(params, errors, best_params, best_error, index):
# Stop a fit via the callback

calls = []
def our_callback(params, errors, best_params, best_error, index):
def our_callback(params, errors, best_params, best_error, index,
additional_info):
calls.append(index)
assert all(isinstance(p, dict) for p in params)
assert isinstance(errors, np.ndarray)
assert isinstance(best_params, dict)
assert isinstance(best_error, Quantity)
assert isinstance(index, int)
assert isinstance(additional_info, dict)
return True # stop

results, errors = tf.fit(n_rounds=2,
Expand Down Expand Up @@ -583,10 +587,11 @@ def test_fitter_refine_tsteps_normalization(setup_constant):
dt, tf = setup_constant

model_traces = tf.generate(params={'c': 5 * mV})
mse_error = MSEMetric(t_start=50*dt).calc(model_traces[None, : , :], tf.output, dt)
mse_error = MSEMetric(t_start=50*dt).calc(model_traces[None, : , :], tf.output[0], dt)
all_errors = []
def callback(parameters, errors, best_parameters, best_error, index):
all_errors.append(float(errors[0]))
def callback(parameters, errors, best_parameters, best_error, index,
additional_args):
all_errors.append(float(errors[0][0]))
return True # stop simulation

# Ignore the first 50 steps at 10mV
Expand Down Expand Up @@ -666,13 +671,15 @@ def test_fitter_callback(setup, caplog):
dt, tf = setup

calls = []
def our_callback(params, errors, best_params, best_error, index):
def our_callback(params, errors, best_params, best_error, index,
additional_info):
calls.append(index)
assert isinstance(params, dict)
assert isinstance(errors, np.ndarray)
assert isinstance(errors, list)
assert isinstance(best_params, dict)
assert isinstance(best_error, Quantity)
assert isinstance(index, int)
assert isinstance(additional_info, dict)

tf.refine({'g': 5 * nS}, g=[1 * nS, 30 * nS], callback=our_callback)
assert len(calls)
Expand Down Expand Up @@ -970,7 +977,7 @@ def test_onlinetracefitter_fit(setup_online):
for attr in attr_fit:
assert hasattr(otf, attr)

assert isinstance(otf.metric, MSEMetric)
assert len(otf.metric) == 1 and isinstance(otf.metric[0], MSEMetric)
assert isinstance(otf.optimizer, Optimizer)

assert isinstance(results, dict)
Expand Down
13 changes: 7 additions & 6 deletions brian2modelfitting/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@


def test_callback_text(capsys):
callback_text([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2)
callback_text([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2, {'output_var': 'v'})


def test_callback_none():
c = callback_none([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2)
c = callback_none([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2, {'output_var': 'v'})
assert isinstance(c, type(None))


def test_ProgressBar():
pb = ProgressBar(total=10)
assert isinstance(pb.t, tqdm.tqdm)
pb([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2)
pb([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2, {'output_var': 'v'})


def test_callback_setup():
Expand All @@ -32,15 +32,16 @@ def test_callback_setup():

c = callback_setup(None, 10)
assert callable(c)
x = c([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2)
x = c([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2, {'output_var': 'v'})
assert x is None

def callback(params, errors, best_params, best_error, index):
def callback(params, errors, best_params, best_error, index,
additional_index):
return params

c = callback_setup(callback, 10)
assert callable(c)
x = c([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2)
x = c([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2, {'output_var': 'v'})
assert_equal(x, [1, 2, 3])


Expand Down
43 changes: 37 additions & 6 deletions brian2modelfitting/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
from tqdm.autonotebook import tqdm
from types import FunctionType

from brian2.units.fundamentalunits import Quantity
from tqdm.autonotebook import tqdm


def callback_text(params, errors, best_params, best_error, index):

def callback_text(params, errors, best_params, best_error, index, additional_info):
"""Default callback print-out for Fitters"""
param_str = ', '.join([f"{p}={v!s}" for p, v in sorted(best_params.items())])
print(f"Round {index}: Best parameters {param_str} (error: {best_error!s})")
params = []
for p, v in sorted(best_params.items()):
if isinstance(v, Quantity):
params.append(f'{p}={v.in_best_unit(precision=2)}')
else:
params.append(f'{p}={v:.2g}')
param_str = ', '.join(params)
if isinstance(best_error, Quantity):
best_error_str = best_error.in_best_unit(precision=2)
else:
best_error_str = f'{best_error:.2g}'
round = f'Round {index}: '
if (additional_info and
'metric_weights' in additional_info and
len(additional_info['metric_weights'])>1):
errors = []
for weight, error, varname in zip(additional_info['metric_weights'],
additional_info['best_errors'],
additional_info['output_var']):
if isinstance(error, Quantity):
errors.append(f'{weight!s}×{error.in_best_unit(precision=2)} ({varname})')
else:
errors.append(f'{weight!s}×{error:.2g} ({varname})')
error_sum = ' + '.join(errors)
print(f"{round}Best parameters {param_str}\n"
f"{' '*len(round)}Best error: {best_error_str} = {error_sum}")
else:
print(f"{round}Best parameters {param_str}\n"
f"{' '*len(round)}Best error: {best_error_str} ({additional_info['output_var'][0]})")


def callback_none(params, errors, best_params, best_error, index):
def callback_none(params, errors, best_params, best_error, index, additional_info):
"""Non-verbose callback"""
pass

Expand All @@ -18,7 +48,8 @@ class ProgressBar(object):
def __init__(self, total=None, **kwds):
self.t = tqdm(total=total, **kwds)

def __call__(self, params, errors, best_params, best_error, index):
def __call__(self, params, errors, best_params, best_error, index,
additional_info):
self.t.update(1)


Expand Down
82 changes: 82 additions & 0 deletions examples/multiobjective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pandas as pd
import numpy as np
from brian2 import *
from brian2modelfitting import *
# set_device('cpp_standalone') # recommend for speed
dt = 0.01*ms
defaultclock.dt = dt

# Generate ground truth data
area = 20000*umetre**2
El = -65*mV
EK = -90*mV
ENa = 50*mV
VT = -63*mV
dt = 0.01*ms
eqs='''
dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
(exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
(exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
(exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
g_na : siemens (constant)
g_kd : siemens (constant)
gl : siemens (constant)
Cm : farad (constant)
'''
inp_ar = np.zeros((10000, 5))*nA
inp_ar[1000:, :] = 1.*nA
inp_ar *= (np.arange(5)*0.25)
inp = TimedArray(inp_ar, dt=dt)
ground_truth = NeuronGroup(5, eqs + 'I = inp(t, i) : amp',
method='exponential_euler')
ground_truth.v = El
ground_truth.Cm = (1*ufarad*cm**-2) * area
ground_truth.gl = (5e-5*siemens*cm**-2) * area
ground_truth.g_na = (100*msiemens*cm**-2) * area
ground_truth.g_kd = (30*msiemens*cm**-2) * area
mon = StateMonitor(ground_truth, ['v', 'm'], record=True)
run(100*ms)
ground_truth_v = mon.v[:]
ground_truth_m = mon.m[:]
## Optimization and Metric Choice
n_opt = NevergradOptimizer()
metric = MSEMetric(t_start=5*ms)

## Fitting
fitter = TraceFitter(model=eqs, input_var='I', output_var=['v', 'm'],
input=inp_ar.T, output=[ground_truth_v,
ground_truth_m],
dt=dt, n_samples=60, param_init={'v': 'El'},
method='exponential_euler')

res, error = fitter.fit(n_rounds=20,
optimizer=n_opt, metric=metric,
metric_weights=[1/(float(100*mV)**2),
1],
callback='text',
gl=[1e-09 *siemens, 1e-07 *siemens],
g_na=[2e-06*siemens, 2e-04*siemens],
g_kd=[6e-07*siemens, 6e-05*siemens],
Cm=[0.1*ufarad*cm**-2 * area, 2*ufarad*cm**-2 * area])

refined_params, _ = fitter.refine(calc_gradient=True)

## Visualization of the results
fits = fitter.generate_traces(params=None, param_init={'v': -65*mV})
refined_fits = fitter.generate_traces(params=refined_params, param_init={'v': -65*mV})

fig, ax = plt.subplots(2, ncols=5, figsize=(20, 5), sharex=True, sharey='row')
for idx in range(5):
ax[0][idx].plot(ground_truth_v[idx]/mV, 'k:', alpha=0.75,
label='ground truth')
ax[0][idx].plot(fits['v'][idx].transpose()/mV, alpha=0.75, label='fit')
ax[0][idx].plot(refined_fits['v'][idx].transpose() / mV, alpha=0.75,
label='refined')
ax[1][idx].plot(ground_truth_m[idx], 'k:', alpha=0.75)
ax[1][idx].plot(fits['m'][idx].transpose(), alpha=0.75)
ax[1][idx].plot(refined_fits['m'][idx].transpose(), alpha=0.75)
ax[0][0].legend(loc='best')
plt.show()

0 comments on commit d63f7ef

Please sign in to comment.