Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into temporal_weighting
Browse files Browse the repository at this point in the history
# Conflicts:
#	brian2modelfitting/fitter.py
  • Loading branch information
mstimberg committed May 18, 2020
2 parents 0154cff + 6ee0669 commit c136a6c
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 83 deletions.
105 changes: 71 additions & 34 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros

from brian2.parsing.sympytools import sympy_to_str, str_to_sympy
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions, fail_for_dimension_mismatch
from brian2.utils.stringtools import get_identifiers

from brian2 import (NeuronGroup, defaultclock, get_device, Network,
Expand Down Expand Up @@ -82,8 +82,9 @@ def setup_fit():
}
if isinstance(get_device(), CPPStandaloneDevice):
if device.has_been_run is True:
build_options = dict(device.build_options)
get_device().reinit()
get_device().activate()
get_device().activate(**build_options)
return simulators[get_device().__class__.__name__]


Expand Down Expand Up @@ -234,7 +235,7 @@ class Fitter(metaclass=abc.ABCMeta):
input : `~numpy.ndarray` or `~brian2.units.fundamentalunits.Quantity`
A 2D array of shape ``(n_traces, time steps)`` given the input that will
be fed into the model.
output : `~numpy.ndarray` or `~brian2.units.fundamentalunits.Quantity` or list
output : `~brian2.units.fundamentalunits.Quantity` or list
Recorded output of the model that the model should reproduce. Should
be a 2D array of the same shape as the input when fitting traces with
`TraceFitter`, a list of spike times when fitting spike trains with
Expand Down Expand Up @@ -293,13 +294,16 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self.refractory = refractory

self.input = input
self.output = Quantity(output)
self.output_ = array(output)
self.output_var = output_var
if output_var == 'spikes':
self.output_dim = DIMENSIONLESS
else:
self.output_dim = model[output_var].dim
fail_for_dimension_mismatch(output, self.output_dim,
'The provided target values '
'("output") need to have the same '
'units as the variable '
'{}'.format(output_var))
self.model = model

self.use_units = use_units
Expand All @@ -310,6 +314,17 @@ def __init__(self, dt, model, input, output, input_var, output_var,
input_dim)
self.model += input_eqs

if output_var != 'spikes':
# For approaches that couple the system to the target values,
# provide a convenient variable
output_expr = 'output_var(t, i % n_traces)'
output_dim = ('1' if self.output_dim is DIMENSIONLESS
else repr(self.output_dim))
output_eqs = "{}_target = {} : {}".format(output_var,
output_expr,
output_dim)
self.model += output_eqs

input_traces = TimedArray(input.transpose(), dt=dt)
self.input_traces = input_traces

Expand Down Expand Up @@ -337,25 +352,30 @@ def setup_simulator(self, network_name, n_neurons, output_var, param_init,
level=level+1)
if hasattr(self, 't_start'): # OnlineTraceFitter
namespace['t_start'] = self.t_start
if network_name != 'generate':

if self.output_var != 'spikes':
namespace['output_var'] = TimedArray(self.output.transpose(),
dt=self.dt)
neurons = self.setup_neuron_group(n_neurons, namespace,
calc_gradient=calc_gradient,
optimize=optimize,
online_error=online_error)
network = Network(neurons)
if isinstance(output_var, str):
output_var = [output_var]
if 'spikes' in output_var:
network.add(SpikeMonitor(neurons, name='spikemonitor'))

if output_var == 'spikes':
monitor = SpikeMonitor(neurons, name='monitor')
else:
record_vars = [output_var]
if calc_gradient:
record_vars.extend([f'S_{output_var}_{p}'
for p in self.parameter_names])
monitor = StateMonitor(neurons, record_vars, record=True,
name='monitor', dt=self.dt)

network = Network(neurons, monitor)
record_vars = [v for v in output_var if v != 'spikes']
if calc_gradient:
if not len(output_var) == 1:
raise AssertionError('Cannot calculate gradient with multiple '
'output variables.')
record_vars.extend([f'S_{output_var[0]}_{p}'
for p in self.parameter_names])
if len(record_vars):
network.add(StateMonitor(neurons, record_vars, record=True,
name='statemonitor', dt=self.dt))

if calc_gradient:
param_init = dict(param_init)
Expand Down Expand Up @@ -402,8 +422,9 @@ def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
refractory=self.refractory, name=name,
namespace=namespace, dt=self.dt, **kwds)
if online_error:
neurons.run_regularly('total_error += (' + self.output_var +
'-output_var(t,i % n_traces))**2 * int(t>=t_start)',
neurons.run_regularly('total_error += ({} - {}_target)**2 * '
'int(t>=t_start)'.format(self.output_var,
self.output_var),
when='end')

return neurons
Expand Down Expand Up @@ -653,18 +674,19 @@ def results(self, format='list', use_units=None):
data = concatenate((params, array(errors)[None, :].transpose()), axis=1)
return DataFrame(data=data, columns=names + ['error'])

def generate(self, params=None, output_var=None, param_init=None, level=0):
def generate(self, output_var=None, params=None, param_init=None, level=0):
"""
Generates traces for best fit of parameters and all inputs.
If provided with other parameters provides those.
Parameters
----------
output_var: str or sequence of str
Name of the output variable to be monitored, or the special name
``spikes`` to record spikes. Can also be a sequence of names to
record multiple variables.
params: dict
Dictionary of parameters to generate fits for.
output_var: str
Name of the output variable to be monitored, or the special name
``spikes`` to record spikes.
param_init: dict
Dictionary of initial values for the model.
level : `int`, optional
Expand All @@ -675,7 +697,9 @@ def generate(self, params=None, output_var=None, param_init=None, level=0):
fit
Either a 2D `.Quantity` with the recorded output variable over time,
with shape <number of input traces> × <number of time steps>, or
a list of spike times for each input trace.
a list of spike times for each input trace. If several names were
given as ``output_var``, then the result is a dictionary with the
names of the variable as the key.
"""
if params is None:
params = self.best_params
Expand All @@ -696,12 +720,19 @@ def generate(self, params=None, output_var=None, param_init=None, level=0):
self.simulator.run(self.duration, param_dic, self.parameter_names,
name='generate')

if not isinstance(output_var, str):
fits = {name: self._simulation_result(name) for name in output_var}
else:
fits = self._simulation_result(output_var)

return fits

def _simulation_result(self, output_var):
if output_var == 'spikes':
fits = get_spikes(self.simulator.monitor,
fits = get_spikes(self.simulator.spikemonitor,
1, self.n_traces)[0] # a single "sample"
else:
fits = getattr(self.simulator.monitor, output_var)[:]

fits = getattr(self.simulator.statemonitor, output_var)[:]
return fits


Expand Down Expand Up @@ -734,6 +765,8 @@ def __init__(self, model, input_var, input, output_var, output, dt,
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
param_init, use_units=use_units)
self.output = Quantity(output)
self.output_ = array(output)
# We store the bounds set in TraceFitter.fit, so that Tracefitter.refine
# can reuse them
self.bounds = None
Expand All @@ -748,7 +781,7 @@ def calc_errors(self, metric):
Returns errors after simulation with StateMonitor.
To be used inside `optim_iter`.
"""
traces = getattr(self.simulator.networks['fit']['monitor'],
traces = getattr(self.simulator.networks['fit']['statemonitor'],
self.output_var+'_')
# Reshape traces for easier calculation of error
traces = reshape(traces, (traces.shape[0]//self.n_traces,
Expand Down Expand Up @@ -920,17 +953,17 @@ def _calc_error(params):
self.parameter_names, self.n_traces, 1)
self.simulator.run(self.duration, param_dic,
self.parameter_names, name='refine')
trace = getattr(self.simulator.monitor, self.output_var+'_')
trace = getattr(self.simulator.statemonitor, self.output_var+'_')
if t_weights is None:
residual = trace[:, t_start_steps:] - self.output_[:, t_start_steps:]
else:
residual = (trace - self.output_)*t_weights
residual = (trace - self.output_) * t_weights
return residual.flatten() * normalization

def _calc_gradient(params):
residuals = []
for name in self.parameter_names:
trace = getattr(self.simulator.monitor,
trace = getattr(self.simulator.statemonitor,
f'S_{self.output_var}_{name}_')
if t_weights is None:
residual = trace[:, t_start_steps:]
Expand Down Expand Up @@ -999,10 +1032,11 @@ def __init__(self, model, input, output, dt, reset, threshold,
"""Initialize the fitter."""
if method is None:
method = 'exponential_euler'
super().__init__(dt, model, input, output, input_var, 'v',
super().__init__(dt, model, input, output, input_var, 'spikes',
n_samples, threshold, reset, refractory, method,
param_init, use_units=use_units)
self.output_var = 'spikes'
self.output = [Quantity(o) for o in output]
self.output_ = [array(o) for o in output]

if param_init:
for param, val in param_init.items():
Expand All @@ -1018,7 +1052,7 @@ def calc_errors(self, metric):
Returns errors after simulation with SpikeMonitor.
To be used inside optim_iter.
"""
spikes = get_spikes(self.simulator.networks['fit']['monitor'],
spikes = get_spikes(self.simulator.networks['fit']['spikemonitor'],
self.n_samples, self.n_traces)
errors = metric.calc(spikes, self.output, self.dt)
return errors
Expand Down Expand Up @@ -1050,6 +1084,9 @@ def __init__(self, model, input_var, input, output_var, output, dt,
n_samples, threshold, reset, refractory, method,
param_init)

self.output = Quantity(output)
self.output_ = array(output)

if output_var not in self.model.names:
raise NameError("%s is not a model variable" % output_var)
if output.shape != input.shape:
Expand Down
54 changes: 33 additions & 21 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from itertools import repeat
from brian2 import Hz, second, Quantity, ms, us, get_dimensions
from brian2.units.fundamentalunits import check_units, in_unit, DIMENSIONLESS
from numpy import (array, sum, square, reshape, abs, amin, digitize,
rint, arange, atleast_2d, NaN, float64, split, shape,)
from numpy import (array, sum, abs, amin, digitize, rint, arange, inf, NaN,
clip)


def firing_rate(spikes):
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_gamma_factor(model, data, delta, time, dt, rate_correction=True):
-------
float
An error based on the Gamma factor. If ``rate_correction`` is used,
then the returned error is :math:`2\frac{\lvert r_\mathrm{data} - r_\mathrm{model}\rvert}{r_\mathrm{data}} - \Gamma`
then the returned error is :math:`1 + 2\frac{\lvert r_\mathrm{data} - r_\mathrm{model}\rvert}{r_\mathrm{data}} - \Gamma`
(with :math:`r_\mathrm{data}` and :math:`r_\mathrm{model}` being the
firing rates in the data/model, and :math:`\Gamma` the coincidence
factor). Without ``rate_correction``, the error is
Expand Down Expand Up @@ -88,11 +88,12 @@ def get_gamma_factor(model, data, delta, time, dt, rate_correction=True):
gamma = (coincidences - NCoincAvg)/(norm*(model_length + data_length))

if rate_correction:
rate_term = 2*abs((data_rate - model_rate)/data_rate)
rate_term = 1 + 2*abs((data_rate - model_rate)/data_rate)
else:
rate_term = 1

return rate_term - gamma
return clip(rate_term - gamma, 0, inf)


def calc_eFEL(traces, inp_times, feat_list, dt):
out_traces = []
Expand Down Expand Up @@ -442,6 +443,7 @@ def get_dimensions(self, output_dim):
def get_normalized_dimensions(self, output_dim):
return output_dim**2 * get_dimensions(self.normalization)**2


class FeatureMetric(TraceMetric):
def __init__(self, stim_times, feat_list, weights=None, combine=None,
t_start=0*second, normalization=1.):
Expand Down Expand Up @@ -551,34 +553,44 @@ class GammaFactor(SpikeMetric):
Calculate gamma factors between goal and calculated spike trains, with
precision delta.
References:
Parameters
----------
delta: `~brian2.units.fundamentalunits.Quantity`
time window
time: `~brian2.units.fundamentalunits.Quantity`
total length of experiment
rate_correction: bool
Whether to include an error term that penalizes differences in firing
rate, following `Clopath et al., Neurocomputing (2007)
<https://doi.org/10.1016/j.neucom.2006.10.047>`_. Defaults to
``True``.
Notes
-----
The gamma factor is commonly defined as 1 for a perfect match and 0 for
a match not better than random (negative values are possible if the match
is *worse* than expected by chance). Since we use the gamma factor as an
error to be minimized, the calculated term is actually r - gamma_factor,
where r is 1 if ``rate_correction`` is ``False``, or a rate-difference
dependent term if ``rate_correction` is ``True``. In both cases, the best
possible error value (i.e. for a perfect match between spike trains) is 0.
References
----------
* `R. Jolivet et al. “A Benchmark Test for a Quantitative Assessment of
Simple Neuron Models.” Journal of Neuroscience Methods, 169, no. 2 (2008):
417–24. <https://doi.org/10.1016/j.jneumeth.2007.11.006>`_
* `C. Clopath et al. “Predicting Neuronal Activity with Simple Models of the
Threshold Type: Adaptive Exponential Integrate-and-Fire Model with
Two Compartments.” Neurocomputing, 70, no. 10 (2007): 1668–73.
<https://doi.org/10.1016/j.neucom.2006.10.047>`_
"""

@check_units(delta=second, time=second, t_start=0*second)
def __init__(self, delta, time, t_start=0*second, normalization=1.,
rate_correction=True):
"""
Initialize the metric with time window delta and time step dt output
Parameters
----------
delta: `~brian2.units.fundamentalunits.Quantity`
time window
time: `~brian2.units.fundamentalunits.Quantity`
total length of experiment
rate_correciton: bool
Whether to include an error term that penalizes differences in firing
rate, following `Clopath et al., Neurocomputing (2007)
<https://doi.org/10.1016/j.neucom.2006.10.047>`_.
"""
super(GammaFactor, self).__init__(t_start=t_start,
normalization=normalization)
self.delta = delta
Expand Down
12 changes: 7 additions & 5 deletions brian2modelfitting/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self):
self.var_init = None

neurons = property(lambda self: self.networks[self.current_net]['neurons'])
monitor = property(lambda self: self.networks[self.current_net]['monitor'])
statemonitor = property(lambda self: self.networks[self.current_net]['statemonitor'])
spikemonitor = property(lambda self: self.networks[self.current_net]['spikemonitor'])

def initialize(self, network, var_init, name='fit'):
"""
Expand All @@ -67,7 +68,8 @@ def initialize(self, network, var_init, name='fit'):
----------
network: `~brian2.core.network.Network`
Network consisting of a `~brian2.groups.neurongroup.NeuronGroup`
named ``neurons`` and a monitor named ``monitor``.
named ``neurons`` and either a monitor named ``spikemonitor``
or a monitor named ``statemonitor``(or both).
var_init: dict
dictionary to initialize the variable states
name: `str`, optional
Expand All @@ -77,9 +79,9 @@ def initialize(self, network, var_init, name='fit'):
if 'neurons' not in network:
raise KeyError('Expected a group named "neurons" in the '
'network.')
if 'monitor' not in network:
raise KeyError('Expected a monitor named "monitor" in the '
'network.')
if 'statemonitor' not in network and 'spikemonitor' not in network:
raise KeyError('Expected a monitor named "spikemonitor" or '
'"statemonitor" in the network.')
self.networks[name] = network
self.current_net = None # will be set in run
self.var_init = var_init
Expand Down

0 comments on commit c136a6c

Please sign in to comment.