Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add STDP synapse unit testing #1840

Merged
merged 24 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
96f881a
add STDP synapse unit testing
Nov 11, 2020
e3ac12a
Merge remote-tracking branch 'upstream/master' into test-stdp-synapse
Apr 19, 2021
75f3a4b
allow weight_recorder to record precise spike times
Apr 20, 2021
c8547d3
allow STDP synapse to use precise spike times
Apr 20, 2021
af21d05
fix STDP synapse unit test
Apr 20, 2021
a495dfe
pycodestyle formatting changes
Apr 21, 2021
791bff0
Merge remote-tracking branch 'upstream/master' into test-stdp-synapse
May 3, 2021
a27b55c
fix test conditions and add plots to test_stdp_multiplicity
May 3, 2021
305e82c
fix Python code formatting
May 3, 2021
ff08892
clean up code in unit test (#1840)
May 3, 2021
b718ca5
fix intermittent failure in test_visualisation.py
May 3, 2021
b6cfcee
close matplotlib figures after saving them to file
May 4, 2021
4425620
Merge remote-tracking branch 'upstream/master' into test-stdp-synapse
Jun 14, 2021
c878c9c
clean up STDP multiplicity test
Aug 30, 2021
9c1cfa9
Merge remote-tracking branch 'upstream/master' into test-stdp-synapse
Aug 30, 2021
fb9228d
move synaptic plasticity precise spike timing feature to #2035
Aug 30, 2021
b4921bb
refactor/clean up stdp synapse unit test
Aug 30, 2021
ae9f481
refactor/clean up stdp synapse unit testing
Aug 30, 2021
7fb3e2d
fix pycodestyle
Aug 30, 2021
ae60892
Merge remote-tracking branch 'upstream/master' into test-stdp-synapse
Oct 4, 2021
f4182fd
clean up unit tests following unittest/pytest migration
Oct 4, 2021
14ad08c
fix pycodestyle
Oct 4, 2021
c03797e
clean up unit tests following unittest/pytest migration
Oct 4, 2021
7c7d949
clean up based on review comments
Oct 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/stdp_synapse.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ inline void
stdp_synapse< targetidentifierT >::send( Event& e, thread t, const CommonSynapseProperties& )
{
// synapse STDP depressing/facilitation dynamics
const double t_spike = e.get_stamp().get_ms();
const double t_spike = e.get_stamp().get_ms() - e.get_offset();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clinssen This change makes me wonder if we have the same error in other synapse models as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does:

double t_spike = e.get_stamp().get_ms();

const double t_spike = e.get_stamp().get_ms();

const double t_spike = e.get_stamp().get_ms();

const double t_spike = e.get_stamp().get_ms();

double t_spike = e.get_stamp().get_ms();

double t_spike = e.get_stamp().get_ms();

double t_spike = e.get_stamp().get_ms();

double t_spike = e.get_stamp().get_ms();

const double t_spike = e.get_stamp().get_ms();

const double t_spike = e.get_stamp().get_ms();

double t_spike = e.get_stamp().get_ms();

const double t_spike = e.get_stamp().get_ms();

const double t_spike = e.get_stamp().get_ms();

double t_spike = e.get_stamp().get_ms();

double t_spike = e.get_stamp().get_ms();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clinssen So this means that until now, we have not taken spike time offsets into account in plasticity mechanisms, doesn't it? It does not have any effect on the actual timing of the spike transmitted, since the offset information is passed on to the receiving neuron in the Event object. Could you create a follow-up issue for this? We should involve @abigailm in the discussion.


// use accessor functions (inherited from Connection< >) to obtain delay and
// target
Expand Down
1 change: 1 addition & 0 deletions nestkernel/connector_base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Connector< ConnectionT >::send_weight_event( const thread tid,
wr_e.set_port( e.get_port() );
wr_e.set_rport( e.get_rport() );
wr_e.set_stamp( e.get_stamp() );
wr_e.set_offset( e.get_offset() );
wr_e.set_sender( e.get_sender() );
wr_e.set_sender_node_id( kernel().connection_manager.get_source_node_id( tid, syn_id_, lcid ) );
wr_e.set_weight( e.get_weight() );
Expand Down
66 changes: 45 additions & 21 deletions pynest/nest/tests/test_stdp_multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
import math
import numpy as np

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
DEBUG_PLOTS = True
except Exception:
DEBUG_PLOTS = False


@nest.ll_api.check_stack
class StdpSpikeMultiplicity(unittest.TestCase):
Expand Down Expand Up @@ -157,8 +164,7 @@ def run_protocol(self, pre_post_shift):
# create spike recorder --- debugging only
spikes = nest.Create("spike_recorder")
nest.Connect(
pre_parrot + post_parrot +
pre_parrot_ps + post_parrot_ps,
pre_parrot + post_parrot + pre_parrot_ps + post_parrot_ps,
spikes
)

Expand Down Expand Up @@ -194,33 +200,51 @@ def run_protocol(self, pre_post_shift):
post_weights['parrot'].append(w_post)
post_weights['parrot_ps'].append(w_post_ps)

if DEBUG_PLOTS:
import datetime
fig, ax = plt.subplots(nrows=2)
fig.suptitle("Final obtained weights")
ax[0].plot(post_weights["parrot"], marker="o", label="parrot")
ax[0].plot(post_weights["parrot_ps"], marker="o", label="parrot_ps")
ax[0].set_ylabel("final weight")
ax[0].set_xticklabels([])
ax[1].semilogy(np.abs(np.array(post_weights["parrot"]) - np.array(post_weights["parrot_ps"])),
marker="o", label="error")
ax[1].set_xticks([i for i in range(len(deltas))])
ax[1].set_xticklabels(["{0:.1E}".format(d) for d in deltas])
ax[1].set_xlabel("timestep [ms]")
for _ax in ax:
_ax.grid(True)
_ax.legend()
plt.savefig("/tmp/test_stdp_multiplicity" + str(datetime.datetime.utcnow()) + ".png")
plt.close(fig)
print(post_weights)
return post_weights

def test_ParrotNeuronSTDPProtocolPotentiation(self):
"""Check weight convergence on potentiation."""

post_weights = self.run_protocol(pre_post_shift=10.0)
w_plain = np.array(post_weights['parrot'])
w_precise = np.array(post_weights['parrot_ps'])
# XXX: TODO: use ``@pytest.mark.parametrize`` for this
def _test_stdp_multiplicity(self, pre_post_shift, max_abs_err=1E-6):
"""Check that for smaller and smaller timestep, weights obtained from parrot and precise parrot converge.

assert all(w_plain == w_plain[0]), 'Plain weights differ'
dw = w_precise - w_plain
dwrel = dw[1:] / dw[:-1]
assert all(np.round(dwrel, decimals=3) ==
0.5), 'Precise weights do not converge.'
Enforce a maximum allowed absolute error ``max_abs_err`` between the final weights for the smallest timestep
tested.

def test_ParrotNeuronSTDPProtocolDepression(self):
"""Check weight convergence on depression."""
Enforce that the error should strictly decrease with smaller timestep."""

post_weights = self.run_protocol(pre_post_shift=-10.0)
post_weights = self.run_protocol(pre_post_shift=pre_post_shift)
w_plain = np.array(post_weights['parrot'])
w_precise = np.array(post_weights['parrot_ps'])

assert all(w_plain == w_plain[0]), 'Plain weights differ'
dw = w_precise - w_plain
dwrel = dw[1:] / dw[:-1]
assert all(np.round(dwrel, decimals=3) ==
0.5), 'Precise weights do not converge.'
assert all(w_plain == w_plain[0]), 'Plain weights should be independent of timestep, but differ!'
clinssen marked this conversation as resolved.
Show resolved Hide resolved
abs_err = np.abs(w_precise - w_plain)
assert abs_err[-1] < max_abs_err
assert np.all(np.diff(abs_err) < 1), 'Error should decrease with smaller timestep!'
clinssen marked this conversation as resolved.
Show resolved Hide resolved

def test_stdp_multiplicity(self):
"""Check weight convergence on potentiation and depression.

See also: _test_stdp_multiplicity()."""
self._test_stdp_multiplicity(pre_post_shift=10.) # test potentiation
self._test_stdp_multiplicity(pre_post_shift=-10.) # test depression


def suite():
Expand Down
Loading