diff --git a/brian2modelfitting/modelfitting/metric.py b/brian2modelfitting/modelfitting/metric.py index 6f596d00..dd3711d5 100644 --- a/brian2modelfitting/modelfitting/metric.py +++ b/brian2modelfitting/modelfitting/metric.py @@ -1,3 +1,4 @@ +import warnings import abc import efel from itertools import repeat @@ -8,7 +9,7 @@ def firing_rate(spikes): - """Raturns rate of the spike train""" + """Returns rate of the spike train""" if len(spikes) < 2: return NaN return (len(spikes) - 1) / (spikes[-1] - spikes[0]) @@ -16,7 +17,7 @@ def firing_rate(spikes): def get_gamma_factor(model, data, delta, time, dt): """ - Calculate gamma factor between model and tagret spike trains, + Calculate gamma factor between model and target spike trains, with precision delta. Parameters @@ -47,8 +48,7 @@ def get_gamma_factor(model, data, delta, time, dt): data_rate = data_length / time model_rate = model_length / time - - if (model_length > 1): + if model_length > 1: bins = .5 * (model[1:] + model[:-1]) indices = digitize(data, bins) diff = abs(data - model[indices]) @@ -90,26 +90,25 @@ class Metric(metaclass=abc.ABCMeta): Metic abstract class to define functions required for a custom metric To be used with modelfitting Fitters. """ - @abc.abstractmethod - def get_features(self, traces, output, n_traces, dt): - """ - Function calculates features / errors for each of the traces and stores - it in an attribute metric.features. - The output of the function has to take shape of (n_samples, n_traces) - or (n_traces, n_samples). + @check_units(t_start=second) + def __init__(self, t_start=0*second, **kwds): + """ + Initialize the metric. Parameters ---------- - traces: 2D array - traces to be evaluated - output: array - goal traces - n_traces: int - number of input traces - dt: Quantity - time step + t_start: Quantity, optional + Start of time window considered for calculating the fit error. + """ + self.t_start = t_start + @abc.abstractmethod + def get_features(self, model_results, target_results, dt): + """ + Function calculates features / errors for each of the input traces. + + The output of the function has to take shape of (n_samples, n_traces). """ pass @@ -117,7 +116,7 @@ def get_features(self, traces, output, n_traces, dt): def get_errors(self, features, n_traces): """ Function weights features/multiple errors into one final error per each - set of parameters and inputs stored metric.errors. + set of parameters. The output of the function has to take shape of (n_samples,). @@ -131,99 +130,179 @@ def get_errors(self, features, n_traces): """ pass - def calc(self, traces, output, n_traces, dt): + +class TraceMetric(Metric): + """ + Input traces have to be shaped into 2D array. + """ + + def calc(self, model_traces, data_traces, dt): """ Perform the error calculation across all parameters, calculate error between each output trace and corresponding - simulation. You can also access metric.features, metric.errors. + simulation. Parameters ---------- - traces: 2D array - traces to be evaluated - output: array - goal traces - n_traces: int - number of input traces + model_traces: ndarray + Traces that should be evaluated and compared to the target data. + Provided as an `.ndarray` of shape (samples, traces, time steps), + where "samples" are the different parameter values that have been + evaluated, and "traces" are the responses of the model to the + different input stimuli. + data_traces: array + The target traces to which the model should be compared. An + `ndarray` of shape (traces, time steps). dt: Quantity - time step + The length of a single time step. Returns ------- - errors: array - weigheted/mean error for each set of parameters + errors: ndarray + Total error for each set of parameters. """ - features = self.get_features(traces, output, n_traces, dt) + start_steps = int(round(self.t_start/dt)) + features = self.get_features(model_traces[:, :, start_steps:], + data_traces[:, start_steps:], + dt) errors = self.get_errors(features) return errors + @abc.abstractmethod + def get_features(self, model_traces, data_traces, dt): + """ + Calculate the features/errors for each simulated trace, by comparing + it to the corresponding data trace. -class TraceMetric(Metric): - """ - Input traces have to be shaped into 2D array. - """ - pass + Parameters + ---------- + model_traces: ndarray + Traces that should be evaluated and compared to the target data. + Provided as an :py:class:`~numpy.ndarray` of shape + ``(n_samples, n_traces, time steps)``, + where ``n_samples`` are the number of different parameter sets that + have been evaluated, and ``n_traces`` are the number of input + stimuli. + data_traces: array + The target traces to which the model should be compared. An + :py:class:`~numpy.ndarray` of shape ``(n_traces, time steps)``. + dt: Quantity + The length of a single time step. + + Returns + ------- + features: ndarray + An :py:class:`~numpy.ndarray` of shape ``(n_samples, n_traces)`` + returning the error/feature value for each simulated trace. + """ + pass class SpikeMetric(Metric): """ - Output spikes contain a list of arrays (possibly of different lengths) - in order to allow different lengths of spike trains. - Example: [array([1, 2, 3]), array([1, 2])] + A metric for comparing the spike trains. """ - pass - - -class MSEMetric(TraceMetric): - __doc__ = "Mean Square Error between goal and calculated output." + \ - Metric.get_features.__doc__ - @check_units(t_start=second) - def __init__(self, t_start=None, **kwds): + def calc(self, model_spikes, data_spikes, dt): """ - Initialize the metric. + Perform the error calculation across all parameters, + calculate error between each output trace and corresponding + simulation. Parameters ---------- - t_start: beggining of time window (Quantity) (optional) + model_spikes: list of list of arrays + A nested list structure for the spikes generated by the model: a + list where each element contains the results for a single parameter + set. Each of these results is a list for each of the input traces, + where the elements of this list are numpy arrays of spike times + (without units, i.e. in seconds). + data_spikes: list of arrays + The target spikes for the fitting, represented in the same way as + ``model_spikes``, i.e. as a list of spike times for each input + stimulus. + dt: Quantity + The length of a single time step. + + Returns + ------- + errors: ndarray + Total error for each set of parameters. + """ - self.t_start = t_start + if self.t_start > 0*second: + relevant_data_spikes = [] + for one_stim in data_spikes: + relevant_data_spikes.append(one_stim[one_stim>float(self.t_start)]) + relevant_model_spikes = [] + for one_sample in model_spikes: + sample_spikes = [] + for one_stim in one_sample: + sample_spikes.append(one_stim[one_stim>float(self.t_start)]) + relevant_model_spikes.append(sample_spikes) + model_spikes = relevant_model_spikes + data_spikes = relevant_data_spikes + features = self.get_features(model_spikes, data_spikes, dt) + errors = self.get_errors(features) - def get_features(self, traces, output, n_traces, dt): - mselist = [] - output = atleast_2d(output) + return errors - if self.t_start is not None: - if not isinstance(dt, Quantity): - raise TypeError("To specify time window you need to also " - "specify dt as Quantity") - t_start = int(self.t_start/dt) - output = output[:, t_start:-1] - traces = traces[:, t_start:-1] + @abc.abstractmethod + def get_features(self, model_spikes, data_spikes, dt): + """ + Calculate the features/errors for each simulated spike train by + comparing it to the corresponding data spike train. + + Parameters + ---------- + model_spikes: list of list of arrays + A nested list structure for the spikes generated by the model: a + list where each element contains the results for a single parameter + set. Each of these results is a list for each of the input traces, + where the elements of this list are numpy arrays of spike times + (without units, i.e. in seconds). + data_spikes: list of arrays + The target spikes for the fitting, represented in the same way as + ``model_spikes``, i.e. as a list of spike times for each input + stimulus. + dt: Quantity + The length of a single time step. - for i in arange(n_traces): - temp_out = output[i] - temp_traces = traces[i::n_traces] + Returns + ------- + features: ndarray + An :py:class:`~numpy.ndarray` of shape ``(n_samples, n_traces)`` + returning the error/feature value for each simulated trace. + """ - for trace in temp_traces: - mse = sum(square(temp_out - trace)) - mselist.append(mse) - feat_arr = reshape(array(mselist), (n_traces, - int(len(mselist)/n_traces))) +class MSEMetric(TraceMetric): + """ + Mean Square Error between goal and calculated output. + """ - return feat_arr + def get_features(self, model_traces, data_traces, dt): + return sum((model_traces - data_traces)**2, axis=2) def get_errors(self, features): - errors = features.mean(axis=0) - return errors + return features.mean(axis=1) class FeatureMetric(TraceMetric): - def __init__(self, traces_times, feat_list, weights=None, combine=None): + def __init__(self, traces_times, feat_list, weights=None, combine=None, + t_start=0*second): + super(FeatureMetric, self).__init__(t_start=t_start) self.traces_times = traces_times + if isinstance(self.traces_times[0][0], Quantity): + for n, trace in enumerate(self.traces_times): + t_start, t_end = trace[0], trace[1] + t_start = t_start / ms + t_end = t_end / ms + self.traces_times[n] = [t_start, t_end] + n_times = shape(self.traces_times)[0] + self.feat_list = feat_list if combine is None: @@ -246,7 +325,7 @@ def check_values(self, feat_list): for k, v in r.items(): if v is None: r[k] = array([99]) - raise Warning('None for key:{}'.format(k)) + warnings.warn('None for key:{}'.format(k)) if (len(r[k])) > 1: raise ValueError("you can only use features that return " "one value") @@ -264,36 +343,21 @@ def feat_to_err(self, d1, d2): return err - def get_features(self, traces, output, n_traces, dt): - if isinstance(self.traces_times[0][0], Quantity): - for n, trace in enumerate(self.traces_times): - t_start, t_end = trace[0], trace[1] - t_start = t_start / ms - t_end = t_end / ms - self.traces_times[n] = [t_start, t_end] - - n_times = shape(self.traces_times)[0] - - if (n_times != (n_traces)): - if (n_times == 1): + def get_features(self, traces, output, dt): + n_samples, n_traces, _ = traces.shape + if len(self.traces_times) != n_traces: + if len(self.traces_times) == 1: self.traces_times = list(repeat(self.traces_times[0], n_traces)) else: raise ValueError("Specify the traces_times variable of appropiate " "size (same as number of traces or 1).") - unit = output.get_best_unit() - output = output/unit - traces = traces/unit self.out_feat = calc_eFEL(output, self.traces_times, self.feat_list, dt) self.check_values(self.out_feat) - sl = int(shape(traces)[0]/n_traces) features = [] - temp_traces = split(traces, sl) - - for ii in arange(sl): - temp_trace = temp_traces[ii] - temp_feat = calc_eFEL(temp_trace, self.traces_times, + for one_sample in traces: + temp_feat = calc_eFEL(one_sample, self.traces_times, self.feat_list, dt) self.check_values(temp_feat) features.append(temp_feat) @@ -315,7 +379,7 @@ def get_errors(self, features): class GammaFactor(SpikeMetric): - __doc__ = """ + """ Calculate gamma factors between goal and calculated spike trains, with precision delta. @@ -323,10 +387,10 @@ class GammaFactor(SpikeMetric): R. Jolivet et al., 'A benchmark test for a quantitative assessment of simple neuron models', Journal of Neuroscience Methods 169, no. 2 (2008): 417-424. - """ + Metric.get_features.__doc__ + """ - @check_units(delta=second, time=second) - def __init__(self, delta, time): + @check_units(delta=second, time=second, t_start=0*second) + def __init__(self, delta, time, t_start=0*second): """ Initialize the metric with time window delta and time step dt output @@ -335,28 +399,21 @@ def __init__(self, delta, time): delta: time window (Quantity) time: total lenght of experiment (Quantity) """ + super(GammaFactor, self).__init__(t_start=t_start) self.delta = delta self.time = time - def get_features(self, traces, output, n_traces, dt): - gamma_factors = [] - if type(output[0]) == float64: - output = atleast_2d(output) - - for i in arange(n_traces): - temp_out = output[i] - temp_traces = traces[i::n_traces] - - for trace in temp_traces: - gf = get_gamma_factor(trace, temp_out, self.delta, self.time, dt) - # gamma_factors.append(abs(1 - gf)) - gamma_factors.append(gf) - - feat_arr = reshape(array(gamma_factors), (n_traces, - int(len(gamma_factors)/n_traces))) - - return feat_arr + def get_features(self, traces, output, dt): + all_gf = [] + for one_sample in traces: + gf_for_sample = [] + for model_response, target_response in zip(one_sample, output): + gf = get_gamma_factor(model_response, target_response, + self.delta, self.time, dt) + gf_for_sample.append(gf) + all_gf.append(gf_for_sample) + return array(all_gf) def get_errors(self, features): - errors = features.mean(axis=0) + errors = features.mean(axis=1) return errors diff --git a/brian2modelfitting/modelfitting/modelfitting.py b/brian2modelfitting/modelfitting/modelfitting.py index fc248ea2..da768e65 100644 --- a/brian2modelfitting/modelfitting/modelfitting.py +++ b/brian2modelfitting/modelfitting/modelfitting.py @@ -1,5 +1,5 @@ import abc -from numpy import ones, array, arange, concatenate, mean, nanmin +from numpy import ones, array, arange, concatenate, mean, nanmin, reshape from brian2 import (NeuronGroup, defaultclock, get_device, Network, StateMonitor, SpikeMonitor, ms, device, second, get_local_namespace, Quantity) @@ -22,18 +22,21 @@ def get_param_dic(params, param_names, n_traces, n_samples): return d -def get_spikes(monitor): +def get_spikes(monitor, n_samples, n_traces): """ Get spikes from spike monitor change format from dict to a list, remove units. """ spike_trains = monitor.spike_trains() - + assert len(spike_trains) == n_samples*n_traces spikes = [] - for i in arange(len(spike_trains)): - spike_list = spike_trains[i] / ms - spikes.append(spike_list) - + i = -1 + for sample in range(n_samples): + sample_spikes = [] + for trace in range(n_traces): + i += 1 + sample_spikes.append(array(spike_trains[i], copy=False)) + spikes.append(sample_spikes) return spikes @@ -129,7 +132,7 @@ def __init__(self, dt, model, input, output, input_var, output_var, self.refractory = refractory self.input = input - self.output = output + self.output = array(output) self.output_var = output_var self.model = model @@ -387,7 +390,8 @@ def generate(self, params=None, output_var=None, param_init=None, level=0): name='neurons_') if output_var == 'spikes': - fits = get_spikes(self.simulator.network['monitor_']) + fits = get_spikes(self.simulator.network['monitor_'], + 1, self.n_traces)[0] # a single "sample" else: fits = getattr(self.simulator.network['monitor_'], output_var) @@ -440,8 +444,13 @@ def calc_errors(self, metric): Returns errors after simulation with StateMonitor. To be used inside optim_iter. """ - traces = getattr(self.simulator.network['monitor'], self.output_var) - errors = metric.calc(traces, self.output, self.n_traces, self.dt) + traces = getattr(self.simulator.network['monitor'], + self.output_var+'_') + # Reshape traces for easier calculation of error + traces = reshape(traces, (traces.shape[0]//self.n_traces, + self.n_traces, + -1)) + errors = metric.calc(traces, self.output, self.dt) return errors def fit(self, optimizer, metric=None, n_rounds=1, callback='text', @@ -500,8 +509,9 @@ def calc_errors(self, metric): Returns errors after simulation with SpikeMonitor. To be used inside optim_iter. """ - spikes = get_spikes(self.simulator.network['monitor']) - errors = metric.calc(spikes, self.output, self.n_traces, self.dt) + spikes = get_spikes(self.simulator.network['monitor'], + self.n_samples, self.n_traces) + errors = metric.calc(spikes, self.output, self.dt) return errors def fit(self, optimizer, metric=None, n_rounds=1, callback='text', diff --git a/brian2modelfitting/tests/test_metric.py b/brian2modelfitting/tests/test_metric.py index 31914ebe..468a66bc 100644 --- a/brian2modelfitting/tests/test_metric.py +++ b/brian2modelfitting/tests/test_metric.py @@ -37,99 +37,87 @@ def test_init(): def test_calc_mse(): mse = MSEMetric() out = np.random.rand(2, 20) - inp = np.random.rand(10, 20) + inp = np.random.rand(5, 2, 20) - errors = mse.calc(inp, out, 2, 0.01*ms) + errors = mse.calc(inp, out, 0.01*ms) assert_equal(np.shape(errors), (5,)) - assert_equal(mse.calc(out, out, 2, 0.1*ms), [0.]) - assert(np.all(mse.calc(inp, out, 2, 0.1*ms) > 0)) + assert_equal(mse.calc(np.tile(out, (5, 1, 1)), out, 0.1*ms), + np.zeros(5)) + assert(np.all(mse.calc(inp, out, 0.1*ms) > 0)) def test_calc_mse_t_start(): mse = MSEMetric(t_start=1*ms) - out = np.random.rand(2, 200) - inp = np.random.rand(10, 200) + out = np.random.rand(2, 20) + inp = np.random.rand(5, 2, 20) - errors = mse.calc(inp, out, 2, 0.1*ms) + errors = mse.calc(inp, out, 0.1*ms) assert_equal(np.shape(errors), (5,)) - assert_equal(mse.calc(out, out, 2, 0.1*ms), [0.]) - assert(np.all(mse.calc(inp, out, 2, 0.1*ms) > 0)) - + assert(np.all(errors > 0)) + # Everything before 1ms should be ignored, so having the same values for + # the rest should give an error of 0 + inp[:, :, 10:] = out[None, :, 10:] + assert_equal(mse.calc(inp, out, 0.1*ms), np.zeros(5)) def test_calc_gf(): assert_raises(TypeError, GammaFactor) assert_raises(DimensionMismatchError, GammaFactor, delta=10) assert_raises(DimensionMismatchError, GammaFactor, time=10) - inp_gf = np.round(np.sort(np.random.rand(10, 5) * 10), 2) + inp_gf = np.round(np.sort(np.random.rand(5, 2, 5) * 10), 2) out_gf = np.round(np.sort(np.random.rand(2, 5) * 10), 2) gf = GammaFactor(delta=10*ms, time=10*ms) - errors = gf.calc(inp_gf, out_gf, 2, 0.1*ms) + errors = gf.calc(inp_gf, out_gf, 0.1*ms) assert_equal(np.shape(errors), (5,)) - assert(np.all(errors > 0)) - errors = gf.calc(out_gf, out_gf, 2, 0.1*ms) - assert_almost_equal(errors, [2.]) + assert(all(errors > 0)) + errors = gf.calc([out_gf]*5, out_gf, 0.1*ms) + assert_almost_equal(errors, np.ones(5)*2) def test_get_features_mse(): mse = MSEMetric() out_mse = np.random.rand(2, 20) - inp_mse = np.random.rand(6, 20) - - features = mse.get_features(inp_mse, out_mse, 2, 0.1*ms) - assert_equal(np.shape(features), (2, 3)) - assert(np.all(np.array(features) > 0)) - - features = mse.get_features(out_mse, out_mse, 2, 0.1*ms) - assert_equal(np.shape(features), (2, 1)) - assert_equal(features, [[0.], [0.]]) - - -def test_get_features_mse_t_start(): - mse = MSEMetric(t_start=1*ms) - out_mse = np.random.rand(2, 200) - inp_mse = np.random.rand(6, 200) + inp_mse = np.random.rand(5, 2, 20) - features = mse.get_features(inp_mse, out_mse, 2, 0.1*ms) - assert_equal(np.shape(features), (2, 3)) + features = mse.get_features(inp_mse, out_mse, 0.1*ms) + assert_equal(np.shape(features), (5, 2)) assert(np.all(np.array(features) > 0)) - features = mse.get_features(out_mse, out_mse, 2, 0.1*ms) - assert_equal(np.shape(features), (2, 1)) - assert_equal(features, [[0.], [0.]]) + features = mse.get_features(np.tile(out_mse, (5, 1, 1)), out_mse, 0.1*ms) + assert_equal(np.shape(features), (5, 2)) + assert_equal(features, np.zeros((5, 2))) def test_get_errors_mse(): mse = MSEMetric() - errors = mse.get_errors(np.random.rand(10, 5)) - print(errors) + errors = mse.get_errors(np.random.rand(5, 10)) assert_equal(np.shape(errors), (5,)) assert(np.all(np.array(errors) > 0)) - errors = mse.get_errors(np.zeros((10, 2))) + errors = mse.get_errors(np.zeros((2, 10))) assert_equal(np.shape(errors), (2,)) assert_equal(errors, [0., 0.]) def test_get_features_gamma(): - inp_gf = np.round(np.sort(np.random.rand(6, 5) * 10), 2) + inp_gf = np.round(np.sort(np.random.rand(3, 2, 5) * 10), 2) out_gf = np.round(np.sort(np.random.rand(2, 5) * 10), 2) gf = GammaFactor(delta=10*ms, time=10*ms) - features = gf.get_features(inp_gf, out_gf, 2, 0.1*ms) - assert_equal(np.shape(features), (2, 3)) + features = gf.get_features(inp_gf, out_gf, 0.1*ms) + assert_equal(np.shape(features), (3, 2)) assert(np.all(np.array(features) > 0)) - features = gf.get_features(out_gf, out_gf, 2, 0.1*ms) - assert_equal(np.shape(features), (2, 1)) - assert_almost_equal(features, [[2.], [2.]]) + features = gf.get_features([out_gf]*3, out_gf, 0.1*ms) + assert_equal(np.shape(features), (3, 2)) + assert_almost_equal(features, np.ones((3, 2))*2) def test_get_errors_gamma(): gf = GammaFactor(delta=10*ms, time=10*ms) - errors = gf.get_errors(np.random.rand(10, 5)) + errors = gf.get_errors(np.random.rand(5, 10)) assert_equal(np.shape(errors), (5,)) assert(np.all(np.array(errors) > 0)) - errors = gf.get_errors(np.zeros((10, 2))) + errors = gf.get_errors(np.zeros((2, 10))) assert_equal(np.shape(errors), (2,)) assert_almost_equal(errors, [0., 0.]) diff --git a/brian2modelfitting/tests/test_modelfitting_spikefitter.py b/brian2modelfitting/tests/test_modelfitting_spikefitter.py index 6b706f90..18968cf7 100644 --- a/brian2modelfitting/tests/test_modelfitting_spikefitter.py +++ b/brian2modelfitting/tests/test_modelfitting_spikefitter.py @@ -66,17 +66,17 @@ def fin(): group.v = -70 * mV spike_mon = SpikeMonitor(group) run(60*ms) - spikes = getattr(spike_mon, 't') / ms + spikes = getattr(spike_mon, 't_') return spike_mon, spikes def test_get_spikes(setup_spikes): spike_mon, spikes = setup_spikes - gs = get_spikes(spike_mon) + gs = get_spikes(spike_mon, 1, 1) assert isinstance(gs, list) - assert isinstance(gs[0], np.ndarray) - assert_equal(gs, [np.array(spikes)]) + assert isinstance(gs[0][0], np.ndarray) + assert_equal(gs, [[np.array(spikes)]]) def test_spikefitter_init(setup): diff --git a/brian2modelfitting/tests/test_utils.py b/brian2modelfitting/tests/test_utils.py index 3130214f..6a180837 100644 --- a/brian2modelfitting/tests/test_utils.py +++ b/brian2modelfitting/tests/test_utils.py @@ -24,7 +24,7 @@ def test_callback_none(): def test_ProgressBar(): pb = ProgressBar(toolbar_width=10) assert_equal(pb.toolbar_width, 10) - assert isinstance(pb.t, tqdm._tqdm.tqdm) + assert isinstance(pb.t, tqdm.tqdm) pb([1, 2, 3], [1.2, 2.3, 0.1], {'a':3}, 0.1, 2) diff --git a/docs_sphinx/conf.py b/docs_sphinx/conf.py index 0350c3a9..f2d9b8bb 100644 --- a/docs_sphinx/conf.py +++ b/docs_sphinx/conf.py @@ -14,7 +14,7 @@ import sys sys.path.insert(0, os.path.abspath('..')) -needs_sphinx = '1.7' +needs_sphinx = '2.0' brian2modelfitting_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), @@ -57,6 +57,9 @@ # The master toctree document. master_doc = 'index' +# autodoc configuration +autodoc_default_options = {'inherited-members': True} + # -- Options for HTML output ------------------------------------------------- # on_rtd is whether we are on readthedocs.org, this line of code grabbed from docs.readthedocs.org on_rtd = os.environ.get('READTHEDOCS', None) == 'True' diff --git a/docs_sphinx/metric/index.rst b/docs_sphinx/metric/index.rst index 740f1830..96fe1620 100644 --- a/docs_sphinx/metric/index.rst +++ b/docs_sphinx/metric/index.rst @@ -1,9 +1,10 @@ Metric ====== -Metric input to specifies the fitness function measuring the performance of the simulation. -This function gets applied on each simulated trace. We have implemented few metrics within -modelfitting. +A *Metric* specifies the fitness function measuring the performance of the +simulation. This function gets applied on each simulated trace. A few metrics +are already implemented and included in the toolbox, but the user can also +provide their own metric. .. contents:: :local: @@ -13,38 +14,44 @@ modelfitting. Mean Square Error ----------------- -:py:class:`~brian2modelfitting.modelfitting.metric.MSEMetric` is implemented to use with :py:class:`~brian2modelfitting.modelfitting.modelfitting.TraceFitter`. Calculated according to well known formula: +:py:class:`~brian2modelfitting.modelfitting.metric.MSEMetric` is provided for +use with :py:class:`~brian2modelfitting.modelfitting.modelfitting.TraceFitter`. +It calculates the mean squared difference between the data and the simulated +trace according to the well known formula: .. math:: MSE ={\frac {1}{n}}\sum _{i=1}^{n}(Y_{i}-{\hat {Y_{i}}})^{2} -To be called in a following way: +It can be initialized in the following way: .. code:: python metric = MSEMetric() -Additionally, :py:class:`~brian2modelfitting.modelfitting.metric.MSEMetric` accepts two optional input arguments -start time ``t_start``. Time steps gets passed from the fitter. The following have to always be provided together and have units -(be a :py:class:`~brian2.units.fundamentalunits.Quantity`). The start time allows the user to measure the error starting -from the provided time (i.e. start of stimulation). +Additionally, :py:class:`~brian2modelfitting.modelfitting.metric.MSEMetric` +accepts an optional input argument start time ``t_start`` (as a +:py:class:`~brian2.units.fundamentalunits.Quantity`). The start time allows the +user to ignore an initial period that will not be included in the error +calculation. .. code:: python metric = MSEMetric(t_start=5*ms) -In :py:class:`~brian2modelfitting.modelfitting.modelfitting.OnlineTraceFitter` mean square error gets calculated in online manner, -with no need of specifying a metric object. +In :py:class:`~brian2modelfitting.modelfitting.modelfitting.OnlineTraceFitter`, +the mean square error gets calculated in online manner, with no need of +specifying a metric object. GammaFactor ----------- +:py:class:`~brian2modelfitting.modelfitting.metric.GammaFactor` is provided for +use with :py:class:`~brian2modelfitting.modelfitting.modelfitting.SpikeFitter` +and measures the coincidence between spike times in the simulated and the target +trace. It is calculcated according to: -:py:class:`~brian2modelfitting.modelfitting.metric.GammaFactor` is implemented to use with :py:class:`~brian2modelfitting.modelfitting.modelfitting.SpikeFitter`. Calculated according to: - - -.. math:: \Gamma = \left (\frac{2}{1-2\delta r_{exp}}\right) \left(\frac{N_{coinc} - 2\delta N_{exp}r_{exp}}{N_{exp} + N_{model}}\right) +.. math:: \Gamma = \left (\frac{2}{1-2\Delta r_{exp}}\right) \left(\frac{N_{coinc} - 2\delta N_{exp}r_{exp}}{N_{exp} + N_{model}}\right) :math:`N_{coinc}` - number of coincidences @@ -52,27 +59,34 @@ GammaFactor :math:`r_{exp}` - average firing rate in experimental train -:math:`2 \delta N_{exp}r_{exp}` - expected number of coincidences with a Poission process +:math:`2 \Delta N_{exp}r_{exp}` - expected number of coincidences with a Poission process For more details on the gamma factor, see -Jolivet et al. 2008, “A benchmark test for a quantitative assessment of simple neuron models”, J. Neurosci. Methods. -(https://www.ncbi.nlm.nih.gov/pubmed/18160135) +`Jolivet et al. 2008, “A benchmark test for a quantitative assessment of simple +neuron models”, J. Neurosci. Methods. +`_ -Upon initialization user has to specify the delta as a :py:class:`~brian2.units.fundamentalunits.Quantity`: +Upon initialization the user has to specify the :math:`\Delta` value, defining +the maximal tolerance for spikes to be considered coincident: .. code:: python metric = GammaFactor(delta=10*ms) +.. warning:: + The ``delta`` parameter has to be smaller than the smallest inter-spike + interval in the spike trains. FeatureMetric ------------- -:py:class:`~brian2modelfitting.modelfitting.metric.FeatureMetric` is implemented to use with :py:class:`~brian2modelfitting.modelfitting.modelfitting.TraceFitter`. -Metric demonstrates a use of feature based metric in the toolbox. Features used for optimization get calculated with use of - -The Electrophys Feature Extract Library (eFEL) library, for which the documentation is available under following link: https://efel.readthedocs.io/en/latest/ +:py:class:`~brian2modelfitting.modelfitting.metric.FeatureMetric` is provided +for use with :py:class:`~brian2modelfitting.modelfitting.modelfitting.TraceFitter`. +This metric allows the user to optimize the match of certain features between +the simulated and the target trace. The features get calculated by Electrophys +Feature Extract Library (eFEL) library, for which the documentation is +available under following link: https://efel.readthedocs.io -To get all of the eFEL features you can run the following code: +To get a list of all the available eFEL features, you can run the following code: .. code:: python @@ -82,73 +96,111 @@ To get all of the eFEL features you can run the following code: .. note:: - User is only allowed to use features that return array of no more than one value. + Currently, only features that are described by a single value are supported + (e.g. the time of the first spike can be used, but not the times of all + spikes). -To define the :py:class:`~brian2modelfitting.modelfitting.metric.FeatureMetric`, user has to define following input parameters: +To use the :py:class:`~brian2modelfitting.modelfitting.metric.FeatureMetric`, +you have to provide the following input parameters: -- ``traces_times`` - list of times indicating start and end of input current, has to be specified for each of input traces, each value has to be a :py:class:`~brian2.units.fundamentalunits.Quantity` +- ``traces_times`` - a list of times indicating start and end of the stimulus + for each of input traces. This information is used by several features, e.g. + the ``voltage_base`` feature will consider the average membrane potential + during the last 10% of time before the stimulus (see the + `eFel documentation `_ + for details). - ``feat_list`` - list of strings with names of features to be used -- ``combine`` - function to be used to compare features between output and simulated traces, (for `combine=None`, subtracts the features) +- ``combine`` - function to be used to compare features between output and + simulated traces, (for ``combine=None``, subtracts the feature values) Example code usage: .. code:: python - traces_times = [[50*ms, 100*ms], [50*ms, 100*ms], [50*ms, 100*ms], [50, 100*ms]] + traces_times = [(50*ms, 100*ms), (50*ms, 100*ms), (50*ms, 100*ms), (50, 100*ms)] feat_list = ['voltage_base', 'time_to_first_spike', 'Spikecount'] metric = FeatureMetric(traces_times, feat_list, combine=None) .. note:: - If times of stimulation are same for all of the traces, user can specify a single list that will be replicated for - ``eFEL`` library: ``traces_times = [[50*ms, 100*ms]]``. - - - + If times of stimulation are the same for all of the traces, then you can + specify a single interval instead: ``traces_times = [(50*ms, 100*ms)]``. Custom Metric ------------- -User is not limited to the provided in the module metrics. Modularity applies -here as well, with one of the two provided abstract classes :py:class:`~brian2modelfitting.modelfitting.metric.TraceMetric` -and :py:class:`~brian2modelfitting.modelfitting.metric.SpikeMetric` prepared for different custom made metrics. +Users are not limited to the metrics provided in the toolbox. If needed, they +can provide their own metric based on one of the abstract classes +:py:class:`~brian2modelfitting.modelfitting.metric.TraceMetric` +and :py:class:`~brian2modelfitting.modelfitting.metric.SpikeMetric`. -New metric will need to have specify following functions: +A new metric will need to specify the following functions: - :py:func:`~brian2modelfitting.modelfitting.metric.Metric.get_features()` - calculates features / errors for each of the traces and stores it in a :py:attr:`~brian2modelfitting.modelfitting.metric.Metric.metric.features` attribute - The output of the function has to take shape of (n_samples, n_traces) or (n_traces, n_samples). + calculates features / errors for each of the simulations. The representation + of the model results and the target data depend on whether traces or spikes + are fitted, see below. - :py:func:`~brian2modelfitting.modelfitting.metric.Metric.get_errors()` - weights features/multiple errors into one final error per each set of parameters and inputs stored in :py:attr:`~brian2modelfitting.modelfitting.metric.Metric.metric.errors` - The output of the function has to take shape of (n_samples,). + weights features/multiple errors into one final error per each set of + parameters and inputs. The features are received as a 2-dimensional + :py:class:`~numpy.ndarray` of shape ``(n_samples, n_traces)`` The output has + to be an array of length ``n_samples``, i.e. one value for each parameter + set. - :py:func:`~brian2modelfitting.modelfitting.metric.Metric.calc()` - performs the error calculation across simulation for all parameters of each round. Specified in the abstract class, can be reused. - + performs the error calculation across simulation for all parameters of each + round. Already implemented in the abstract class and therefore does not + need to be reimplemented necessarily. TraceMetric ~~~~~~~~~~~ -To create a new metric for :py:class:`~brian2modelfitting.modelfitting.modelfitting.TraceFitter`, you have to inherit from :py:class:`~brian2modelfitting.modelfitting.metric.TraceMetric`. -Input and output traces have to be shaped into 2D array. +To create a new metric for +:py:class:`~brian2modelfitting.modelfitting.modelfitting.TraceFitter`, you have +to inherit from :py:class:`~brian2modelfitting.modelfitting.metric.TraceMetric` +and overwrite the :py:meth:`~.TraceMetric.get_features` and/or +:py:meth:`~.TraceMetric.get_errors` method. The model traces for the +:py:meth:`~.TraceMetric.get_features` function are provided as a 3-dimensional +:py:class:`~numpy.ndarray` of shape ``(n_samples, n_traces, time steps)``, +where ``n_samples`` are the number of different parameter sets that have been +evaluated, and ``n_traces`` the number of different stimuli that have been +evaluated for each parameter set. The output of the function has to take the +shape of ``(n_samples, n_traces)``. This array is the input to the +:py:meth:`~.TraceMetric.get_errors` method (see above). .. code:: python class NewTraceMetric(TraceMetric): - def get_features(): + def get_features(self, model_traces, data_traces, dt): ... - def get_errors(): + def get_errors(self, features): ... SpikeMetric ~~~~~~~~~~~ -To create a new metric for :py:class:`~brian2modelfitting.modelfitting.modelfitting.SpikeFitter`., you have to inherit from :py:class:`~brian2modelfitting.modelfitting.metric.SpikeMetric`. -Inputs of the metric have to be 2D array. -Output spikes contain a list of arrays (possibly of different lengths) in order -to allow different lengths of spike trains. - -.. code:: python - - [array([1, 2, 3]), array([1, 2])] +To create a new metric for +:py:class:`~brian2modelfitting.modelfitting.modelfitting.SpikeFitter`, you have +to inherit from :py:class:`~brian2modelfitting.modelfitting.metric.SpikeMetric`. +Inputs of the metric in :py:meth:`~.SpikeMetric.get_features` are a nested list +structure for the spikes generated by the model: a list where each element +contains the results for a single parameter set. Each of these results is a list +for each of the input traces, where the elements of this list are numpy arrays +of spike times (without units, i.e. in seconds). For example, if two parameters +sets and 3 different input stimuli were tested, this structure could look like +this:: + + [ + [array([0.01, 0.5]), array([]), array([])], + [array([0.02]), array([]), array([])] + ] + +This means that the both parameter sets only generate spikes for the first input +stimulus, but the first parameter sets generates two while the second generates +only a single one. + +The target spikes are represented in the same way as a list of spike times for +each input stimulus. The results of the function have to be returned as in +:py:class:`~.TraceMetric`, i.e. as a 2-d array of shape +``(n_samples, n_traces)``. \ No newline at end of file diff --git a/examples/IF_spikefitter.py b/examples/IF_spikefitter.py index 8e21e4c1..0f1b716f 100644 --- a/examples/IF_spikefitter.py +++ b/examples/IF_spikefitter.py @@ -51,7 +51,7 @@ # pass parameters to the NeuronGroup fitter = SpikeFitter(model=eqs_fit, input_var='I', dt=dt, - input=inp_trace * amp, output=out_spikes, + input=inp_trace * amp, output=[out_spikes], n_samples=30, threshold='v > -50*mV', param_init={'v': -70*mV},