Skip to content

Commit

Permalink
clean up code in unit test (#1840)
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed May 3, 2021
1 parent 305e82c commit ff08892
Showing 1 changed file with 32 additions and 40 deletions.
72 changes: 32 additions & 40 deletions pynest/nest/tests/test_stdp_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,21 @@
from math import exp
import numpy as np

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
DEBUG_PLOTS = True
except Exception:
DEBUG_PLOTS = False


@nest.ll_api.check_stack
class STDPSynapseTest(unittest.TestCase):
"""
XXX TODO
Compare the STDP synaptic plasticity model against a self-contained Python reference.
Random pre and post spike times are generated according to a Poisson distribution; some hard-coded spike times are
added to make sure to test for edge cases such as simultaneous pre- and post spike.
"""

def init_params(self):
Expand Down Expand Up @@ -61,37 +71,29 @@ def init_params(self):
self.hardcoded_post_times = np.array([2, 3, 4, 8, 9, 10, 12.])

def do_nest_simulation_and_compare_to_reproduced_weight(self, fname_snip):
if True:
pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest = self.do_the_nest_simulation()
print("For model to be tested:")
print("\tpre_spikes = " + str(pre_spikes))
print("\tpost_spikes = " + str(post_spikes))

pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest = self.do_the_nest_simulation()
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes,
t_weight_by_nest,
weight_by_nest,
fname_snip=fname_snip,
title_snip=self.nest_neuron_model + " (NEST)")

t_weight_reproduced_independently, weight_reproduced_independently = self.reproduce_weight_drift(
pre_spikes, post_spikes,
self.init_weight,
fname_snip=fname_snip)

# ``weight_by_nest`` containts only weight values at pre spike times, ``weight_reproduced_independently``
# contains the weight at pre *and* post times: check that weights are equal only for pre spike times
assert len(weight_by_nest) > 0
for idx_pre_spike_nest, t_pre_spike_nest in enumerate(t_weight_by_nest):
idx_pre_spike_reproduced_independently = \
np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2)
np.testing.assert_allclose(t_pre_spike_nest,
t_weight_reproduced_independently[idx_pre_spike_reproduced_independently])
print("testing equal t = " + str(t_pre_spike_nest))
print("\tweight_by_nest = " + str(weight_by_nest[idx_pre_spike_nest]))
print("\tweight_reproduced_independently = " + str(
weight_reproduced_independently[idx_pre_spike_reproduced_independently]))
np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest],
weight_reproduced_independently[idx_pre_spike_reproduced_independently])
t_weight_reproduced_independently, weight_reproduced_independently = self.reproduce_weight_drift(
pre_spikes, post_spikes,
self.init_weight,
fname_snip=fname_snip)

# ``weight_by_nest`` containts only weight values at pre spike times, ``weight_reproduced_independently``
# contains the weight at pre *and* post times: check that weights are equal only for pre spike times
assert len(weight_by_nest) > 0
for idx_pre_spike_nest, t_pre_spike_nest in enumerate(t_weight_by_nest):
idx_pre_spike_reproduced_independently = \
np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2)
np.testing.assert_allclose(t_pre_spike_nest,
t_weight_reproduced_independently[idx_pre_spike_reproduced_independently])
np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest],
weight_reproduced_independently[idx_pre_spike_reproduced_independently])

def do_the_nest_simulation(self):
"""
Expand Down Expand Up @@ -165,7 +167,6 @@ def do_the_nest_simulation(self):
self.synapse_parameters["synapse_model"] += "_rec"
nest.Connect(presynaptic_neuron, postsynaptic_neuron, syn_spec=self.synapse_parameters)
self.synapse_parameters["synapse_model"] = self.synapse_model
plastic_synapse_of_interest = nest.GetConnections(synapse_model=self.synapse_model + "_rec")

nest.Simulate(self.simulation_duration)

Expand All @@ -179,7 +180,7 @@ def do_the_nest_simulation(self):
return pre_spikes, post_spikes, t_hist, weight

def reproduce_weight_drift(self, pre_spikes, post_spikes, initial_weight, fname_snip=""):

"""Independent, self-contained model of STDP"""
def facilitate(w, Kpre, Wmax_=1.):
norm_w = (w / self.synapse_parameters["Wmax"]) + (
self.synapse_parameters["lambda"] * pow(
Expand Down Expand Up @@ -236,8 +237,6 @@ def Kpost_at_time(t, spikes, init=1., inclusive=True):
post_spikes_delayed = post_spikes + self.dendritic_delay

while t < self.simulation_duration:
print("t = " + str(t))

idx_next_pre_spike = -1
if np.where((pre_spikes - t) > 0)[0].size > 0:
idx_next_pre_spike = np.where((pre_spikes - t) > 0)[0][0]
Expand Down Expand Up @@ -281,35 +280,28 @@ def Kpost_at_time(t, spikes, init=1., inclusive=True):
t = t_next

if handle_post_spike:
print("Handling post spike at t = " + str(t))
# Kpost += 1. <-- not necessary, will call Kpost_at_time(t) later to compute Kpost for any value t
print("\tFacilitating from " + str(weight), end="")
weight = facilitate(weight, Kpre)
print(" to " + str(weight) + " using Kpre = " + str(Kpre))

if handle_pre_spike:
print("Handling pre spike at t = " + str(t))
Kpre += 1.
_Kpost = Kpost_at_time(t - self.dendritic_delay, post_spikes, init=self.init_weight, inclusive=False)
print("\tDepressing from " + str(weight), end="")
weight = depress(weight, _Kpost)
print(" to " + str(weight) + " using Kpost = " + str(_Kpost))

# logging
t_log.append(t)
w_log.append(weight)
Kpre_log.append(Kpre)

Kpost_log = [Kpost_at_time(t - self.dendritic_delay, post_spikes, init=self.init_weight) for t in t_log]
self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log,
fname_snip=fname_snip + "_ref", title_snip="Reference")
if DEBUG_PLOTS:
self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log,
fname_snip=fname_snip + "_ref", title_snip="Reference")

return t_log, w_log

def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None,
fname_snip="", title_snip=""):
import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=3)

n_spikes = len(pre_spikes)
Expand Down

0 comments on commit ff08892

Please sign in to comment.