diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 28221aaae..545ffbb33 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -35,7 +35,9 @@ losses, # loss functions measure, # methods for data analysis datasets, # methods for generating data - inputs) # methods for generating input currents + inputs, # methods for generating input currents + algorithms, # online or offline training algorithms + ) # numerical integrators @@ -58,7 +60,8 @@ rates, # rate models synapses, # synaptic dynamics synouts, # synaptic output - synplast) # synaptic plasticity + synplast, # synaptic plasticity + ) # dynamics training diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py index 98a73c4f1..2c2d80b38 100644 --- a/brainpy/analysis/highdim/slow_points.py +++ b/brainpy/analysis/highdim/slow_points.py @@ -120,6 +120,7 @@ def __init__( # parameters for `f_cell` is DynamicalSystem instance inputs: Sequence = None, + fun_inputs: Callable = None, t: float = None, dt: float = None, included_vars: Dict[str, bm.Variable] = None, @@ -175,7 +176,7 @@ def __init__( # input function if inputs is not None: inputs = check_and_format_inputs(host=self.target, inputs=inputs) - _input_step, _has_iter = build_inputs(inputs) + _input_step, _has_iter = build_inputs(inputs, fun_inputs) if _has_iter: raise UnsupportedError(f'Do not support iterable inputs when using fixed point finder.') else: diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 86a047b07..ba59af43d 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -231,7 +231,7 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None): elif isinstance(nodes, dict): nodes = nodes.values() for node in nodes: - for name in node.local_delay_vars.keys(): + for name in node.local_delay_vars: delay = self.global_delay_data[name][0] target = self.global_delay_data[name][1] delay.update(target.value) @@ -250,7 +250,7 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): elif isinstance(nodes, dict): nodes = nodes.values() for node in nodes: - for name in node.local_delay_vars.keys(): + for name in node.local_delay_vars: delay = self.global_delay_data[name][0] target = self.global_delay_data[name][1] delay.reset(target.value) @@ -260,15 +260,18 @@ def __del__(self): This function is used to pop out the variables which registered in global delay data. """ - for key in tuple(self.local_delay_vars.keys()): - val = self.global_delay_data.pop(key) - del val - val = self.local_delay_vars.pop(key) - del val - for key in tuple(self.implicit_nodes.keys()): - del self.implicit_nodes[key] - for key in tuple(self.implicit_vars.keys()): - del self.implicit_vars[key] + if hasattr(self, 'local_delay_vars'): + for key in tuple(self.local_delay_vars.keys()): + val = self.global_delay_data.pop(key) + del val + val = self.local_delay_vars.pop(key) + del val + if hasattr(self, 'implicit_nodes'): + for key in tuple(self.implicit_nodes.keys()): + del self.implicit_nodes[key] + if hasattr(self, 'implicit_vars'): + for key in tuple(self.implicit_vars.keys()): + del self.implicit_vars[key] for key in tuple(self.__dict__.keys()): del self.__dict__[key] gc.collect() diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py index 0bc01702a..9617e1f57 100644 --- a/brainpy/dyn/runners.py +++ b/brainpy/dyn/runners.py @@ -2,7 +2,7 @@ import time from collections.abc import Iterable -from typing import Dict, Union, Sequence +from typing import Dict, Union, Sequence, Callable import jax import jax.numpy as jnp @@ -33,9 +33,9 @@ def check_and_format_inputs(host, inputs): Parameters ---------- host : DynamicalSystem - The host which contains all data. + The host which contains all data. inputs : tuple, list - The inputs of the population. + The inputs of the population. Returns ------- @@ -161,12 +161,30 @@ def check_and_format_inputs(host, inputs): return formatted_inputs -def build_inputs(inputs): +def build_inputs(inputs, fun_inputs): + """Build input function. + + Parameters + ---------- + inputs : tuple, list + The inputs of the population. + fun_inputs: optional, callable + The input function customized by users. + + Returns + ------- + func: callable + The input function. + """ + fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + if not (fun_inputs is None or callable(fun_inputs)): + raise ValueError + _has_iter_array = False for variable, value, type_, op in inputs: # variable @@ -202,6 +220,8 @@ def _f_ops(ops, var, data): raise ValueError(f'Unknown input operation: {ops}') def func(tdi): + if fun_inputs is not None: + fun_inputs(tdi) for ops, values in fix_inputs.items(): for var, data in values: _f_ops(ops, var, data) @@ -225,6 +245,7 @@ class DSRunner(Runner): ---------- target : DynamicalSystem The target model to run. + inputs : list, tuple The inputs for the target DynamicalSystem. It should be the format of `[(target, value, [type, operation])]`, where `target` is the @@ -239,6 +260,37 @@ class DSRunner(Runner): - ``operation``: should be a string, support `+`, `-`, `*`, `/`, `=`. - Also, if you want to specify multiple inputs, just give multiple ``(target, value, [type, operation])``, for example ``[(target1, value1), (target2, value2)]``. + + fun_inputs: callable + The functional inputs. Manually specify the inputs for the target variables. + This input function should receive one argument `shared` which contains the shared arguments like + time `t`, time step `dt`, and index `i`. + + monitors: None, sequence of str, dict, Monitor + Variables to monitor. + + - A list of string. Like `monitors=['a', 'b', 'c']` + - A list of string with index specification. Like `monitors=[('a', 1), ('b', [1,3,5]), 'c']` + - A dict with the explicit monitor target, like: `monitors={'a': model.spike, 'b': model.V}` + - A dict with the index specification, like: `monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}` + + fun_monitors: dict + Monitoring variables by callable functions. Should be a dict. + The `key` should be a string for the later retrieval by `runner.mon[key]`. + The `value` should be a callable function which receives two arguments: `t` and `dt`. + + jit: bool, dict + The JIT settings. + + progress_bar: bool + Use progress bar to report the running progress or not? + + dyn_vars: Optional, dict + The dynamically changed variables. Instance of :py:class:`~.Variable`. + + numpy_mon_after_run : bool + When finishing the network running, transform the JAX arrays into numpy ndarray or not? + """ target: DynamicalSystem @@ -246,7 +298,12 @@ class DSRunner(Runner): def __init__( self, target: DynamicalSystem, + + # inputs for target variables inputs: Sequence = (), + fun_inputs: Callable = None, + + # extra info dt: float = None, t0: Union[float, int] = 0., **kwargs @@ -269,11 +326,10 @@ def __init__( # Build the monitor function self._mon_info = self.format_monitors() - # self._monitor_step = self.build_monitors(*self.format_monitors()) # Build input function - inputs = check_and_format_inputs(host=target, inputs=inputs) - self._input_step, _ = build_inputs(inputs) + self._input_step, _ = build_inputs(check_and_format_inputs(host=target, inputs=inputs), + fun_inputs=fun_inputs) # run function self._f_predict_compiled = dict() @@ -581,4 +637,3 @@ def __del__(self): for key in tuple(self._f_predict_compiled.keys()): del self._f_predict_compiled[key] super(DSRunner, self).__del__() - diff --git a/brainpy/math/index_tricks.py b/brainpy/math/index_tricks.py index 5748b2a6b..dd3a1c9fb 100644 --- a/brainpy/math/index_tricks.py +++ b/brainpy/math/index_tricks.py @@ -1,17 +1,3 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import abc from jax import core @@ -61,17 +47,18 @@ class _Mgrid(_IndexGrid): Examples: Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - >>> jnp.mgrid[0:4:1] + >>> import brainpy.math as bm + >>> bm.mgrid[0:4:1] DeviceArray([0, 1, 2, 3], dtype=int32) Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - >>> jnp.mgrid[0:1:4j] + >>> bm.mgrid[0:1:4j] DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) Multiple slices can be used to create broadcasted grids of indices: - >>> jnp.mgrid[:2, :3] + >>> bm.mgrid[:2, :3] DeviceArray([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], @@ -96,17 +83,17 @@ class _Ogrid(_IndexGrid): Examples: Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - >>> jnp.ogrid[0:4:1] + >>> bm.ogrid[0:4:1] DeviceArray([0, 1, 2, 3], dtype=int32) Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - >>> jnp.ogrid[0:1:4j] + >>> bm.ogrid[0:1:4j] DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) Multiple slices can be used to create sparse grids of indices: - >>> jnp.ogrid[:2, :3] + >>> bm.ogrid[:2, :3] [DeviceArray([[0], [1]], dtype=int32), DeviceArray([[0, 1, 2]], dtype=int32)] @@ -200,13 +187,13 @@ class RClass(_AxisConcat): Examples: Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects: - >>> jnp.r_[-1:5:1, 0, 0, jnp.array([1,2,3])] + >>> bm.r_[-1:5:1, 0, 0, bm.array([1,2,3])] DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32) An imaginary value for ``step`` will create a ``jnp.linspace`` object instead, which includes the right endpoint: - >>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])] + >>> bm.r_[-1:1:6j, 0, bm.array([1,2,3])] DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005, 0.6 , 1. , 0. , 1. , 2. , 3. ], dtype=float32) @@ -215,11 +202,11 @@ class RClass(_AxisConcat): specify concatenation axis, minimum number of dimensions, and the position of the upgraded array's original dimensions in the resulting array's shape tuple: - >>> jnp.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output + >>> bm.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output DeviceArray([[1, 2, 3], [4, 5, 6]], dtype=int32) - >>> jnp.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front + >>> bm.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front DeviceArray([[1], [2], [3], @@ -230,7 +217,7 @@ class RClass(_AxisConcat): Negative values for ``trans1d`` offset the last axis towards the start of the shape tuple: - >>> jnp.r_['0,2,-2', [1,2,3], [4,5,6]] + >>> bm.r_['0,2,-2', [1,2,3], [4,5,6]] DeviceArray([[1], [2], [3], @@ -241,10 +228,10 @@ class RClass(_AxisConcat): Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs to create an array with an extra row or column axis, respectively: - >>> jnp.r_['r',[1,2,3], [4,5,6]] + >>> bm.r_['r',[1,2,3], [4,5,6]] DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32) - >>> jnp.r_['c',[1,2,3], [4,5,6]] + >>> bm.r_['c',[1,2,3], [4,5,6]] DeviceArray([[1], [2], [3], @@ -274,8 +261,8 @@ class CClass(_AxisConcat): Examples: - >>> a = jnp.arange(6).reshape((2,3)) - >>> jnp.c_[a,a] + >>> a = bm.arange(6).reshape((2,3)) + >>> bm.c_[a,a] DeviceArray([[0, 1, 2, 0, 1, 2], [3, 4, 5, 3, 4, 5]], dtype=int32) @@ -283,7 +270,7 @@ class CClass(_AxisConcat): concatenation axis, minimum number of dimensions, and the position of the upgraded array's original dimensions in the resulting array's shape tuple: - >>> jnp.c_['0,2', [1,2,3], [4,5,6]] + >>> bm.c_['0,2', [1,2,3], [4,5,6]] DeviceArray([[1], [2], [3], @@ -291,7 +278,7 @@ class CClass(_AxisConcat): [5], [6]], dtype=int32) - >>> jnp.c_['0,2,-1', [1,2,3], [4,5,6]] + >>> bm.c_['0,2,-1', [1,2,3], [4,5,6]] DeviceArray([[1, 2, 3], [4, 5, 6]], dtype=int32) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 2c68e5556..15de3627e 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -262,9 +262,6 @@ def __imul__(self, oc): self._value = self._value * (oc._value if isinstance(oc, JaxArray) else oc) return self - # def __div__(self, oc): - # return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc)) - def __rdiv__(self, oc): return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value) @@ -421,44 +418,12 @@ def block_host_until_ready(self, *args): def block_until_ready(self, *args): self._value.block_until_ready(*args) - # def broadcast(self, operand, sizes): - # """Broadcasts an array, adding new major dimensions. - # - # Wraps XLA's `Broadcast - # `_ - # operator. - # - # Parameters - # ---------- - # operand: an array - # sizes: - # A sequence of integers, giving the sizes of new major dimensions - # to add. - # - # Returns - # ------- - # ary : array - # An array containing the result. - # """ - # raise NotImplementedError - # - # def client(self, *args): - # raise NotImplementedError - # - # def clone(self, *args): - # raise NotImplementedError - # - # def copy_to_device(self, *args): - # raise NotImplementedError - # - # def copy_to_host_async(self, *args): - # raise NotImplementedError - # - # def device(self, *args): - # raise NotImplementedError - # - # def device_buffer(self, *args): - # raise NotImplementedError + def device(self): + raise self.value.device() + + @property + def device_buffer(self): + raise self.value.device_buffer # ----------------------- # # NumPy methods # diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index 3132004cf..0be715ccb 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -20,8 +20,8 @@ def register_op( op_name: str, cpu_func: Callable, + out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]], gpu_func: Callable = None, - out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]] = None, apply_cpu_func_to_gpu: bool = False ): """ diff --git a/brainpy/math/operators/pre2post.py b/brainpy/math/operators/pre2post.py index 66dc8f06d..0bf0e59e0 100644 --- a/brainpy/math/operators/pre2post.py +++ b/brainpy/math/operators/pre2post.py @@ -469,24 +469,20 @@ def pre2post_matmul_mask2(event, conn, mask): # f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum() * (Ml[i] * Mr[:, j]).sum(), in_axes=(0, None)) @partial(vmap, in_axes=(0, None)) def f0(i, j): - return cond(event[i] > 0., - lambda _: event[i] * jnp.sum(Cl[i] * Cr[:, j]) * jnp.sum(Ml[i] * Mr[:, j]), - lambda _: 0., - None) - # fori_loop(0, - # Cr.shape[1], - # lambda x: f0(x[0], x[1]), - # ) + return cond(event[i], + lambda: cond(Ml[i] @ Mr[:, j], + lambda: (Cl[i] * Cr[:, j]).sum(), + lambda: 0.), + lambda: 0.) ii = jnp.arange(Cl.shape[0]) jj = jnp.arange(Cr.shape[1]) - def body(_, j): - r = f0(ii, j).sum() - return 0, r - - _, out = scan(body, 0, jj) - - # f1 = jit(vmap(lambda ii, j: f0(ii, j).sum(), in_axes=(None, 0))) - return out + # def body(_, j): + # r = f0(ii, j).sum() + # return 0, r + # _, out = scan(body, 0, jj) + # return out + f = jit(vmap(lambda j: f0(ii, j).sum())) + return f(jj) diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index 319061d18..19ce7a81b 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -24,6 +24,7 @@ class Runner(object): ---------- target: Any The target model. + monitors: None, sequence of str, dict, Monitor Variables to monitor. @@ -36,14 +37,18 @@ class Runner(object): Monitoring variables by callable functions. Should be a dict. The `key` should be a string for later retrieval by `runner.mon[key]`. The `value` should be a callable function which receives two arguments: `t` and `dt`. + jit: bool, dict The JIT settings. + progress_bar: bool - Use progress bar or not? + Use progress bar to report the running progress or not? + dyn_vars: Optional, dict The dynamically changed variables. Instance of :py:class:`~.Variable`. + numpy_mon_after_run : bool - Transform the JAX arrays into numpy ndarray or not, when finishing the network running? + When finishing the network running, transform the JAX arrays into numpy ndarray or not? """ mon: DotDict @@ -210,8 +215,9 @@ def build_monitors(self, return_without_idx, return_with_idx, shared_args) -> Ca raise NotImplementedError def __del__(self): - for key in tuple(self.mon.keys()): - del self.mon[key] + if hasattr(self, 'mon'): + for key in tuple(self.mon.keys()): + del self.mon[key] for key in tuple(self.__dict__.keys()): del self.__dict__[key] gc.collect() diff --git a/brainpy/train/__init__.py b/brainpy/train/__init__.py index 88bbf73ce..a5902d3af 100644 --- a/brainpy/train/__init__.py +++ b/brainpy/train/__init__.py @@ -24,6 +24,6 @@ from .base import * from .back_propagation import * -from .online_trainer import * -from .offline_trainer import * +from .online import * +from .offline import * diff --git a/brainpy/train/base.py b/brainpy/train/base.py index b364ff088..6a264448c 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -14,7 +14,7 @@ from . import constants as c __all__ = [ - 'DSTrainer', 'DSRunner', + 'DSTrainer', ] diff --git a/brainpy/train/offline_trainer.py b/brainpy/train/offline.py similarity index 100% rename from brainpy/train/offline_trainer.py rename to brainpy/train/offline.py diff --git a/brainpy/train/online_trainer.py b/brainpy/train/online.py similarity index 100% rename from brainpy/train/online_trainer.py rename to brainpy/train/online.py diff --git a/changelog.rst b/changelog.rst index a9add84e3..3be71c3bd 100644 --- a/changelog.rst +++ b/changelog.rst @@ -2,16 +2,88 @@ Release notes (brainpy) ####################### -brainpy 2.x (LTS) -***************** +brainpy 2.2.x (LTS) +******************* + +BrainPy 2.2.x is a complete re-design of the framework, +tackling the shortcomings of brainpy 2.1.x generation, +effectively bringing it to industry needs and standards. + + + + + + +brainpy 2.1.x (LTS) +******************* +Version 2.1.12 (2022.05.17) +=========================== + + +Highlights +~~~~~~~~~~ + +This release is excellent. We have made important improvements. + +1. We provide dozens of random sampling in NumPy which are not + supportted in JAX, such as ``brainpy.math.random.bernoulli``, + ``brainpy.math.random.lognormal``, ``brainpy.math.random.binomial``, + ``brainpy.math.random.chisquare``, ``brainpy.math.random.dirichlet``, + ``brainpy.math.random.geometric``, ``brainpy.math.random.f``, + ``brainpy.math.random.hypergeometric``, + ``brainpy.math.random.logseries``, + ``brainpy.math.random.multinomial``, + ``brainpy.math.random.multivariate_normal``, + ``brainpy.math.random.negative_binomial``, + ``brainpy.math.random.noncentral_chisquare``, + ``brainpy.math.random.noncentral_f``, ``brainpy.math.random.power``, + ``brainpy.math.random.rayleigh``, ``brainpy.math.random.triangular``, + ``brainpy.math.random.vonmises``, ``brainpy.math.random.wald``, + ``brainpy.math.random.weibull`` +2. make efficient checking on numerical values. Instead of direct + ``id_tap()`` checking which has large overhead, currently + ``brainpy.tools.check_erro_in_jit()`` is highly efficient. +3. Fix ``JaxArray`` operator errors on ``None`` +4. improve oo-to-function transformation speeds +5. ``io`` works: ``.save_states()`` and ``.load_states()`` + +What’s Changed +~~~~~~~~~~~~~~ + +- support dtype setting in array interchange functions by + [@chaoming0625](https://github.com/chaoming0625) in + `#209 `__ +- fix `#144 `__: + operations on None raise errors by + [@chaoming0625](https://github.com/chaoming0625) in + `#210 `__ +- add tests and new functions for random sampling by + [@c-xy17](https://github.com/c-xy17) in + `#213 `__ +- feat: fix ``io`` for brainpy.Base by + [@chaoming0625](https://github.com/chaoming0625) in + `#211 `__ +- update advanced tutorial documentation by + [@chaoming0625](https://github.com/chaoming0625) in + `#212 `__ +- fix `#149 `__ + (dozens of random samplings in NumPy) and fix JaxArray op errors by + [@chaoming0625](https://github.com/chaoming0625) in + `#216 `__ +- feat: efficient checking on numerical values by + [@chaoming0625](https://github.com/chaoming0625) in + `#217 `__ + +**Full Changelog**: +`V2.1.11...V2.1.12 `__ Version 2.1.11 (2022.05.15) -========================== +=========================== What's Changed @@ -29,7 +101,7 @@ What's Changed Version 2.1.10 (2022.05.05) -========================== +=========================== What's Changed diff --git a/docs/apis/algorithms.rst b/docs/apis/algorithms.rst new file mode 100644 index 000000000..d90f44023 --- /dev/null +++ b/docs/apis/algorithms.rst @@ -0,0 +1,14 @@ +``brainpy.algorithms`` module +============================= + +.. currentmodule:: brainpy.algorithms +.. automodule:: brainpy.algorithms + + +.. toctree:: + :maxdepth: 1 + + auto/algorithms/offline + auto/algorithms/online + auto/algorithms/utils + diff --git a/docs/apis/compat.rst b/docs/apis/compat.rst deleted file mode 100644 index 03d41492c..000000000 --- a/docs/apis/compat.rst +++ /dev/null @@ -1,16 +0,0 @@ -``brainpy.compat`` module -=========================== - -.. currentmodule:: brainpy.compat -.. automodule:: brainpy.compat - - -.. toctree:: - :maxdepth: 1 - - auto/compat/brainobjects - auto/compat/integrators - auto/compat/layers - auto/compat/models - auto/compat/runners - auto/compat/monitor diff --git a/docs/apis/compat_nn.rst b/docs/apis/compat_nn.rst deleted file mode 100644 index ad22b3247..000000000 --- a/docs/apis/compat_nn.rst +++ /dev/null @@ -1,19 +0,0 @@ -``brainpy.compat.nn`` module -============================ - -.. currentmodule:: brainpy.compat.nn -.. automodule:: brainpy.compat.nn - - -.. toctree:: - :maxdepth: 1 - - auto/compat/nn_base - auto/compat/nn_operations - auto/compat/nn_graph_flow - auto/compat/nn_runners - auto/compat/nn_algorithms - auto/compat/nn_data_types - auto/compat/nn_nodes_base - auto/compat/nn_nodes_ANN - auto/compat/nn_nodes_RC diff --git a/docs/apis/dyn.rst b/docs/apis/dyn.rst index 4b0d64853..2a041fb2a 100644 --- a/docs/apis/dyn.rst +++ b/docs/apis/dyn.rst @@ -15,6 +15,7 @@ auto/dyn/synouts auto/dyn/synplast auto/dyn/rates + auto/dyn/layers auto/dyn/channel_base auto/dyn/channel_sodium auto/dyn/channel_potassium diff --git a/docs/apis/math_compat.rst b/docs/apis/math_compat.rst deleted file mode 100644 index fb32766fa..000000000 --- a/docs/apis/math_compat.rst +++ /dev/null @@ -1,12 +0,0 @@ -``brainpy.math.compat`` module -=============================== - -.. currentmodule:: brainpy.math.compat -.. automodule:: brainpy.math.compat - - -.. toctree:: - :maxdepth: 1 - - auto/math/optimizers - auto/math/losses diff --git a/docs/apis/train.rst b/docs/apis/train.rst new file mode 100644 index 000000000..ea3e53dc7 --- /dev/null +++ b/docs/apis/train.rst @@ -0,0 +1,15 @@ +``brainpy.train`` module +======================== + +.. currentmodule:: brainpy.train +.. automodule:: brainpy.train + + +.. toctree:: + :maxdepth: 1 + + auto/train/base + auto/train/online + auto/train/offline + auto/train/back_propagation + diff --git a/docs/auto_generater.py b/docs/auto_generater.py index 8fe673b62..f94653560 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -5,12 +5,10 @@ import os from brainpy.math import (activations, autograd, controls, function, - jit, parallels, setting, delayvars, operators, - compat) + jit, parallels, setting, delayvars, operators) block_list = ['test', 'register_pytree_node', 'call', 'namedtuple', 'jit', 'wraps', 'index', 'function'] -for module in [jit, autograd, function, controls, activations, parallels, setting, delayvars, compat, - operators]: +for module in [jit, autograd, function, controls, activations, parallels, setting, delayvars, operators]: for k in dir(module): if (not k.startswith('_')) and (not inspect.ismodule(getattr(module, k))): block_list.append(k) @@ -173,6 +171,20 @@ def _section(header, numpy_mod, brainpy_mod, jax_mod, klass=None, is_jax=False): return buf +def generate_algorithm_docs(path='apis/auto/algorithms/'): + if not os.path.exists(path): os.makedirs(path) + + write_module(module_name='brainpy.algorithms.offline', + filename=os.path.join(path, 'offline.rst'), + header='Offline Training Algorithms') + write_module(module_name='brainpy.algorithms.online', + filename=os.path.join(path, 'online.rst'), + header='Online Training Algorithms') + write_module(module_name='brainpy.algorithms.utils', + filename=os.path.join(path, 'utils.rst'), + header='Training Algorithm Utilities') + + def generate_analysis_docs(path='apis/auto/analysis/'): if not os.path.exists(path): os.makedirs(path) @@ -187,6 +199,23 @@ def generate_analysis_docs(path='apis/auto/analysis/'): header='Stability Analysis') +def generate_train_docs(path='apis/auto/train/'): + if not os.path.exists(path): + os.makedirs(path) + write_module(module_name='brainpy.train.base', + filename=os.path.join(path, 'base.rst'), + header='Base Training Class') + write_module(module_name='brainpy.train.online', + filename=os.path.join(path, 'online.rst'), + header='Online Training Method') + write_module(module_name='brainpy.train.offline', + filename=os.path.join(path, 'offline.rst'), + header='Offline Training Method') + write_module(module_name='brainpy.train.back_propagation', + filename=os.path.join(path, 'back_propagation.rst'), + header='Back-propagation Training Method') + + def generate_base_docs(path='apis/auto/'): if not os.path.exists(path): os.makedirs(path) @@ -275,8 +304,10 @@ def generate_dyn_docs(path='apis/auto/dyn/'): section_names=[a[1] for a in module_and_name]) module_and_name = [ - ('biological_models', 'Biological Models'), ('abstract_models', 'Abstract Models'), + ('biological_models', 'Biological Models'), + ('couplings', 'Coupling Models'), + ('gap_junction', 'Gap Junction Models'), ('learning_rules', 'Learning Rule Models'), ] write_submodules(module_name='brainpy.dyn.synapses', @@ -293,7 +324,6 @@ def generate_dyn_docs(path='apis/auto/dyn/'): module_and_name = [ ('populations', 'Population Models'), - ('couplings', 'Coupling Models'), ] write_submodules(module_name='brainpy.dyn.rates', filename=os.path.join(path, 'rates.rst'), @@ -389,8 +419,8 @@ def generate_losses_docs(path='apis/auto/'): os.makedirs(path) module_and_name = [ - ('Comparison', 'comparison'), - ('Regularization', 'regularization'), + ('comparison', 'Comparison', ), + ('regularization', 'Regularization', ), ] write_submodules(module_name='brainpy.losses', filename=os.path.join(path, 'losses.rst'), @@ -474,6 +504,59 @@ def generate_measure_docs(path='apis/auto/'): header='``brainpy.measure`` module') + +def generate_optimizers_docs(path='apis/auto/'): + if not os.path.exists(path): + os.makedirs(path) + + module_and_name = [ + ('optimizer', 'Optimizers'), + ('scheduler', 'Schedulers'), + ] + write_submodules(module_name='brainpy.optimizers', + filename=os.path.join(path, 'optimizers.rst'), + header='``brainpy.optimizers`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + + +def generate_running_docs(path='apis/auto/'): + if not os.path.exists(path): + os.makedirs(path) + + module_and_name = [ + ('multiprocess', 'Parallel Pool'), + ('runner', 'Runners') + ] + write_submodules(module_name='brainpy.running', + filename=os.path.join(path, 'running.rst'), + header='``brainpy.running`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + + +def generate_tools_docs(path='apis/auto/tools/'): + if not os.path.exists(path): + os.makedirs(path) + + write_module(module_name='brainpy.tools.checking', + filename=os.path.join(path, 'checking.rst'), + header='Type Checking') + write_module(module_name='brainpy.tools.codes', + filename=os.path.join(path, 'codes.rst'), + header='Code Tools') + write_module(module_name='brainpy.tools.others', + filename=os.path.join(path, 'others.rst'), + header='Other Tools') + write_module(module_name='brainpy.tools.errors', + filename=os.path.join(path, 'errors.rst'), + header='Error Tools') + + +# ---------- # +# Deprecated # +# ---------- # + def generate_nn_docs(path='apis/auto/nn/'): if not os.path.exists(path): os.makedirs(path) @@ -524,56 +607,6 @@ def generate_nn_docs(path='apis/auto/nn/'): filename=os.path.join(path, 'nodes_RC.rst'), header='Nodes: reservoir computing') - -def generate_optimizers_docs(path='apis/auto/'): - if not os.path.exists(path): - os.makedirs(path) - - module_and_name = [ - ('optimizer', 'Optimizers'), - ('scheduler', 'Schedulers'), - ] - write_submodules(module_name='brainpy.optimizers', - filename=os.path.join(path, 'optimizers.rst'), - header='``brainpy.optimizers`` module', - submodule_names=[k[0] for k in module_and_name], - section_names=[k[1] for k in module_and_name]) - - -def generate_running_docs(path='apis/auto/'): - if not os.path.exists(path): - os.makedirs(path) - - module_and_name = [ - ('monitor', 'Monitors'), - ('parallel', 'Parallel Pool'), - ('runner', 'Runners') - ] - write_submodules(module_name='brainpy.running', - filename=os.path.join(path, 'running.rst'), - header='``brainpy.running`` module', - submodule_names=[k[0] for k in module_and_name], - section_names=[k[1] for k in module_and_name]) - - -def generate_tools_docs(path='apis/auto/tools/'): - if not os.path.exists(path): - os.makedirs(path) - - write_module(module_name='brainpy.tools.checking', - filename=os.path.join(path, 'checking.rst'), - header='Type Checking') - write_module(module_name='brainpy.tools.codes', - filename=os.path.join(path, 'codes.rst'), - header='Code Tools') - write_module(module_name='brainpy.tools.others', - filename=os.path.join(path, 'others.rst'), - header='Other Tools') - write_module(module_name='brainpy.tools.errors', - filename=os.path.join(path, 'errors.rst'), - header='Error Tools') - - def generate_compact_docs(path='apis/auto/compat/'): if not os.path.exists(path): os.makedirs(path) diff --git a/docs/conf.py b/docs/conf.py index 5bebd4968..00eb80e12 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,10 +21,11 @@ from docs import auto_generater auto_generater.generate_base_docs() +auto_generater.generate_analysis_docs() +auto_generater.generate_train_docs() +auto_generater.generate_algorithm_docs() auto_generater.generate_math_docs() auto_generater.generate_dyn_docs() -auto_generater.generate_nn_docs() -auto_generater.generate_analysis_docs() auto_generater.generate_integrators_doc() auto_generater.generate_inputs_docs() auto_generater.generate_running_docs() @@ -35,7 +36,8 @@ auto_generater.generate_measure_docs() auto_generater.generate_datasets_docs() auto_generater.generate_tools_docs() -auto_generater.generate_compact_docs() +# auto_generater.generate_nn_docs() +# auto_generater.generate_compact_docs() # auto_generater.generate_math_compact_docs() diff --git a/docs/index.rst b/docs/index.rst index 0973c3f8e..726c45847 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -91,10 +91,11 @@ The code of BrainPy is open-sourced at GitHub: apis/auto/base.rst apis/math.rst apis/dyn.rst - apis/nn.rst + apis/train.rst apis/analysis.rst apis/integrators.rst apis/datasets.rst + apis/algorithms.rst apis/auto/inputs.rst apis/auto/connect.rst apis/auto/initialize.rst @@ -103,7 +104,6 @@ The code of BrainPy is open-sourced at GitHub: apis/auto/measure.rst apis/auto/running.rst apis/tools.rst - apis/compat.rst apis/auto/changelog-brainpy.rst apis/auto/changelog-brainpylib.rst