Skip to content

Commit

Permalink
fix bugs of brainpy.reset_level()
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 15, 2023
1 parent ecb7bde commit a443878
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
19 changes: 12 additions & 7 deletions brainpy/_src/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,22 @@ def reset_state(target: DynamicalSystem, *args, **kwargs):
Args:
target: The target DynamicalSystem.
"""
nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values())
# assign the 'reset_level' to each reset state function
for node in nodes:
if not hasattr(node.reset_state, 'reset_level'):
node.reset_state.reset_level = 0

dynsys.the_top_layer_reset_state = False

try:
nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values())
nodes_with_level = []

# reset node whose `reset_state` has no `reset_level`
for node in nodes:
if not hasattr(node.reset_state, 'reset_level'):
node.reset_state(*args, **kwargs)
else:
nodes_with_level.append(node)

# reset the node's states
for l in range(_max_level):
for node in nodes:
for node in nodes_with_level:
if node.reset_state.reset_level == l:
node.reset_state(*args, **kwargs)

Expand Down
30 changes: 30 additions & 0 deletions brainpy/_src/tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import brainpy as bp

import unittest


class TestResetLevel(unittest.TestCase):

def test1(self):
class Level0(bp.DynamicalSystem):
@bp.reset_level(0)
def reset_state(self, *args, **kwargs):
print('Level 0')

class Level1(bp.DynamicalSystem):
@bp.reset_level(1)
def reset_state(self, *args, **kwargs):
print('Level 1')

class Net(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.l0 = Level0()
self.l1 = Level1()
self.l0_2 = Level0()
self.l1_2 = Level1()

net = Net()
net.reset()


0 comments on commit a443878

Please sign in to comment.