Skip to content

Commit

Permalink
Merge pull request #33 from brian-team/output_var
Browse files Browse the repository at this point in the history
[MRG] provide access to target variable in model
  • Loading branch information
romainbrette committed Apr 23, 2020
2 parents 148c167 + 5b92e23 commit baf0ae6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
30 changes: 23 additions & 7 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros

from brian2.parsing.sympytools import sympy_to_str, str_to_sympy
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions, fail_for_dimension_mismatch
from brian2.utils.stringtools import get_identifiers

from brian2 import (NeuronGroup, defaultclock, get_device, Network,
Expand Down Expand Up @@ -234,7 +234,7 @@ class Fitter(metaclass=abc.ABCMeta):
input : `~numpy.ndarray` or `~brian2.units.fundamentalunits.Quantity`
A 2D array of shape ``(n_traces, time steps)`` given the input that will
be fed into the model.
output : `~numpy.ndarray` or `~brian2.units.fundamentalunits.Quantity` or list
output : `~brian2.units.fundamentalunits.Quantity` or list
Recorded output of the model that the model should reproduce. Should
be a 2D array of the same shape as the input when fitting traces with
`TraceFitter`, a list of spike times when fitting spike trains with
Expand Down Expand Up @@ -298,6 +298,11 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self.output_dim = DIMENSIONLESS
else:
self.output_dim = model[output_var].dim
fail_for_dimension_mismatch(output, self.output_dim,
'The provided target values '
'("output") need to have the same '
'units as the variable '
'{}'.format(output_var))
self.model = model

self.use_units = use_units
Expand All @@ -308,6 +313,17 @@ def __init__(self, dt, model, input, output, input_var, output_var,
input_dim)
self.model += input_eqs

if output_var != 'spikes':
# For approaches that couple the system to the target values,
# provide a convenient variable
output_expr = 'output_var(t, i % n_traces)'
output_dim = ('1' if self.output_dim is DIMENSIONLESS
else repr(self.output_dim))
output_eqs = "{}_target = {} : {}".format(output_var,
output_expr,
output_dim)
self.model += output_eqs

input_traces = TimedArray(input.transpose(), dt=dt)
self.input_traces = input_traces

Expand Down Expand Up @@ -335,7 +351,7 @@ def setup_simulator(self, network_name, n_neurons, output_var, param_init,
level=level+1)
if hasattr(self, 't_start'): # OnlineTraceFitter
namespace['t_start'] = self.t_start
if network_name != 'generate' and self.output_var != 'spikes':
if self.output_var != 'spikes':
namespace['output_var'] = TimedArray(self.output.transpose(),
dt=self.dt)
neurons = self.setup_neuron_group(n_neurons, namespace,
Expand Down Expand Up @@ -400,8 +416,9 @@ def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
refractory=self.refractory, name=name,
namespace=namespace, dt=self.dt, **kwds)
if online_error:
neurons.run_regularly('total_error += (' + self.output_var +
'-output_var(t,i % n_traces))**2 * int(t>=t_start)',
neurons.run_regularly('total_error += ({} - {}_target)**2 * '
'int(t>=t_start)'.format(self.output_var,
self.output_var),
when='end')

return neurons
Expand Down Expand Up @@ -971,12 +988,11 @@ def __init__(self, model, input, output, dt, reset, threshold,
"""Initialize the fitter."""
if method is None:
method = 'exponential_euler'
super().__init__(dt, model, input, output, input_var, 'v',
super().__init__(dt, model, input, output, input_var, 'spikes',
n_samples, threshold, reset, refractory, method,
param_init, use_units=use_units)
self.output = [Quantity(o) for o in output]
self.output_ = [array(o) for o in output]
self.output_var = 'spikes'

if param_init:
for param, val in param_init.items():
Expand Down
4 changes: 2 additions & 2 deletions brian2modelfitting/tests/test_modelfitting_spikefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
metric = GammaFactor(60*ms, 60*ms)


@pytest.fixture()
@pytest.fixture
def setup(request):
dt = 0.01 * ms
sf = SpikeFitter(model=model, input_var='I', dt=dt,
Expand All @@ -45,7 +45,7 @@ def fin():
return dt, sf


@pytest.fixture()
@pytest.fixture
def setup_spikes(request):
def fin():
reinit_devices()
Expand Down
18 changes: 15 additions & 3 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpy.testing.utils import assert_equal
from brian2 import (zeros, Equations, NeuronGroup, StateMonitor, TimedArray,
nS, mV, volt, ms, pA, pF, Quantity, set_device, get_device,
Network, have_same_dimensions)
Network, have_same_dimensions, DimensionMismatchError)
from brian2.equations.equations import DIFFERENTIAL_EQUATION, SUBEXPRESSION
from brian2modelfitting import (NevergradOptimizer, TraceFitter, MSEMetric,
OnlineTraceFitter, Simulator, Metric,
Expand Down Expand Up @@ -85,7 +85,7 @@ def fin():
def setup_constant(request):
dt = 0.1 * ms
# Membrane potential is constant at 10mV for first 50 steps, then at 20mV
out_trace = np.hstack([np.ones(50) * 10 * mV, np.ones(50) * 20 * mV])
out_trace = np.hstack([np.ones(50) * 10, np.ones(50) * 20])*mV
tf = TraceFitter(dt=dt,
model=constant_model,
input_var='x',
Expand Down Expand Up @@ -177,6 +177,9 @@ def test_tracefitter_init(setup):
assert isinstance(tf.input_traces, TimedArray)
assert isinstance(tf.model, Equations)

target_var = '{}_target'.format(tf.output_var)
assert target_var in tf.model
assert tf.model[target_var].dim is tf.output_dim


def test_tracefitter_init_errors(setup):
Expand All @@ -202,6 +205,15 @@ def test_tracefitter_init_errors(setup):
input_var='v',
output_var='I',)

with pytest.raises(DimensionMismatchError):
tf = TraceFitter(dt=dt,
model=model,
input_var='v',
output_var='I',
input=input_traces,
output=np.array(output_traces), # no units
n_samples=2)


def test_fitter_fit(setup):
dt, tf = setup
Expand Down Expand Up @@ -378,7 +390,7 @@ def test_fitter_refine_calc_gradient():
def exp_fit(x, a, b):
return a * np.exp(x / b) -70 - a * np.exp(0)
outputs = np.vstack([exp_fit(np.arange(100), 1.2836869755582263, 51.41761887704586),
exp_fit(np.arange(100), 2.567374463239943, 51.417624003833076)])
exp_fit(np.arange(100), 2.567374463239943, 51.417624003833076)])*volt

model = '''
dv/dt = (g_L * (E_L - v) + I_e)/Cm : volt
Expand Down
10 changes: 10 additions & 0 deletions docs_sphinx/features/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,13 @@ with each simulation.
result, error = fitter.fit(optimizer=optimizer,
n_rounds=1,
gl=[1e-8*siemens*cm**-2 * area, 1e-3*siemens*cm**-2 * area],)
Reference the target values in the equations
--------------------------------------------

A model can refer to the target output values within the equations. For example, if you
are fitting a membrane potential trace *v* (i.e. `output_var='v'`), then the equations
can refer to the target trace as `v_target`. This allows you for example to add a coupling
term like `coupling*(v_target - v)` to the equation for `v`, pulling the trajectory towards the
correct solution.

0 comments on commit baf0ae6

Please sign in to comment.