diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py index 904e96ff2..5ddff168a 100644 --- a/brainpy/connect/base.py +++ b/brainpy/connect/base.py @@ -116,13 +116,13 @@ def build_conn(self): import brainpy as bp class MyConnector(bp.conn.TwoEndConnector): - def build_mat(self, pre_size, post_size): + def build_mat(self, ): return conn_matrix - def build_csr(self, pre_size, post_size): + def build_csr(self, ): return post_ids, inptr - def build_coo(self, pre_size, post_size): + def build_coo(self, ): return pre_ids, post_ids """ @@ -196,8 +196,6 @@ def check(self, structures: Union[Tuple, List, str]): raise ConnectorError(f'Unknown synapse structure "{n}". ' f'Only {SUPPORTED_SYN_STRUCTURE} is supported.') - - def _return_by_mat(self, structures, mat, all_data: dict): assert mat.ndim == 2 if (CONN_MAT in structures) and (CONN_MAT not in all_data): @@ -332,70 +330,56 @@ def build_conn(self): """ pass - def require(self, *sizes_or_structures): - sizes_or_structures = list(sizes_or_structures) - pre_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None - post_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None - structures = sizes_or_structures - if isinstance(post_size, str): - structures.insert(0, post_size) - post_size = None - if isinstance(pre_size, str): - structures.insert(0, pre_size) - pre_size = None - - version2_style = (pre_size is not None) and (post_size is not None) - if not version2_style: - try: - assert self.pre_num is not None and self.post_num is not None - except AssertionError: - raise ConnectorError(f'self.pre_num or self.post_num is not defined. ' - f'Please use self.__call__(pre_size, post_size) ' - f'before requiring connection data.') - if pre_size is None: - pre_size = self.pre_size - if post_size is None: - post_size = self.post_size + def require(self, *structures): + try: + assert self.pre_num is not None and self.post_num is not None + except AssertionError: + raise ConnectorError(f'self.pre_num or self.post_num is not defined. ' + f'Please use self.__call__() ' + f'before requiring connection data.') self.check(structures) if self.is_version2_style: if len(structures) == 1: if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'): - return self.build_csr(pre_size, post_size) + r = self.build_csr() + return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'): - return self.build_mat(pre_size, post_size) + return bm.asarray(self.build_mat(), dtype=MAT_DTYPE) elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'): - return self.build_coo(pre_size, post_size)[0] + return bm.asarray(self.build_coo()[0], dtype=IDX_DTYPE) elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): - return self.build_coo(pre_size, post_size)[1] + return bm.asarray(self.build_coo()[1], dtype=IDX_DTYPE) elif len(structures) == 2: if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): - return self.build_coo(pre_size, post_size) + r = self.build_coo() + return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) conn_data = dict(csr=None, ij=None, mat=None) if not hasattr(self.build_coo, 'not_customized'): - conn_data['ij'] = self.build_coo(pre_size, post_size) + conn_data['ij'] = self.build_coo() elif not hasattr(self.build_csr, 'not_customized'): - conn_data['csr'] = self.build_csr(pre_size, post_size) + conn_data['csr'] = self.build_csr() elif not hasattr(self.build_mat, 'not_customized'): - conn_data['mat'] = self.build_mat(pre_size, post_size) + conn_data['mat'] = self.build_mat() + else: conn_data = self.build_conn() return self.make_returns(structures, conn_data) - def requires(self, *sizes_or_structures): - return self.require(*sizes_or_structures) + def requires(self, *structures): + return self.require(*structures) @tools.not_customized - def build_mat(self, pre_size=None, post_size=None): + def build_mat(self): pass @tools.not_customized - def build_csr(self, pre_size=None, post_size=None): + def build_csr(self): pass @tools.not_customized - def build_coo(self, pre_size=None, post_size=None): + def build_coo(self): pass @@ -425,7 +409,6 @@ def __call__(self, pre_size, post_size=None): else: post_size = tuple(post_size) self.pre_size, self.post_size = pre_size, post_size - self.pre_num = tools.size2num(self.pre_size) self.post_num = tools.size2num(self.post_size) return self diff --git a/brainpy/connect/custom_conn.py b/brainpy/connect/custom_conn.py index 7cac4cd72..e452061e5 100644 --- a/brainpy/connect/custom_conn.py +++ b/brainpy/connect/custom_conn.py @@ -7,7 +7,6 @@ from brainpy import tools from brainpy.errors import ConnectorError from .base import * -from .utils import * __all__ = [ 'MatConn', @@ -34,11 +33,9 @@ def __call__(self, pre_size, post_size): assert self.post_num == tools.size2num(post_size) return self - def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - assert self.conn_mat.shape[0] == pre_num - assert self.conn_mat.shape[1] == post_num + def build_mat(self): + assert self.conn_mat.shape[0] == self.pre_num + assert self.conn_mat.shape[1] == self.post_num return self.conn_mat @@ -68,14 +65,12 @@ def __call__(self, pre_size, post_size): f'the maximum id ({self.max_post}) of self.post_ids.') return self - def build_coo(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num <= self.max_pre: - raise ConnectorError(f'pre_num ({pre_num}) should be greater than ' + def build_coo(self): + if self.pre_num <= self.max_pre: + raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than ' f'the maximum id ({self.max_pre}) of self.pre_ids.') - if post_num <= self.max_post: - raise ConnectorError(f'post_num ({post_num}) should be greater than ' + if self.post_num <= self.max_post: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' f'the maximum id ({self.max_post}) of self.post_ids.') return self.pre_ids, self.post_ids @@ -91,16 +86,12 @@ def __init__(self, indices, inptr): self.pre_num = self.inptr.size - 1 self.max_post = bm.max(self.indices) - def build_csr(self, pre_size=None, post_size=None): - pre_size = get_pre_size(self, pre_size) - post_size = get_post_size(self, post_size) - pre_num = np.prod(pre_size) - post_num = np.prod(post_size) - if pre_num != self.pre_num: + def build_csr(self): + if self.pre_num != self.pre_num: raise ConnectorError(f'(pre_size, post_size) is inconsistent with ' f'the shape of the sparse matrix.') - if post_num <= self.max_post: - raise ConnectorError(f'post_num ({post_num}) should be greater than ' + if self.post_num <= self.max_post: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' f'the maximum id ({self.max_post}) of self.post_ids.') return self.indices, self.inptr diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py index 4daf894f8..bc0eafa33 100644 --- a/brainpy/connect/random_conn.py +++ b/brainpy/connect/random_conn.py @@ -1,14 +1,11 @@ # -*- coding: utf-8 -*- -import jax import numpy as np from typing import Optional -from brainpy import math as bm from brainpy.errors import ConnectorError -from brainpy.tools.others import numba_seed, numba_jit, SUPPORT_NUMBA, format_seed +from brainpy.tools.others import numba_seed, numba_jit, numba_range, SUPPORT_NUMBA, format_seed from .base import * -from .utils import * __all__ = [ 'FixedProb', @@ -27,7 +24,6 @@ class FixedProb(TwoEndConnector): """Connect the post-synaptic neurons with fixed probability. - .. versionchanged:: 2.2.3.2 Parameters ---------- @@ -37,11 +33,16 @@ class FixedProb(TwoEndConnector): The ratio of pre-synaptic neurons to connect. include_self : bool Whether create (i, i) conn? + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + seed : optional, int Seed the random generator. """ - def __init__(self, prob, pre_ratio=1., include_self=True, seed=None): + def __init__(self, prob, pre_ratio=1., include_self=True, allow_multi_conn=False, seed=None): super(FixedProb, self).__init__() assert 0. <= prob <= 1. assert 0. <= pre_ratio <= 1. @@ -49,57 +50,80 @@ def __init__(self, prob, pre_ratio=1., include_self=True, seed=None): self.pre_ratio = pre_ratio self.include_self = include_self self.seed = format_seed(seed) - self.rng = bm.random.RandomState(seed=self.seed) + self.rng = np.random.RandomState(seed=self.seed) + self.allow_multi_conn = allow_multi_conn def __repr__(self): return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' - f'include_self={self.include_self}, seed={self.seed})') - - def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - pre_state = self.rng.rand(pre_num, 1) < self.pre_ratio - mat = (self.rng.rand(pre_num, post_num) < self.prob) * pre_state - if not self.include_self: - bm.fill_diagonal(mat, False) - return mat.astype(MAT_DTYPE) + f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' + f'seed={self.seed})') - def _iii(self, pre_size, post_size): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if (not self.include_self) and (pre_num != post_num): - raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + def _iii(self): + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' f'But `include_self` is set to True.') - post_num_to_select = int(post_num * self.prob) - post_ids = bm.arange(post_num) + if self.pre_ratio < 1.: - pre_num_to_select = int(pre_num * self.pre_ratio) - pre_ids = self.rng.choice(pre_num, size=pre_num_to_select, replace=False) + pre_num_to_select = int(self.pre_num * self.pre_ratio) + pre_ids = self.rng.choice(self.pre_num, size=pre_num_to_select, replace=False) + else: + pre_num_to_select = self.pre_num + pre_ids = np.arange(self.pre_num) + + post_num_total = self.post_num + post_num_to_select = int(self.post_num * self.prob) + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) else: - pre_num_to_select = pre_num - pre_ids = bm.arange(pre_num) + rng = self.rng - @jax.vmap - def f(i, key): - posts = bm.delete(post_ids, i) if not self.include_self else post_ids - return self.rng.permutation(posts, key=key)[:post_num_to_select] + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - selected_post_ids = f(pre_ids, self.rng.split_keys(pre_ids.size)).flatten() + else: + @numba_jit#(parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=np.int32) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = single_conn() return pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids - def build_coo(self, pre_size=None, post_size=None): - _, post_num_to_select, selected_post_ids, pre_ids = self._iii(pre_size, post_size) - selected_pre_ids = bm.repeat(pre_ids, post_num_to_select) + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._iii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = np.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) - def build_csr(self, pre_size=None, post_size=None): - pre_num_to_select, post_num_to_select, selected_post_ids, _ = self._iii(pre_size, post_size) - selected_pre_inptr = bm.cumsum(bm.concatenate([bm.zeros(1), bm.ones(pre_num_to_select) * post_num_to_select])) + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() + pre_nums = np.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == np.reshape(pre_ids, (-1, 1)) + pre_nums -= np.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[np.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = np.cumsum(np.concatenate([np.zeros(1), pre_nums])) return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) + def build_mat(self): + pre_state = self.rng.rand(self.pre_num, 1) < self.pre_ratio + mat = (self.rng.rand(self.pre_num, self.post_num) < self.prob) * pre_state + if not self.include_self: + np.fill_diagonal(mat, False) + return mat.astype(MAT_DTYPE) + class FixedNum(TwoEndConnector): - def __init__(self, num, include_self=True, seed=None): + def __init__(self, num, include_self=True, allow_multi_conn=False, seed=None): super(FixedNum, self).__init__() if isinstance(num, int): assert num >= 0, '"num" must be a non-negative integer.' @@ -110,7 +134,8 @@ def __init__(self, num, include_self=True, seed=None): self.num = num self.seed = format_seed(seed) self.include_self = include_self - self.rng = bm.random.RandomState(seed=self.seed) + self.allow_multi_conn = allow_multi_conn + self.rng = np.random.RandomState(seed=self.seed) def __repr__(self): return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' @@ -128,28 +153,49 @@ class FixedPreNum(FixedNum): Whether create (i, i) conn ? seed : None, int Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + """ - def build_coo(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if isinstance(self.num, int) and self.num > pre_num: + def build_coo(self): + if isinstance(self.num, int) and self.num > self.pre_num: raise ConnectorError(f'"num" must be smaller than "pre_num", ' - f'but got {self.num} > {pre_num}') - if (not self.include_self) and (pre_num != post_num): - raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + f'but got {self.num} > {self.pre_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' f'But `include_self` is set to True.') - pre_num_to_select = int(pre_num * self.num) if isinstance(self.num, float) else self.num - pre_ids = bm.arange(pre_num) - post_ids = bm.arange(post_num) + pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num + pre_num_total = self.pre_num + post_num_total = self.post_num + seed = self.seed + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select, )) + else: + @numba_jit#(parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((post_num_total, pre_num_to_select), dtype=np.int32) + for i in numba_range(post_num_total): + posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) + return posts - @jax.vmap - def f(post_i, key): - pres = bm.delete(pre_ids, post_i) if not self.include_self else pre_ids - return self.rng.permutation(pres, key=key)[:pre_num_to_select] + selected_pre_ids = single_conn() - selected_pre_ids = f(post_ids, self.rng.split_keys(post_num)).flatten() - selected_post_ids = bm.repeat(post_ids, pre_num_to_select) + post_nums = np.ones((post_num_total,), dtype=np.int32) * pre_num_to_select + if not self.include_self: + true_ids = selected_pre_ids == np.reshape(np.arange(pre_num_total), (-1, 1)) + post_nums -= np.sum(true_ids, axis=1) + selected_pre_ids = selected_pre_ids.flatten()[np.logical_not(true_ids).flatten()] + else: + selected_pre_ids = selected_pre_ids.flatten() + selected_post_ids = np.repeat(np.arange(post_num_total), post_nums) return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) @@ -165,39 +211,67 @@ class FixedPostNum(FixedNum): Whether create (i, i) conn ? seed : None, int Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + """ - def _ii(self, pre_size, post_size): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if isinstance(self.num, int) and self.num > post_num: + def _ii(self): + if isinstance(self.num, int) and self.num > self.post_num: raise ConnectorError(f'"num" must be smaller than "post_num", ' - f'but got {self.num} > {post_num}') - if (not self.include_self) and (pre_num != post_num): - raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + f'but got {self.num} > {self.post_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' f'But `include_self` is set to True.') - post_num_to_select = int(post_num * self.num) if isinstance(self.num, float) else self.num - pre_ids = bm.arange(pre_num) - post_ids = bm.arange(post_num) - - @jax.vmap - def f(pre_i, key): - posts = bm.delete(post_ids, pre_i) if not self.include_self else post_ids - return self.rng.permutation(posts, key=key)[:post_num_to_select] + post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num + pre_num_to_select = self.pre_num + pre_ids = np.arange(self.pre_num) + + post_num_total = self.post_num + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng - selected_post_ids = f(pre_ids, self.rng.split_keys(pre_num)).flatten() - return pre_num, post_num_to_select, selected_post_ids + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - def build_csr(self, pre_size=None, post_size=None): - pre_num, post_num, selected_post_ids = self._ii(pre_size, post_size) - selected_pre_inptr = bm.cumsum(bm.concatenate([bm.zeros(1), bm.ones(pre_num) * post_num])) - return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) + else: + @numba_jit#(parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=np.int32) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = single_conn() + return pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids - def build_coo(self, pre_size=None, post_size=None): - pre_num, post_num, selected_post_ids = self._ii(pre_size, post_size) - selected_pre_ids = bm.repeat(bm.arange(pre_num), post_num) + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._ii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = np.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() + pre_nums = np.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == np.reshape(pre_ids, (-1, 1)) + pre_nums -= np.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[np.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = np.cumsum(np.concatenate([np.zeros(1), pre_nums])) + return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) + class GaussianProb(OneEndConnector): r"""Builds a Gaussian connectivity pattern within a population of neurons, @@ -253,7 +327,7 @@ def __init__( self.include_self = include_self self.periodic_boundary = periodic_boundary self.seed = format_seed(seed) - self.rng = bm.random.RandomState(self.seed) + self.rng = np.random.RandomState(self.seed) def __repr__(self): return (f'{self.__class__.__name__}(sigma={self.sigma}, ' @@ -318,9 +392,9 @@ def build_mat(self, pre_size=None, post_size=None): prob_mat /= prob_mat.max() # connectivity - conn_mat = bm.asarray(prob_mat) >= self.rng.random(prob_mat.shape) + conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) if not self.include_self: - bm.fill_diagonal(conn_mat, False) + np.fill_diagonal(conn_mat, False) return conn_mat diff --git a/brainpy/connect/regular_conn.py b/brainpy/connect/regular_conn.py index 0d3286bb9..253f00998 100644 --- a/brainpy/connect/regular_conn.py +++ b/brainpy/connect/regular_conn.py @@ -6,9 +6,7 @@ from brainpy import math as bm from brainpy.errors import ConnectorError -from brainpy.tools import size2num from .base import * -from .utils import * __all__ = [ 'One2One', 'one2one', @@ -36,31 +34,25 @@ def __call__(self, pre_size, post_size): f'same size, but {self.pre_num} != {self.post_num}.') return self - def build_coo(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num != post_num: + def build_coo(self): + if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {pre_num} != {post_num}.') - return bm.arange(pre_num, dtype=IDX_DTYPE), bm.arange(post_num, dtype=IDX_DTYPE), + f'same size, but {self.pre_num} != {self.post_num}.') + return np.arange(self.pre_num, dtype=IDX_DTYPE), np.arange(self.post_num, dtype=IDX_DTYPE), - def build_csr(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num != post_num: + def build_csr(self): + if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {pre_num} != {post_num}.') - ind = bm.arange(pre_num) - indptr = np.arange(pre_num + 1) - return bm.asarray(ind, dtype=IDX_DTYPE), bm.arange(indptr, dtype=IDX_DTYPE), + f'same size, but {self.pre_num} != {self.post_num}.') + ind = np.arange(self.pre_num) + indptr = np.arange(self.pre_num + 1) + return np.asarray(ind, dtype=IDX_DTYPE), np.arange(indptr, dtype=IDX_DTYPE), def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num != post_num: + if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {pre_num} != {post_num}.') - return bm.fill_diagonal(bm.zeros((pre_num, post_num), dtype=MAT_DTYPE), True) + f'same size, but {self.pre_num} != {self.post_num}.') + return np.fill_diagonal(np.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE), True) one2one = One2One() @@ -79,12 +71,10 @@ def __init__(self, include_self=True): def __repr__(self): return f'{self.__class__.__name__}(include_self={self.include_self})' - def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - mat = bm.ones((pre_num, post_num), dtype=MAT_DTYPE) + def build_mat(self): + mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE) if not self.include_self: - bm.fill_diagonal(mat, False) + np.fill_diagonal(mat, False) return mat @@ -117,22 +107,18 @@ def __init__( def __repr__(self): return f'{self.__class__.__name__}(include_self={self.include_self}, periodic_boundary={self.periodic_boundary})' - def _format(self, pre_size, post_size): - pre_size = get_pre_size(self, pre_size) - post_size = get_post_size(self, post_size) - pre_num = size2num(pre_size) - post_num = size2num(post_size) - dim = len(post_size) - if pre_num != post_num: + def _format(self): + dim = len(self.post_size) + if self.pre_num != self.post_num: raise ConnectorError(f'{self.__class__.__name__} is used to for connection within ' f'a same population. But we detect pre_num != post_num ' - f'({pre_num} != {post_num}).') + f'({self.pre_num} != {self.post_num}).') # point indices - indices = bm.meshgrid(*(bm.arange(size) for size in post_size), indexing='ij') + indices = bm.meshgrid(*(bm.arange(size) for size in self.post_size), indexing='ij') indices = bm.asarray(indices) - indices = indices.reshape(dim, post_num).T - lengths = bm.asarray(post_size) - return lengths, post_size, dim, indices + indices = indices.reshape(dim, self.post_num).T + lengths = bm.asarray(self.post_size) + return lengths, dim, indices def _get_strides(self, dim): # increments @@ -147,8 +133,8 @@ def _select_stride(self, stride: np.ndarray) -> np.ndarray: def _select_dist(self, dist: bm.ndarray) -> bm.ndarray: raise NotImplementedError - def build_mat(self, pre_size=None, post_size=None): - sizes, post_size, _, indices = self._format(pre_size, post_size) + def build_mat(self): + sizes, _, indices = self._format() @jax.vmap def f_connect(pre_id): @@ -160,8 +146,8 @@ def f_connect(pre_id): return bm.asarray(f_connect(indices), dtype=MAT_DTYPE) - def build_coo(self, pre_size=None, post_size=None): - sizes, post_size, dim, indices = self._format(pre_size, post_size) + def build_coo(self): + sizes, dim, indices = self._format() strides = self._get_strides(dim) @jax.vmap @@ -186,7 +172,7 @@ def f_connect(pre_id): pres = pres.flatten() posts = posts.flatten() else: - strides = bm.asarray(get_size_length(post_size)) + strides = bm.asarray(get_size_length(self.post_size)) pres = bm.sum(pres * strides, axis=1) posts = bm.sum(posts * strides, axis=1) return bm.asarray(pres, dtype=IDX_DTYPE), bm.asarray(posts, dtype=IDX_DTYPE) diff --git a/brainpy/connect/tests/test_random_conn.py b/brainpy/connect/tests/test_random_conn.py index 3df5185ea..8744be01b 100644 --- a/brainpy/connect/tests/test_random_conn.py +++ b/brainpy/connect/tests/test_random_conn.py @@ -18,12 +18,13 @@ def test_size_consistent(self): def test_require_method(self): conn2 = bp.connect.FixedProb(prob=0.1, seed=123) conn2(pre_size=(10, 20), post_size=(10, 20)) - mat = conn2.require(100, 1000, bp.connect.CONN_MAT) - self.assertTrue(mat.shape == (100, 1000)) - mat = conn2.require(bp.connect.CONN_MAT) self.assertTrue(mat.shape == (200, 200)) + mat = conn2(100, 1000).require(bp.connect.CONN_MAT) + self.assertTrue(mat.shape == (100, 1000)) + + def test_random_fix_pre1(): for num in [0.4, 20]: @@ -34,8 +35,11 @@ def test_random_fix_pre1(): mat2 = conn2.require(bp.connect.CONN_MAT) print() + print(f'num = {num}') print('conn_mat 1\n', mat1) + print(mat1.sum()) print('conn_mat 2\n', mat2) + print(mat2.sum()) assert bp.math.array_equal(mat1, mat2) diff --git a/brainpy/connect/tests/test_regular_conn.py b/brainpy/connect/tests/test_regular_conn.py index b2a4fd41e..4fe2ab85d 100644 --- a/brainpy/connect/tests/test_regular_conn.py +++ b/brainpy/connect/tests/test_regular_conn.py @@ -65,9 +65,9 @@ def test_grid_four(self): for include_self in [True, False]: for size in (10, [10, 10], (4, 4, 5)): conn = bp.conn.GridFour(include_self=include_self, - periodic_boundary=periodic_boundary) - mat = conn.build_mat(size, size) - pre_ids, post_ids = conn.build_coo(size, size) + periodic_boundary=periodic_boundary)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) new_mat[pre_ids, post_ids] = True @@ -79,9 +79,9 @@ def test_grid_eight(self): for include_self in [True, False]: for size in (10, [10, 10], (4, 4, 5)): conn = bp.conn.GridEight(include_self=include_self, - periodic_boundary=periodic_boundary) - mat = conn.build_mat(size, size) - pre_ids, post_ids = conn.build_coo(size, size) + periodic_boundary=periodic_boundary)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) new_mat[pre_ids, post_ids] = True @@ -94,9 +94,9 @@ def test_grid_N(self): for size in (10, [10, 10], (4, 4, 5)): conn = bp.conn.GridN(include_self=include_self, periodic_boundary=periodic_boundary, - N=2) - mat = conn.build_mat(size, size) - pre_ids, post_ids = conn.build_coo(size, size) + N=2)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) new_mat[pre_ids, post_ids] = True diff --git a/brainpy/connect/utils.py b/brainpy/connect/utils.py deleted file mode 100644 index ca7a8e8e9..000000000 --- a/brainpy/connect/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- - - -from brainpy.errors import ConnectorError -from brainpy.tools import size2num - -__all__ = [ - 'get_pre_num', 'get_post_num', - 'get_pre_size', 'get_post_size', -] - - -def get_pre_size(obj, pre_size=None): - if pre_size is None: - if obj.pre_size is None: - raise ConnectorError('Please provide "pre_size" and "post_size"') - else: - return obj.pre_size - else: - return (pre_size, ) if isinstance(pre_size, int) else pre_size - - -def get_pre_num(obj, pre_size=None): - return size2num(get_pre_size(obj, pre_size)) - - -def get_post_size(obj, post_size=None): - if post_size is None: - if obj.post_size is None: - raise ConnectorError('Please provide "pre_size" and "post_size"') - else: - return obj.post_size - else: - return (post_size,) if isinstance(post_size, int) else post_size - - -def get_post_num(obj, post_size=None): - return size2num(get_post_size(obj, post_size)) diff --git a/brainpy/math/autograd.py b/brainpy/math/autograd.py index 02c028da7..200514e72 100644 --- a/brainpy/math/autograd.py +++ b/brainpy/math/autograd.py @@ -16,7 +16,9 @@ from jax.util import safe_map from brainpy import errors -from brainpy.math.jaxarray import JaxArray +from brainpy.base.naming import get_unique_name +from brainpy.math.jaxarray import JaxArray, add_context, del_context + __all__ = [ 'grad', # gradient of scalar function @@ -28,20 +30,26 @@ def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars, argnums, return_value, has_aux): + name = get_unique_name('_brainpy_object_oriented_grad_') + # outputs def call_func(*args, **kwargs): old_grad_vs = [v.value for v in grad_vars] old_dyn_vs = [v.value for v in dyn_vars] try: + add_context(name) grads, (outputs, new_grad_vs, new_dyn_vs) = grad_func(old_grad_vs, old_dyn_vs, *args, **kwargs) + del_context(name) except UnexpectedTracerError as e: + del_context(name) for v, d in zip(grad_vars, old_grad_vs): v._value = d for v, d in zip(dyn_vars, old_dyn_vs): v._value = d raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e except Exception as e: + del_context(name) for v, d in zip(grad_vars, old_grad_vs): v._value = d for v, d in zip(dyn_vars, old_dyn_vs): v._value = d raise e diff --git a/brainpy/math/controls.py b/brainpy/math/controls.py index ab1d1b923..efbee4b5d 100644 --- a/brainpy/math/controls.py +++ b/brainpy/math/controls.py @@ -13,9 +13,10 @@ from jax.core import UnexpectedTracerError from brainpy import errors +from brainpy.base.naming import get_unique_name from brainpy.math.jaxarray import (JaxArray, Variable, - turn_on_global_jit, - turn_off_global_jit) + add_context, + del_context) from brainpy.math.numpy_ops import as_device_array __all__ = [ @@ -158,17 +159,19 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False): out_vars=out_vars, has_return=has_return) + name = get_unique_name('_brainpy_object_oriented_make_loop_') + # functions if has_return: def call(xs=None, length=None): init_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, (out_values, results) = lax.scan( f=fun2scan, init=init_values, xs=xs, length=length) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -178,15 +181,15 @@ def call(xs=None, length=None): def call(xs): init_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, out_values = lax.scan(f=fun2scan, init=init_values, xs=xs) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_values): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -255,20 +258,22 @@ def _cond_fun(op): for v, d in zip(dyn_vars, dyn_values): v._value = d return as_device_array(cond_fun(static_values)) + name = get_unique_name('_brainpy_object_oriented_make_while_') + def call(x=None): dyn_init = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, _ = lax.while_loop(cond_fun=_cond_fun, body_fun=_body_fun, init_val=(dyn_init, x)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -330,6 +335,8 @@ def make_cond(true_fun, false_fun, dyn_vars=None): if not isinstance(v, JaxArray): raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}') + name = get_unique_name('_brainpy_object_oriented_make_cond_') + if len(dyn_vars) > 0: def _true_fun(op): dyn_vals, static_vals = op @@ -348,15 +355,15 @@ def _false_fun(op): def call(pred, x=None): old_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, res = lax.cond(pred, _true_fun, _false_fun, (old_values, x)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -364,9 +371,9 @@ def call(pred, x=None): else: def call(pred, x=None): - turn_on_global_jit() + add_context(name) res = lax.cond(pred, true_fun, false_fun, x) - turn_off_global_jit() + del_context(name) return res return call @@ -445,6 +452,8 @@ def cond( if not isinstance(v, Variable): raise ValueError(f'Only support {Variable.__name__}, but got {type(v)}') + name = get_unique_name('_brainpy_object_oriented_cond_') + # calling the model if len(dyn_vars) > 0: def _true_fun(op): @@ -463,25 +472,25 @@ def _false_fun(op): old_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, res = lax.cond(pred=pred, true_fun=_true_fun, false_fun=_false_fun, operand=(old_values, operands)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d else: - turn_on_global_jit() + add_context(name) res = lax.cond(pred, true_fun, false_fun, operands) - turn_off_global_jit() + del_context(name) return res @@ -591,7 +600,11 @@ def ifelse( if show_code: print(codes) exec(compile(codes.strip(), '', 'exec'), code_scope) f = code_scope['f'] - return f(operands) + name = get_unique_name('_brainpy_object_oriented_ifelse_') + add_context(name) + r = f(operands) + del_context(name) + return r def for_loop(body_fun: Callable, @@ -694,22 +707,24 @@ def fun2scan(dyn_vals, x): results = body_fun(*x) return [v.value for v in dyn_vars], results + name = get_unique_name('_brainpy_object_oriented_for_loop_') + # functions init_vals = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_vals, out_vals = lax.scan(f=fun2scan, init=init_vals, xs=operands, reverse=reverse, unroll=unroll) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_vals): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_vals): v._value = d raise e for v, d in zip(dyn_vars, dyn_vals): v._value = d @@ -797,19 +812,20 @@ def _cond_fun(op): r = cond_fun(*static_vals) return r if isinstance(r, JaxArray) else r + name = get_unique_name('_brainpy_object_oriented_while_loop_') dyn_init = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, out = lax.while_loop(cond_fun=_cond_fun, body_fun=_body_fun, init_val=(dyn_init, operands)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 02b79d381..3ab28adcd 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -33,20 +33,54 @@ msg = ('JaxArray created outside of the jit function ' 'cannot be updated in JIT mode. You should ' 'mark it as brainpy.math.Variable instead.') -_global_jit_mode = False +_jax_transformation_context_ = [] -def turn_on_global_jit(): - """Turn on the global JIT mode to declare - all instantiated JaxArray cannot be updated.""" - global _global_jit_mode - _global_jit_mode = True +def add_context(name): + _jax_transformation_context_.append(name) -def turn_off_global_jit(): - """Turn off the global JIT mode.""" - global _global_jit_mode - _global_jit_mode = False + +def del_context(name=None): + try: + context = _jax_transformation_context_.pop(-1) + if name is not None: + if context != name: + raise MathError('Transformation context is different!') + # warnings.warn(, UserWarning) + except IndexError: + raise MathError('No transformation context!') + # warnings.warn('No transformation context!', UserWarning) + + +def get_context(): + if len(_jax_transformation_context_) > 0: + return _jax_transformation_context_[-1] + else: + return None + + +def check_context(arr_context): + if arr_context is None: + if len(_jax_transformation_context_) > 0: + raise MathError(f'JaxArray created outside of the transformation functions ' + f'({_jax_transformation_context_[-1]}) cannot be updated. ' + f'You should mark it as a brainpy.math.Variable instead.') + return True + else: + return False + else: + if len(_jax_transformation_context_) > 0: + if arr_context != _jax_transformation_context_[-1]: + raise MathError(f'JaxArray context "{arr_context}" differs from the JAX ' + f'transformation context "{_jax_transformation_context_[-1]}"' + '\n\n' + 'JaxArray created in one transformation function ' + 'cannot be updated another transformation function. ' + 'You should mark it as a brainpy.math.Variable instead.') + return True + else: + return False def _check_input_array(array): @@ -61,7 +95,7 @@ def _check_input_array(array): class JaxArray(object): """Multiple-dimensional array in JAX backend. """ - __slots__ = ("_value", "_outside_global_jit") + __slots__ = ("_value", "_transform_context") def __init__(self, value, dtype=None): # array value @@ -73,7 +107,7 @@ def __init__(self, value, dtype=None): value = jnp.asarray(value, dtype=dtype) self._value = value # jit mode - self._outside_global_jit = False if _global_jit_mode else True + self._transform_context = get_context() @property def value(self): @@ -86,7 +120,7 @@ def value(self, value): def update(self, value): """Update the value of this JaxArray. """ - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) if isinstance(value, JaxArray): value = value.value @@ -189,7 +223,7 @@ def __getitem__(self, index): return self.value[index] def __setitem__(self, index, value): - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) # value is JaxArray @@ -260,7 +294,7 @@ def __radd__(self, oc): def __iadd__(self, oc): # a += b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value += _check_input_array(oc) return self @@ -273,7 +307,7 @@ def __rsub__(self, oc): def __isub__(self, oc): # a -= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value - _check_input_array(oc) return self @@ -286,7 +320,7 @@ def __rmul__(self, oc): def __imul__(self, oc): # a *= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value * _check_input_array(oc) return self @@ -302,7 +336,7 @@ def __rtruediv__(self, oc): def __itruediv__(self, oc): # a /= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value / _check_input_array(oc) return self @@ -315,7 +349,7 @@ def __rfloordiv__(self, oc): def __ifloordiv__(self, oc): # a //= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value // _check_input_array(oc) return self @@ -334,7 +368,7 @@ def __rmod__(self, oc): def __imod__(self, oc): # a %= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value % _check_input_array(oc) return self @@ -347,7 +381,7 @@ def __rpow__(self, oc): def __ipow__(self, oc): # a **= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value ** _check_input_array(oc) return self @@ -360,7 +394,7 @@ def __rmatmul__(self, oc): def __imatmul__(self, oc): # a @= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value @ _check_input_array(oc) return self @@ -373,7 +407,7 @@ def __rand__(self, oc): def __iand__(self, oc): # a &= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value & _check_input_array(oc) return self @@ -386,7 +420,7 @@ def __ror__(self, oc): def __ior__(self, oc): # a |= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value | _check_input_array(oc) return self @@ -399,7 +433,7 @@ def __rxor__(self, oc): def __ixor__(self, oc): # a ^= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value ^ _check_input_array(oc) return self @@ -412,7 +446,7 @@ def __rlshift__(self, oc): def __ilshift__(self, oc): # a <<= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value << _check_input_array(oc) return self @@ -425,7 +459,7 @@ def __rrshift__(self, oc): def __irshift__(self, oc): # a >>= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value >> _check_input_array(oc) return self @@ -547,7 +581,7 @@ def dot(self, b): def fill(self, value): """Fill the array with a scalar value.""" - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = jnp.ones_like(self.value) * value @@ -675,7 +709,7 @@ def sort(self, axis=-1, kind='quicksort', order=None): but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties. """ - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self.value.sort(axis=axis, kind=kind, order=order) @@ -1513,23 +1547,6 @@ def __init__(self, value_or_size, dtype=None, batch_axis: int = None): super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis) -register_pytree_node(JaxArray, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: JaxArray(*flat_contents)) - -register_pytree_node(Variable, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: Variable(*flat_contents)) - -register_pytree_node(TrainVar, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: TrainVar(*flat_contents)) - -register_pytree_node(Parameter, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: Parameter(*flat_contents)) - - class VariableView(Variable): """A view of a Variable instance. @@ -1559,6 +1576,7 @@ class VariableView(Variable): Moreover, it's worthy to note that ``VariableView`` is not a PyTree. """ + def __init__(self, value: Variable, index): self.index = index if not isinstance(value, Variable): @@ -1700,3 +1718,25 @@ def value(self, value): f"while we got {value.dtype}.") self._value[self.index] = value.value if isinstance(value, JaxArray) else value + +def _jaxarray_unflatten(aux_data, flat_contents): + r = JaxArray(*flat_contents) + r._transform_context = aux_data[0] + return r + + +register_pytree_node(JaxArray, + lambda t: ((t.value,), (t._transform_context, )), + _jaxarray_unflatten) + +register_pytree_node(Variable, + lambda t: ((t.value,), None), + lambda aux_data, flat_contents: Variable(*flat_contents)) + +register_pytree_node(TrainVar, + lambda t: ((t.value,), None), + lambda aux_data, flat_contents: TrainVar(*flat_contents)) + +register_pytree_node(Parameter, + lambda t: ((t.value,), None), + lambda aux_data, flat_contents: Parameter(*flat_contents)) diff --git a/brainpy/math/jit.py b/brainpy/math/jit.py index 9e22d7dd0..01836de2c 100644 --- a/brainpy/math/jit.py +++ b/brainpy/math/jit.py @@ -15,12 +15,13 @@ try: from jax.errors import UnexpectedTracerError, ConcretizationTypeError except ImportError: - from jax.core import UnexpectedTracerError + from jax.core import UnexpectedTracerError, ConcretizationTypeError from brainpy import errors from brainpy.base.base import Base +from brainpy.base.naming import get_unique_name from brainpy.base.collector import TensorCollector -from brainpy.math.jaxarray import JaxArray, turn_on_global_jit, turn_off_global_jit +from brainpy.math.jaxarray import JaxArray, add_context, del_context from brainpy.tools.codes import change_func_name __all__ = [ @@ -38,22 +39,24 @@ def jitted_func(variable_data, *args, **kwargs): changes = vars.dict() return out, changes + name = get_unique_name('_brainpy_object_oriented_jit_') + def call(*args, **kwargs): variable_data = vars.dict() try: - turn_on_global_jit() + add_context(name) out, changes = jitted_func(variable_data, *args, **kwargs) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for key, v in vars.items(): v._value = variable_data[key] raise errors.JaxTracerError(variables=vars) from e except ConcretizationTypeError as e: - turn_off_global_jit() + del_context(name) for key, v in vars.items(): v._value = variable_data[key] raise errors.ConcretizationTypeError() from e except Exception as e: - turn_off_global_jit() + del_context(name) for key, v in vars.items(): v._value = variable_data[key] raise e for key, v in vars.items(): v._value = changes[key] @@ -64,11 +67,12 @@ def call(*args, **kwargs): def _make_jit_without_vars(func, static_argnames=None, device=None, f_name=None): jit_f = jax.jit(func, static_argnames=static_argnames, device=device) + name = get_unique_name('_jax_functional_jit_') def call(*args, **kwargs): - turn_on_global_jit() + add_context(name) r = jit_f(*args, **kwargs) - turn_off_global_jit() + del_context(name) return r return change_func_name(name=f_name, f=call) if f_name else call diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index 1ac3b2eb0..838772b93 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -128,9 +128,9 @@ def register_op( out_shapes=eval_shape, apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) - def fixed_op(*inputs): + def fixed_op(*inputs, **info): inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs]) - res = f.bind(*inputs) + res = f.bind(*inputs, **info) return res[0] if len(res) == 1 else res return fixed_op diff --git a/brainpy/math/tests/test_transformation_context.py b/brainpy/math/tests/test_transformation_context.py new file mode 100644 index 000000000..2732afa83 --- /dev/null +++ b/brainpy/math/tests/test_transformation_context.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + + +import unittest + +import brainpy as bp +import brainpy.math as bm + + +class TestJIT(unittest.TestCase): + def test1(self): + @bm.jit + def f1(a): + a[:] = 1. + return a + + a = bm.zeros(10) + with self.assertRaises(bp.errors.MathError): + print(f1(a)) + + def test2(self): + @bm.jit + def f1(a): + b = a + 1 + + @bm.jit + def f2(x): + x.value = 1. + return x + + return f2(b) + + with self.assertRaises(bp.errors.MathError): + print(f1(bm.ones(2))) + + def test3(self): + @bm.jit + def f1(a): + return a + 1 + + @bm.jit + def f2(b): + b[:] = 1. + return b + + with self.assertRaises(bp.errors.MathError): + print(f2(f1(bm.ones(2)))) + + def test4(self): + @bm.jit + def f2(a): + b = bm.ones(1) + b += 10 + return a + b + + print(f2(bm.ones(1))) diff --git a/brainpy/tools/others/numba_util.py b/brainpy/tools/others/numba_util.py index db01c27f2..ce2fe370c 100644 --- a/brainpy/tools/others/numba_util.py +++ b/brainpy/tools/others/numba_util.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- - - +import numba import numpy as np try: from numba import njit @@ -11,6 +10,7 @@ __all__ = [ 'numba_jit', 'numba_seed', + 'numba_range', 'SUPPORT_NUMBA', ] @@ -38,3 +38,4 @@ def numba_seed(seed): _seed(seed) +numba_range = numba.prange if SUPPORT_NUMBA else range diff --git a/docs/others/citing.rst b/docs/others/publications.rst similarity index 79% rename from docs/others/citing.rst rename to docs/others/publications.rst index 34edadd9e..510a313e1 100644 --- a/docs/others/citing.rst +++ b/docs/others/publications.rst @@ -1,10 +1,14 @@ If BrainPy has been significant in your research, and you would like to acknowledge -the project in your academic publication, we suggest citing the following paper: +the project in your academic publication, we suggest citing the following papers: + + + +- Chaoming Wang, Xiaoyu Chen, Tianqiu Zhang, Si Wu. *BrainPy: a flexible, integrative, efficient, and extensible framework towards general-purpose brain dynamics programming*. In submission. + - Wang, C., Jiang, Y., Liu, X., Lin, X., Zou, X., Ji, Z., & Wu, S. (2021, December). *A Just-In-Time Compilation Approach for Neural Dynamics Simulation*. In International Conference on Neural Information Processing (pp. 15-26). Springer, Cham. -In BibTeX format: .. code-block:: diff --git a/extensions/brainpylib/custom_op/cpu.py b/extensions/brainpylib/custom_op/cpu.py index 7a002499d..a32f1b0a4 100644 --- a/extensions/brainpylib/custom_op/cpu.py +++ b/extensions/brainpylib/custom_op/cpu.py @@ -2,10 +2,10 @@ import ctypes -import numba +from jax import dtypes from jax.abstract_arrays import ShapedArray from jax.lib import xla_client -from numba import types +from numba import types, carray, cfunc _lambda_no = 0 ctypes.pythonapi.PyCapsule_New.argtypes = [ @@ -16,23 +16,29 @@ ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object -def _compile_cpu_signature(func, input_dtypes, input_shapes, - output_dtypes, output_shapes): +def _compile_cpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + debug=False +): code_scope = dict( func_to_call=func, input_shapes=input_shapes, input_dtypes=input_dtypes, output_shapes=output_shapes, output_dtypes=output_dtypes, - carray=numba.carray, + carray=carray, ) args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' for i in range(len(input_shapes)) ] args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' for i in range(len(output_shapes)) ] @@ -45,15 +51,14 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): {args_in} ) func_to_call(args_out, args_in) - '''.format(args_in=",\n ".join(args_in), - args_out=",\n ".join(args_out)) - # print(code_string) + '''.format(args_in="\n ".join(args_in), + args_out="\n ".join(args_out)) + if debug: print(code_string) exec(compile(code_string.strip(), '', 'exec'), code_scope) new_f = code_scope['xla_cpu_custom_call_target'] - wrapper = numba.cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr))) - xla_c_rule = wrapper(new_f) + xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), + types.CPointer(types.voidptr)))(new_f) target_name = xla_c_rule.native_name.encode("ascii") capsule = ctypes.pythonapi.PyCapsule_New( xla_c_rule.address, # A CFFI pointer to a function @@ -64,12 +69,18 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): return target_name -def func_cpu_translation(func, abs_eval_fn, c, *inputs): +def func_cpu_translation(func, abs_eval_fn, c, *inputs, **info): input_shapes = [c.get_shape(arg) for arg in inputs] + for v in info.values(): + if not isinstance(v, (int, float)): + raise TypeError + input_shapes.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) + input_shapes = tuple(input_shapes) input_dtypes = tuple(shape.element_type() for shape in input_shapes) input_dimensions = tuple(shape.dimensions() for shape in input_shapes) output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_shapes)) + for shape in input_shapes[:len(inputs)]), + **info) output_shapes = tuple(array.shape for array in output_abstract_arrays) output_dtypes = tuple(array.dtype for array in output_abstract_arrays) output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) @@ -77,13 +88,15 @@ def func_cpu_translation(func, abs_eval_fn, c, *inputs): for arg in zip(output_dtypes, output_shapes, output_layouts)] xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes) target_name = _compile_cpu_signature(func, - input_dtypes, input_dimensions, - output_dtypes, output_shapes) + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes) return xla_client.ops.CustomCallWithLayout( c, target_name, - operands=inputs, + operands=inputs + tuple(xla_client.ops.ConstantLiteral(c, i) for i in info.values()), operand_shapes_with_layout=input_shapes, shape_with_layout=xla_output_shape, ) diff --git a/extensions/brainpylib/custom_op/cuda.py b/extensions/brainpylib/custom_op/cuda.py deleted file mode 100644 index 4b66349aa..000000000 --- a/extensions/brainpylib/custom_op/cuda.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- - -import ctypes -import ctypes.util -import sys - -from cffi import FFI -from numba import cuda -from numba import types - - -class Dl_info(ctypes.Structure): - """ - Structure of the Dl_info returned by the CFFI of dl.dladdr - """ - - _fields_ = ( - ("dli_fname", ctypes.c_char_p), - ("dli_fbase", ctypes.c_void_p), - ("dli_sname", ctypes.c_char_p), - ("dli_saddr", ctypes.c_void_p), - ) - - -# Find the dynamic linker library path. Only works on unix-like os -libdl_path = ctypes.util.find_library("dl") -if libdl_path: - # Load the dynamic linker dynamically - libdl = ctypes.CDLL(libdl_path) - - # Define dladdr to get the pointer to a symbol in a shared - # library already loaded. - # https://man7.org/linux/man-pages/man3/dladdr.3.html - libdl.dladdr.argtypes = (ctypes.c_void_p, ctypes.POINTER(Dl_info)) - # restype is None as it returns by reference -else: - # On Windows it is nontrivial to have libdl, so we disable everything about - # it and use other ways to find paths of libraries - libdl = None - - -def find_path_of_symbol_in_library(symbol): - if libdl is None: - raise ValueError("libdl not found.") - - info = Dl_info() - result = libdl.dladdr(symbol, ctypes.byref(info)) - if result and info.dli_fname: - return info.dli_fname.decode(sys.getfilesystemencoding()) - else: - raise ValueError("Cannot determine path of Library.") - - -try: - _libcuda = cuda.driver.find_driver() - if sys.platform == "win32": - libcuda_path = ctypes.util.find_library(_libcuda._name) - else: - libcuda_path = find_path_of_symbol_in_library(_libcuda.cuMemcpy) - numba_cffi_loaded = True -except Exception: - numba_cffi_loaded = False - - -if numba_cffi_loaded: - # functions needed - ffi = FFI() - ffi.cdef("int cuMemcpy(void* dst, void* src, unsigned int len, int type);") - ffi.cdef("int cuMemcpyAsync(void* dst, void* src, unsigned int len, int type, void* stream);") - ffi.cdef("int cuStreamSynchronize(void* stream);") - ffi.cdef("int cudaMallocHost(void** ptr, size_t size);") - ffi.cdef("int cudaFreeHost(void* ptr);") - - # load libraray - # could ncuda.driver.find_library() - libcuda = ffi.dlopen(libcuda_path) - cuMemcpy = libcuda.cuMemcpy - cuMemcpyAsync = libcuda.cuMemcpyAsync - cuStreamSynchronize = libcuda.cuStreamSynchronize - - memcpyHostToHost = types.int32(0) - memcpyHostToDevice = types.int32(1) - memcpyDeviceToHost = types.int32(2) - memcpyDeviceToDevice = types.int32(3) diff --git a/extensions/brainpylib/custom_op/gpu.py b/extensions/brainpylib/custom_op/gpu.py index fd833d985..83839ead9 100644 --- a/extensions/brainpylib/custom_op/gpu.py +++ b/extensions/brainpylib/custom_op/gpu.py @@ -1,11 +1,89 @@ # -*- coding: utf-8 -*- -import numba +import ctypes +import ctypes.util +import sys + import numpy as np +from cffi import FFI from jax.abstract_arrays import ShapedArray from jax.lib import xla_client +from jax import dtypes +from numba import cuda, cfunc, types + + +class Dl_info(ctypes.Structure): + """ + Structure of the Dl_info returned by the CFFI of dl.dladdr + """ + + _fields_ = ( + ("dli_fname", ctypes.c_char_p), + ("dli_fbase", ctypes.c_void_p), + ("dli_sname", ctypes.c_char_p), + ("dli_saddr", ctypes.c_void_p), + ) + + +# Find the dynamic linker library path. Only works on unix-like os +libdl_path = ctypes.util.find_library("dl") +if libdl_path: + # Load the dynamic linker dynamically + libdl = ctypes.CDLL(libdl_path) + + # Define dladdr to get the pointer to a symbol in a shared + # library already loaded. + # https://man7.org/linux/man-pages/man3/dladdr.3.html + libdl.dladdr.argtypes = (ctypes.c_void_p, ctypes.POINTER(Dl_info)) + # restype is None as it returns by reference +else: + # On Windows it is nontrivial to have libdl, so we disable everything about + # it and use other ways to find paths of libraries + libdl = None + + +def find_path_of_symbol_in_library(symbol): + if libdl is None: + raise ValueError("libdl not found.") + + info = Dl_info() + result = libdl.dladdr(symbol, ctypes.byref(info)) + if result and info.dli_fname: + return info.dli_fname.decode(sys.getfilesystemencoding()) + else: + raise ValueError("Cannot determine path of Library.") + + +try: + _libcuda = cuda.driver.find_driver() + if sys.platform == "win32": + libcuda_path = ctypes.util.find_library(_libcuda._name) + else: + libcuda_path = find_path_of_symbol_in_library(_libcuda.cuMemcpy) + numba_cffi_loaded = True +except Exception: + numba_cffi_loaded = False + +if numba_cffi_loaded: + # functions needed + ffi = FFI() + ffi.cdef("int cuMemcpy(void* dst, void* src, unsigned int len, int type);") + ffi.cdef("int cuMemcpyAsync(void* dst, void* src, unsigned int len, int type, void* stream);") + ffi.cdef("int cuStreamSynchronize(void* stream);") + ffi.cdef("int cudaMallocHost(void** ptr, size_t size);") + ffi.cdef("int cudaFreeHost(void* ptr);") + + # load libraray + # could ncuda.driver.find_library() + libcuda = ffi.dlopen(libcuda_path) + cuMemcpy = libcuda.cuMemcpy + cuMemcpyAsync = libcuda.cuMemcpyAsync + cuStreamSynchronize = libcuda.cuStreamSynchronize -from .cuda import * + memcpyHostToHost = types.int32(0) + memcpyHostToDevice = types.int32(1) + memcpyDeviceToHost = types.int32(2) + memcpyDeviceToDevice = types.int32(3) _lambda_no = 0 ctypes.pythonapi.PyCapsule_New.argtypes = [ @@ -16,8 +94,14 @@ ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object -def _compile_gpu_signature(func, input_dtypes, input_shapes, - output_dtypes, output_shapes): +def _compile_gpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + debug=False +): input_byte_size = tuple( np.prod(shape) * dtype.itemsize for (shape, dtype) in zip(input_shapes, input_dtypes) @@ -46,7 +130,7 @@ def _compile_gpu_signature(func, input_dtypes, input_shapes, ) args_in = [ - f'empty(input_shapes[{i}], dtype=input_dtypes[{i}])' + f'empty(input_shapes[{i}], dtype=input_dtypes[{i}]),' for i in range(len(input_shapes)) ] cuMemcpyAsync_in = [ @@ -54,7 +138,7 @@ def _compile_gpu_signature(func, input_dtypes, input_shapes, for i in range(len(input_shapes)) ] args_out = [ - f'empty(output_shapes[{i}], dtype=output_dtypes[{i}])' + f'empty(output_shapes[{i}], dtype=output_dtypes[{i}]),' for i in range(len(output_shapes)) ] cuMemcpyAsync_out = [ @@ -75,15 +159,15 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len): cuStreamSynchronize(stream) func_to_call(args_out, args_in) {cuMemcpyAsync_out} - '''.format(args_in=",\n ".join(args_in), - args_out=",\n ".join(args_out), + '''.format(args_in="\n ".join(args_in), + args_out="\n ".join(args_out), cuMemcpyAsync_in="\n ".join(cuMemcpyAsync_in), cuMemcpyAsync_out="\n ".join(cuMemcpyAsync_out)) - # print(code_string) + if debug: print(code_string) exec(compile(code_string.strip(), '', 'exec'), code_scope) new_f = code_scope['xla_gpu_custom_call_target'] - wrapper = numba.cfunc(types.void( + wrapper = cfunc(types.void( types.voidptr, types.CPointer(types.voidptr), types.voidptr, types.uint64)) @@ -98,15 +182,18 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len): return target_name -def func_gpu_translation(func, abs_eval_fn, c, *inputs): +def func_gpu_translation(func, abs_eval_fn, c, *inputs, **info): if not numba_cffi_loaded: raise RuntimeError("Numba cffi could not be loaded.") input_shapes = [c.get_shape(arg) for arg in inputs] + for v in info.values(): + input_shapes.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) input_dtypes = tuple(shape.element_type() for shape in input_shapes) input_dimensions = tuple(shape.dimensions() for shape in input_shapes) output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_shapes)) + for shape in input_shapes[:len(inputs)]), + **info) output_shapes = tuple(array.shape for array in output_abstract_arrays) output_dtypes = tuple(array.dtype for array in output_abstract_arrays) output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) @@ -114,13 +201,15 @@ def func_gpu_translation(func, abs_eval_fn, c, *inputs): for arg in zip(output_dtypes, output_shapes, output_layouts)] xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes) target_name = _compile_gpu_signature(func, - input_dtypes, input_dimensions, - output_dtypes, output_shapes) + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes) return xla_client.ops.CustomCallWithLayout( c, target_name, - operands=inputs, + operands=inputs + tuple(xla_client.ops.ConstantLiteral(c, i) for i in info.values()), operand_shapes_with_layout=input_shapes, shape_with_layout=xla_output_shape, ) diff --git a/extensions/brainpylib/custom_op/regis_op.py b/extensions/brainpylib/custom_op/regis_op.py index cfc09ca6e..683923d1a 100644 --- a/extensions/brainpylib/custom_op/regis_op.py +++ b/extensions/brainpylib/custom_op/regis_op.py @@ -71,9 +71,9 @@ def register_op( cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) # output shape evaluation function - def abs_eval_rule(*input_shapes): + def abs_eval_rule(*input_shapes, **info): if callable(out_shapes): - shapes = out_shapes(*input_shapes) + shapes = out_shapes(*input_shapes, **info) elif isinstance(out_shapes, ShapedArray): shapes = [out_shapes] elif isinstance(out_shapes, (tuple, list)): @@ -95,17 +95,18 @@ def abs_eval_rule(*input_shapes): return shapes # output evaluation function - def eval_rule(*inputs): + def eval_rule(*inputs, **info): # compute the output shapes - output_shapes = abs_eval_rule(*inputs) + output_shapes = abs_eval_rule(*inputs, **info) # Preallocate the outputs outputs = tuple(np.zeros(shape.shape, dtype=shape.dtype) for shape in output_shapes) # convert inputs to a tuple inputs = tuple(np.asarray(arg) for arg in inputs) + inputs += tuple(np.asarray(i) for i in info.values()) # call the kernel cpu_func(outputs, inputs) # Return the outputs - return tuple(outputs) + return outputs[0] if len(outputs) == 1 else tuple(outputs) # cpu function prim.def_abstract_eval(abs_eval_rule) diff --git a/extensions/brainpylib/custom_op/tests/a.py b/extensions/brainpylib/custom_op/tests/a.py new file mode 100644 index 000000000..77a66bb97 --- /dev/null +++ b/extensions/brainpylib/custom_op/tests/a.py @@ -0,0 +1,56 @@ + +import brainpy.math as bm +import brainpy as bp +from jax.abstract_arrays import ShapedArray +import numba + +bm.set_platform('cpu') + +def try1(): + def abs_eval(events, indices, indptr, *, weight, post_num): + return ShapedArray((post_num,), bm.float64) + + def con_compute(outs, ins): + post_val, = outs + post_val.fill(0) + events, indices, indptr, weight, _ = ins + # weight = weight[()] + weight = weight + print(weight) + for i in range(events.size): + if events[i]: + for j in numba.prange(indptr[i], indptr[i + 1]): + post_val[indices[j]] += weight + + event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) + + events = bm.random.RandomState(123).rand(10) < 0.2 + indices, indptr = bp.conn.FixedProb(0.1, seed=123)(10, 20).require('pre2post') + # print(bm.jit(, static_argnames=('weight', 'post_num'))(events, indices, indptr, weight=1., post_num=20)) + print(event_sum(events, indices, indptr, weight=1., post_num=20)) + + +def try2(): + def abs_eval(events, indices, indptr, post_val, weight): + return post_val + + def con_compute(outs, ins): + post_val, = outs + events, indices, indptr, _, weight = ins + weight = weight[()] + for i in range(events.size): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + index = indices[j] + post_val[index] += weight + + event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) + + events = bm.random.RandomState(123).rand(10) < 0.2 + indices, indptr = bp.conn.FixedProb(0.1, seed=123)(10, 20).require('pre2post') + print(bm.jit(event_sum)(events, indices, indptr, bm.zeros(20), 1.)) + + +if __name__ == '__main__': + try1() + # try2() diff --git a/extensions/brainpylib/custom_op/tests/ei_net.py b/extensions/brainpylib/custom_op/tests/ei_net.py new file mode 100644 index 000000000..2a913d69a --- /dev/null +++ b/extensions/brainpylib/custom_op/tests/ei_net.py @@ -0,0 +1,80 @@ +import brainpy.math as bm +import brainpy as bp +from jax.abstract_arrays import ShapedArray + +bm.set_platform('cpu') + + +def abs_eval(events, indices, indptr, *, weight, post_num): + return ShapedArray((post_num,), bm.float32) + + +def con_compute(outs, ins): + post_val, = outs + post_val.fill(0) + events, indices, indptr, weight, _ = ins + weight = weight[()] + for i in range(events.size): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + index = indices[j] + post_val[index] += weight + + +event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute, apply_cpu_func_to_gpu=True) + + +class ExponentialV2(bp.dyn.TwoEndConn): + """Exponential synapse model using customized operator written in C++.""" + + def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): + super(ExponentialV2, self).__init__(pre=pre, post=post, conn=conn) + self.check_pre_attrs('spike') + self.check_post_attrs('input', 'V') + + # parameters + self.E = E + self.tau = tau + self.delay = delay + self.g_max = g_max + self.pre2post = self.conn.require('pre2post') + + # variables + self.g = bm.Variable(bm.zeros(self.post.num)) + + # function + self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto') + + def update(self, tdi): + self.g.value = self.integral(self.g, tdi.t, tdi.dt) + self.g += event_sum(self.pre.spike, + self.pre2post[0], + self.pre2post[1], + weight=self.g_max, + post_num=self.post.num) + self.post.input += self.g * (self.E - self.post.V) + + +class EINet(bp.dyn.Network): + def __init__(self, scale): + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto') + I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto') + + # synapses + E2E = ExponentialV2(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + E2I = ExponentialV2(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + I2E = ExponentialV2(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) + I2I = ExponentialV2(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) + + super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) + + +net2 = EINet(scale=10.) +runner2 = bp.dyn.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) +t, _ = runner2.predict(100., eval_time=True) +print(t) + + diff --git a/extensions/setup_mac.py b/extensions/setup_mac.py index 1450ee46a..8bfc9c0a3 100644 --- a/extensions/setup_mac.py +++ b/extensions/setup_mac.py @@ -3,7 +3,9 @@ import os import re import glob +import sys +import pybind11 from pybind11.setup_helpers import Pybind11Extension from setuptools import find_packages, setup from setuptools.command.build_ext import build_ext @@ -19,10 +21,10 @@ # extension modules ext_modules = [ Pybind11Extension("brainpylib/cpu_ops", - sources=["lib/cpu_ops.cc"] + glob.glob("lib/*_cpu.cc"), + sources=glob.glob("lib/cpu_*.cc"), cxx_std=11, - # extra_link_args=["-rpath", "/Users/ztqakita/miniforge3/lib"], # m1 - extra_link_args=["-rpath", "/Users/ztqakita/opt/miniconda3/lib"], # intel + # extra_link_args=["-rpath", os.environ["CONDA_PREFIX"] + "/lib"], + extra_link_args=["-rpath", re.sub('/lib/.*', '/lib', sys.path[1])], define_macros=[('VERSION_INFO', __version__)]), ]