Skip to content

Commit

Permalink
fix Python code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed May 3, 2021
1 parent a27b55c commit 305e82c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
10 changes: 5 additions & 5 deletions pynest/nest/tests/test_stdp_multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def run_protocol(self, pre_post_shift):
# create spike recorder --- debugging only
spikes = nest.Create("spike_recorder")
nest.Connect(
pre_parrot + post_parrot +
pre_parrot_ps + post_parrot_ps,
pre_parrot + post_parrot + pre_parrot_ps + post_parrot_ps,
spikes
)

Expand Down Expand Up @@ -209,7 +208,8 @@ def run_protocol(self, pre_post_shift):
ax[0].plot(post_weights["parrot_ps"], marker="o", label="parrot_ps")
ax[0].set_ylabel("final weight")
ax[0].set_xticklabels([])
ax[1].semilogy(np.abs(np.array(post_weights["parrot"]) - np.array(post_weights["parrot_ps"])), marker="o", label="error")
ax[1].semilogy(np.abs(np.array(post_weights["parrot"]) - np.array(post_weights["parrot_ps"])),
marker="o", label="error")
ax[1].set_xticks([i for i in range(len(deltas))])
ax[1].set_xticklabels(["{0:.1E}".format(d) for d in deltas])
ax[1].set_xlabel("timestep [ms]")
Expand All @@ -224,7 +224,8 @@ def run_protocol(self, pre_post_shift):
def _test_stdp_multiplicity(self, pre_post_shift, max_abs_err=1E-6):
"""Check that for smaller and smaller timestep, weights obtained from parrot and precise parrot converge.
Enforce a maximum allowed absolute error ``max_abs_err`` between the final weights for the smallest timestep tested.
Enforce a maximum allowed absolute error ``max_abs_err`` between the final weights for the smallest timestep
tested.
Enforce that the error should strictly decrease with smaller timestep."""

Expand All @@ -245,7 +246,6 @@ def test_stdp_multiplicity(self):
self._test_stdp_multiplicity(pre_post_shift=-10.) # test depression



def suite():

# makeSuite is sort of obsolete http://bugs.python.org/issue2721
Expand Down
32 changes: 22 additions & 10 deletions pynest/nest/tests/test_stdp_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,25 @@ def do_nest_simulation_and_compare_to_reproduced_weight(self, fname_snip):
self.init_weight,
fname_snip=fname_snip)

# ``weight_by_nest`` containts only weight values at pre spike times, ``weight_reproduced_independently`` contains the weight at pre *and* post times: check that weights are equal only for pre spike times
# ``weight_by_nest`` containts only weight values at pre spike times, ``weight_reproduced_independently``
# contains the weight at pre *and* post times: check that weights are equal only for pre spike times
assert len(weight_by_nest) > 0
for idx_pre_spike_nest, t_pre_spike_nest in enumerate(t_weight_by_nest):
idx_pre_spike_reproduced_independently = np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2)
np.testing.assert_allclose(t_pre_spike_nest, t_weight_reproduced_independently[idx_pre_spike_reproduced_independently])
idx_pre_spike_reproduced_independently = \
np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2)
np.testing.assert_allclose(t_pre_spike_nest,
t_weight_reproduced_independently[idx_pre_spike_reproduced_independently])
print("testing equal t = " + str(t_pre_spike_nest))
print("\tweight_by_nest = " + str(weight_by_nest[idx_pre_spike_nest]))
print("\tweight_reproduced_independently = " + str(weight_reproduced_independently[idx_pre_spike_reproduced_independently]))
np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest], weight_reproduced_independently[idx_pre_spike_reproduced_independently])
print("\tweight_reproduced_independently = " + str(
weight_reproduced_independently[idx_pre_spike_reproduced_independently]))
np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest],
weight_reproduced_independently[idx_pre_spike_reproduced_independently])

def do_the_nest_simulation(self):
"""
This function is where calls to NEST reside. Returns the generated pre- and post spike sequences and the resulting weight established by STDP.
This function is where calls to NEST reside. Returns the generated pre- and post spike sequences and the
resulting weight established by STDP.
"""
nest.set_verbosity('M_WARNING')
nest.ResetKernel()
Expand Down Expand Up @@ -175,14 +181,18 @@ def do_the_nest_simulation(self):
def reproduce_weight_drift(self, pre_spikes, post_spikes, initial_weight, fname_snip=""):

def facilitate(w, Kpre, Wmax_=1.):
norm_w = (w / self.synapse_parameters["Wmax"]) + (self.synapse_parameters["lambda"] * pow(1 - (w / self.synapse_parameters["Wmax"]), self.synapse_parameters["mu_plus"]) * Kpre)
norm_w = (w / self.synapse_parameters["Wmax"]) + (
self.synapse_parameters["lambda"] * pow(
1 - (w / self.synapse_parameters["Wmax"]), self.synapse_parameters["mu_plus"]) * Kpre)
if norm_w < 1.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
return self.synapse_parameters["Wmax"]

def depress(w, Kpost):
norm_w = (w / self.synapse_parameters["Wmax"]) - (self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] * pow(w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost)
norm_w = (w / self.synapse_parameters["Wmax"]) - (
self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] * pow(
w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost)
if norm_w > 0.0:
return norm_w * self.synapse_parameters["Wmax"]
else:
Expand Down Expand Up @@ -291,11 +301,13 @@ def Kpost_at_time(t, spikes, init=1., inclusive=True):
Kpre_log.append(Kpre)

Kpost_log = [Kpost_at_time(t - self.dendritic_delay, post_spikes, init=self.init_weight) for t in t_log]
self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log, fname_snip=fname_snip + "_ref", title_snip="Reference")
self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log,
fname_snip=fname_snip + "_ref", title_snip="Reference")

return t_log, w_log

def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None, fname_snip="", title_snip=""):
def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None,
fname_snip="", title_snip=""):
import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=3)
Expand Down

0 comments on commit 305e82c

Please sign in to comment.