Skip to content

Commit

Permalink
Merge pull request #432 from brian-team/spikegenerator_fixes
Browse files Browse the repository at this point in the history
Small change to the algorithm for `SpikeGeneratorGroup`
  • Loading branch information
thesamovar committed Mar 23, 2015
2 parents bac0ab0 + 294d42f commit f8bcf88
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 46 deletions.
16 changes: 8 additions & 8 deletions brian2/codegen/runtime/cython_rt/templates/spikegenerator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
cdef double _spike_time

# We need some precomputed values that will be used during looping
not_first_spike = {{_lastindex}}[0] > 0
not_end_period = abs(padding_after) > (dt - epsilon)

# If there is a periodicity in the SpikeGenerator, we need to reset the lastindex
# when all spikes have been played and at the end of the period
if not_first_spike and ({{spike_time}}[{{_lastindex}}[0] - 1] > padding_before):
{{_lastindex}}[0] = 0
not_end_period = abs(padding_after) > (dt - epsilon) and abs(padding_after) < (period - epsilon)

for _idx in range({{_lastindex}}[0], _num{{spike_time}}):
_spike_time = {{spike_time}}[_idx]
Expand All @@ -34,6 +28,12 @@
_cpp_numspikes += 1

{{_spikespace}}[N] = _cpp_numspikes
{{_lastindex}}[0] += _cpp_numspikes

# If there is a periodicity in the SpikeGenerator, we need to reset the lastindex
# when all spikes have been played and at the end of the period
if not_end_period:
{{_lastindex}}[0] += _cpp_numspikes
else:
{{_lastindex}}[0] = 0

{% endblock %}
20 changes: 9 additions & 11 deletions brian2/codegen/runtime/numpy_rt/templates/spikegenerator.py_
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@ _n_spikes = 0
epsilon = 1e-3*dt

# We need some precomputed values that will be used during looping
not_first_spike = {{_lastindex}}[0] > 0
not_end_period = abs(padding_after) > (dt - epsilon)

# If there is a periodicity in the SpikeGenerator, we need to reset the lastindex
# when all spikes have been played and at the end of the period
if not_first_spike and ({{spike_time}}[{{_lastindex}}[0] - 1] > padding_before):
{{_lastindex}}[0] = 0
not_end_period = abs(padding_after) > (dt - epsilon) and abs(padding_after) < (period - epsilon)
_lastindex_before = {{_lastindex}}[0]

if not_end_period:
_n_spikes = np.searchsorted({{spike_time}}[{{_lastindex}}[0]:], padding_after - epsilon, side='right')
_n_spikes = np.searchsorted({{spike_time}}[_lastindex_before:], padding_after - epsilon, side='right')
{{_lastindex}}[0] += _n_spikes
else:
_n_spikes = np.searchsorted({{spike_time}}[{{_lastindex}}[0]:], period, side='right')
_n_spikes = np.searchsorted({{spike_time}}[_lastindex_before:], period, side='right')
# If there is a periodicity in the SpikeGenerator, we need to reset the lastindex
# when all spikes have been played and at the end of the period
{{_lastindex}}[0] = 0

_indices = {{neuron_index}}[{{_lastindex}}[0]:{{_lastindex}}[0]+_n_spikes]
_indices = {{neuron_index}}[_lastindex_before:_lastindex_before+_n_spikes]

{{_spikespace}}[:_n_spikes] = _indices
{{_spikespace}}[-1] = _n_spikes
{{_lastindex}}[0] += _n_spikes
27 changes: 12 additions & 15 deletions brian2/codegen/runtime/weave_rt/templates/spikegenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,17 @@
{% block maincode %}
{# USES_VARIABLES {_spikespace, N, t, dt, neuron_index, spike_time, period, _lastindex} #}

double padding_before = fmod(t, period);
double padding_after = fmod(t+dt, period);
double epsilon = 1e-3*dt;
const double padding_before = fmod(t, period);
const double padding_after = fmod(t+dt, period);
const double epsilon = 1e-3*dt;

// We need some precomputed values that will be used during looping
bool not_first_spike = ({{_lastindex}}[0] > 0);
bool not_end_period = (fabs(padding_after) > epsilon);
const bool not_end_period = (fabs(padding_after) > epsilon) && (fabs(padding_after) < (period - epsilon));
bool test;

// TODO: We don't deal with more than one spike per neuron yet
long _cpp_numspikes = 0;

// If there is a periodicity in the SpikeGenerator, we need to reset the lastindex
// when all spikes have been played and at the end of the period
if (not_first_spike && ({{spike_time}}[{{_lastindex}}[0] - 1] > padding_before))
{
{{_lastindex}}[0] = 0;
}

for(int _idx={{_lastindex}}[0]; _idx < _numspike_time; _idx++)
{
if (not_end_period)
Expand All @@ -32,11 +24,16 @@
if (test)
break;
{{_spikespace}}[_cpp_numspikes++] = {{neuron_index}}[_idx];
}
}

{{_spikespace}}[N] = _cpp_numspikes;
{{_lastindex}}[0] += _cpp_numspikes;


// If there is a periodicity in the SpikeGenerator, we need to reset the lastindex
// when all spikes have been played and at the end of the period
if (! not_end_period)
{{_lastindex}}[0] = 0;
else
{{_lastindex}}[0] += _cpp_numspikes;

{% endblock %}

Expand Down
23 changes: 11 additions & 12 deletions brian2/devices/cpp_standalone/templates/spikegenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,19 @@
{% block maincode %}
{# USES_VARIABLES {_spikespace, N, t, dt, neuron_index, spike_time, period, _lastindex } #}

double padding_before = fmod(t, period);
double padding_after = fmod(t+dt, period);
double epsilon = 1e-3*dt;
const double padding_before = fmod(t, period);
const double padding_after = fmod(t+dt, period);
const double epsilon = 1e-3*dt;

// We need some precomputed values that will be used during looping
bool not_first_spike = ({{_lastindex}}[0] > 0);
bool not_end_period = (fabs(padding_after) > epsilon);
const bool not_end_period = (fabs(padding_after) > epsilon) && (fabs(padding_after) < (period - epsilon));
bool test;

// TODO: We don't deal with more than one spike per neuron yet
long _cpp_numspikes = 0;

{{ openmp_pragma('single') }}
{

if (not_first_spike && ({{spike_time}}[{{_lastindex}}[0] - 1] > padding_before))
{
{{_lastindex}}[0] = 0;
}

for(int _idx={{_lastindex}}[0]; _idx < _numspike_time; _idx++)
{
if (not_end_period)
Expand All @@ -37,7 +30,13 @@
}

{{_spikespace}}[N] = _cpp_numspikes;
{{_lastindex}}[0] += _cpp_numspikes;

// If there is a periodicity in the SpikeGenerator, we need to reset the lastindex
// when all spikes have been played and at the end of the period
if (! not_end_period)
{{_lastindex}}[0] = 0;
else
{{_lastindex}}[0] += _cpp_numspikes;
}

{% endblock %}
39 changes: 39 additions & 0 deletions brian2/tests/test_spikegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,43 @@ def test_spikegenerator_rounding():
net.run(10000*dt)
assert_equal(mon[0].count, np.ones(10000))

@attr('standalone-compatible', 'long')
@with_setup(teardown=restore_device)
def test_spikegenerator_rounding_long():
# all spikes should fall in separate bins
dt = 0.1*ms
N = 1000000
indices = np.zeros(N)
times = np.arange(N)*dt
SG = SpikeGeneratorGroup(1, indices, times, dt=dt)
target = NeuronGroup(1, 'count : 1')
syn = Synapses(SG, target, pre='count+=1', connect=True)
spikes = SpikeMonitor(SG)
mon = StateMonitor(target, 'count', record=0, when='end')
net = Network(SG, spikes, target, syn, mon)
net.run(N*dt, report='text')
assert spikes.count[0] == N, 'expected %d spikes, got %d' % (N, spikes.count[0])
assert all(np.diff(mon[0].count[:]) == 1)

@attr('standalone-compatible', 'long')
@with_setup(teardown=restore_device)
def test_spikegenerator_rounding_period():
# all spikes should fall in separate bins
dt = 0.1*ms
N = 100
repeats = 10000
indices = np.zeros(N)
times = np.arange(N)*dt
SG = SpikeGeneratorGroup(1, indices, times, dt=dt, period=N*dt)
target = NeuronGroup(1, 'count : 1')
syn = Synapses(SG, target, pre='count+=1', connect=True)
spikes = SpikeMonitor(SG)
mon = StateMonitor(target, 'count', record=0, when='end')
net = Network(SG, spikes, target, syn, mon)
net.run(N*repeats*dt, report='text')
#print np.int_(np.round(spikes.t/dt))
assert_equal(spikes.count[0], N*repeats)
assert all(np.diff(mon[0].count[:]) == 1)

@attr('codegen-independent')
@with_setup(teardown=restore_initial_state)
Expand Down Expand Up @@ -215,6 +252,8 @@ def test_spikegenerator_standalone():
test_spikegenerator_period_repeat()
test_spikegenerator_incorrect_period()
test_spikegenerator_rounding()
test_spikegenerator_rounding_long()
test_spikegenerator_rounding_period()
test_spikegenerator_multiple_spikes_per_bin()
test_spikegenerator_standalone()

0 comments on commit f8bcf88

Please sign in to comment.