From 17a322860b2c98c01e96bcbef2213635ce405f88 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 1 Jul 2022 20:14:25 +0800 Subject: [PATCH 1/2] fix bugs --- brainpy/dyn/synapses/abstract_models.py | 16 ++++++++-------- brainpy/dyn/synapses/biological_models.py | 8 ++++---- brainpy/train/back_propagation.py | 16 ++++------------ .../training/SurrogateGrad_lif_fashion_mnist.py | 2 +- 4 files changed, 17 insertions(+), 25 deletions(-) diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index f015c973a..bd2b31f0a 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -128,7 +128,7 @@ def __init__( def reset_state(self, batch_size=None): self.output.reset_state(batch_size) - self.stp.reset_state(batch_size) + if self.stp is not None: self.stp.reset_state(batch_size) def update(self, tdi, pre_spike=None): # pre-synaptic spikes @@ -140,7 +140,7 @@ def update(self, tdi, pre_spike=None): # update sub-components self.output.update(tdi) - self.stp.update(tdi, pre_spike) + if self.stp is not None: self.stp.update(tdi, pre_spike) # synaptic values onto the post if isinstance(self.conn, All2All): @@ -312,7 +312,7 @@ def __init__( def reset_state(self, batch_size=None): self.g.value = variable(bm.zeros, batch_size, self.post.num) self.output.reset_state(batch_size) - self.stp.reset_state(batch_size) + if self.stp is not None: self.stp.reset_state(batch_size) def update(self, tdi, pre_spike=None): t, dt = tdi['t'], tdi['dt'] @@ -326,7 +326,7 @@ def update(self, tdi, pre_spike=None): # update sub-components self.output.update(tdi) - self.stp.update(tdi, pre_spike) + if self.stp is not None: self.stp.update(tdi, pre_spike) # post values if isinstance(self.conn, All2All): @@ -503,7 +503,7 @@ def reset_state(self, batch_size=None): self.h.value = variable(bm.zeros, batch_size, self.pre.num) self.g.value = variable(bm.zeros, batch_size, self.pre.num) self.output.reset_state(batch_size) - self.stp.reset_state(batch_size) + if self.stp is not None: self.stp.reset_state(batch_size) def dh(self, h, t): return -h / self.tau_rise @@ -523,7 +523,7 @@ def update(self, tdi, pre_spike=None): # update sub-components self.output.update(tdi) - self.stp.update(tdi, pre_spike) + if self.stp is not None: self.stp.update(tdi, pre_spike) # update synaptic variables self.g.value, self.h.value = self.integral(self.g, self.h, t, dt) @@ -912,7 +912,7 @@ def reset_state(self, batch_size=None): self.g.value = variable(bm.zeros, batch_size, self.pre.num) self.x.value = variable(bm.zeros, batch_size, self.pre.num) self.output.reset_state(batch_size) - self.stp.reset_state(batch_size) + if self.stp is not None: self.stp.reset_state(batch_size) def update(self, tdi, pre_spike=None): t, dt = tdi['t'], tdi['dt'] @@ -924,8 +924,8 @@ def update(self, tdi, pre_spike=None): pre_spike = stop_gradient(pre_spike) # update sub-components - self.stp.update(tdi, pre_spike) self.output.update(tdi) + if self.stp is not None: self.stp.update(tdi, pre_spike) # update synapse variables self.g.value, self.x.value = self.integral(self.g, self.x, t, dt=dt) diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index cfa6ef26c..2b7c4d581 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -202,7 +202,7 @@ def reset_state(self, batch_size=None): self.g = variable(bm.zeros, batch_size, self.pre.num) self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num) self.output.reset_state(batch_size) - self.stp.reset_state(batch_size) + if self.stp is not None: self.stp.reset_state(batch_size) def dg(self, g, t, TT): dg = self.alpha * TT * (1 - g) - self.beta * g @@ -220,7 +220,7 @@ def update(self, tdi, pre_spike=None): # update sub-components self.output.update(tdi) - self.stp.update(tdi, pre_spike) + if self.stp is not None: self.stp.update(tdi, pre_spike) # update synaptic variables self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) @@ -553,8 +553,8 @@ def reset_state(self, batch_size=None): self.g = variable(bm.zeros, batch_size, self.pre.num) self.x = variable(bm.zeros, batch_size, self.pre.num) self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num) - self.stp.reset_state(batch_size) self.output.reset_state(batch_size) + if self.stp is not None: self.stp.reset_state(batch_size) def dg(self, g, t, x): return self.alpha1 * x * (1 - g) - self.beta1 * g @@ -574,7 +574,7 @@ def update(self, tdi, pre_spike=None): # update sub-components self.output.update(tdi) - self.stp.update(tdi, pre_spike) + if self.stp is not None: self.stp.update(tdi, pre_spike) # update synapse variables self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index 497c3c1fc..0517ddb22 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -232,20 +232,12 @@ def fit( print(msg) t0 = t1 - # # testing set - # if test_data is not None: - # test_data_ = self._get_batchable_data(test_data, batch_size, False) - # for x, y in test_data_: - # if reset_state: - # self.target.reset_state(self._get_batch_size(x)) - # self.reset_state() - # loss = self.f_loss(shared_args)(x, y) - # all_test_losses.append(loss) - # finally self._train_losses = bm.asarray(all_train_losses) - self._train_loss_aux = {k: bm.asarray(v) for k, v in all_train_loss_aux.items()} - # self._test_losses = bm.asarray(all_test_losses) + if all_train_loss_aux is None: + self._train_loss_aux = dict() + else: + self._train_loss_aux = {k: bm.asarray(v) for k, v in all_train_loss_aux.items()} self.progress_bar = true_progress_bar def _get_batchable_data(self, data, num_batch, shuffle=False): diff --git a/examples/training/SurrogateGrad_lif_fashion_mnist.py b/examples/training/SurrogateGrad_lif_fashion_mnist.py index dc6008e68..66730a0df 100644 --- a/examples/training/SurrogateGrad_lif_fashion_mnist.py +++ b/examples/training/SurrogateGrad_lif_fashion_mnist.py @@ -197,7 +197,7 @@ def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, net = SNN(num_in=num_input, num_rec=100, num_out=10) # load the dataset -root = "some_path/" +root = r"E:\data\fashion-mnist" train_dataset = bp.datasets.FashionMNIST(root, train=True, transform=None, From 3b653247bddd07fd753bb191521c7dc99f6e8c7c Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 1 Jul 2022 20:22:44 +0800 Subject: [PATCH 2/2] fix bugs --- brainpy/dyn/synapses/abstract_models.py | 15 ++++++++++----- brainpy/dyn/synapses/biological_models.py | 6 ++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index bd2b31f0a..92f879dea 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -330,10 +330,12 @@ def update(self, tdi, pre_spike=None): # post values if isinstance(self.conn, All2All): - syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) + if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) + if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': @@ -343,7 +345,8 @@ def update(self, tdi, pre_spike=None): # if not isinstance(self.stp, _NullSynSTP): # raise NotImplementedError() else: - syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) + if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates self.g.value = self.integral(self.g.value, t, dt) + post_vs @@ -530,7 +533,8 @@ def update(self, tdi, pre_spike=None): self.h += pre_spike # post values - syn_value = self.stp(self.g) + syn_value = self.g.value + if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): @@ -932,7 +936,8 @@ def update(self, tdi, pre_spike=None): self.x += pre_spike # post-synaptic value - syn_value = self.stp(self.g) + syn_value = self.g.value + if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index 2b7c4d581..fa1b7e1cc 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -230,7 +230,8 @@ def update(self, tdi, pre_spike=None): self.g.value = self.integral(self.g, t, TT, dt) # post-synaptic values - syn_value = self.stp(self.g) + syn_value = self.g.value + if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): @@ -584,7 +585,8 @@ def update(self, tdi, pre_spike=None): self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt) # post-synaptic value - syn_value = self.stp(self.g) + syn_value = self.g.value + if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One):