diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 3a0860c66..11fb1215f 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -238,6 +238,18 @@ def reset(self): """ raise NotImplementedError('Must implement "reset" function by subclass self.') + def update_local_delays(self): + # update delays + for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values(): + for name in node.local_delay_vars.keys(): + self.global_delay_vars[name].update(self.global_delay_targets[name].value) + + def reset_local_delays(self): + # reset delays + for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values(): + for name in node.local_delay_vars.keys(): + self.global_delay_vars[name].reset(self.global_delay_targets[name]) + class Container(DynamicalSystem): """Container object which is designed to add other instances of DynamicalSystem. @@ -725,7 +737,7 @@ def __init__( def derivative(self, V, t): Iext = self.input.value * (1e-3 / self.A) - channels = self.nodes(level=1, include_self=False).unique().subset(Channel) + channels = self.nodes(level=1, include_self=False).subset(Channel).unique() for ch in channels.values(): Iext = Iext + ch.current(V) return Iext / self.C @@ -737,7 +749,7 @@ def reset(self): def update(self, t, dt): V = self.integral(self.V.value, t, dt) - channels = self.nodes(level=1, include_self=False).unique().subset(Channel) + channels = self.nodes(level=1, include_self=False).subset(Channel).unique() for node in channels.values(): node.update(t, dt, self.V.value) self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) diff --git a/brainpy/train/base.py b/brainpy/train/base.py index d0c7bd698..1eb1aab37 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -3,8 +3,8 @@ import inspect from typing import Union, Callable, Optional, Dict -from brainpy.train.algorithms import OfflineAlgorithm, OnlineAlgorithm from brainpy.dyn.base import DynamicalSystem +from brainpy.train.algorithms import OfflineAlgorithm, OnlineAlgorithm from brainpy.types import Tensor __all__ = [ @@ -78,9 +78,9 @@ def reset(self, batch_size=1): for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): node.reset(batch_size=batch_size) - def reset_batch_state(self, batch_size=1): + def reset_state(self, batch_size=1): for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): - node.reset_batch_state(batch_size=batch_size) + node.reset_state(batch_size=batch_size) @not_implemented def online_init(self): @@ -95,7 +95,8 @@ def offline_init(self): @not_implemented def online_fit(self, target: Tensor, - fit_record: Dict[str, Tensor]): + fit_record: Dict[str, Tensor], + shared_args: Dict = None): raise NotImplementedError('Subclass must implement online_fit() function when using ' 'OnlineTrainer.') @@ -108,7 +109,6 @@ def offline_fit(self, 'OfflineTrainer.') - class Sequential(TrainingSystem): def __init__(self, *modules, name: str = None, **kw_modules): super(Sequential, self).__init__(name=name, trainable=False) diff --git a/brainpy/train/layers/dropout.py b/brainpy/train/layers/dropout.py index b4ca7fc5f..e5d9598ae 100644 --- a/brainpy/train/layers/dropout.py +++ b/brainpy/train/layers/dropout.py @@ -48,3 +48,9 @@ def forward(self, x, shared_args=None): return bm.where(keep_mask, x / self.prob, 0.) else: return x + + def reset(self, batch_size=1): + pass + + def reset_state(self, batch_size=1): + pass diff --git a/brainpy/train/layers/linear.py b/brainpy/train/layers/linear.py index 0446dc390..977531fdd 100644 --- a/brainpy/train/layers/linear.py +++ b/brainpy/train/layers/linear.py @@ -83,6 +83,12 @@ def forward(self, x, shared_args=None): self.fit_record['output'] = res return res + def reset(self, batch_size=1): + pass + + def reset_state(self, batch_size=1): + pass + def online_init(self): if self.b is None: num_input = self.num_in diff --git a/brainpy/train/layers/nvar.py b/brainpy/train/layers/nvar.py index 0d8697c9e..c0cbd1414 100644 --- a/brainpy/train/layers/nvar.py +++ b/brainpy/train/layers/nvar.py @@ -117,9 +117,9 @@ def __init__( def reset(self, batch_size=1): self.idx[0] = 0 - self.reset_batch_state(batch_size) + self.reset_state(batch_size) - def reset_batch_state(self, batch_size=1): + def reset_state(self, batch_size=1): """Reset the node state which depends on batch size.""" # To store the last inputs. # Note, the batch axis is not in the first dimension, so we diff --git a/brainpy/train/layers/recurrents.py b/brainpy/train/layers/recurrents.py index 12de6065a..d48d7a3de 100644 --- a/brainpy/train/layers/recurrents.py +++ b/brainpy/train/layers/recurrents.py @@ -43,9 +43,9 @@ def __init__(self, self.state[:] = self.state2train def reset(self, batch_size=1): - self.reset_batch_state(batch_size) + self.reset_state(batch_size) - def reset_batch_state(self, batch_size=1): + def reset_state(self, batch_size=1): self.state._value = init_param(self._state_initializer, (batch_size, self.num_out), allow_none=False) if self.train_state: self.state2train.value = init_param(self._state_initializer, self.num_out, allow_none=False) diff --git a/brainpy/train/layers/reservoir.py b/brainpy/train/layers/reservoir.py index a508fa46e..8d4300ff3 100644 --- a/brainpy/train/layers/reservoir.py +++ b/brainpy/train/layers/reservoir.py @@ -199,6 +199,9 @@ def __init__( def reset(self, batch_size=1): self.state._value = bm.zeros((batch_size,) + self.output_shape).value + def reset_state(self, batch_size=1): + pass + def forward(self, x, shared_args=None): """Feedforward output.""" # inputs diff --git a/brainpy/train/runners/back_propagation.py b/brainpy/train/runners/back_propagation.py index 11e6b1483..0a1194946 100644 --- a/brainpy/train/runners/back_propagation.py +++ b/brainpy/train/runners/back_propagation.py @@ -180,7 +180,7 @@ def fit( # training set for x, y in train_data_: if reset_state: - self.target.reset_batch_state(check_data_batch_size(x)) + self.target.reset_state(check_data_batch_size(x)) loss = self.f_train(shared_args)(x, y) all_train_losses.append(loss) train_i += 1 @@ -194,7 +194,7 @@ def fit( if test_data_ is not None: for x, y in test_data_: if reset_state: - self.target.reset_batch_state(check_data_batch_size(x)) + self.target.reset_state(check_data_batch_size(x)) loss = self.f_loss(shared_args)(x, y) all_test_losses.append(loss) @@ -399,7 +399,7 @@ def predict( num_batch = self._get_xs_batch_size(xs) # reset the model states if reset_state: - self.target.reset_batch_state(num_batch) + self.target.reset_state(num_batch) # init monitor for key in self.mon.item_contents.keys(): self.mon.item_contents[key] = [] # reshape the monitor items diff --git a/docs/apis/compat_nn.rst b/docs/apis/compat_nn.rst new file mode 100644 index 000000000..ad22b3247 --- /dev/null +++ b/docs/apis/compat_nn.rst @@ -0,0 +1,19 @@ +``brainpy.compat.nn`` module +============================ + +.. currentmodule:: brainpy.compat.nn +.. automodule:: brainpy.compat.nn + + +.. toctree:: + :maxdepth: 1 + + auto/compat/nn_base + auto/compat/nn_operations + auto/compat/nn_graph_flow + auto/compat/nn_runners + auto/compat/nn_algorithms + auto/compat/nn_data_types + auto/compat/nn_nodes_base + auto/compat/nn_nodes_ANN + auto/compat/nn_nodes_RC diff --git a/docs/apis/nn.rst b/docs/apis/nn.rst deleted file mode 100644 index b83650cbe..000000000 --- a/docs/apis/nn.rst +++ /dev/null @@ -1,19 +0,0 @@ -``brainpy.nn`` module -=========================== - -.. currentmodule:: brainpy.nn -.. automodule:: brainpy.nn - - -.. toctree:: - :maxdepth: 1 - - auto/nn/base - auto/nn/operations - auto/nn/graph_flow - auto/nn/runners - auto/nn/algorithms - auto/nn/data_types - auto/nn/nodes_base - auto/nn/nodes_ANN - auto/nn/nodes_RC diff --git a/docs/auto_generater.py b/docs/auto_generater.py index b7b26435a..64a773683 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -588,11 +588,55 @@ def generate_compact_docs(path='apis/auto/compat/'): filename=os.path.join(path, 'runners.rst'), header='Runners') + write_module(module_name='brainpy.compat.nn.base', + filename=os.path.join(path, 'nn_base.rst'), + header='Base Classes') + write_module(module_name='brainpy.compat.nn.operations', + filename=os.path.join(path, 'nn_operations.rst'), + header='Node Operations') + write_module(module_name='brainpy.compat.nn.graph_flow', + filename=os.path.join(path, 'nn_graph_flow.rst'), + header='Node Graph Tools') + write_module(module_name='brainpy.compat.nn.datatypes', + filename=os.path.join(path, 'nn_data_types.rst'), + header='Data Types') + module_and_name = [ + ('rnn_runner', 'Base RNN Runner'), + ('rnn_trainer', 'Base RNN Trainer'), + ('online_trainer', 'Online RNN Trainer'), + ('offline_trainer', 'Offline RNN Trainer'), + ('back_propagation', 'Back-propagation Trainer'), + ] + write_submodules(module_name='brainpy.compat.nn.runners', + filename=os.path.join(path, 'nn_runners.rst'), + header='Runners and Trainers', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + module_and_name = [ + ('online', 'Online Training Algorithms'), + ('offline', 'Offline Training Algorithms'), + ] + write_submodules(module_name='brainpy.compat.nn.algorithms', + filename=os.path.join(path, 'nn_algorithms.rst'), + header='Training Algorithms', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + write_module(module_name='brainpy.compat.nn.nodes.base', + filename=os.path.join(path, 'nn_nodes_base.rst'), + header='Nodes: basic') + write_module(module_name='brainpy.compat.nn.nodes.ANN', + filename=os.path.join(path, 'nn_nodes_ANN.rst'), + header='Nodes: artificial neural network ') + write_module(module_name='brainpy.compat.nn.nodes.RC', + filename=os.path.join(path, 'nn_nodes_RC.rst'), + header='Nodes: reservoir computing') + def generate_math_compact_docs(path='apis/auto/math/'): if not os.path.exists(path): os.makedirs(path) + write_module(module_name='brainpy.math.compat.optimizers', filename=os.path.join(path, 'optimizers.rst'), header='Optimizers') diff --git a/examples/training/Gauthier_2021_ngrc_lorenz.py b/examples/training/Gauthier_2021_ngrc_lorenz.py index e14dadb87..2d89fa47a 100644 --- a/examples/training/Gauthier_2021_ngrc_lorenz.py +++ b/examples/training/Gauthier_2021_ngrc_lorenz.py @@ -105,7 +105,7 @@ def plot_lorenz(ground_truth, predictions): # Model # # ----- # -class NGRC(bp.train.TrainNet): +class NGRC(bp.train.TrainingSystem): def __init__(self, num_in): super(NGRC, self).__init__() self.r = bp.train.NVAR(num_in, delay=2, order=2, constant=True) diff --git a/examples/training/echo_state_network.py b/examples/training/echo_state_network.py index 2cd6fe048..9f8275000 100644 --- a/examples/training/echo_state_network.py +++ b/examples/training/echo_state_network.py @@ -105,7 +105,7 @@ def ngrc(num_in=10, num_out=30): def ngrc_bacth(num_in=10, num_out=30): model = NGRC(num_in, num_out) batch_size = 10 - model.reset_batch_state(batch_size) + model.reset_state(batch_size) X = bm.random.random((batch_size, 200, num_in)) Y = bm.random.random((batch_size, 200, num_out)) trainer = bp.train.RidgeTrainer(model, beta=1e-6) diff --git a/examples/training/integrator_rnn.py b/examples/training/integrator_rnn.py index ec203038c..1bcae3aa5 100644 --- a/examples/training/integrator_rnn.py +++ b/examples/training/integrator_rnn.py @@ -68,7 +68,7 @@ def loss(predictions, targets, l2_reg=2e-4): plt.plot(trainer.train_losses.numpy()) plt.show() -model.reset_batch_state(1) +model.reset_state(1) x, y = build_inputs_and_targets(batch_size=1) predicts = trainer.predict(x)