Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,18 @@ def reset(self):
"""
raise NotImplementedError('Must implement "reset" function by subclass self.')

def update_local_delays(self):
# update delays
for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
for name in node.local_delay_vars.keys():
self.global_delay_vars[name].update(self.global_delay_targets[name].value)

def reset_local_delays(self):
# reset delays
for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
for name in node.local_delay_vars.keys():
self.global_delay_vars[name].reset(self.global_delay_targets[name])


class Container(DynamicalSystem):
"""Container object which is designed to add other instances of DynamicalSystem.
Expand Down Expand Up @@ -725,7 +737,7 @@ def __init__(

def derivative(self, V, t):
Iext = self.input.value * (1e-3 / self.A)
channels = self.nodes(level=1, include_self=False).unique().subset(Channel)
channels = self.nodes(level=1, include_self=False).subset(Channel).unique()
for ch in channels.values():
Iext = Iext + ch.current(V)
return Iext / self.C
Expand All @@ -737,7 +749,7 @@ def reset(self):

def update(self, t, dt):
V = self.integral(self.V.value, t, dt)
channels = self.nodes(level=1, include_self=False).unique().subset(Channel)
channels = self.nodes(level=1, include_self=False).subset(Channel).unique()
for node in channels.values():
node.update(t, dt, self.V.value)
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
Expand Down
10 changes: 5 additions & 5 deletions brainpy/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import inspect
from typing import Union, Callable, Optional, Dict

from brainpy.train.algorithms import OfflineAlgorithm, OnlineAlgorithm
from brainpy.dyn.base import DynamicalSystem
from brainpy.train.algorithms import OfflineAlgorithm, OnlineAlgorithm
from brainpy.types import Tensor

__all__ = [
Expand Down Expand Up @@ -78,9 +78,9 @@ def reset(self, batch_size=1):
for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values():
node.reset(batch_size=batch_size)

def reset_batch_state(self, batch_size=1):
def reset_state(self, batch_size=1):
for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values():
node.reset_batch_state(batch_size=batch_size)
node.reset_state(batch_size=batch_size)

@not_implemented
def online_init(self):
Expand All @@ -95,7 +95,8 @@ def offline_init(self):
@not_implemented
def online_fit(self,
target: Tensor,
fit_record: Dict[str, Tensor]):
fit_record: Dict[str, Tensor],
shared_args: Dict = None):
raise NotImplementedError('Subclass must implement online_fit() function when using '
'OnlineTrainer.')

Expand All @@ -108,7 +109,6 @@ def offline_fit(self,
'OfflineTrainer.')



class Sequential(TrainingSystem):
def __init__(self, *modules, name: str = None, **kw_modules):
super(Sequential, self).__init__(name=name, trainable=False)
Expand Down
6 changes: 6 additions & 0 deletions brainpy/train/layers/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ def forward(self, x, shared_args=None):
return bm.where(keep_mask, x / self.prob, 0.)
else:
return x

def reset(self, batch_size=1):
pass

def reset_state(self, batch_size=1):
pass
6 changes: 6 additions & 0 deletions brainpy/train/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def forward(self, x, shared_args=None):
self.fit_record['output'] = res
return res

def reset(self, batch_size=1):
pass

def reset_state(self, batch_size=1):
pass

def online_init(self):
if self.b is None:
num_input = self.num_in
Expand Down
4 changes: 2 additions & 2 deletions brainpy/train/layers/nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def __init__(

def reset(self, batch_size=1):
self.idx[0] = 0
self.reset_batch_state(batch_size)
self.reset_state(batch_size)

def reset_batch_state(self, batch_size=1):
def reset_state(self, batch_size=1):
"""Reset the node state which depends on batch size."""
# To store the last inputs.
# Note, the batch axis is not in the first dimension, so we
Expand Down
4 changes: 2 additions & 2 deletions brainpy/train/layers/recurrents.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __init__(self,
self.state[:] = self.state2train

def reset(self, batch_size=1):
self.reset_batch_state(batch_size)
self.reset_state(batch_size)

def reset_batch_state(self, batch_size=1):
def reset_state(self, batch_size=1):
self.state._value = init_param(self._state_initializer, (batch_size, self.num_out), allow_none=False)
if self.train_state:
self.state2train.value = init_param(self._state_initializer, self.num_out, allow_none=False)
Expand Down
3 changes: 3 additions & 0 deletions brainpy/train/layers/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def __init__(
def reset(self, batch_size=1):
self.state._value = bm.zeros((batch_size,) + self.output_shape).value

def reset_state(self, batch_size=1):
pass

def forward(self, x, shared_args=None):
"""Feedforward output."""
# inputs
Expand Down
6 changes: 3 additions & 3 deletions brainpy/train/runners/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def fit(
# training set
for x, y in train_data_:
if reset_state:
self.target.reset_batch_state(check_data_batch_size(x))
self.target.reset_state(check_data_batch_size(x))
loss = self.f_train(shared_args)(x, y)
all_train_losses.append(loss)
train_i += 1
Expand All @@ -194,7 +194,7 @@ def fit(
if test_data_ is not None:
for x, y in test_data_:
if reset_state:
self.target.reset_batch_state(check_data_batch_size(x))
self.target.reset_state(check_data_batch_size(x))
loss = self.f_loss(shared_args)(x, y)
all_test_losses.append(loss)

Expand Down Expand Up @@ -399,7 +399,7 @@ def predict(
num_batch = self._get_xs_batch_size(xs)
# reset the model states
if reset_state:
self.target.reset_batch_state(num_batch)
self.target.reset_state(num_batch)
# init monitor
for key in self.mon.item_contents.keys():
self.mon.item_contents[key] = [] # reshape the monitor items
Expand Down
19 changes: 19 additions & 0 deletions docs/apis/compat_nn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
``brainpy.compat.nn`` module
============================

.. currentmodule:: brainpy.compat.nn
.. automodule:: brainpy.compat.nn


.. toctree::
:maxdepth: 1

auto/compat/nn_base
auto/compat/nn_operations
auto/compat/nn_graph_flow
auto/compat/nn_runners
auto/compat/nn_algorithms
auto/compat/nn_data_types
auto/compat/nn_nodes_base
auto/compat/nn_nodes_ANN
auto/compat/nn_nodes_RC
19 changes: 0 additions & 19 deletions docs/apis/nn.rst

This file was deleted.

44 changes: 44 additions & 0 deletions docs/auto_generater.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,11 +588,55 @@ def generate_compact_docs(path='apis/auto/compat/'):
filename=os.path.join(path, 'runners.rst'),
header='Runners')

write_module(module_name='brainpy.compat.nn.base',
filename=os.path.join(path, 'nn_base.rst'),
header='Base Classes')
write_module(module_name='brainpy.compat.nn.operations',
filename=os.path.join(path, 'nn_operations.rst'),
header='Node Operations')
write_module(module_name='brainpy.compat.nn.graph_flow',
filename=os.path.join(path, 'nn_graph_flow.rst'),
header='Node Graph Tools')
write_module(module_name='brainpy.compat.nn.datatypes',
filename=os.path.join(path, 'nn_data_types.rst'),
header='Data Types')
module_and_name = [
('rnn_runner', 'Base RNN Runner'),
('rnn_trainer', 'Base RNN Trainer'),
('online_trainer', 'Online RNN Trainer'),
('offline_trainer', 'Offline RNN Trainer'),
('back_propagation', 'Back-propagation Trainer'),
]
write_submodules(module_name='brainpy.compat.nn.runners',
filename=os.path.join(path, 'nn_runners.rst'),
header='Runners and Trainers',
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])
module_and_name = [
('online', 'Online Training Algorithms'),
('offline', 'Offline Training Algorithms'),
]
write_submodules(module_name='brainpy.compat.nn.algorithms',
filename=os.path.join(path, 'nn_algorithms.rst'),
header='Training Algorithms',
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])
write_module(module_name='brainpy.compat.nn.nodes.base',
filename=os.path.join(path, 'nn_nodes_base.rst'),
header='Nodes: basic')
write_module(module_name='brainpy.compat.nn.nodes.ANN',
filename=os.path.join(path, 'nn_nodes_ANN.rst'),
header='Nodes: artificial neural network ')
write_module(module_name='brainpy.compat.nn.nodes.RC',
filename=os.path.join(path, 'nn_nodes_RC.rst'),
header='Nodes: reservoir computing')


def generate_math_compact_docs(path='apis/auto/math/'):
if not os.path.exists(path):
os.makedirs(path)


write_module(module_name='brainpy.math.compat.optimizers',
filename=os.path.join(path, 'optimizers.rst'),
header='Optimizers')
Expand Down
2 changes: 1 addition & 1 deletion examples/training/Gauthier_2021_ngrc_lorenz.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def plot_lorenz(ground_truth, predictions):

# Model #
# ----- #
class NGRC(bp.train.TrainNet):
class NGRC(bp.train.TrainingSystem):
def __init__(self, num_in):
super(NGRC, self).__init__()
self.r = bp.train.NVAR(num_in, delay=2, order=2, constant=True)
Expand Down
2 changes: 1 addition & 1 deletion examples/training/echo_state_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def ngrc(num_in=10, num_out=30):
def ngrc_bacth(num_in=10, num_out=30):
model = NGRC(num_in, num_out)
batch_size = 10
model.reset_batch_state(batch_size)
model.reset_state(batch_size)
X = bm.random.random((batch_size, 200, num_in))
Y = bm.random.random((batch_size, 200, num_out))
trainer = bp.train.RidgeTrainer(model, beta=1e-6)
Expand Down
2 changes: 1 addition & 1 deletion examples/training/integrator_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def loss(predictions, targets, l2_reg=2e-4):
plt.plot(trainer.train_losses.numpy())
plt.show()

model.reset_batch_state(1)
model.reset_state(1)
x, y = build_inputs_and_targets(batch_size=1)
predicts = trainer.predict(x)

Expand Down