Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions brainpy/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(

def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
self.stp.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def update(self, tdi, pre_spike=None):
# pre-synaptic spikes
Expand All @@ -140,7 +140,7 @@ def update(self, tdi, pre_spike=None):

# update sub-components
self.output.update(tdi)
self.stp.update(tdi, pre_spike)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# synaptic values onto the post
if isinstance(self.conn, All2All):
Expand Down Expand Up @@ -312,7 +312,7 @@ def __init__(
def reset_state(self, batch_size=None):
self.g.value = variable(bm.zeros, batch_size, self.post.num)
self.output.reset_state(batch_size)
self.stp.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def update(self, tdi, pre_spike=None):
t, dt = tdi['t'], tdi['dt']
Expand All @@ -326,14 +326,16 @@ def update(self, tdi, pre_spike=None):

# update sub-components
self.output.update(tdi)
self.stp.update(tdi, pre_spike)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# post values
if isinstance(self.conn, All2All):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
Expand All @@ -343,7 +345,8 @@ def update(self, tdi, pre_spike=None):
# if not isinstance(self.stp, _NullSynSTP):
# raise NotImplementedError()
else:
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
# updates
self.g.value = self.integral(self.g.value, t, dt) + post_vs
Expand Down Expand Up @@ -503,7 +506,7 @@ def reset_state(self, batch_size=None):
self.h.value = variable(bm.zeros, batch_size, self.pre.num)
self.g.value = variable(bm.zeros, batch_size, self.pre.num)
self.output.reset_state(batch_size)
self.stp.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def dh(self, h, t):
return -h / self.tau_rise
Expand All @@ -523,14 +526,15 @@ def update(self, tdi, pre_spike=None):

# update sub-components
self.output.update(tdi)
self.stp.update(tdi, pre_spike)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synaptic variables
self.g.value, self.h.value = self.integral(self.g, self.h, t, dt)
self.h += pre_spike

# post values
syn_value = self.stp(self.g)
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
Expand Down Expand Up @@ -912,7 +916,7 @@ def reset_state(self, batch_size=None):
self.g.value = variable(bm.zeros, batch_size, self.pre.num)
self.x.value = variable(bm.zeros, batch_size, self.pre.num)
self.output.reset_state(batch_size)
self.stp.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def update(self, tdi, pre_spike=None):
t, dt = tdi['t'], tdi['dt']
Expand All @@ -924,15 +928,16 @@ def update(self, tdi, pre_spike=None):
pre_spike = stop_gradient(pre_spike)

# update sub-components
self.stp.update(tdi, pre_spike)
self.output.update(tdi)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synapse variables
self.g.value, self.x.value = self.integral(self.g, self.x, t, dt=dt)
self.x += pre_spike

# post-synaptic value
syn_value = self.stp(self.g)
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
Expand Down
14 changes: 8 additions & 6 deletions brainpy/dyn/synapses/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def reset_state(self, batch_size=None):
self.g = variable(bm.zeros, batch_size, self.pre.num)
self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num)
self.output.reset_state(batch_size)
self.stp.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def dg(self, g, t, TT):
dg = self.alpha * TT * (1 - g) - self.beta * g
Expand All @@ -220,7 +220,7 @@ def update(self, tdi, pre_spike=None):

# update sub-components
self.output.update(tdi)
self.stp.update(tdi, pre_spike)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synaptic variables
self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
Expand All @@ -230,7 +230,8 @@ def update(self, tdi, pre_spike=None):
self.g.value = self.integral(self.g, t, TT, dt)

# post-synaptic values
syn_value = self.stp(self.g)
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
Expand Down Expand Up @@ -553,8 +554,8 @@ def reset_state(self, batch_size=None):
self.g = variable(bm.zeros, batch_size, self.pre.num)
self.x = variable(bm.zeros, batch_size, self.pre.num)
self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num)
self.stp.reset_state(batch_size)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

def dg(self, g, t, x):
return self.alpha1 * x * (1 - g) - self.beta1 * g
Expand All @@ -574,7 +575,7 @@ def update(self, tdi, pre_spike=None):

# update sub-components
self.output.update(tdi)
self.stp.update(tdi, pre_spike)
if self.stp is not None: self.stp.update(tdi, pre_spike)

# update synapse variables
self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
Expand All @@ -584,7 +585,8 @@ def update(self, tdi, pre_spike=None):
self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt)

# post-synaptic value
syn_value = self.stp(self.g)
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
Expand Down
16 changes: 4 additions & 12 deletions brainpy/train/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,12 @@ def fit(
print(msg)
t0 = t1

# # testing set
# if test_data is not None:
# test_data_ = self._get_batchable_data(test_data, batch_size, False)
# for x, y in test_data_:
# if reset_state:
# self.target.reset_state(self._get_batch_size(x))
# self.reset_state()
# loss = self.f_loss(shared_args)(x, y)
# all_test_losses.append(loss)

# finally
self._train_losses = bm.asarray(all_train_losses)
self._train_loss_aux = {k: bm.asarray(v) for k, v in all_train_loss_aux.items()}
# self._test_losses = bm.asarray(all_test_losses)
if all_train_loss_aux is None:
self._train_loss_aux = dict()
else:
self._train_loss_aux = {k: bm.asarray(v) for k, v in all_train_loss_aux.items()}
self.progress_bar = true_progress_bar

def _get_batchable_data(self, data, num_batch, shuffle=False):
Expand Down
2 changes: 1 addition & 1 deletion examples/training/SurrogateGrad_lif_fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100,
net = SNN(num_in=num_input, num_rec=100, num_out=10)

# load the dataset
root = "some_path/"
root = r"E:\data\fashion-mnist"
train_dataset = bp.datasets.FashionMNIST(root,
train=True,
transform=None,
Expand Down