Skip to content

Commit

Permalink
Merge pull request #34 from brian-team/generate_multiple_variables
Browse files Browse the repository at this point in the history
[MRG] Fitter.generate for multiple variables
  • Loading branch information
romainbrette committed May 8, 2020
2 parents ed28bf9 + c651368 commit 6ee0669
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 37 deletions.
61 changes: 38 additions & 23 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,25 +352,30 @@ def setup_simulator(self, network_name, n_neurons, output_var, param_init,
level=level+1)
if hasattr(self, 't_start'): # OnlineTraceFitter
namespace['t_start'] = self.t_start

if self.output_var != 'spikes':
namespace['output_var'] = TimedArray(self.output.transpose(),
dt=self.dt)
neurons = self.setup_neuron_group(n_neurons, namespace,
calc_gradient=calc_gradient,
optimize=optimize,
online_error=online_error)
network = Network(neurons)
if isinstance(output_var, str):
output_var = [output_var]
if 'spikes' in output_var:
network.add(SpikeMonitor(neurons, name='spikemonitor'))

if output_var == 'spikes':
monitor = SpikeMonitor(neurons, name='monitor')
else:
record_vars = [output_var]
if calc_gradient:
record_vars.extend([f'S_{output_var}_{p}'
for p in self.parameter_names])
monitor = StateMonitor(neurons, record_vars, record=True,
name='monitor', dt=self.dt)

network = Network(neurons, monitor)
record_vars = [v for v in output_var if v != 'spikes']
if calc_gradient:
if not len(output_var) == 1:
raise AssertionError('Cannot calculate gradient with multiple '
'output variables.')
record_vars.extend([f'S_{output_var[0]}_{p}'
for p in self.parameter_names])
if len(record_vars):
network.add(StateMonitor(neurons, record_vars, record=True,
name='statemonitor', dt=self.dt))

if calc_gradient:
param_init = dict(param_init)
Expand Down Expand Up @@ -669,18 +674,19 @@ def results(self, format='list', use_units=None):
data = concatenate((params, array(errors)[None, :].transpose()), axis=1)
return DataFrame(data=data, columns=names + ['error'])

def generate(self, params=None, output_var=None, param_init=None, level=0):
def generate(self, output_var=None, params=None, param_init=None, level=0):
"""
Generates traces for best fit of parameters and all inputs.
If provided with other parameters provides those.
Parameters
----------
output_var: str or sequence of str
Name of the output variable to be monitored, or the special name
``spikes`` to record spikes. Can also be a sequence of names to
record multiple variables.
params: dict
Dictionary of parameters to generate fits for.
output_var: str
Name of the output variable to be monitored, or the special name
``spikes`` to record spikes.
param_init: dict
Dictionary of initial values for the model.
level : `int`, optional
Expand All @@ -691,7 +697,9 @@ def generate(self, params=None, output_var=None, param_init=None, level=0):
fit
Either a 2D `.Quantity` with the recorded output variable over time,
with shape <number of input traces> × <number of time steps>, or
a list of spike times for each input trace.
a list of spike times for each input trace. If several names were
given as ``output_var``, then the result is a dictionary with the
names of the variable as the key.
"""
if params is None:
params = self.best_params
Expand All @@ -712,12 +720,19 @@ def generate(self, params=None, output_var=None, param_init=None, level=0):
self.simulator.run(self.duration, param_dic, self.parameter_names,
name='generate')

if not isinstance(output_var, str):
fits = {name: self._simulation_result(name) for name in output_var}
else:
fits = self._simulation_result(output_var)

return fits

def _simulation_result(self, output_var):
if output_var == 'spikes':
fits = get_spikes(self.simulator.monitor,
fits = get_spikes(self.simulator.spikemonitor,
1, self.n_traces)[0] # a single "sample"
else:
fits = getattr(self.simulator.monitor, output_var)[:]

fits = getattr(self.simulator.statemonitor, output_var)[:]
return fits


Expand Down Expand Up @@ -766,7 +781,7 @@ def calc_errors(self, metric):
Returns errors after simulation with StateMonitor.
To be used inside `optim_iter`.
"""
traces = getattr(self.simulator.networks['fit']['monitor'],
traces = getattr(self.simulator.networks['fit']['statemonitor'],
self.output_var+'_')
# Reshape traces for easier calculation of error
traces = reshape(traces, (traces.shape[0]//self.n_traces,
Expand Down Expand Up @@ -916,14 +931,14 @@ def _calc_error(params):
self.parameter_names, self.n_traces, 1)
self.simulator.run(self.duration, param_dic,
self.parameter_names, name='refine')
trace = getattr(self.simulator.monitor, self.output_var+'_')
trace = getattr(self.simulator.statemonitor, self.output_var+'_')
residual = trace[:, t_start_steps:] - self.output_[:, t_start_steps:]
return residual.flatten() * normalization

def _calc_gradient(params):
residuals = []
for name in self.parameter_names:
trace = getattr(self.simulator.monitor,
trace = getattr(self.simulator.statemonitor,
f'S_{self.output_var}_{name}_')
residual = trace[:, t_start_steps:]
residuals.append(residual.flatten() * normalization)
Expand Down Expand Up @@ -1009,7 +1024,7 @@ def calc_errors(self, metric):
Returns errors after simulation with SpikeMonitor.
To be used inside optim_iter.
"""
spikes = get_spikes(self.simulator.networks['fit']['monitor'],
spikes = get_spikes(self.simulator.networks['fit']['spikemonitor'],
self.n_samples, self.n_traces)
errors = metric.calc(spikes, self.output, self.dt)
return errors
Expand Down
12 changes: 7 additions & 5 deletions brian2modelfitting/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self):
self.var_init = None

neurons = property(lambda self: self.networks[self.current_net]['neurons'])
monitor = property(lambda self: self.networks[self.current_net]['monitor'])
statemonitor = property(lambda self: self.networks[self.current_net]['statemonitor'])
spikemonitor = property(lambda self: self.networks[self.current_net]['spikemonitor'])

def initialize(self, network, var_init, name='fit'):
"""
Expand All @@ -67,7 +68,8 @@ def initialize(self, network, var_init, name='fit'):
----------
network: `~brian2.core.network.Network`
Network consisting of a `~brian2.groups.neurongroup.NeuronGroup`
named ``neurons`` and a monitor named ``monitor``.
named ``neurons`` and either a monitor named ``spikemonitor``
or a monitor named ``statemonitor``(or both).
var_init: dict
dictionary to initialize the variable states
name: `str`, optional
Expand All @@ -77,9 +79,9 @@ def initialize(self, network, var_init, name='fit'):
if 'neurons' not in network:
raise KeyError('Expected a group named "neurons" in the '
'network.')
if 'monitor' not in network:
raise KeyError('Expected a monitor named "monitor" in the '
'network.')
if 'statemonitor' not in network and 'spikemonitor' not in network:
raise KeyError('Expected a monitor named "spikemonitor" or '
'"statemonitor" in the network.')
self.networks[name] = network
self.current_net = None # will be set in run
self.var_init = var_init
Expand Down
16 changes: 16 additions & 0 deletions brian2modelfitting/tests/test_modelfitting_spikefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,19 @@ def test_spikefitter_generate(setup):
param_init={'v': -70*mV})
assert isinstance(traces, np.ndarray)
assert_equal(np.shape(traces), np.shape(inp_trace))


def test_spikefitter_generate_multiple_variables(setup):
dt, sf = setup
results, errors = sf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
recordings = sf.generate(params=None,
output_var=['v', 'spikes'],
param_init={'v': -70*mV})
assert isinstance(recordings, dict)
assert set(recordings.keys()) == {'v', 'spikes'}
assert_equal(np.shape(recordings['v']), np.shape(inp_trace))
assert_equal(np.shape(recordings['spikes'])[0], np.shape(inp_trace)[0])
14 changes: 14 additions & 0 deletions brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,20 @@ def test_fitter_generate_traces(setup):
assert_equal(np.shape(traces), np.shape(output_traces))


def test_fitter_generate_traces_multiple_vars(setup):
dt, tf = setup
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=metric,
g=[1*nS, 30*nS],
restart=False,)
traces = tf.generate(output_var=['I', 'g'])
assert isinstance(traces, dict)
assert set(traces.keys()) == {'I', 'g'}
assert_equal(np.shape(traces['I']), np.shape(output_traces))
assert_equal(np.shape(traces['g']), np.shape(output_traces))


def test_fitter_generate_traces_standalone(setup_standalone):
dt, tf = setup_standalone
results, errors = tf.fit(n_rounds=2,
Expand Down
8 changes: 4 additions & 4 deletions brian2modelfitting/tests/test_simulation_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setup(request):
duration = 10 * ms

neurons = NeuronGroup(1, model, name='neurons')
monitor = StateMonitor(neurons, 'I', record=True, name='monitor')
monitor = StateMonitor(neurons, 'I', record=True, name='statemonitor')

net = Network(neurons, monitor)

Expand Down Expand Up @@ -91,7 +91,7 @@ def test_run_simulation_runtime(setup):
rts.initialize(net, var_init=None)

rts.run(duration, {'g': 100, 'E': 10}, ['g', 'E'])
I = getattr(rts.networks['fit']['monitor'], 'I')
I = getattr(rts.statemonitor, 'I')
assert_equal(np.shape(I), (1, duration/dt))


Expand All @@ -100,12 +100,12 @@ def test_run_simulation_runtime_var_init(setup):
start_scope()

neurons = NeuronGroup(1, model2, name='neurons')
monitor = StateMonitor(neurons, 'v', record=True, name='monitor')
monitor = StateMonitor(neurons, 'v', record=True, name='statemonitor')
net = Network(neurons, monitor)

rts = RuntimeSimulator()
rts.initialize(net, var_init={'v': -60*mV})

rts.run(duration, {'gL': 100, 'C': 10}, ['gL', 'C'])
v = getattr(rts.networks['fit']['monitor'], 'v')
v = getattr(rts.statemonitor, 'v')
assert_equal(np.shape(v), (1, duration/dt))
6 changes: 3 additions & 3 deletions brian2modelfitting/tests/test_simulation_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setup(request):
duration = 10 * ms

neurons = NeuronGroup(1, model, name='neurons')
monitor = StateMonitor(neurons, 'I', record=True, name='monitor')
monitor = StateMonitor(neurons, 'I', record=True, name='statemonitor')

net = Network(neurons, monitor)

Expand All @@ -48,7 +48,7 @@ def setup_standalone(request):
dt = 0.1 * ms
duration = 10 * ms
neurons = NeuronGroup(1, model, name='neurons')
monitor = StateMonitor(neurons, 'I', record=True, name='monitor')
monitor = StateMonitor(neurons, 'I', record=True, name='statemonitor')

net = Network(neurons, monitor)

Expand Down Expand Up @@ -100,5 +100,5 @@ def test_run_simulation_standalone(setup_standalone):
sas.initialize(net, var_init=None)

sas.run(duration, {'g': 100, 'E': 10}, ['g', 'E'])
I = getattr(sas.monitor, 'I')
I = getattr(sas.statemonitor, 'I')
assert_equal(np.shape(I), (1, duration/dt))
16 changes: 14 additions & 2 deletions docs_sphinx/features/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ any input arguments:
fitter = SpikeFitter(...)
results, error = fitter.fit(...)
spikes = fitter.generate_traces()
spikes = fitter.generate_spikes()
Custom generate
Expand All @@ -118,11 +118,23 @@ arguments:
fitter.generate(params=None, output_var=None, param_init=None, level=0)
Where ``params`` is a dictionary of parameters for which the traces we generate.
``output_var`` provides an option to pick variable for visualization. With
``output_var`` provides an option to pick one or more variable for visualization. With
``param_init``, user can define the initial values for differential equations.
``level`` allows for specification of namespace level from which we get
the constant parameters of the model.

If ``output_var`` is the name of a single variable name (or the special name ``'spikes'``), a single `~.Quantity`
(for normal variables) or a list of spikes time arrays (for ``'spikes'``) will be returned. If a list of names is
provided, then the result is a dictionary with all the results.

.. code:: python
fitter = TraceFitter(...)
results, error = fitter.fit(...)
traces = fitter.generate(output_var=['v', 'h', 'n', 'm'])
v_trace = traces['v']
h_trace = traces['h']
...
Results
Expand Down

0 comments on commit 6ee0669

Please sign in to comment.