Skip to content

Commit

Permalink
brainpy.math.defjvp and brainpy.math.XLACustomOp.defjvp (#554)
Browse files Browse the repository at this point in the history
* [running] fix multiprocessing bugs

* fix tests

* [doc] update doc

* update

* [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation

* [math] `clear_buffer_memory` support to clear array and compilation both

* [dyn] compatible old version of `.reset_state()` function

* [setup] update installation info

* [install] upgrade dependency

* updates

* [math] add `brainpy.math.defjvp`, support to define jvp rules for Primitive with multiple results. See examples in `test_ad_support.py`

* [math] add `brainpy.math.XLACustomOp.defjvp`

* [doc] upgrade `brainpy.math.defjvp` docstring
  • Loading branch information
chaoming0625 committed Dec 4, 2023
1 parent 6c599a7 commit 8c28685
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 46 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
</p>


BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Taichi](https://github.com/taichi-dev/taichi), [Numba](https://github.com/numba/numba), and others). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.

- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
- **Source**: https://github.com/brainpy/BrainPy
Expand Down Expand Up @@ -77,6 +77,7 @@ We provide a Binder environment for BrainPy. You can use the following button to
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
- [《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling)
- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course)


Expand Down
4 changes: 2 additions & 2 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.6"
__version__ = "2.4.6.post2"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down Expand Up @@ -75,7 +75,7 @@
)
NeuGroup = NeuGroupNS = dyn.NeuDyn

# shared parameters
# common tools
from brainpy._src.context import (share as share)
from brainpy._src.helpers import (reset_state as reset_state,
save_state as save_state,
Expand Down
38 changes: 19 additions & 19 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
import os
import sys
from jax.lib import xla_client


__all__ = [
'import_taichi',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]


_minimal_brainpylib_version = '0.1.10'
_minimal_taichi_version = (1, 7, 0)

taichi = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
'> pip install taichi==1.7.0 -U')
os.environ["TI_LOG_LEVEL"] = "error"


def import_taichi():
global taichi
if taichi is None:
try:
import taichi as taichi # noqa
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)

if taichi.__version__ < _minimal_taichi_version:
raise RuntimeError(
f'We need taichi>={_minimal_taichi_version}. '
f'Currently you can install taichi>={_minimal_taichi_version} through taichi-nightly:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
with open(os.devnull, 'w') as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
import taichi as taichi # noqa
except ModuleNotFoundError:
raise ModuleNotFoundError(taichi_install_info)
finally:
sys.stdout = old_stdout

if taichi.__version__ != _minimal_taichi_version:
raise RuntimeError(taichi_install_info)
return taichi


Expand Down Expand Up @@ -82,6 +85,3 @@ def import_brainpylib_gpu_ops():
'See https://brainpy.readthedocs.io for installation instructions.')

return brainpylib_gpu_ops



198 changes: 194 additions & 4 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class STDP_Song2000(Projection):
\begin{aligned}
\frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
\frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\
\frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\
\frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
\frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
\end{aligned}
where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
Expand All @@ -64,8 +64,8 @@ class STDP_Song2000(Projection):
class STDPNet(bp.DynamicalSystem):
def __init__(self, num_pre, num_post):
super().__init__()
self.pre = bp.dyn.LifRef(num_pre, name='neu1')
self.post = bp.dyn.LifRef(num_post, name='neu2')
self.pre = bp.dyn.LifRef(num_pre)
self.post = bp.dyn.LifRef(num_post)
self.syn = bp.dyn.STDP_Song2000(
pre=self.pre,
delay=1.,
Expand Down Expand Up @@ -219,3 +219,193 @@ def update(self):
return current


# class PairedSTDP(Projection):
# r"""Paired spike-time-dependent plasticity model.
#
# This model filters the synaptic currents according to the variables: :math:`w`.
#
# .. math::
#
# I_{syn}^+(t) = I_{syn}^-(t) * w
#
# where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
# and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse,
# the conductance of the synapse will increase w.
#
# The dynamics of :math:`w` is governed by the following equation:
#
# .. math::
#
# \begin{aligned}
# \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
# \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
# \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
# \end{aligned}
#
# where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
# of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike.
#
# Here is an example of the usage of this class::
#
# import brainpy as bp
# import brainpy.math as bm
#
# class STDPNet(bp.DynamicalSystem):
# def __init__(self, num_pre, num_post):
# super().__init__()
# self.pre = bp.dyn.LifRef(num_pre)
# self.post = bp.dyn.LifRef(num_post)
# self.syn = bp.dyn.STDP_Song2000(
# pre=self.pre,
# delay=1.,
# comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
# weight=bp.init.Uniform(max_val=0.1)),
# syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
# out=bp.dyn.COBA.desc(E=0.),
# post=self.post,
# tau_s=16.8,
# tau_t=33.7,
# A1=0.96,
# A2=0.53,
# )
#
# def update(self, I_pre, I_post):
# self.syn()
# self.pre(I_pre)
# self.post(I_post)
# conductance = self.syn.refs['syn'].g
# Apre = self.syn.refs['pre_trace'].g
# Apost = self.syn.refs['post_trace'].g
# current = self.post.sum_inputs(self.post.V)
# return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
#
# duration = 300.
# I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
# [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
# I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
# [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
#
# net = STDPNet(1, 1)
# def run(i, I_pre, I_post):
# pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
# return pre_spike, post_spike, g, Apre, Apost, current, W
#
# indices = bm.arange(0, duration, bm.dt)
# pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post])
#
# Args:
# tau_s: float. The time constant of :math:`A_{pre}`.
# tau_t: float. The time constant of :math:`A_{post}`.
# A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value.
# A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value.
# W_max: float. The maximum weight.
# W_min: float. The minimum weight.
# pre: DynamicalSystem. The pre-synaptic neuron group.
# delay: int, float. The pre spike delay length. (ms)
# syn: DynamicalSystem. The synapse model.
# comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers.
# out: DynamicalSystem. The synaptic current output models.
# post: DynamicalSystem. The post-synaptic neuron group.
# out_label: str. The output label.
# name: str. The model name.
# """
#
# def __init__(
# self,
# pre: JointType[DynamicalSystem, SupportAutoDelay],
# delay: Union[None, int, float],
# syn: ParamDescriber[DynamicalSystem],
# comm: JointType[DynamicalSystem, SupportSTDP],
# out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
# post: DynamicalSystem,
# # synapse parameters
# tau_s: float = 16.8,
# tau_t: float = 33.7,
# lambda_: float = 0.96,
# alpha: float = 0.53,
# mu: float = 0.53,
# W_max: Optional[float] = None,
# W_min: Optional[float] = None,
# # others
# out_label: Optional[str] = None,
# name: Optional[str] = None,
# mode: Optional[bm.Mode] = None,
# ):
# super().__init__(name=name, mode=mode)
#
# # synaptic models
# check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
# check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
# check.is_instance(syn, ParamDescriber[DynamicalSystem])
# check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
# check.is_instance(post, DynamicalSystem)
# self.pre_num = pre.num
# self.post_num = post.num
# self.comm = comm
# self._is_align_post = issubclass(syn.cls, AlignPost)
#
# # delay initialization
# delay_cls = register_delay_by_return(pre)
# delay_cls.register_entry(self.name, delay)
#
# # synapse and output initialization
# if self._is_align_post:
# syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post,
# proj_name=self.name)
# else:
# syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre')
# out_cls = out()
# add_inp_fun(out_label, self.name, out_cls, post)
#
# # references
# self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
# self.refs['delay'] = delay_cls
# self.refs['syn'] = syn_cls # invisible to ``self.node()``
# self.refs['out'] = out_cls # invisible to ``self.node()``
# self.refs['comm'] = comm
#
# # tracing pre-synaptic spikes using Exponential model
# self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s))
#
# # tracing post-synaptic spikes using Exponential model
# self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t))
#
# # synapse parameters
# self.W_max = W_max
# self.W_min = W_min
# self.tau_s = tau_s
# self.tau_t = tau_t
# self.A1 = A1
# self.A2 = A2
#
# def update(self):
# # pre-synaptic spikes
# pre_spike = self.refs['delay'].at(self.name) # spike
# # pre-synaptic variables
# if self._is_align_post:
# # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance
# x = pre_spike
# else:
# # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance
# x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable
#
# # post spikes
# if not hasattr(self.refs['post'], 'spike'):
# raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.')
# post_spike = self.refs['post'].spike
#
# # weight updates
# Apost = self.refs['post_trace'].g
# self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max)
# Apre = self.refs['pre_trace'].g
# self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max)
#
# # synaptic currents
# current = self.comm(x)
# if self._is_align_post:
# self.refs['syn'].add_current(current) # synapse post current
# else:
# self.refs['out'].bind_cond(current) # align pre
# return current


52 changes: 52 additions & 0 deletions brainpy/_src/math/op_register/ad_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import functools
from functools import partial

from jax import tree_util
from jax.core import Primitive
from jax.interpreters import ad

__all__ = [
'defjvp',
]


def defjvp(primitive, *jvp_rules):
"""Define JVP rules for any JAX primitive.
This function is similar to ``jax.interpreters.ad.defjvp``.
However, the JAX one only supports primitive with ``multiple_results=False``.
``brainpy.math.defjvp`` enables to define the independent JVP rule for
each input parameter no matter ``multiple_results=False/True``.
For examples, please see ``test_ad_support.py``.
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
"""
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
else:
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)


def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
assert primitive.multiple_results
val_out = tuple(primitive.bind(*primals, **params))
tree = tree_util.tree_structure(val_out)
tangents_out = []
for rule, t in zip(jvp_rules, tangents):
if rule is not None and type(t) is not ad.Zero:
r = tuple(rule(t, *primals, **params))
tangents_out.append(r)
assert tree_util.tree_structure(r) == tree
return val_out, functools.reduce(_add_tangents,
tangents_out,
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))


def _add_tangents(xs, ys):
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))

0 comments on commit 8c28685

Please sign in to comment.