Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix STDP k-value error for edge case #2443

Merged
merged 16 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions nestkernel/archiving_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,13 @@ nest::ArchivingNode::set_spiketime( Time const& t_sp, double offset )
// - its access counter indicates it has been read out by all connected
// STDP synapses, and
// - there is another, later spike, that is strictly more than
// (max_delay_ + eps) away from the new spike (at t_sp_ms)
// (min_global_delay + max_local_delay + eps) away from the new spike (at t_sp_ms)
while ( history_.size() > 1 )
{
const double next_t_sp = history_[ 1 ].t_;
if ( history_.front().access_counter_ >= n_incoming_
and t_sp_ms - next_t_sp > max_delay_ + kernel().connection_manager.get_stdp_eps() )
and t_sp_ms - next_t_sp
clinssen marked this conversation as resolved.
Show resolved Hide resolved
> max_delay_ + kernel().connection_manager.get_min_delay() + kernel().connection_manager.get_stdp_eps() )
{
history_.pop_front();
}
Expand Down
197 changes: 117 additions & 80 deletions testsuite/pytests/test_stdp_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
try:
import matplotlib as mpl
import matplotlib.pyplot as plt

DEBUG_PLOTS = True
except Exception:
DEBUG_PLOTS = False
Expand All @@ -42,11 +43,11 @@ class TestSTDPSynapse:
"""

def init_params(self):
self.resolution = 0.1 # [ms]
self.simulation_duration = 1E3 # [ms]
self.resolution = 0.1 # [ms]
self.simulation_duration = 1E3 # [ms]
self.synapse_model = "stdp_synapse"
self.presynaptic_firing_rate = 20. # [ms^-1]
self.postsynaptic_firing_rate = 20. # [ms^-1]
self.presynaptic_firing_rate = 20. # [ms^-1]
self.postsynaptic_firing_rate = 20. # [ms^-1]
self.tau_pre = 16.8
self.tau_post = 33.7
self.init_weight = .5
Expand All @@ -65,6 +66,7 @@ def init_params(self):
}
self.neuron_parameters = {
"tau_minus": self.tau_post,
"t_ref": 1.0
}

# While the random sequences, fairly long, would supposedly
Expand All @@ -75,34 +77,32 @@ def init_params(self):
# append some hardcoded spike sequences:
# pre: 1 5 6 7 9 11 12 13
clinssen marked this conversation as resolved.
Show resolved Hide resolved
# post: 2 3 4 8 9 10 12
clinssen marked this conversation as resolved.
Show resolved Hide resolved
self.hardcoded_pre_times = np.array([1, 5, 6, 7, 9, 11, 12, 13], dtype=float)
self.hardcoded_post_times = np.array([2, 3, 4, 8, 9, 10, 12], dtype=float)
self.hardcoded_trains_length = 2. + max(np.amax(self.hardcoded_pre_times), np.amax(self.hardcoded_post_times))
self.hardcoded_pre_times = np.array(
[1, 5, 6, 7, 9, 11, 12, 13, 14.5, 16.1], dtype=float)
self.hardcoded_post_times = np.array(
[2, 3, 4, 8, 9, 10, 12, 13.2, 15.1, 16.4], dtype=float)
self.hardcoded_trains_length = 2. + max(
np.amax(self.hardcoded_pre_times),
np.amax(self.hardcoded_post_times))

def do_nest_simulation_and_compare_to_reproduced_weight(self, fname_snip):
pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest = self.do_the_nest_simulation()
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes,
t_weight_by_nest,
weight_by_nest,
t_weight_by_nest, weight_by_nest,
fname_snip=fname_snip,
title_snip=self.nest_neuron_model + " (NEST)")

t_weight_reproduced_independently, weight_reproduced_independently = self.reproduce_weight_drift(
pre_spikes, post_spikes,
self.init_weight,
fname_snip=fname_snip)
pre_spikes, post_spikes, self.init_weight, fname_snip=fname_snip)

# ``weight_by_nest`` containts only weight values at pre spike times, ``weight_reproduced_independently``
# contains the weight at pre *and* post times: check that weights are equal only for pre spike times
# contains the weight at pre *and* post times: check that weights are equal for pre spike times
assert len(weight_by_nest) > 0
for idx_pre_spike_nest, t_pre_spike_nest in enumerate(t_weight_by_nest):
idx_pre_spike_reproduced_independently = \
np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2)
np.testing.assert_allclose(t_pre_spike_nest,
t_weight_reproduced_independently[idx_pre_spike_reproduced_independently])
np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest],
weight_reproduced_independently[idx_pre_spike_reproduced_independently])
np.testing.assert_allclose(t_weight_by_nest,
clinssen marked this conversation as resolved.
Show resolved Hide resolved
t_weight_reproduced_independently)
np.testing.assert_allclose(weight_by_nest,
weight_reproduced_independently)

def do_the_nest_simulation(self):
"""
Expand All @@ -122,14 +122,17 @@ def do_the_nest_simulation(self):
"poisson_generator",
2,
params=({"rate": self.presynaptic_firing_rate,
"stop": (self.simulation_duration - self.hardcoded_trains_length)},
"stop": (
self.simulation_duration - self.hardcoded_trains_length)},
{"rate": self.postsynaptic_firing_rate,
"stop": (self.simulation_duration - self.hardcoded_trains_length)}))
"stop": (
self.simulation_duration - self.hardcoded_trains_length)}))
presynaptic_generator = generators[0]
postsynaptic_generator = generators[1]

wr = nest.Create('weight_recorder')
nest.CopyModel(self.synapse_model, self.synapse_model + "_rec", {"weight_recorder": wr})
nest.CopyModel(self.synapse_model, self.synapse_model + "_rec",
{"weight_recorder": wr})

spike_senders = nest.Create(
"spike_generator",
Expand All @@ -145,43 +148,55 @@ def do_the_nest_simulation(self):
# The recorder is to save the randomly generated spike trains.
spike_recorder = nest.Create("spike_recorder")

nest.Connect(presynaptic_generator + pre_spike_generator, presynaptic_neuron,
syn_spec={"synapse_model": "static_synapse", "weight": 9999.})
nest.Connect(postsynaptic_generator + post_spike_generator, postsynaptic_neuron,
syn_spec={"synapse_model": "static_synapse", "weight": 9999.})
nest.Connect(presynaptic_generator + pre_spike_generator,
presynaptic_neuron,
syn_spec={"synapse_model": "static_synapse",
"weight": 9999.})
nest.Connect(postsynaptic_generator + post_spike_generator,
postsynaptic_neuron,
syn_spec={"synapse_model": "static_synapse",
"weight": 9999.})
nest.Connect(presynaptic_neuron + postsynaptic_neuron, spike_recorder,
syn_spec={"synapse_model": "static_synapse"})
# The synapse of interest itself
self.synapse_parameters["synapse_model"] += "_rec"
nest.Connect(presynaptic_neuron, postsynaptic_neuron, syn_spec=self.synapse_parameters)
nest.Connect(presynaptic_neuron, postsynaptic_neuron,
syn_spec=self.synapse_parameters)
self.synapse_parameters["synapse_model"] = self.synapse_model

nest.Simulate(self.simulation_duration)

all_spikes = nest.GetStatus(spike_recorder, keys='events')[0]
pre_spikes = all_spikes['times'][all_spikes['senders'] == presynaptic_neuron.tolist()[0]]
post_spikes = all_spikes['times'][all_spikes['senders'] == postsynaptic_neuron.tolist()[0]]
pre_spikes = all_spikes['times'][
all_spikes['senders'] == presynaptic_neuron.tolist()[0]]
post_spikes = all_spikes['times'][
all_spikes['senders'] == postsynaptic_neuron.tolist()[0]]

t_hist = nest.GetStatus(wr, "events")[0]["times"]
weight = nest.GetStatus(wr, "events")[0]["weights"]

return pre_spikes, post_spikes, t_hist, weight

def reproduce_weight_drift(self, pre_spikes, post_spikes, initial_weight, fname_snip=""):
def reproduce_weight_drift(self, pre_spikes, post_spikes, initial_weight,
fname_snip=""):
"""Independent, self-contained model of STDP"""

def facilitate(w, Kpre, Wmax_=1.):
norm_w = (w / self.synapse_parameters["Wmax"]) + (
self.synapse_parameters["lambda"] * pow(
1 - (w / self.synapse_parameters["Wmax"]), self.synapse_parameters["mu_plus"]) * Kpre)
norm_w = (w / self.synapse_parameters["Wmax"]) + \
(self.synapse_parameters["lambda"] *
pow(1 - (w / self.synapse_parameters["Wmax"]),
self.synapse_parameters["mu_plus"]) * Kpre)
if norm_w < 1.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
return self.synapse_parameters["Wmax"]

def depress(w, Kpost):
norm_w = (w / self.synapse_parameters["Wmax"]) - (
self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] * pow(
w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost)
norm_w = (w / self.synapse_parameters["Wmax"]) - \
(self.synapse_parameters["alpha"] *
self.synapse_parameters["lambda"] *
pow(w / self.synapse_parameters["Wmax"],
self.synapse_parameters["mu_minus"]) * Kpost)
if norm_w > 0.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
Expand Down Expand Up @@ -209,42 +224,52 @@ def Kpost_at_time(t, spikes, inclusive=True):
Kpost *= exp(-(t - t_curr) / self.tau_post)
return Kpost

eps = 1e-6
t = 0.
idx_next_pre_spike = 0
idx_next_post_spike = 0
t_last_pre_spike = -1
t_last_post_spike = -1
Kpre = 0.
weight = initial_weight

t_log = []
w_log = []
Kpre_log = []

# logging
t_log.append(t)
w_log.append(weight)
Kpre_log.append(Kpre)
t_log = list()
w_log = dict()
clinssen marked this conversation as resolved.
Show resolved Hide resolved
Kpre_log = list()
pre_spike_times = list()

post_spikes_delayed = post_spikes + self.dendritic_delay

while t < self.simulation_duration:
idx_next_pre_spike = -1
if np.where((pre_spikes - t) > 0)[0].size > 0:
idx_next_pre_spike = np.where((pre_spikes - t) > 0)[0][0]
if idx_next_pre_spike >= pre_spikes.size:
t_next_pre_spike = -1
else:
t_next_pre_spike = pre_spikes[idx_next_pre_spike]

idx_next_post_spike = -1
if np.where((post_spikes_delayed - t) > 0)[0].size > 0:
idx_next_post_spike = np.where((post_spikes_delayed - t) > 0)[0][0]
if idx_next_post_spike >= post_spikes.size:
t_next_post_spike = -1
else:
t_next_post_spike = post_spikes_delayed[idx_next_post_spike]

if idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike < t_next_pre_spike:
handle_post_spike = True
if t_next_post_spike == -1:
a = 1
JanVogelsang marked this conversation as resolved.
Show resolved Hide resolved

if t_next_post_spike >= 0 and (
t_next_post_spike + eps < t_next_pre_spike or t_next_pre_spike < 0):
JanVogelsang marked this conversation as resolved.
Show resolved Hide resolved
handle_pre_spike = False
elif idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike > t_next_pre_spike:
handle_post_spike = False
handle_post_spike = True
idx_next_post_spike += 1
elif t_next_pre_spike >= 0 and (
t_next_post_spike > t_next_pre_spike + eps or t_next_post_spike < 0):
handle_pre_spike = True
handle_post_spike = False
idx_next_pre_spike += 1
else:
# simultaneous spikes (both true) or no more spikes to process (both false)
handle_post_spike = idx_next_post_spike >= 0
handle_pre_spike = idx_next_pre_spike >= 0
handle_pre_spike = t_next_pre_spike >= 0
handle_post_spike = t_next_post_spike >= 0
idx_next_pre_spike += 1
idx_next_post_spike += 1

# integrate to min(t_next_pre_spike, t_next_post_spike)
t_next = t
Expand All @@ -257,53 +282,63 @@ def Kpost_at_time(t, spikes, inclusive=True):
# no more spikes to process
t_next = self.simulation_duration

'''# max timestep
t_next_ = min(t_next, t + 1E-3)
if t_next != t_next_:
t_next = t_next_
handle_pre_spike = False
handle_post_spike = False'''

h = t_next - t
Kpre *= exp(-h / self.tau_pre)
t = t_next

if handle_post_spike:
# Kpost += 1. <-- not necessary, will call Kpost_at_time(t) later to compute Kpost for any value t
weight = facilitate(weight, Kpre)
if not handle_pre_spike or abs(
t_next_post_spike - t_last_post_spike) > eps:
if abs(t_next_post_spike - t_last_pre_spike) > eps:
weight = facilitate(weight, Kpre)

if handle_pre_spike:
Kpre += 1.
_Kpost = Kpost_at_time(t - self.dendritic_delay, post_spikes, inclusive=False)
weight = depress(weight, _Kpost)
if not handle_post_spike or abs(
t_next_pre_spike - t_last_pre_spike) > eps:
if abs(t_next_pre_spike - t_last_post_spike) > eps:
_Kpost = Kpost_at_time(t - self.dendritic_delay,
post_spikes, inclusive=False)
weight = depress(weight, _Kpost)
t_last_pre_spike = t_next_pre_spike
pre_spike_times.append(t)

if handle_post_spike:
t_last_post_spike = t_next_post_spike

# logging
t_log.append(t)
w_log.append(weight)
Kpre_log.append(Kpre)
w_log[t] = weight
t_log.append(t)

Kpost_log = [Kpost_at_time(t - self.dendritic_delay, post_spikes) for t in t_log]
Kpost_log = [Kpost_at_time(t - self.dendritic_delay, post_spikes) for t
in t_log]
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log,
fname_snip=fname_snip + "_ref", title_snip="Reference")
self.plot_weight_evolution(pre_spikes, post_spikes, t_log,
w_log.values(), Kpre_log, Kpost_log,
fname_snip=fname_snip + "_ref",
title_snip="Reference")

return t_log, w_log
return pre_spike_times, [w for t, w in w_log.items() if
t in pre_spike_times]

def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None,
def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log,
Kpre_log=None, Kpost_log=None,
fname_snip="", title_snip=""):
fig, ax = plt.subplots(nrows=3)

n_spikes = len(pre_spikes)
for i in range(n_spikes):
ax[0].plot(2 * [pre_spikes[i]], [0, 1], linewidth=2, color="blue", alpha=.4)
ax[0].plot(2 * [pre_spikes[i]], [0, 1], linewidth=2, color="blue",
alpha=.4)
ax[0].set_ylabel("Pre spikes")
ax0_ = ax[0].twinx()
if Kpre_log:
ax0_.plot(t_log, Kpre_log)

n_spikes = len(post_spikes)
for i in range(n_spikes):
ax[1].plot(2 * [post_spikes[i]], [0, 1], linewidth=2, color="red", alpha=.4)
ax[1].plot(2 * [post_spikes[i]], [0, 1], linewidth=2, color="red",
alpha=.4)
ax1_ = ax[1].twinx()
ax[1].set_ylabel("Post spikes")
if Kpost_log:
Expand All @@ -320,7 +355,8 @@ def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=
_ax.set_xlim(0., self.simulation_duration)

fig.suptitle(title_snip)
fig.savefig("/tmp/nest_stdp_synapse_test" + fname_snip + ".png", dpi=300)
fig.savefig("/tmp/nest_stdp_synapse_test" + fname_snip + ".png",
dpi=300)
plt.close(fig)

def test_stdp_synapse(self):
Expand All @@ -331,4 +367,5 @@ def test_stdp_synapse(self):
for self.nest_neuron_model in ["iaf_psc_exp", "iaf_cond_exp"]:
fname_snip = "_[nest_neuron_mdl=" + self.nest_neuron_model + "]"
fname_snip += "_[dend_delay=" + str(self.dendritic_delay) + "]"
self.do_nest_simulation_and_compare_to_reproduced_weight(fname_snip=fname_snip)
self.do_nest_simulation_and_compare_to_reproduced_weight(
fname_snip=fname_snip)