Skip to content

Commit

Permalink
update doc, upgrade reset_state, update projection models (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 13, 2024
1 parent 947cc74 commit 02b85b2
Show file tree
Hide file tree
Showing 15 changed files with 479 additions and 289 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/interoperation_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def initialize_carry(self, rng, batch_dims, size=None, init_fn=None):
raise NotImplementedError

_state_vars = self.model.vars().unique().not_subset(bm.TrainVar)
self.model.reset_state(batch_size=batch_dims)
self.model.reset(batch_size=batch_dims)
return [_state_vars.dict(), 0, 0.]

def setup(self):
Expand Down
17 changes: 17 additions & 0 deletions brainpy/_src/dyn/projections/align_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def update(self, x):
self.refs['syn'].add_current(current) # synapse post current
return current

syn = property(lambda self: self.refs['syn'])
out = property(lambda self: self.refs['out'])
post = property(lambda self: self.refs['post'])


class FullProjAlignPostMg(Projection):
"""Full-chain synaptic projection with the align-post reduction and the automatic synapse merging.
Expand Down Expand Up @@ -270,6 +274,12 @@ def update(self):
self.refs['syn'].add_current(current) # synapse post current
return current

syn = property(lambda self: self.refs['syn'])
out = property(lambda self: self.refs['out'])
delay = property(lambda self: self.refs['delay'])
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])


class HalfProjAlignPost(Projection):
"""Defining the half-part of synaptic projection with the align-post reduction.
Expand Down Expand Up @@ -363,6 +373,8 @@ def update(self, x):
self.refs['out'].bind_cond(g) # synapse post current
return current

post = property(lambda self: self.refs['post'])


class FullProjAlignPost(Projection):
"""Full-chain synaptic projection with the align-post reduction.
Expand Down Expand Up @@ -488,3 +500,8 @@ def update(self):
g = self.syn(self.comm(x))
self.refs['out'].bind_cond(g) # synapse post current
return g

delay = property(lambda self: self.refs['delay'])
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
out = property(lambda self: self.refs['out'])
23 changes: 23 additions & 0 deletions brainpy/_src/dyn/projections/align_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def update(self, x=None):
self.refs['out'].bind_cond(current)
return current

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])


class FullProjAlignPreDSMg(Projection):
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
Expand Down Expand Up @@ -326,6 +332,11 @@ def update(self):
self.refs['out'].bind_cond(current)
return current

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
out = property(lambda self: self.refs['out'])


class FullProjAlignPreSD(Projection):
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
Expand Down Expand Up @@ -454,6 +465,12 @@ def update(self, x=None):
self.refs['out'].bind_cond(current)
return current

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])


class FullProjAlignPreDS(Projection):
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
Expand Down Expand Up @@ -581,3 +598,9 @@ def update(self):
g = self.comm(self.syn(spk))
self.refs['out'].bind_cond(g)
return g

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])

6 changes: 6 additions & 0 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def __init__(
self.A1 = A1
self.A2 = A2

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])

def update(self):
# pre-synaptic spikes
pre_spike = self.refs['delay'].at(self.name) # spike
Expand Down

0 comments on commit 02b85b2

Please sign in to comment.