Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,8 @@ def __init__(

def __repr__(self):
names = self.__class__.__name__
return (f'{names}(name={self.name}, mode={self.mode}, '
f'{" " * len(names)} pre={self.pre}, '
return (f'{names}(name={self.name}, mode={self.mode}, \n'
f'{" " * len(names)} pre={self.pre}, \n'
f'{" " * len(names)} post={self.post})')

def check_pre_attrs(self, *attrs):
Expand Down
4 changes: 0 additions & 4 deletions brainpy/dyn/layers/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ class Dropout(DynamicalSystem):
In training, to compensate for the fraction of input values dropped (`rate`),
all surviving values are multiplied by `1 / (1 - rate)`.

The parameter `shared_axes` allows to specify a list of axes on which
the mask will be shared: we will use size 1 on those axes for dropout mask
and broadcast it. Sharing reduces randomness, but can save memory.

This layer is active only during training (`mode='train'`). In other
circumstances it is a no-op.

Expand Down
4 changes: 2 additions & 2 deletions brainpy/dyn/synapses/gap_junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
conn=conn,
name=name)
# checking
self.check_pre_attrs('V', 'spike')
self.check_post_attrs('V', 'input', 'spike')
self.check_pre_attrs('V')
self.check_post_attrs('V', 'input')

# assert isinstance(self.output, _NullSynOut)
# assert isinstance(self.stp, _NullSynSTP)
Expand Down
42 changes: 37 additions & 5 deletions brainpy/math/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,42 @@ def __jax_array__(self):

class Variable(JaxArray):
"""The pointer to specify the dynamical variable.

Initializing an instance of ``Variable`` by two ways:

>>> import brainpy.math as bm
>>> # 1. init a Variable by the concreate data
>>> v1 = bm.Variable(bm.zeros(10))
>>> # 2. init a Variable by the data shape
>>> v2 = bm.Variable(10)

Note that when initializing a `Variable` by the data shape,
all values in this `Variable` will be initialized as zeros.

Parameters
----------
value_or_size: Shape, Array
The value or the size of the value.
dtype:
The type of the data.
batch_axis: optional, int
The batch axis.
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
def __init__(
self,
value_or_size,
dtype=None,
batch_axis: int = None
):
if isinstance(value_or_size, int):
value = jnp.zeros(value_or_size, dtype=dtype)
elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]):
value = jnp.zeros(value_or_size, dtype=dtype)
else:
value = value_or_size

super(Variable, self).__init__(value, dtype=dtype)

# check batch axis
Expand Down Expand Up @@ -1464,17 +1496,17 @@ class TrainVar(Variable):
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
super(TrainVar, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
super(TrainVar, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)


class Parameter(Variable):
"""The pointer to specify the parameter.
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
super(Parameter, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)


register_pytree_node(JaxArray,
Expand Down
11 changes: 11 additions & 0 deletions brainpy/math/tests/test_jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,14 @@ def test_none(self):
ee = a + e


class TestVariable(unittest.TestCase):
def test_variable_init(self):
self.assertTrue(
bm.array_equal(bm.Variable(bm.zeros(10)),
bm.Variable(10))
)
bm.random.seed(123)
self.assertTrue(
not bm.array_equal(bm.Variable(bm.random.rand(10)),
bm.Variable(10))
)
101 changes: 101 additions & 0 deletions examples/simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-


""""
Implementation of the paper:

- Fazli, Mehran, and Richard Bertram. "Network Properties of Electrically
Coupled Bursting Pituitary Cells." Frontiers in Endocrinology 13 (2022).
"""

import brainpy as bp
import brainpy.math as bm


class PituitaryCell(bp.NeuGroup):
def __init__(self, size, name=None):
super(PituitaryCell, self).__init__(size, name=name)

# parameter values
self.vn = -5
self.kc = 0.12
self.ff = 0.005
self.vca = 60
self.vk = -75
self.vl = -50.0
self.gk = 2.5
self.cm = 5
self.gbk = 1
self.gca = 2.1
self.gsk = 2
self.vm = -20
self.vb = -5
self.sn = 10
self.sm = 12
self.sbk = 2
self.taun = 30
self.taubk = 5
self.ks = 0.4
self.alpha = 0.0015
self.gl = 0.2

# variables
self.V = bm.Variable(bm.random.random(self.num) * -90 + 20)
self.n = bm.Variable(bm.random.random(self.num) / 2)
self.b = bm.Variable(bm.random.random(self.num) / 2)
self.c = bm.Variable(bm.random.random(self.num))
self.input = bm.Variable(self.num)

# integrators
self.integral = bp.odeint(bp.JointEq(self.dV, self.dn, self.dc, self.db), method='exp_euler')

def dn(self, n, t, V):
ninf = 1 / (1 + bm.exp((self.vn - V) / self.sn))
return (ninf - n) / self.taun

def db(self, b, t, V):
bkinf = 1 / (1 + bm.exp((self.vb - V) / self.sbk))
return (bkinf - b) / self.taubk

def dc(self, c, t, V):
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
ica = self.gca * minf * (V - self.vca)
return -self.ff * (self.alpha * ica + self.kc * c)

def dV(self, V, t, n, b, c):
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
cinf = c ** 2 / (c ** 2 + self.ks * self.ks)
ica = self.gca * minf * (V - self.vca)
isk = self.gsk * cinf * (V - self.vk)
ibk = self.gbk * b * (V - self.vk)
ikdr = self.gk * n * (V - self.vk)
il = self.gl * (V - self.vl)
return -(ica + isk + ibk + ikdr + il + self.input) / self.cm

def update(self, tdi, x=None):
V, n, c, b = self.integral(self.V.value, self.n.value, self.c.value, self.b.value, tdi.t, tdi.dt)
self.V.value = V
self.n.value = n
self.c.value = c
self.b.value = b

def clear_input(self):
self.input.value = bm.zeros_like(self.input)


class PituitaryNetwork(bp.Network):
def __init__(self, num, gc):
super(PituitaryNetwork, self).__init__()

self.N = PituitaryCell(num)
self.gj = bp.synapses.GapJunction(self.N, self.N, bp.conn.All2All(include_self=False), g_max=gc)


if __name__ == '__main__':
net = PituitaryNetwork(2, 0.002)
runner = bp.DSRunner(net, monitors={'V': net.N.V}, dt=0.5)
runner.run(10 * 1e3)

fig, gs = bp.visualize.get_figure(1, 1, 6, 10)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=(0, 1), show=True)
3 changes: 2 additions & 1 deletion extensions/brainpylib/atomic_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def _atomic_sum_translation(c, values, pre_ids, post_ids, *, post_num, platform=
shape_with_layout=x_shape(np.dtype(values_dtype), (post_num,), (0,)),
)
elif platform == 'gpu':
if gpu_ops is None: raise ValueError('Cannot find compiled gpu wheels.')
if gpu_ops is None:
raise ValueError('Cannot find compiled gpu wheels.')

opaque = gpu_ops.build_atomic_sum_descriptor(conn_size, post_num)
if values_dim[0] != 1:
Expand Down
60 changes: 30 additions & 30 deletions extensions/brainpylib/event_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

from functools import partial

from typing import Union, Tuple
import jax.numpy as jnp
import numpy as np
from jax import core
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.interpreters import xla, batching
from jax.lax import scan
from jax.lib import xla_client

from .utils import GPUOperatorNotFound

try:
from . import gpu_ops
except ImportError:
Expand All @@ -26,7 +29,10 @@
_event_sum_prim = core.Primitive("event_sum")


def event_sum(events, pre2post, post_num, values):
def event_sum(events: jnp.ndarray,
pre2post: Tuple[jnp.ndarray, jnp.ndarray],
post_num: int,
values: Union[float, jnp.ndarray]):
# events
if events.dtype != jnp.bool_:
raise ValueError(f'"events" must be a vector of bool, while we got {events.dtype}')
Expand All @@ -39,17 +45,16 @@ def event_sum(events, pre2post, post_num, values):
if indices.dtype != indptr.dtype:
raise ValueError(f"The dtype of pre2post[0] must be equal to that of pre2post[1], "
f"while we got {(indices.dtype, indptr.dtype)}")
if indices.dtype not in [jnp.uint32, jnp.uint64]:
raise ValueError(f'The dtype of pre2post must be uint32 or uint64, while we got {indices.dtype}')
if indices.dtype not in [jnp.uint32, jnp.uint64, jnp.int32, jnp.int64]:
raise ValueError(f'The dtype of pre2post must be integer, while we got {indices.dtype}')

# output value
values = jnp.asarray([values])
if values.dtype not in [jnp.float32, jnp.float64]:
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.')
if values.size not in [1, indices.size]:
dtype = values.dtype if isinstance(values, jnp.ndarray) else dtypes.canonicalize_dtype(type(values))
if dtype not in [jnp.float32, jnp.float64]:
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {dtype}.')
if np.size(values) not in [1, indices.size]:
raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre2post[0]) (a vector), '
f'while we got {values.size} != 1 != {indices.size}')
values = values.flatten()
f'while we got {np.size(values)} != 1 != {indices.size}')
# bind operator
return _event_sum_prim.bind(events, indices, indptr, values, post_num=post_num)

Expand All @@ -58,34 +63,27 @@ def _event_sum_abstract(events, indices, indptr, values, *, post_num):
return ShapedArray(dtype=values.dtype, shape=(post_num,))


_event_sum_prim.def_abstract_eval(_event_sum_abstract)
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))


def _event_sum_translation(c, events, indices, indptr, values, *, post_num, platform="cpu"):
# The pre/post shape
# The shape of pre/post
pre_size = np.array(c.get_shape(events).dimensions()[0], dtype=np.uint32)
_pre_shape = x_shape(np.dtype(np.uint32), (), ())
_post_shape = x_shape(np.dtype(np.uint32), (), ())

# The indices shape
indices_shape = c.get_shape(indices)
Itype = indices_shape.element_type()
assert Itype in [np.uint32, np.uint64]

# The value shape
values_shape = c.get_shape(values)
Ftype = values_shape.element_type()
assert Ftype in [np.float32, np.float64]
values_dim = values_shape.dimensions()

# We dispatch a different call depending on the dtype
f_type = b'_f32' if Ftype == np.float32 else b'_f64'
i_type = b'_i32' if Itype == np.uint32 else b'_i64'
f_type = b'_f32' if Ftype in np.float32 else b'_f64'
i_type = b'_i32' if Itype in [np.uint32, np.int32] else b'_i64'

# And then the following is what changes between the GPU and CPU
if platform == "cpu":
v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
v_type = b'_event_sum_homo' if len(values_dim) == 0 else b'_event_sum_heter'
return x_ops.CustomCallWithLayout(
c,
platform.encode() + v_type + f_type + i_type,
Expand All @@ -103,9 +101,12 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
c.get_shape(values)),
shape_with_layout=x_shape(np.dtype(Ftype), (post_num,), (0,)),
)

# GPU platform
elif platform == 'gpu':
if gpu_ops is None:
raise ValueError('Cannot find compiled gpu wheels.')
raise GPUOperatorNotFound('event_sum')

v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
opaque = gpu_ops.build_event_sum_descriptor(pre_size, post_num)
return x_ops.CustomCallWithLayout(
Expand All @@ -127,11 +128,7 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
raise ValueError("Unsupported platform, we only support 'cpu' or 'gpu'")


xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")


def _event_sum_batch(args, axes):
def _event_sum_batch(args, axes, *, post_num):
batch_axes, batch_args, non_batch_args = [], {}, {}
for ax_i, ax in enumerate(axes):
if ax is None:
Expand All @@ -143,19 +140,22 @@ def _event_sum_batch(args, axes):
def f(_, x):
pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
for i in range(len(axes))])
return 0, _event_sum_prim.bind(*pars)
return 0, _event_sum_prim.bind(*pars, post_num=post_num)

_, outs = scan(f, 0, batch_args)
return outs, 0


_event_sum_prim.def_abstract_eval(_event_sum_abstract)
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))
batching.primitive_batchers[_event_sum_prim] = _event_sum_batch

xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")

# ---------------------------
# event sum kernel 2
# ---------------------------


_event_sum2_prim = core.Primitive("event_sum2")


Expand Down
18 changes: 18 additions & 0 deletions extensions/brainpylib/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-


__all__ = [
'GPUOperatorNotFound',
]


class GPUOperatorNotFound(Exception):
def __init__(self, name):
super(GPUOperatorNotFound, self).__init__(f'''
GPU operator for "{name}" does not found.

Please compile brainpylib GPU operators with the guidance in the following link:

https://brainpy.readthedocs.io/en/latest/tutorial_advanced/compile_brainpylib.html
''')

Loading