diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index a08837c82..6d69816bd 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -53,10 +53,9 @@ from . import surrogate from .surrogate.compt import * -# JAX transformations for Variable and class objects +# Variable and Objects for object-oriented JAX transformations from .object_transform import * - # environment settings from .modes import * from .environment import * diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index 14a1bc6b1..2df3c27e7 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -2,12 +2,13 @@ import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple as TupleType import numpy as np from jax import numpy as jnp from jax.tree_util import register_pytree_node + from brainpy.errors import MathError __all__ = [ @@ -997,7 +998,7 @@ def __init__( f'but the batch axis is set to be {batch_axis}.') @property - def shape_nb(self) -> Tuple[int, ...]: + def shape_nb(self) -> TupleType[int, ...]: """Shape without batch axis.""" shape = list(self.value.shape) if self.batch_axis is not None: @@ -1562,7 +1563,6 @@ class BatchVariable(Variable): pass - class VariableView(Variable): """A view of a Variable instance. @@ -1742,7 +1742,7 @@ def _jaxarray_unflatten(aux_data, flat_contents): register_pytree_node(Array, - lambda t: ((t.value,), (t._transform_context, )), + lambda t: ((t.value,), (t._transform_context,)), _jaxarray_unflatten) register_pytree_node(Variable, @@ -1756,3 +1756,4 @@ def _jaxarray_unflatten(aux_data, flat_contents): register_pytree_node(Parameter, lambda t: ((t.value,), None), lambda aux_data, flat_contents: Parameter(*flat_contents)) + diff --git a/brainpy/math/object_transform/__init__.py b/brainpy/math/object_transform/__init__.py index 53b64c9ef..12545f566 100644 --- a/brainpy/math/object_transform/__init__.py +++ b/brainpy/math/object_transform/__init__.py @@ -30,9 +30,15 @@ + controls.__all__ + jit.__all__ + function.__all__ + + base_object.__all__ + + base_transform.__all__ + + collector.__all__ ) from .autograd import * from .controls import * from .jit import * from .function import * +from .base_object import * +from .base_transform import * +from .collector import * diff --git a/brainpy/math/object_transform/base_object.py b/brainpy/math/object_transform/base_object.py index f5b3b8a85..c78ba20da 100644 --- a/brainpy/math/object_transform/base_object.py +++ b/brainpy/math/object_transform/base_object.py @@ -1,26 +1,39 @@ # -*- coding: utf-8 -*- import os -import logging import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union +import jax +import numpy as np +from jax._src.tree_util import _registry +from jax.tree_util import register_pytree_node +from jax.tree_util import register_pytree_node_class +from jax.util import safe_zip + from brainpy import errors from .collector import Collector, ArrayCollector -from ..ndarray import Variable, VariableView, TrainVar +from ..ndarray import (Array, + Variable, + VariableView, + TrainVar) StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) + __all__ = [ - 'check_name_uniqueness', - 'get_unique_name', - 'clear_name_cache', + # naming + 'check_name_uniqueness', 'get_unique_name', 'clear_name_cache', + # objects 'BrainPyObject', 'Base', 'FunAsObject', + + # variables + 'numerical_seq', 'object_seq', + 'numerical_dict', 'object_dict', ] -logger = logging.getLogger('brainpy.brainpy_object') _name2id = dict() _typed_names = {} @@ -59,7 +72,7 @@ def clear_name_cache(ignore_warn=False): _name2id.clear() _typed_names.clear() if not ignore_warn: - logger.warning(f'All named models and their ids are cleared.') + warnings.warn(f'All named models and their ids are cleared.', UserWarning) class BrainPyObject(object): @@ -78,6 +91,11 @@ class BrainPyObject(object): _excluded_vars = () def __init__(self, name=None): + super().__init__() + cls = self.__class__ + if cls not in _registry: + register_pytree_node_class(cls) + # check whether the object has a unique name. self._name = None self._name = self.unique_name(name=name) @@ -91,15 +109,17 @@ def __init__(self, name=None): # which cannot be accessed by self.xxx self.implicit_nodes = Collector() - def __setattr__(self, key, value) -> None: - """Overwrite __setattr__ method for non-changeable Variable setting. + def __setattr__(self, key: str, value: Any) -> None: + """Overwrite `__setattr__` method for change Variable values. .. versionadded:: 2.3.1 Parameters ---------- key: str + The attribute. value: Any + The value. """ if key in self.__dict__: val = self.__dict__[key] @@ -109,19 +129,24 @@ def __setattr__(self, key, value) -> None: super().__setattr__(key, value) def tree_flatten(self): - """ + """Flattens the object as a PyTree. + + The flattening order is determined by attributes added order. + .. versionadded:: 2.3.1 Returns ------- - + res: tuple + A tuple of dynamical values and static values. """ + dts = (BrainPyObject,) + tuple(dynamical_types) dynamic_names = [] dynamic_values = [] static_names = [] static_values = [] for k, v in self.__dict__.items(): - if isinstance(v, (ArrayCollector, BrainPyObject, Variable)): + if isinstance(v, dts): dynamic_names.append(k) dynamic_values.append(v) else: @@ -531,3 +556,85 @@ def __repr__(self) -> str: node_string = ", \n".join(nodes) return (f'{name}(nodes=[{node_string}],\n' + " " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})') + + +class numerical_seq(list): + """A list to represent a dynamically changed numerical + sequence in which its element can be changed during JIT compilation. + + .. note:: + The element must be numerical, like ``bool``, ``int``, ``float``, + ``jax.Array``, ``numpy.ndarray``, ``brainpy.math.Array``. + """ + def append(self, element): + if not isinstance(element, (bool, int, float, jax.Array, Array, np.ndarray)): + raise TypeError(f'Each element should be a numerical value.') + + def extend(self, iterable) -> None: + for element in iterable: + self.append(element) + + +register_pytree_node(numerical_seq, + lambda x: (tuple(x), ()), + lambda _, values: numerical_seq(values)) + + +class object_seq(list): + """A list to represent a sequence of :py:class:`~.BrainPyObject`. + + .. note:: + The element must be :py:class:`~.BrainPyObject`. + """ + def append(self, element): + if not isinstance(element, BrainPyObject): + raise TypeError(f'Only support {BrainPyObject.__name__}') + + def extend(self, iterable) -> None: + for element in iterable: + self.append(element) + + +register_pytree_node(object_seq, + lambda x: (tuple(x), ()), + lambda _, values: object_seq(values)) + + +class numerical_dict(dict): + """A dict to represent a dynamically changed numerical + dictionary in which its element can be changed during JIT compilation. + + .. note:: + Each key must be a string, and each value must be numerical, including + ``bool``, ``int``, ``float``, ``jax.Array``, ``numpy.ndarray``, + ``brainpy.math.Array``. + """ + def update(self, *args, **kwargs) -> 'numerical_dict': + super().update(*args, **kwargs) + return self + + +register_pytree_node(numerical_dict, + lambda x: (tuple(x.values()), tuple(x.keys())), + lambda keys, values: numerical_dict(safe_zip(keys, values))) + + +class object_dict(dict): + """A dict to represent a dictionary of :py:class:`~.BrainPyObject`. + + .. note:: + Each key must be a string, and each value must be :py:class:`~.BrainPyObject`. + """ + def update(self, *args, **kwargs) -> 'object_dict': + super().update(*args, **kwargs) + return self + + +register_pytree_node(object_dict, + lambda x: (tuple(x.values()), tuple(x.keys())), + lambda keys, values: object_dict(safe_zip(keys, values))) + +dynamical_types = [Variable, + numerical_seq, numerical_dict, + object_seq, object_dict] + diff --git a/brainpy/tools/dicts.py b/brainpy/tools/dicts.py index 6dbd0c4dc..d177daefe 100644 --- a/brainpy/tools/dicts.py +++ b/brainpy/tools/dicts.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- + +from typing import Union, Dict, Sequence + import numpy as np from jax.tree_util import register_pytree_node from jax.util import safe_zip @@ -98,6 +101,119 @@ def update(self, *args, **kwargs): super().update(*args, **kwargs) return self + def __add__(self, other): + """Merging two dicts. + + Parameters + ---------- + other: dict + The other dict instance. + + Returns + ------- + gather: Collector + The new collector. + """ + gather = type(self)(self) + gather.update(other) + return gather + + def __sub__(self, other: Union[Dict, Sequence]): + """Remove other item in the collector. + + Parameters + ---------- + other: dict, sequence + The items to remove. + + Returns + ------- + gather: Collector + The new collector. + """ + if not isinstance(other, (dict, tuple, list)): + raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.') + gather = type(self)(self) + if isinstance(other, dict): + for key, val in other.items(): + if key in gather: + if id(val) != id(gather[key]): + raise ValueError(f'Cannot remove {key}, because we got two different values: ' + f'{val} != {gather[key]}') + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') + elif isinstance(other, (list, tuple)): + id_to_keys = {} + for k, v in self.items(): + id_ = id(v) + if id_ not in id_to_keys: + id_to_keys[id_] = [] + id_to_keys[id_].append(k) + + keys_to_remove = [] + for key in other: + if isinstance(key, str): + keys_to_remove.append(key) + else: + keys_to_remove.extend(id_to_keys[id(key)]) + + for key in set(keys_to_remove): + if key in gather: + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') + else: + raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}') + return gather + + def subset(self, var_type): + """Get the subset of the (key, value) pair. + + ``subset()`` can be used to get a subset of some class: + + >>> import brainpy as bp + >>> + >>> some_collector = Collector() + >>> + >>> # get all trainable variables + >>> some_collector.subset(bp.math.TrainVar) + >>> + >>> # get all Variable + >>> some_collector.subset(bp.math.Variable) + + or, it can be used to get a subset of integrators: + + >>> # get all ODE integrators + >>> some_collector.subset(bp.ode.ODEIntegrator) + + Parameters + ---------- + var_type : type + The type/class to match. + """ + gather = type(self)() + for key, value in self.items(): + if isinstance(value, var_type): + gather[key] = value + return gather + + def unique(self): + """Get a new type of collector with unique values. + + If one value is assigned to two or more keys, + then only one pair of (key, value) will be returned. + """ + gather = type(self)() + seen = set() + for k, v in self.items(): + if id(v) not in seen: + seen.add(id(v)) + gather[k] = v + return gather + register_pytree_node( DotDict, diff --git a/docs/core_concept/brainpy_dynamical_system.ipynb b/docs/core_concept/brainpy_dynamical_system.ipynb index 8a78a9713..567a84b6d 100644 --- a/docs/core_concept/brainpy_dynamical_system.ipynb +++ b/docs/core_concept/brainpy_dynamical_system.ipynb @@ -21,7 +21,7 @@ { "cell_type": "markdown", "source": [ - "BrainPy supports models in brain simulation and brain-inspired computing.\n", + "BrainPy supports modelings in brain simulation and brain-inspired computing.\n", "\n", "All these supports are based on one common concept: **Dynamical System** via ``brainpy.DynamicalSystem``.\n", "\n", @@ -71,7 +71,26 @@ { "cell_type": "markdown", "source": [ - "All models used in brain simulation and brain-inspired computing is ``DynamicalSystem``.\n", + "All models used in brain simulation and brain-inspired computing is ``DynamicalSystem``.\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "```{note}\n", + "``DynamicalSystem`` is a subclass of ``BrainPyOject``. Therefore it supports to use [object-oriented transformations](./brainpy_transform_concept.ipynb) as stated in the previous tutorial.\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ "\n", "A ``DynamicalSystem`` defines the updating rule of the model at single time step.\n", "\n", @@ -117,6 +136,12 @@ "\n", "First, *all ``DynamicalSystem`` should implement ``.update()`` function*, which receives two arguments:\n", "\n", + "```\n", + "class YourModel(bp.DynamicalSystem):\n", + " def update(self, s, x):\n", + " pass\n", + "```\n", + "\n", "- `s` (or named as others): A dict, to indicate shared arguments across all nodes/layers in the network, like\n", " - the current time ``t``, or\n", " - the current running index ``i``, or\n", @@ -124,7 +149,7 @@ " - the current phase of training or testing ``fit=True/False``.\n", "- `x` (or named as others): The individual input for this node/layer.\n", "\n", - "We call `s` as shared arguments because they are shared and same for all nodes/layers at current time step. On the contrary, different nodes/layers have different input `x`." + "We call `s` as shared arguments because they are same and shared for all nodes/layers. On the contrary, different nodes/layers have different input `x`." ], "metadata": { "collapsed": false @@ -180,7 +205,7 @@ " def update(self, s, x):\n", " # define how the model states update\n", " # according to the external input\n", - " t, dt = s.get('t'), s.get('dt', bm.dt)\n", + " t, dt = s.get('t'), s.get('dt')\n", " V = self.integral(self.V, t, x, dt=dt)\n", " spike = V >= self.V_th\n", " self.V.value = bm.where(spike, self.V_rest, V)\n", @@ -198,9 +223,9 @@ "\n", "Second, **explicitly consider which computing mode your ``DynamicalSystem`` supports**.\n", "\n", - "Brain simulation usually constructs models without batching dimension (we refer to it as *non-batching mode*, as seen in above LIF model), while brain-inspired computation trains models with a batch of data (*batching mode* or *training mode*).\n", + "Brain simulation usually builds models without batching dimension (we refer to it as *non-batching mode*, as seen in above LIF model), while brain-inspired computation trains models with a batch of data (*batching mode* or *training mode*).\n", "\n", - "So, to write a model applicable to the abroad applications in brain simulation and brain-inspired computing, you need to consider which mode your model supports, one of them, or both of them." + "So, to write a model applicable to abroad applications in brain simulation and brain-inspired computing, you need to consider which mode your model supports, one of them, or both of them." ], "metadata": { "collapsed": false @@ -213,13 +238,13 @@ "\n", "When considering the computing mode, we can program a general LIF model for brain simulation and brain-inspired computing.\n", "\n", - "To overcome the non-differential property of the spike in the LIF model for brain simulation, for the code\n", + "To overcome the non-differential property of the spike in the LIF model for brain simulation, i.e., at the code of\n", "\n", "```python\n", "spike = V >= self.V_th\n", "```\n", "\n", - "LIF models used in brain-inspired computing calculate the spiking state using the surrogate gradient function, i.e., replacing the backward gradient with a smooth function, like\n", + "LIF models used in brain-inspired computing calculate the spiking state using the surrogate gradient function. Usually, we replace the backward gradient of the spike with a smooth function, like\n", "\n", "$$\n", "g'(x) = \\frac{1}{(\\alpha * |x| + 1.) ^ 2}\n", @@ -289,6 +314,15 @@ "collapsed": false } }, + { + "cell_type": "markdown", + "source": [ + "The following code snippet utilizes the LIF model to build an E/I balanced network ``EINet``, which is a classical network model in brain simulation." + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "code", "execution_count": 21, @@ -327,7 +361,7 @@ { "cell_type": "markdown", "source": [ - "Here the ``EINet`` defines an E/I balanced network which is a classical network model in brain simulation. The following ``AINet`` utilizes the LIF model to construct a model for AI training." + "Moreover, our LIF model can also be used in brain-inspired computing scenario. The following ``AINet`` uses the LIF model to construct a model for AI training." ], "metadata": { "collapsed": false @@ -389,7 +423,7 @@ "source": [ "### 1. ``brainpy.math.for_loop``\n", "\n", - "``for_loop`` is a structural control flow API which runs a function with the looping over the inputs.\n", + "``for_loop`` is a structural control flow API which runs a function with the looping over the inputs. Moreover, this API just-in-time compile the looping process into the machine code.\n", "\n", "Suppose we have 200 time steps with the step size of 0.1, we can run the model with:" ], @@ -430,9 +464,9 @@ { "cell_type": "markdown", "source": [ - "### 2. ``brainpy.DSRunner`` and ``brainpy.DSTrainer``\n", + "### 2. ``brainpy.DSRunner``\n", "\n", - "Another way to run the model in BrainPy is using the structural running object ``DSRunner`` and ``DSTrainer``. They provide more flexible way to monitoring the variables in a ``DynamicalSystem``.\n" + "Another way to run the model in BrainPy is using the structural running object ``DSRunner`` and ``DSTrainer``. They provide more flexible way to monitoring the variables in a ``DynamicalSystem``. The details users should refer to the [DSRunner tutorial](../tutorial_simulation/simulation_dsrunner.ipynb).\n" ], "metadata": { "collapsed": false diff --git a/docs/core_concept/brainpy_transform_concept.ipynb b/docs/core_concept/brainpy_transform_concept.ipynb index 1a26b8fd5..ec309369a 100644 --- a/docs/core_concept/brainpy_transform_concept.ipynb +++ b/docs/core_concept/brainpy_transform_concept.ipynb @@ -556,31 +556,6 @@ "![](./imgs/grad_with_loss.png)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To inspect what OO transformations currently BrainPy supports, you can use" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": "['grad',\n 'vector_grad',\n 'jacobian',\n 'jacrev',\n 'jacfwd',\n 'hessian',\n 'make_loop',\n 'make_while',\n 'make_cond',\n 'cond',\n 'ifelse',\n 'for_loop',\n 'while_loop',\n 'jit',\n 'to_object']" - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "bm.object_transform.__all__" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/docs/quickstart/analysis.ipynb b/docs/quickstart/analysis.ipynb index 1314b8866..f09160b8d 100644 --- a/docs/quickstart/analysis.ipynb +++ b/docs/quickstart/analysis.ipynb @@ -339,7 +339,7 @@ "execution_count": 8, "outputs": [], "source": [ - "class GJCoupledFHN(bp.dyn.DynamicalSystem):\n", + "class GJCoupledFHN(bp.DynamicalSystem):\n", " def __init__(self, num=4, method='exp_auto'):\n", " super(GJCoupledFHN, self).__init__()\n", "\n", @@ -421,7 +421,7 @@ "\n", "# simulation with an input\n", "Iext = bm.asarray([0., 0., 0., 0.6])\n", - "runner = bp.dyn.DSRunner(model, monitors=['V'], inputs=['Iext', Iext])\n", + "runner = bp.DSRunner(model, monitors=['V'], inputs=['Iext', Iext])\n", "runner.run(300.)\n", "\n", "# visualization\n", diff --git a/docs/quickstart/training.ipynb b/docs/quickstart/training.ipynb index 255729d01..3f853a692 100644 --- a/docs/quickstart/training.ipynb +++ b/docs/quickstart/training.ipynb @@ -983,20 +983,18 @@ " self.num_out = num_out\n", "\n", " # neuron groups\n", - " self.i = bp.neurons.InputGroup(num_in, mode=bp.modes.training)\n", - " self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1., mode=bp.modes.training)\n", - " self.o = bp.neurons.LeakyIntegrator(num_out, tau=5, mode=bp.modes.training)\n", + " self.i = bp.neurons.InputGroup(num_in)\n", + " self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)\n", + " self.o = bp.neurons.LeakyIntegrator(num_out, tau=5)\n", "\n", " # synapse: i->r\n", " self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), tau=10.,\n", " output=bp.synouts.CUBA(target_var=None),\n", - " g_max=bp.init.KaimingNormal(scale=20.),\n", - " mode=bp.modes.training)\n", + " g_max=bp.init.KaimingNormal(scale=20.))\n", " # synapse: r->o\n", " self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), tau=10.,\n", " output=bp.synouts.CUBA(target_var=None),\n", - " g_max=bp.init.KaimingNormal(scale=20.),\n", - " mode=bp.modes.training)\n", + " g_max=bp.init.KaimingNormal(scale=20.))\n", "\n", " # whole model\n", " self.model = bp.Sequential(self.i, self.i2r, self.r, self.r2o, self.o)\n",