Skip to content

Commit

Permalink
Make Gamma factor calculation more consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 23, 2020
1 parent 8ccb68e commit 31ea0be
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 30 deletions.
54 changes: 33 additions & 21 deletions brian2modelfitting/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from itertools import repeat
from brian2 import Hz, second, Quantity, ms, us, get_dimensions
from brian2.units.fundamentalunits import check_units, in_unit, DIMENSIONLESS
from numpy import (array, sum, square, reshape, abs, amin, digitize,
rint, arange, atleast_2d, NaN, float64, split, shape,)
from numpy import (array, sum, abs, amin, digitize, rint, arange, inf, NaN,
clip)


def firing_rate(spikes):
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_gamma_factor(model, data, delta, time, dt, rate_correction=True):
-------
float
An error based on the Gamma factor. If ``rate_correction`` is used,
then the returned error is :math:`2\frac{\lvert r_\mathrm{data} - r_\mathrm{model}\rvert}{r_\mathrm{data}} - \Gamma`
then the returned error is :math:`1 + 2\frac{\lvert r_\mathrm{data} - r_\mathrm{model}\rvert}{r_\mathrm{data}} - \Gamma`
(with :math:`r_\mathrm{data}` and :math:`r_\mathrm{model}` being the
firing rates in the data/model, and :math:`\Gamma` the coincidence
factor). Without ``rate_correction``, the error is
Expand Down Expand Up @@ -88,11 +88,12 @@ def get_gamma_factor(model, data, delta, time, dt, rate_correction=True):
gamma = (coincidences - NCoincAvg)/(norm*(model_length + data_length))

if rate_correction:
rate_term = 2*abs((data_rate - model_rate)/data_rate)
rate_term = 1 + 2*abs((data_rate - model_rate)/data_rate)
else:
rate_term = 1

return rate_term - gamma
return clip(rate_term - gamma, 0, inf)


def calc_eFEL(traces, inp_times, feat_list, dt):
out_traces = []
Expand Down Expand Up @@ -401,6 +402,7 @@ def get_dimensions(self, output_dim):
def get_normalized_dimensions(self, output_dim):
return output_dim**2 * get_dimensions(self.normalization)**2


class FeatureMetric(TraceMetric):
def __init__(self, stim_times, feat_list, weights=None, combine=None,
t_start=0*second, normalization=1.):
Expand Down Expand Up @@ -510,34 +512,44 @@ class GammaFactor(SpikeMetric):
Calculate gamma factors between goal and calculated spike trains, with
precision delta.
References:
Parameters
----------
delta: `~brian2.units.fundamentalunits.Quantity`
time window
time: `~brian2.units.fundamentalunits.Quantity`
total length of experiment
rate_correction: bool
Whether to include an error term that penalizes differences in firing
rate, following `Clopath et al., Neurocomputing (2007)
<https://doi.org/10.1016/j.neucom.2006.10.047>`_. Defaults to
``True``.
Notes
-----
The gamma factor is commonly defined as 1 for a perfect match and 0 for
a match not better than random (negative values are possible if the match
is *worse* than expected by chance). Since we use the gamma factor as an
error to be minimized, the calculated term is actually r - gamma_factor,
where r is 1 if ``rate_correction`` is ``False``, or a rate-difference
dependent term if ``rate_correction` is ``True``. In both cases, the best
possible error value (i.e. for a perfect match between spike trains) is 0.
References
----------
* `R. Jolivet et al. “A Benchmark Test for a Quantitative Assessment of
Simple Neuron Models.” Journal of Neuroscience Methods, 169, no. 2 (2008):
417–24. <https://doi.org/10.1016/j.jneumeth.2007.11.006>`_
* `C. Clopath et al. “Predicting Neuronal Activity with Simple Models of the
Threshold Type: Adaptive Exponential Integrate-and-Fire Model with
Two Compartments.” Neurocomputing, 70, no. 10 (2007): 1668–73.
<https://doi.org/10.1016/j.neucom.2006.10.047>`_
"""

@check_units(delta=second, time=second, t_start=0*second)
def __init__(self, delta, time, t_start=0*second, normalization=1.,
rate_correction=True):
"""
Initialize the metric with time window delta and time step dt output
Parameters
----------
delta: `~brian2.units.fundamentalunits.Quantity`
time window
time: `~brian2.units.fundamentalunits.Quantity`
total length of experiment
rate_correciton: bool
Whether to include an error term that penalizes differences in firing
rate, following `Clopath et al., Neurocomputing (2007)
<https://doi.org/10.1016/j.neucom.2006.10.047>`_.
"""
super(GammaFactor, self).__init__(t_start=t_start,
normalization=normalization)
self.delta = delta
Expand Down
18 changes: 9 additions & 9 deletions brian2modelfitting/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_get_gamma_factor():
gf1 = get_gamma_factor(src2, trg, delta=0.5*ms, time=12*ms, dt=0.1*ms)
gf2 = get_gamma_factor(src, src2, delta=0.5*ms, time=5*ms, dt=0.1*ms)

assert_equal(gf0, -1)
assert gf1 > 0 # Since data rate = 2 * model rate
assert gf2 > -1
assert_equal(gf0, 0)
assert gf1 > 1 # Since data rate = 2 * model rate
assert gf2 > 0

gf0 = get_gamma_factor(trg, trg, delta=0.5*ms, time=12*ms, dt=0.1*ms,
rate_correction=False)
Expand Down Expand Up @@ -88,15 +88,15 @@ def test_calc_gf():
assert_raises(DimensionMismatchError, GammaFactor, delta=10*mV)
assert_raises(DimensionMismatchError, GammaFactor, time=10)

model_spikes = [[np.array([1, 5, 8]), np.array([2, 3, 8, 9])], # Correct rate
[np.array([1, 5]), np.array([0, 2, 3, 8, 9])]] # Wrong rate
data_spikes = [np.array([0, 5, 9]), np.array([1, 3, 5, 6])]
model_spikes = [[np.array([1, 5, 8])*1e-3, np.array([2, 3, 8, 9])*1e-3], # Correct rate
[np.array([1, 5])*1e-3, np.array([0, 2, 3, 8, 9])*1e-3]] # Wrong rate
data_spikes = [np.array([0, 5, 9])*1e-3, np.array([1, 3, 5, 6])*1e-3]

gf = GammaFactor(delta=0.5*ms, time=10*ms)
errors = gf.calc([data_spikes]*5, data_spikes, 0.1*ms)
assert_almost_equal(errors, np.ones(5)*-1)
assert_almost_equal(errors, np.zeros(5))
errors = gf.calc(model_spikes, data_spikes, 0.1*ms)
assert errors[0] > -1 # correct rate
assert errors[0] > 0 # correct rate
assert errors[1] > errors[0]

gf = GammaFactor(delta=0.5*ms, time=10*ms, rate_correction=False)
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_get_features_gamma():

features = gf.get_features([data_spikes]*3, data_spikes, 0.1*ms)
assert_equal(np.shape(features), (3, 2))
assert_almost_equal(features, np.ones((3, 2))*-1)
assert_almost_equal(features, np.zeros((3, 2)))


def test_get_errors_gamma():
Expand Down

0 comments on commit 31ea0be

Please sign in to comment.