diff --git a/testsuite/pytests/test_jonke_synapse.py b/testsuite/pytests/test_jonke_synapse.py index 686f53a801..67c6c6eb3d 100644 --- a/testsuite/pytests/test_jonke_synapse.py +++ b/testsuite/pytests/test_jonke_synapse.py @@ -23,13 +23,12 @@ Test functionality of the Tetzlaff stdp synapse """ -import unittest import nest import numpy as np @nest.ll_api.check_stack -class JonkeSynapseTest(unittest.TestCase): +class TestJonkeSynapse: """ Test the weight change by STDP. The test is performed by generating two Poisson spike trains, @@ -73,12 +72,12 @@ def test_weight_drift(self): weight_reproduced_independently = self.reproduce_weight_drift( pre_spikes, post_spikes, self.synapse_parameters["weight"]) - self.assertAlmostEqual( + np.testing.assert_almost_equal( weight_reproduced_independently, weight_by_nest, - msg=f"{self.synapse_parameters['synapse_model']} test:\n" + - f"Resulting synaptic weight {weight_by_nest} " + - f"differs from expected {weight_reproduced_independently}") + err_msg=f"{self.synapse_parameters['synapse_model']} test:\n" + + f"Resulting synaptic weight {weight_by_nest} " + + f"differs from expected {weight_reproduced_independently}") def do_the_nest_simulation(self): """ @@ -111,7 +110,7 @@ def do_the_nest_simulation(self): # reveal small differences in the weight change between NEST # and ours, some low-probability events (say, coinciding # spikes) can well not have occurred. To generate and - # test every possible combination of pre/post precedence, we + # test every possible combination of pre/post order, we # append some hardcoded spike sequences: # pre: 1 5 6 7 9 11 12 13 # post: 2 3 4 8 9 10 12 @@ -244,17 +243,3 @@ def depress(self, _delta_t, weight, Kminus): if weight < 0: weight = 0 return weight - - -def suite(): - suite = unittest.TestLoader().loadTestsFromTestCase(JonkeSynapseTest) - return unittest.TestSuite([suite]) - - -def run(): - runner = unittest.TextTestRunner(verbosity=2) - runner.run(suite()) - - -if __name__ == "__main__": - run() diff --git a/testsuite/pytests/test_stdp_multiplicity.py b/testsuite/pytests/test_stdp_multiplicity.py index 091f64d769..80ade69bef 100644 --- a/testsuite/pytests/test_stdp_multiplicity.py +++ b/testsuite/pytests/test_stdp_multiplicity.py @@ -21,14 +21,21 @@ # This script tests the parrot_neuron in NEST. -import nest -import unittest import math +import nest import numpy as np +import pytest + +try: + import matplotlib as mpl + import matplotlib.pyplot as plt + DEBUG_PLOTS = True +except Exception: + DEBUG_PLOTS = False @nest.ll_api.check_stack -class StdpSpikeMultiplicity(unittest.TestCase): +class TestStdpSpikeMultiplicity: """ Test correct handling of spike multiplicity in STDP. @@ -51,23 +58,29 @@ class StdpSpikeMultiplicity(unittest.TestCase): delta, since in this case all spikes are at the end of the step, i.e., all spikes have identical times independent of delta. - 2. We choose delta values that are decrease by factors of 2. The + 2. We choose delta values that are decreased by factors of 2. The plasticity rules depend on spike-time differences through + :: + exp(dT / tau) where dT is the time between pre- and postsynaptic spikes. We construct pre- and postsynaptic spike times so that - dT = pre_post_shift + m * delta + :: - with m * delta < resolution << pre_post_shift. The time-dependence + dT = pre_post_shift + m * delta + + with ``m * delta < resolution << pre_post_shift``. The time-dependence of the plasticity rule is therefore to good approximation linear in delta. - We can thus test as follows: Let w_pl be the weight obtained with the - plain parrot, and w_ps_j the weight obtained with the precise parrot - for delta_j = delta0 / 2^j. Then, + We can thus test as follows: Let ``w_pl`` be the weight obtained with the + plain parrot, and ``w_ps_j`` the weight obtained with the precise parrot + for ``delta_j = delta0 / 2^j``. Then, + + :: ( w_ps_{j+1} - w_pl ) / ( w_ps_j - w_pl ) ~ 0.5 for all j @@ -157,8 +170,7 @@ def run_protocol(self, pre_post_shift): # create spike recorder --- debugging only spikes = nest.Create("spike_recorder") nest.Connect( - pre_parrot + post_parrot + - pre_parrot_ps + post_parrot_ps, + pre_parrot + post_parrot + pre_parrot_ps + post_parrot_ps, spikes ) @@ -194,47 +206,42 @@ def run_protocol(self, pre_post_shift): post_weights['parrot'].append(w_post) post_weights['parrot_ps'].append(w_post_ps) + if DEBUG_PLOTS: + fig, ax = plt.subplots(nrows=2) + fig.suptitle("Final obtained weights") + ax[0].plot(post_weights["parrot"], marker="o", label="parrot") + ax[0].plot(post_weights["parrot_ps"], marker="o", label="parrot_ps") + ax[0].set_ylabel("final weight") + ax[0].set_xticklabels([]) + ax[1].semilogy(np.abs(np.array(post_weights["parrot"]) - np.array(post_weights["parrot_ps"])), + marker="o", label="error") + ax[1].set_xticks([i for i in range(len(deltas))]) + ax[1].set_xticklabels(["{0:.1E}".format(d) for d in deltas]) + ax[1].set_xlabel("timestep [ms]") + for _ax in ax: + _ax.grid(True) + _ax.legend() + plt.savefig("/tmp/test_stdp_multiplicity.png") + plt.close(fig) + print(post_weights) return post_weights - def test_ParrotNeuronSTDPProtocolPotentiation(self): - """Check weight convergence on potentiation.""" - - post_weights = self.run_protocol(pre_post_shift=10.0) - w_plain = np.array(post_weights['parrot']) - w_precise = np.array(post_weights['parrot_ps']) + @pytest.mark.parametrize("pre_post_shift", [10., # test potentiation + -10.]) # test depression + def test_stdp_multiplicity(self, pre_post_shift, max_abs_err=1E-3): + """Check that for smaller and smaller timestep, weights obtained from parrot and precise parrot converge. - assert all(w_plain == w_plain[0]), 'Plain weights differ' - dw = w_precise - w_plain - dwrel = dw[1:] / dw[:-1] - assert all(np.round(dwrel, decimals=3) == - 0.5), 'Precise weights do not converge.' + Enforce a maximum allowed absolute error ``max_abs_err`` between the final weights for the smallest timestep + tested. - def test_ParrotNeuronSTDPProtocolDepression(self): - """Check weight convergence on depression.""" + Enforce that the error should strictly decrease with smaller timestep.""" - post_weights = self.run_protocol(pre_post_shift=-10.0) + post_weights = self.run_protocol(pre_post_shift=pre_post_shift) w_plain = np.array(post_weights['parrot']) w_precise = np.array(post_weights['parrot_ps']) - assert all(w_plain == w_plain[0]), 'Plain weights differ' - dw = w_precise - w_plain - dwrel = dw[1:] / dw[:-1] - assert all(np.round(dwrel, decimals=3) == - 0.5), 'Precise weights do not converge.' - - -def suite(): - - # makeSuite is sort of obsolete http://bugs.python.org/issue2721 - # using loadTestsFromTestCase instead. - suite = unittest.TestLoader().loadTestsFromTestCase(StdpSpikeMultiplicity) - return unittest.TestSuite([suite]) - - -def run(): - runner = unittest.TextTestRunner(verbosity=2) - runner.run(suite()) - - -if __name__ == "__main__": - run() + assert all(w_plain == w_plain[0]), 'Plain weights should be independent of timestep!' + abs_err = np.abs(w_precise - w_plain) + assert abs_err[-1] < max_abs_err, 'Final absolute error is ' + '{0:.2E}'.format(abs_err[-1]) \ + + ' but should be <= ' + '{0:.2E}'.format(max_abs_err) + assert np.all(np.diff(abs_err) < 0), 'Error should decrease with smaller timestep!' diff --git a/testsuite/pytests/test_stdp_nn_synapses.py b/testsuite/pytests/test_stdp_nn_synapses.py index e1dbcd70f3..cdc3ed3733 100644 --- a/testsuite/pytests/test_stdp_nn_synapses.py +++ b/testsuite/pytests/test_stdp_nn_synapses.py @@ -23,12 +23,14 @@ # and stdp_nn_restr_synapse in NEST. import nest -import unittest +import numpy as np +import pytest + from math import exp @nest.ll_api.check_stack -class STDPNNSynapsesTest(unittest.TestCase): +class TestSTDPNNSynapses: """ Test the weight change by STDP with three nearest-neighbour spike pairing schemes. @@ -43,12 +45,12 @@ class STDPNNSynapsesTest(unittest.TestCase): Instead, it directly iterates through the spike history. """ + @pytest.fixture(autouse=True) def setUp(self): self.resolution = 0.1 # [ms] self.presynaptic_firing_rate = 20.0 # [Hz] self.postsynaptic_firing_rate = 20.0 # [Hz] self.simulation_duration = 1e+4 # [ms] - self.hardcoded_trains_length = 15. # [ms] self.synapse_parameters = { "receptor_type": 1, "delay": self.resolution, @@ -66,6 +68,18 @@ def setUp(self): "tau_minus": 33.7 } + # While the random sequences, fairly long, would supposedly + # reveal small differences in the weight change between NEST + # and ours, some low-probability events (say, coinciding + # spikes) can well not have occured. To generate and + # test every possible combination of pre/post order, we + # append some hardcoded spike sequences: + # pre: 1 5 6 7 9 11 12 13 + # post: 2 3 4 8 9 10 12 + self.hardcoded_pre_times = np.array([1, 5, 6, 7, 9, 11, 12, 13], dtype=float) + self.hardcoded_post_times = np.array([2, 3, 4, 8, 9, 10, 12], dtype=float) + self.hardcoded_trains_length = 2. + max(np.amax(self.hardcoded_pre_times), np.amax(self.hardcoded_post_times)) + def do_nest_simulation_and_compare_to_reproduced_weight(self, pairing_scheme): synapse_model = "stdp_" + pairing_scheme + "_synapse" @@ -75,13 +89,13 @@ def do_nest_simulation_and_compare_to_reproduced_weight(self, weight_reproduced_independently = self.reproduce_weight_drift( pre_spikes, post_spikes, self.synapse_parameters["weight"]) - self.assertAlmostEqual( + np.testing.assert_almost_equal( weight_reproduced_independently, weight_by_nest, - msg=synapse_model + " test: " - "Resulting synaptic weight %e " - "differs from expected %e" % ( - weight_by_nest, weight_reproduced_independently)) + err_msg=synapse_model + " test: " + "Resulting synaptic weight %e " + "differs from expected %e" % ( + weight_by_nest, weight_reproduced_independently)) def do_the_nest_simulation(self): """ @@ -93,12 +107,10 @@ def do_the_nest_simulation(self): nest.ResetKernel() nest.resolution = self.resolution - neurons = nest.Create( + presynaptic_neuron, postsynaptic_neuron = nest.Create( "parrot_neuron", 2, params=self.neuron_parameters) - presynaptic_neuron = neurons[0] - postsynaptic_neuron = neurons[1] generators = nest.Create( "poisson_generator", @@ -110,32 +122,13 @@ def do_the_nest_simulation(self): presynaptic_generator = generators[0] postsynaptic_generator = generators[1] - # While the random sequences, fairly long, would supposedly - # reveal small differences in the weight change between NEST - # and ours, some low-probability events (say, coinciding - # spikes) can well not have occured. To generate and - # test every possible combination of pre/post precedence, we - # append some hardcoded spike sequences: - # pre: 1 5 6 7 9 11 12 13 - # post: 2 3 4 8 9 10 12 - ( - hardcoded_pre_times, - hardcoded_post_times - ) = [ - [ - self.simulation_duration - self.hardcoded_trains_length + t - for t in train - ] for train in ( - (1, 5, 6, 7, 9, 11, 12, 13), - (2, 3, 4, 8, 9, 10, 12) - ) - ] - spike_senders = nest.Create( "spike_generator", 2, - params=({"spike_times": hardcoded_pre_times}, - {"spike_times": hardcoded_post_times}) + params=({"spike_times": self.hardcoded_pre_times + + self.simulation_duration - self.hardcoded_trains_length}, + {"spike_times": self.hardcoded_post_times + + self.simulation_duration - self.hardcoded_trains_length}) ) pre_spike_generator = spike_senders[0] post_spike_generator = spike_senders[1] @@ -285,20 +278,3 @@ def test_nn_pre_centered_synapse(self): def test_nn_restr_synapse(self): self.do_nest_simulation_and_compare_to_reproduced_weight("nn_restr") - - -def suite(): - - # makeSuite is sort of obsolete http://bugs.python.org/issue2721 - # using loadTestsFromTestCase instead. - suite = unittest.TestLoader().loadTestsFromTestCase(STDPNNSynapsesTest) - return unittest.TestSuite([suite]) - - -def run(): - runner = unittest.TextTestRunner(verbosity=2) - runner.run(suite()) - - -if __name__ == "__main__": - run() diff --git a/testsuite/pytests/test_stdp_synapse.py b/testsuite/pytests/test_stdp_synapse.py new file mode 100644 index 0000000000..22bcbefe47 --- /dev/null +++ b/testsuite/pytests/test_stdp_synapse.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# +# test_stdp_synapse.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import nest +import pytest +from math import exp +import numpy as np + +try: + import matplotlib as mpl + import matplotlib.pyplot as plt + DEBUG_PLOTS = True +except Exception: + DEBUG_PLOTS = False + + +@nest.ll_api.check_stack +class TestSTDPSynapse: + """ + Compare the STDP synaptic plasticity model against a self-contained Python reference. + + Random pre and post spike times are generated according to a Poisson distribution; some hard-coded spike times are + added to make sure to test for edge cases such as simultaneous pre- and post spike. + """ + + def init_params(self): + self.resolution = 0.1 # [ms] + self.simulation_duration = 1E3 # [ms] + self.synapse_model = "stdp_synapse" + self.presynaptic_firing_rate = 20. # [ms^-1] + self.postsynaptic_firing_rate = 20. # [ms^-1] + self.tau_pre = 16.8 + self.tau_post = 33.7 + self.init_weight = .5 + self.synapse_parameters = { + "synapse_model": self.synapse_model, + "receptor_type": 0, + "delay": self.dendritic_delay, + # STDP constants + "lambda": 0.01, + "alpha": 0.85, + "mu_plus": 0.0, + "mu_minus": 0.0, + "tau_plus": self.tau_pre, + "Wmax": 15.0, + "weight": self.init_weight + } + self.neuron_parameters = { + "tau_minus": self.tau_post, + } + + # While the random sequences, fairly long, would supposedly + # reveal small differences in the weight change between NEST + # and ours, some low-probability events (say, coinciding + # spikes) can well not have occured. To generate and + # test every possible combination of pre/post order, we + # append some hardcoded spike sequences: + # pre: 1 5 6 7 9 11 12 13 + # post: 2 3 4 8 9 10 12 + self.hardcoded_pre_times = np.array([1, 5, 6, 7, 9, 11, 12, 13], dtype=float) + self.hardcoded_post_times = np.array([2, 3, 4, 8, 9, 10, 12], dtype=float) + self.hardcoded_trains_length = 2. + max(np.amax(self.hardcoded_pre_times), np.amax(self.hardcoded_post_times)) + + def do_nest_simulation_and_compare_to_reproduced_weight(self, fname_snip): + pre_spikes, post_spikes, t_weight_by_nest, weight_by_nest = self.do_the_nest_simulation() + if DEBUG_PLOTS: + self.plot_weight_evolution(pre_spikes, post_spikes, + t_weight_by_nest, + weight_by_nest, + fname_snip=fname_snip, + title_snip=self.nest_neuron_model + " (NEST)") + + t_weight_reproduced_independently, weight_reproduced_independently = self.reproduce_weight_drift( + pre_spikes, post_spikes, + self.init_weight, + fname_snip=fname_snip) + + # ``weight_by_nest`` containts only weight values at pre spike times, ``weight_reproduced_independently`` + # contains the weight at pre *and* post times: check that weights are equal only for pre spike times + assert len(weight_by_nest) > 0 + for idx_pre_spike_nest, t_pre_spike_nest in enumerate(t_weight_by_nest): + idx_pre_spike_reproduced_independently = \ + np.argmin((t_pre_spike_nest - t_weight_reproduced_independently)**2) + np.testing.assert_allclose(t_pre_spike_nest, + t_weight_reproduced_independently[idx_pre_spike_reproduced_independently]) + np.testing.assert_allclose(weight_by_nest[idx_pre_spike_nest], + weight_reproduced_independently[idx_pre_spike_reproduced_independently]) + + def do_the_nest_simulation(self): + """ + This function is where calls to NEST reside. Returns the generated pre- and post spike sequences and the + resulting weight established by STDP. + """ + nest.set_verbosity('M_WARNING') + nest.ResetKernel() + nest.SetKernelStatus({'resolution': self.resolution}) + + presynaptic_neuron, postsynaptic_neuron = nest.Create( + self.nest_neuron_model, + 2, + params=self.neuron_parameters) + + generators = nest.Create( + "poisson_generator", + 2, + params=({"rate": self.presynaptic_firing_rate, + "stop": (self.simulation_duration - self.hardcoded_trains_length)}, + {"rate": self.postsynaptic_firing_rate, + "stop": (self.simulation_duration - self.hardcoded_trains_length)})) + presynaptic_generator = generators[0] + postsynaptic_generator = generators[1] + + wr = nest.Create('weight_recorder') + nest.CopyModel(self.synapse_model, self.synapse_model + "_rec", {"weight_recorder": wr}) + + spike_senders = nest.Create( + "spike_generator", + 2, + params=({"spike_times": self.hardcoded_pre_times + + self.simulation_duration - self.hardcoded_trains_length}, + {"spike_times": self.hardcoded_post_times + + self.simulation_duration - self.hardcoded_trains_length}) + ) + pre_spike_generator = spike_senders[0] + post_spike_generator = spike_senders[1] + + # The recorder is to save the randomly generated spike trains. + spike_recorder = nest.Create("spike_recorder") + + nest.Connect(presynaptic_generator + pre_spike_generator, presynaptic_neuron, + syn_spec={"synapse_model": "static_synapse", "weight": 9999.}) + nest.Connect(postsynaptic_generator + post_spike_generator, postsynaptic_neuron, + syn_spec={"synapse_model": "static_synapse", "weight": 9999.}) + nest.Connect(presynaptic_neuron + postsynaptic_neuron, spike_recorder, + syn_spec={"synapse_model": "static_synapse"}) + # The synapse of interest itself + self.synapse_parameters["synapse_model"] += "_rec" + nest.Connect(presynaptic_neuron, postsynaptic_neuron, syn_spec=self.synapse_parameters) + self.synapse_parameters["synapse_model"] = self.synapse_model + + nest.Simulate(self.simulation_duration) + + all_spikes = nest.GetStatus(spike_recorder, keys='events')[0] + pre_spikes = all_spikes['times'][all_spikes['senders'] == presynaptic_neuron.tolist()[0]] + post_spikes = all_spikes['times'][all_spikes['senders'] == postsynaptic_neuron.tolist()[0]] + + t_hist = nest.GetStatus(wr, "events")[0]["times"] + weight = nest.GetStatus(wr, "events")[0]["weights"] + + return pre_spikes, post_spikes, t_hist, weight + + def reproduce_weight_drift(self, pre_spikes, post_spikes, initial_weight, fname_snip=""): + """Independent, self-contained model of STDP""" + def facilitate(w, Kpre, Wmax_=1.): + norm_w = (w / self.synapse_parameters["Wmax"]) + ( + self.synapse_parameters["lambda"] * pow( + 1 - (w / self.synapse_parameters["Wmax"]), self.synapse_parameters["mu_plus"]) * Kpre) + if norm_w < 1.0: + return norm_w * self.synapse_parameters["Wmax"] + else: + return self.synapse_parameters["Wmax"] + + def depress(w, Kpost): + norm_w = (w / self.synapse_parameters["Wmax"]) - ( + self.synapse_parameters["alpha"] * self.synapse_parameters["lambda"] * pow( + w / self.synapse_parameters["Wmax"], self.synapse_parameters["mu_minus"]) * Kpost) + if norm_w > 0.0: + return norm_w * self.synapse_parameters["Wmax"] + else: + return 0. + + def Kpost_at_time(t, spikes, init=1., inclusive=True): + t_curr = 0. + Kpost = 0. + for spike_idx, t_sp in enumerate(spikes): + if t < t_sp: + # integrate to t + Kpost *= exp(-(t - t_curr) / self.tau_post) + return Kpost + # integrate to t_sp + Kpost *= exp(-(t_sp - t_curr) / self.tau_post) + if inclusive: + Kpost += 1. + if t == t_sp: + return Kpost + if not inclusive: + Kpost += 1. + t_curr = t_sp + # if we get here, t > t_last_spike + # integrate to t + Kpost *= exp(-(t - t_curr) / self.tau_post) + return Kpost + + t = 0. + Kpre = 0. + weight = initial_weight + + t_log = [] + w_log = [] + Kpre_log = [] + + # logging + t_log.append(t) + w_log.append(weight) + Kpre_log.append(Kpre) + + post_spikes_delayed = post_spikes + self.dendritic_delay + + while t < self.simulation_duration: + idx_next_pre_spike = -1 + if np.where((pre_spikes - t) > 0)[0].size > 0: + idx_next_pre_spike = np.where((pre_spikes - t) > 0)[0][0] + t_next_pre_spike = pre_spikes[idx_next_pre_spike] + + idx_next_post_spike = -1 + if np.where((post_spikes_delayed - t) > 0)[0].size > 0: + idx_next_post_spike = np.where((post_spikes_delayed - t) > 0)[0][0] + t_next_post_spike = post_spikes_delayed[idx_next_post_spike] + + if idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike < t_next_pre_spike: + handle_post_spike = True + handle_pre_spike = False + elif idx_next_pre_spike >= 0 and idx_next_post_spike >= 0 and t_next_post_spike > t_next_pre_spike: + handle_post_spike = False + handle_pre_spike = True + else: + # simultaneous spikes (both true) or no more spikes to process (both false) + handle_post_spike = idx_next_post_spike >= 0 + handle_pre_spike = idx_next_pre_spike >= 0 + + # integrate to min(t_next_pre_spike, t_next_post_spike) + t_next = t + if handle_pre_spike: + t_next = max(t, t_next_pre_spike) + if handle_post_spike: + t_next = max(t, t_next_post_spike) + + if t_next == t: + # no more spikes to process + t_next = self.simulation_duration + + '''# max timestep + t_next_ = min(t_next, t + 1E-3) + if t_next != t_next_: + t_next = t_next_ + handle_pre_spike = False + handle_post_spike = False''' + + h = t_next - t + Kpre *= exp(-h / self.tau_pre) + t = t_next + + if handle_post_spike: + # Kpost += 1. <-- not necessary, will call Kpost_at_time(t) later to compute Kpost for any value t + weight = facilitate(weight, Kpre) + + if handle_pre_spike: + Kpre += 1. + _Kpost = Kpost_at_time(t - self.dendritic_delay, post_spikes, init=self.init_weight, inclusive=False) + weight = depress(weight, _Kpost) + + # logging + t_log.append(t) + w_log.append(weight) + Kpre_log.append(Kpre) + + Kpost_log = [Kpost_at_time(t - self.dendritic_delay, post_spikes, init=self.init_weight) for t in t_log] + if DEBUG_PLOTS: + self.plot_weight_evolution(pre_spikes, post_spikes, t_log, w_log, Kpre_log, Kpost_log, + fname_snip=fname_snip + "_ref", title_snip="Reference") + + return t_log, w_log + + def plot_weight_evolution(self, pre_spikes, post_spikes, t_log, w_log, Kpre_log=None, Kpost_log=None, + fname_snip="", title_snip=""): + fig, ax = plt.subplots(nrows=3) + + n_spikes = len(pre_spikes) + for i in range(n_spikes): + ax[0].plot(2 * [pre_spikes[i]], [0, 1], linewidth=2, color="blue", alpha=.4) + ax[0].set_ylabel("Pre spikes") + ax0_ = ax[0].twinx() + if Kpre_log: + ax0_.plot(t_log, Kpre_log) + + n_spikes = len(post_spikes) + for i in range(n_spikes): + ax[1].plot(2 * [post_spikes[i]], [0, 1], linewidth=2, color="red", alpha=.4) + ax1_ = ax[1].twinx() + ax[1].set_ylabel("Post spikes") + if Kpost_log: + ax1_.plot(t_log, Kpost_log) + + ax[2].plot(t_log, w_log, marker="o", label="nestml") + ax[2].set_ylabel("w") + + ax[2].set_xlabel("Time [ms]") + for _ax in ax: + _ax.grid(which="major", axis="both") + _ax.grid(which="minor", axis="x", linestyle=":", alpha=.4) + _ax.minorticks_on() + _ax.set_xlim(0., self.simulation_duration) + + fig.suptitle(title_snip) + fig.savefig("/tmp/nest_stdp_synapse_test" + fname_snip + ".png", dpi=300) + plt.close(fig) + + def test_stdp_synapse(self): + self.dendritic_delay = float('nan') + self.init_params() + for self.dendritic_delay in [1., self.resolution]: + self.init_params() + for self.nest_neuron_model in ["iaf_psc_exp", "iaf_cond_exp"]: + fname_snip = "_[nest_neuron_mdl=" + self.nest_neuron_model + "]" + fname_snip += "_[dend_delay=" + str(self.dendritic_delay) + "]" + self.do_nest_simulation_and_compare_to_reproduced_weight(fname_snip=fname_snip) diff --git a/testsuite/pytests/test_visualization.py b/testsuite/pytests/test_visualization.py index b49538b245..69ba6d0f25 100644 --- a/testsuite/pytests/test_visualization.py +++ b/testsuite/pytests/test_visualization.py @@ -23,10 +23,10 @@ Tests for visualization functions. """ -import os -import unittest import nest import numpy as np +import os +import pytest try: import matplotlib.pyplot as plt @@ -50,7 +50,7 @@ HAVE_PANDAS = False -class VisualizationTestCase(unittest.TestCase): +class TestVisualization: def nest_tmpdir(self): """Returns temp dir path from environment, current dir otherwise.""" if 'NEST_DATA_PATH' in os.environ: @@ -58,15 +58,16 @@ def nest_tmpdir(self): else: return '.' + @pytest.fixture(autouse=True) def setUp(self): self.filenames = [] - - def tearDown(self): + yield + # fixture teardown code below for filename in self.filenames: # Cleanup temporary datafiles os.remove(filename) - @unittest.skipIf(not HAVE_PYDOT, 'pydot not found') + @pytest.mark.skipif(not HAVE_PYDOT, reason='pydot not found') def test_plot_network(self): """Test plot_network""" import nest.visualization as nvis @@ -78,24 +79,24 @@ def test_plot_network(self): filename = os.path.join(self.nest_tmpdir(), 'network_plot.png') self.filenames.append(filename) nvis.plot_network(sources + targets, filename) - self.assertTrue(os.path.isfile(filename), 'Plot was not created or not saved') + assert os.path.isfile(filename), 'Plot was not created or not saved' def voltage_trace_verify(self, device): - self.assertIsNotNone(plt._pylab_helpers.Gcf.get_active(), 'No active figure') + assert plt._pylab_helpers.Gcf.get_active() is not None, 'No active figure' ax = plt.gca() vm = device.get('events', 'V_m') for ref_vm, line in zip((vm[::2], vm[1::2]), ax.lines): x_data, y_data = line.get_data() # Check that times are correct - self.assertEqual(list(x_data), list(np.unique(device.get('events', 'times')))) + assert list(x_data) == list(np.unique(device.get('events', 'times'))) # Check that voltmeter data corresponds to the lines in the plot - self.assertTrue(all(np.isclose(ref_vm, y_data))) + assert all(np.isclose(ref_vm, y_data)) plt.close(ax.get_figure()) - @unittest.skipIf(not PLOTTING_POSSIBLE, 'Plotting impossible because matplotlib or display missing') + @pytest.mark.skipif(not PLOTTING_POSSIBLE, reason='Plotting impossible because matplotlib or display missing') def test_voltage_trace_from_device(self): """Test voltage_trace from device""" - import nest.voltage_trace as nvtrace + import nest.voltage_trace nest.ResetKernel() nodes = nest.Create('iaf_psc_alpha', 2) pg = nest.Create('poisson_generator', 1, {'rate': 1000.}) @@ -105,10 +106,11 @@ def test_voltage_trace_from_device(self): nest.Simulate(100) # Test with data from device + plt.close("all") nest.voltage_trace.from_device(device) self.voltage_trace_verify(device) - # Test with fata from file + # Test with data from file vm = device.get('events') data = np.zeros([len(vm['senders']), 3]) data[:, 0] = vm['senders'] @@ -117,6 +119,8 @@ def test_voltage_trace_from_device(self): filename = os.path.join(self.nest_tmpdir(), 'voltage_trace.txt') self.filenames.append(filename) np.savetxt(filename, data) + + plt.close("all") nest.voltage_trace.from_file(filename) self.voltage_trace_verify(device) @@ -138,19 +142,19 @@ def spike_recorder_data_setup(self, to_file=False): return sr def spike_recorder_raster_verify(self, sr_ref): - self.assertIsNotNone(plt._pylab_helpers.Gcf.get_active(), 'No active figure') + assert plt._pylab_helpers.Gcf.get_active() is not None, 'No active figure' fig = plt.gcf() axs = fig.get_axes() x_data, y_data = axs[0].lines[0].get_data() plt.close(fig) # Have to use isclose() because of round-off errors - self.assertEqual(x_data.shape, sr_ref.shape) - self.assertTrue(all(np.isclose(x_data, sr_ref))) + assert x_data.shape == sr_ref.shape + assert all(np.isclose(x_data, sr_ref)) - @unittest.skipIf(not PLOTTING_POSSIBLE, 'Plotting impossible because matplotlib or display missing') + @pytest.mark.skipif(not PLOTTING_POSSIBLE, reason='Plotting impossible because matplotlib or display missing') def test_raster_plot(self): """Test raster_plot""" - import nest.raster_plot as nraster + import nest.raster_plot sr, sr_to_file = self.spike_recorder_data_setup(to_file=True) spikes = sr.get('events') @@ -186,17 +190,7 @@ def test_raster_plot(self): all_extracted = nest.raster_plot.extract_events(data) times_30_to_40_extracted = nest.raster_plot.extract_events(data, time=[30., 40.], sel=[3]) source_2_extracted = nest.raster_plot.extract_events(data, sel=[2]) - self.assertTrue(np.array_equal(all_extracted, data)) - self.assertTrue(np.all(times_30_to_40_extracted[:, 1] >= 30.)) - self.assertTrue(np.all(times_30_to_40_extracted[:, 1] < 40.)) - self.assertEqual(len(source_2_extracted), 0) - - -def suite(): - suite = unittest.makeSuite(VisualizationTestCase, 'test') - return suite - - -if __name__ == "__main__": - runner = unittest.TextTestRunner(verbosity=2) - runner.run(suite()) + assert np.array_equal(all_extracted, data) + assert np.all(times_30_to_40_extracted[:, 1] >= 30.) + assert np.all(times_30_to_40_extracted[:, 1] < 40.) + assert len(source_2_extracted) == 0