From ef7e7b5de50996c8924382256ffe56432331132e Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 5 May 2022 11:43:16 +0800 Subject: [PATCH 1/7] updates --- brainpy/nn/base.py | 11 ++++++++--- brainpy/nn/nodes/base/dense.py | 7 +++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/brainpy/nn/base.py b/brainpy/nn/base.py index 7a13430b4..4872624f0 100644 --- a/brainpy/nn/base.py +++ b/brainpy/nn/base.py @@ -123,7 +123,12 @@ def __init__( # parameters if input_shape is not None: - self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)} + if input_shape[0] is None: + input_shape = tools.to_size(input_shape) + else: + input_shape = (None,) + tools.to_size(input_shape) + self._feedforward_shapes = {self.name: input_shape} + self._init_ff_conn() def __repr__(self): return (f"{type(self).__name__}(name={self.name}, " @@ -424,12 +429,12 @@ def _init_fb_conn(self): self._is_fb_initialized = True @not_implemented - def init_fb_conn(self): + def init_fb_conn(self, fb_shapes): """Initialize the feedback connections. This function will be called only once.""" raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.') - def init_ff_conn(self): + def init_ff_conn(self, ff_shapes): """Initialize the feedforward connections. This function will be called only once.""" raise NotImplementedError('Please implement the feedforward initialization.') diff --git a/brainpy/nn/nodes/base/dense.py b/brainpy/nn/nodes/base/dense.py index 7f7646662..ee8ecc781 100644 --- a/brainpy/nn/nodes/base/dense.py +++ b/brainpy/nn/nodes/base/dense.py @@ -70,7 +70,10 @@ def __init__( self.bias = None self.Wfb = None - def init_ff_conn(self): + if self.feedforward_shapes is not None: + self.init_ff_conn(self.feedforward_shapes) + + def init_ff_conn(self, ff_shapes): # shapes other_size, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) # set output size @@ -83,7 +86,7 @@ def init_ff_conn(self): self.Wff = bm.TrainVar(self.Wff) self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - def init_fb_conn(self): + def init_fb_conn(self, fb_shapes): other_size, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) # initialize feedback weights From 961ed60317accc49dc4213180e0116276f26de41 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 5 May 2022 11:45:35 +0800 Subject: [PATCH 2/7] Revert "updates" This reverts commit ef7e7b5de50996c8924382256ffe56432331132e. --- brainpy/nn/base.py | 11 +++-------- brainpy/nn/nodes/base/dense.py | 7 ++----- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/brainpy/nn/base.py b/brainpy/nn/base.py index 058b16f4c..483780072 100644 --- a/brainpy/nn/base.py +++ b/brainpy/nn/base.py @@ -123,12 +123,7 @@ def __init__( # parameters if input_shape is not None: - if input_shape[0] is None: - input_shape = tools.to_size(input_shape) - else: - input_shape = (None,) + tools.to_size(input_shape) - self._feedforward_shapes = {self.name: input_shape} - self._init_ff_conn() + self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)} def __repr__(self): return (f"{type(self).__name__}(name={self.name}, " @@ -429,12 +424,12 @@ def _init_fb_conn(self): self._is_fb_initialized = True @not_implemented - def init_fb_conn(self, fb_shapes): + def init_fb_conn(self): """Initialize the feedback connections. This function will be called only once.""" raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.') - def init_ff_conn(self, ff_shapes): + def init_ff_conn(self): """Initialize the feedforward connections. This function will be called only once.""" raise NotImplementedError('Please implement the feedforward initialization.') diff --git a/brainpy/nn/nodes/base/dense.py b/brainpy/nn/nodes/base/dense.py index b73fa6b32..bf9cdbbaf 100644 --- a/brainpy/nn/nodes/base/dense.py +++ b/brainpy/nn/nodes/base/dense.py @@ -70,10 +70,7 @@ def __init__( self.bias = None self.Wfb = None - if self.feedforward_shapes is not None: - self.init_ff_conn(self.feedforward_shapes) - - def init_ff_conn(self, ff_shapes): + def init_ff_conn(self): # shapes other_size, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) # set output size @@ -86,7 +83,7 @@ def init_ff_conn(self, ff_shapes): self.Wff = bm.TrainVar(self.Wff) self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - def init_fb_conn(self, fb_shapes): + def init_fb_conn(self): other_size, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) # initialize feedback weights From fffa9f7650c63e2a70ea1aca84e4fb4a2152425e Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 17 Oct 2022 18:10:50 +0800 Subject: [PATCH 3/7] go back to connection establish to numba methods --- brainpy/connect/base.py | 69 +++--- brainpy/connect/custom_conn.py | 33 +-- brainpy/connect/random_conn.py | 242 ++++++++++++++------- brainpy/connect/regular_conn.py | 68 +++--- brainpy/connect/tests/test_random_conn.py | 10 +- brainpy/connect/tests/test_regular_conn.py | 18 +- brainpy/connect/utils.py | 38 ---- brainpy/tools/others/numba_util.py | 5 +- 8 files changed, 242 insertions(+), 241 deletions(-) delete mode 100644 brainpy/connect/utils.py diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py index 904e96ff2..5ddff168a 100644 --- a/brainpy/connect/base.py +++ b/brainpy/connect/base.py @@ -116,13 +116,13 @@ def build_conn(self): import brainpy as bp class MyConnector(bp.conn.TwoEndConnector): - def build_mat(self, pre_size, post_size): + def build_mat(self, ): return conn_matrix - def build_csr(self, pre_size, post_size): + def build_csr(self, ): return post_ids, inptr - def build_coo(self, pre_size, post_size): + def build_coo(self, ): return pre_ids, post_ids """ @@ -196,8 +196,6 @@ def check(self, structures: Union[Tuple, List, str]): raise ConnectorError(f'Unknown synapse structure "{n}". ' f'Only {SUPPORTED_SYN_STRUCTURE} is supported.') - - def _return_by_mat(self, structures, mat, all_data: dict): assert mat.ndim == 2 if (CONN_MAT in structures) and (CONN_MAT not in all_data): @@ -332,70 +330,56 @@ def build_conn(self): """ pass - def require(self, *sizes_or_structures): - sizes_or_structures = list(sizes_or_structures) - pre_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None - post_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None - structures = sizes_or_structures - if isinstance(post_size, str): - structures.insert(0, post_size) - post_size = None - if isinstance(pre_size, str): - structures.insert(0, pre_size) - pre_size = None - - version2_style = (pre_size is not None) and (post_size is not None) - if not version2_style: - try: - assert self.pre_num is not None and self.post_num is not None - except AssertionError: - raise ConnectorError(f'self.pre_num or self.post_num is not defined. ' - f'Please use self.__call__(pre_size, post_size) ' - f'before requiring connection data.') - if pre_size is None: - pre_size = self.pre_size - if post_size is None: - post_size = self.post_size + def require(self, *structures): + try: + assert self.pre_num is not None and self.post_num is not None + except AssertionError: + raise ConnectorError(f'self.pre_num or self.post_num is not defined. ' + f'Please use self.__call__() ' + f'before requiring connection data.') self.check(structures) if self.is_version2_style: if len(structures) == 1: if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'): - return self.build_csr(pre_size, post_size) + r = self.build_csr() + return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'): - return self.build_mat(pre_size, post_size) + return bm.asarray(self.build_mat(), dtype=MAT_DTYPE) elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'): - return self.build_coo(pre_size, post_size)[0] + return bm.asarray(self.build_coo()[0], dtype=IDX_DTYPE) elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): - return self.build_coo(pre_size, post_size)[1] + return bm.asarray(self.build_coo()[1], dtype=IDX_DTYPE) elif len(structures) == 2: if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): - return self.build_coo(pre_size, post_size) + r = self.build_coo() + return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) conn_data = dict(csr=None, ij=None, mat=None) if not hasattr(self.build_coo, 'not_customized'): - conn_data['ij'] = self.build_coo(pre_size, post_size) + conn_data['ij'] = self.build_coo() elif not hasattr(self.build_csr, 'not_customized'): - conn_data['csr'] = self.build_csr(pre_size, post_size) + conn_data['csr'] = self.build_csr() elif not hasattr(self.build_mat, 'not_customized'): - conn_data['mat'] = self.build_mat(pre_size, post_size) + conn_data['mat'] = self.build_mat() + else: conn_data = self.build_conn() return self.make_returns(structures, conn_data) - def requires(self, *sizes_or_structures): - return self.require(*sizes_or_structures) + def requires(self, *structures): + return self.require(*structures) @tools.not_customized - def build_mat(self, pre_size=None, post_size=None): + def build_mat(self): pass @tools.not_customized - def build_csr(self, pre_size=None, post_size=None): + def build_csr(self): pass @tools.not_customized - def build_coo(self, pre_size=None, post_size=None): + def build_coo(self): pass @@ -425,7 +409,6 @@ def __call__(self, pre_size, post_size=None): else: post_size = tuple(post_size) self.pre_size, self.post_size = pre_size, post_size - self.pre_num = tools.size2num(self.pre_size) self.post_num = tools.size2num(self.post_size) return self diff --git a/brainpy/connect/custom_conn.py b/brainpy/connect/custom_conn.py index 7cac4cd72..e452061e5 100644 --- a/brainpy/connect/custom_conn.py +++ b/brainpy/connect/custom_conn.py @@ -7,7 +7,6 @@ from brainpy import tools from brainpy.errors import ConnectorError from .base import * -from .utils import * __all__ = [ 'MatConn', @@ -34,11 +33,9 @@ def __call__(self, pre_size, post_size): assert self.post_num == tools.size2num(post_size) return self - def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - assert self.conn_mat.shape[0] == pre_num - assert self.conn_mat.shape[1] == post_num + def build_mat(self): + assert self.conn_mat.shape[0] == self.pre_num + assert self.conn_mat.shape[1] == self.post_num return self.conn_mat @@ -68,14 +65,12 @@ def __call__(self, pre_size, post_size): f'the maximum id ({self.max_post}) of self.post_ids.') return self - def build_coo(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num <= self.max_pre: - raise ConnectorError(f'pre_num ({pre_num}) should be greater than ' + def build_coo(self): + if self.pre_num <= self.max_pre: + raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than ' f'the maximum id ({self.max_pre}) of self.pre_ids.') - if post_num <= self.max_post: - raise ConnectorError(f'post_num ({post_num}) should be greater than ' + if self.post_num <= self.max_post: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' f'the maximum id ({self.max_post}) of self.post_ids.') return self.pre_ids, self.post_ids @@ -91,16 +86,12 @@ def __init__(self, indices, inptr): self.pre_num = self.inptr.size - 1 self.max_post = bm.max(self.indices) - def build_csr(self, pre_size=None, post_size=None): - pre_size = get_pre_size(self, pre_size) - post_size = get_post_size(self, post_size) - pre_num = np.prod(pre_size) - post_num = np.prod(post_size) - if pre_num != self.pre_num: + def build_csr(self): + if self.pre_num != self.pre_num: raise ConnectorError(f'(pre_size, post_size) is inconsistent with ' f'the shape of the sparse matrix.') - if post_num <= self.max_post: - raise ConnectorError(f'post_num ({post_num}) should be greater than ' + if self.post_num <= self.max_post: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' f'the maximum id ({self.max_post}) of self.post_ids.') return self.indices, self.inptr diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py index 4daf894f8..bc0eafa33 100644 --- a/brainpy/connect/random_conn.py +++ b/brainpy/connect/random_conn.py @@ -1,14 +1,11 @@ # -*- coding: utf-8 -*- -import jax import numpy as np from typing import Optional -from brainpy import math as bm from brainpy.errors import ConnectorError -from brainpy.tools.others import numba_seed, numba_jit, SUPPORT_NUMBA, format_seed +from brainpy.tools.others import numba_seed, numba_jit, numba_range, SUPPORT_NUMBA, format_seed from .base import * -from .utils import * __all__ = [ 'FixedProb', @@ -27,7 +24,6 @@ class FixedProb(TwoEndConnector): """Connect the post-synaptic neurons with fixed probability. - .. versionchanged:: 2.2.3.2 Parameters ---------- @@ -37,11 +33,16 @@ class FixedProb(TwoEndConnector): The ratio of pre-synaptic neurons to connect. include_self : bool Whether create (i, i) conn? + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + seed : optional, int Seed the random generator. """ - def __init__(self, prob, pre_ratio=1., include_self=True, seed=None): + def __init__(self, prob, pre_ratio=1., include_self=True, allow_multi_conn=False, seed=None): super(FixedProb, self).__init__() assert 0. <= prob <= 1. assert 0. <= pre_ratio <= 1. @@ -49,57 +50,80 @@ def __init__(self, prob, pre_ratio=1., include_self=True, seed=None): self.pre_ratio = pre_ratio self.include_self = include_self self.seed = format_seed(seed) - self.rng = bm.random.RandomState(seed=self.seed) + self.rng = np.random.RandomState(seed=self.seed) + self.allow_multi_conn = allow_multi_conn def __repr__(self): return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' - f'include_self={self.include_self}, seed={self.seed})') - - def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - pre_state = self.rng.rand(pre_num, 1) < self.pre_ratio - mat = (self.rng.rand(pre_num, post_num) < self.prob) * pre_state - if not self.include_self: - bm.fill_diagonal(mat, False) - return mat.astype(MAT_DTYPE) + f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' + f'seed={self.seed})') - def _iii(self, pre_size, post_size): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if (not self.include_self) and (pre_num != post_num): - raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + def _iii(self): + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' f'But `include_self` is set to True.') - post_num_to_select = int(post_num * self.prob) - post_ids = bm.arange(post_num) + if self.pre_ratio < 1.: - pre_num_to_select = int(pre_num * self.pre_ratio) - pre_ids = self.rng.choice(pre_num, size=pre_num_to_select, replace=False) + pre_num_to_select = int(self.pre_num * self.pre_ratio) + pre_ids = self.rng.choice(self.pre_num, size=pre_num_to_select, replace=False) + else: + pre_num_to_select = self.pre_num + pre_ids = np.arange(self.pre_num) + + post_num_total = self.post_num + post_num_to_select = int(self.post_num * self.prob) + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) else: - pre_num_to_select = pre_num - pre_ids = bm.arange(pre_num) + rng = self.rng - @jax.vmap - def f(i, key): - posts = bm.delete(post_ids, i) if not self.include_self else post_ids - return self.rng.permutation(posts, key=key)[:post_num_to_select] + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - selected_post_ids = f(pre_ids, self.rng.split_keys(pre_ids.size)).flatten() + else: + @numba_jit#(parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=np.int32) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = single_conn() return pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids - def build_coo(self, pre_size=None, post_size=None): - _, post_num_to_select, selected_post_ids, pre_ids = self._iii(pre_size, post_size) - selected_pre_ids = bm.repeat(pre_ids, post_num_to_select) + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._iii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = np.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) - def build_csr(self, pre_size=None, post_size=None): - pre_num_to_select, post_num_to_select, selected_post_ids, _ = self._iii(pre_size, post_size) - selected_pre_inptr = bm.cumsum(bm.concatenate([bm.zeros(1), bm.ones(pre_num_to_select) * post_num_to_select])) + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() + pre_nums = np.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == np.reshape(pre_ids, (-1, 1)) + pre_nums -= np.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[np.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = np.cumsum(np.concatenate([np.zeros(1), pre_nums])) return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) + def build_mat(self): + pre_state = self.rng.rand(self.pre_num, 1) < self.pre_ratio + mat = (self.rng.rand(self.pre_num, self.post_num) < self.prob) * pre_state + if not self.include_self: + np.fill_diagonal(mat, False) + return mat.astype(MAT_DTYPE) + class FixedNum(TwoEndConnector): - def __init__(self, num, include_self=True, seed=None): + def __init__(self, num, include_self=True, allow_multi_conn=False, seed=None): super(FixedNum, self).__init__() if isinstance(num, int): assert num >= 0, '"num" must be a non-negative integer.' @@ -110,7 +134,8 @@ def __init__(self, num, include_self=True, seed=None): self.num = num self.seed = format_seed(seed) self.include_self = include_self - self.rng = bm.random.RandomState(seed=self.seed) + self.allow_multi_conn = allow_multi_conn + self.rng = np.random.RandomState(seed=self.seed) def __repr__(self): return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' @@ -128,28 +153,49 @@ class FixedPreNum(FixedNum): Whether create (i, i) conn ? seed : None, int Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + """ - def build_coo(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if isinstance(self.num, int) and self.num > pre_num: + def build_coo(self): + if isinstance(self.num, int) and self.num > self.pre_num: raise ConnectorError(f'"num" must be smaller than "pre_num", ' - f'but got {self.num} > {pre_num}') - if (not self.include_self) and (pre_num != post_num): - raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + f'but got {self.num} > {self.pre_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' f'But `include_self` is set to True.') - pre_num_to_select = int(pre_num * self.num) if isinstance(self.num, float) else self.num - pre_ids = bm.arange(pre_num) - post_ids = bm.arange(post_num) + pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num + pre_num_total = self.pre_num + post_num_total = self.post_num + seed = self.seed + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select, )) + else: + @numba_jit#(parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((post_num_total, pre_num_to_select), dtype=np.int32) + for i in numba_range(post_num_total): + posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) + return posts - @jax.vmap - def f(post_i, key): - pres = bm.delete(pre_ids, post_i) if not self.include_self else pre_ids - return self.rng.permutation(pres, key=key)[:pre_num_to_select] + selected_pre_ids = single_conn() - selected_pre_ids = f(post_ids, self.rng.split_keys(post_num)).flatten() - selected_post_ids = bm.repeat(post_ids, pre_num_to_select) + post_nums = np.ones((post_num_total,), dtype=np.int32) * pre_num_to_select + if not self.include_self: + true_ids = selected_pre_ids == np.reshape(np.arange(pre_num_total), (-1, 1)) + post_nums -= np.sum(true_ids, axis=1) + selected_pre_ids = selected_pre_ids.flatten()[np.logical_not(true_ids).flatten()] + else: + selected_pre_ids = selected_pre_ids.flatten() + selected_post_ids = np.repeat(np.arange(post_num_total), post_nums) return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) @@ -165,39 +211,67 @@ class FixedPostNum(FixedNum): Whether create (i, i) conn ? seed : None, int Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + """ - def _ii(self, pre_size, post_size): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if isinstance(self.num, int) and self.num > post_num: + def _ii(self): + if isinstance(self.num, int) and self.num > self.post_num: raise ConnectorError(f'"num" must be smaller than "post_num", ' - f'but got {self.num} > {post_num}') - if (not self.include_self) and (pre_num != post_num): - raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + f'but got {self.num} > {self.post_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' f'But `include_self` is set to True.') - post_num_to_select = int(post_num * self.num) if isinstance(self.num, float) else self.num - pre_ids = bm.arange(pre_num) - post_ids = bm.arange(post_num) - - @jax.vmap - def f(pre_i, key): - posts = bm.delete(post_ids, pre_i) if not self.include_self else post_ids - return self.rng.permutation(posts, key=key)[:post_num_to_select] + post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num + pre_num_to_select = self.pre_num + pre_ids = np.arange(self.pre_num) + + post_num_total = self.post_num + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng - selected_post_ids = f(pre_ids, self.rng.split_keys(pre_num)).flatten() - return pre_num, post_num_to_select, selected_post_ids + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - def build_csr(self, pre_size=None, post_size=None): - pre_num, post_num, selected_post_ids = self._ii(pre_size, post_size) - selected_pre_inptr = bm.cumsum(bm.concatenate([bm.zeros(1), bm.ones(pre_num) * post_num])) - return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) + else: + @numba_jit#(parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=np.int32) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = single_conn() + return pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids - def build_coo(self, pre_size=None, post_size=None): - pre_num, post_num, selected_post_ids = self._ii(pre_size, post_size) - selected_pre_ids = bm.repeat(bm.arange(pre_num), post_num) + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._ii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = np.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() + pre_nums = np.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == np.reshape(pre_ids, (-1, 1)) + pre_nums -= np.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[np.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = np.cumsum(np.concatenate([np.zeros(1), pre_nums])) + return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) + class GaussianProb(OneEndConnector): r"""Builds a Gaussian connectivity pattern within a population of neurons, @@ -253,7 +327,7 @@ def __init__( self.include_self = include_self self.periodic_boundary = periodic_boundary self.seed = format_seed(seed) - self.rng = bm.random.RandomState(self.seed) + self.rng = np.random.RandomState(self.seed) def __repr__(self): return (f'{self.__class__.__name__}(sigma={self.sigma}, ' @@ -318,9 +392,9 @@ def build_mat(self, pre_size=None, post_size=None): prob_mat /= prob_mat.max() # connectivity - conn_mat = bm.asarray(prob_mat) >= self.rng.random(prob_mat.shape) + conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) if not self.include_self: - bm.fill_diagonal(conn_mat, False) + np.fill_diagonal(conn_mat, False) return conn_mat diff --git a/brainpy/connect/regular_conn.py b/brainpy/connect/regular_conn.py index 0d3286bb9..c40758c08 100644 --- a/brainpy/connect/regular_conn.py +++ b/brainpy/connect/regular_conn.py @@ -6,9 +6,7 @@ from brainpy import math as bm from brainpy.errors import ConnectorError -from brainpy.tools import size2num from .base import * -from .utils import * __all__ = [ 'One2One', 'one2one', @@ -36,31 +34,25 @@ def __call__(self, pre_size, post_size): f'same size, but {self.pre_num} != {self.post_num}.') return self - def build_coo(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num != post_num: + def build_coo(self): + if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {pre_num} != {post_num}.') - return bm.arange(pre_num, dtype=IDX_DTYPE), bm.arange(post_num, dtype=IDX_DTYPE), + f'same size, but {self.pre_num} != {self.post_num}.') + return bm.arange(self.pre_num, dtype=IDX_DTYPE), bm.arange(self.post_num, dtype=IDX_DTYPE), - def build_csr(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num != post_num: + def build_csr(self): + if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {pre_num} != {post_num}.') - ind = bm.arange(pre_num) - indptr = np.arange(pre_num + 1) + f'same size, but {self.pre_num} != {self.post_num}.') + ind = bm.arange(self.pre_num) + indptr = np.arange(self.pre_num + 1) return bm.asarray(ind, dtype=IDX_DTYPE), bm.arange(indptr, dtype=IDX_DTYPE), def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - if pre_num != post_num: + if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {pre_num} != {post_num}.') - return bm.fill_diagonal(bm.zeros((pre_num, post_num), dtype=MAT_DTYPE), True) + f'same size, but {self.pre_num} != {self.post_num}.') + return bm.fill_diagonal(bm.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE), True) one2one = One2One() @@ -79,10 +71,8 @@ def __init__(self, include_self=True): def __repr__(self): return f'{self.__class__.__name__}(include_self={self.include_self})' - def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(self, pre_size) - post_num = get_post_num(self, post_size) - mat = bm.ones((pre_num, post_num), dtype=MAT_DTYPE) + def build_mat(self): + mat = bm.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE) if not self.include_self: bm.fill_diagonal(mat, False) return mat @@ -117,22 +107,18 @@ def __init__( def __repr__(self): return f'{self.__class__.__name__}(include_self={self.include_self}, periodic_boundary={self.periodic_boundary})' - def _format(self, pre_size, post_size): - pre_size = get_pre_size(self, pre_size) - post_size = get_post_size(self, post_size) - pre_num = size2num(pre_size) - post_num = size2num(post_size) - dim = len(post_size) - if pre_num != post_num: + def _format(self): + dim = len(self.post_size) + if self.pre_num != self.post_num: raise ConnectorError(f'{self.__class__.__name__} is used to for connection within ' f'a same population. But we detect pre_num != post_num ' - f'({pre_num} != {post_num}).') + f'({self.pre_num} != {self.post_num}).') # point indices - indices = bm.meshgrid(*(bm.arange(size) for size in post_size), indexing='ij') + indices = bm.meshgrid(*(bm.arange(size) for size in self.post_size), indexing='ij') indices = bm.asarray(indices) - indices = indices.reshape(dim, post_num).T - lengths = bm.asarray(post_size) - return lengths, post_size, dim, indices + indices = indices.reshape(dim, self.post_num).T + lengths = bm.asarray(self.post_size) + return lengths, dim, indices def _get_strides(self, dim): # increments @@ -147,8 +133,8 @@ def _select_stride(self, stride: np.ndarray) -> np.ndarray: def _select_dist(self, dist: bm.ndarray) -> bm.ndarray: raise NotImplementedError - def build_mat(self, pre_size=None, post_size=None): - sizes, post_size, _, indices = self._format(pre_size, post_size) + def build_mat(self): + sizes, _, indices = self._format() @jax.vmap def f_connect(pre_id): @@ -160,8 +146,8 @@ def f_connect(pre_id): return bm.asarray(f_connect(indices), dtype=MAT_DTYPE) - def build_coo(self, pre_size=None, post_size=None): - sizes, post_size, dim, indices = self._format(pre_size, post_size) + def build_coo(self): + sizes, dim, indices = self._format() strides = self._get_strides(dim) @jax.vmap @@ -186,7 +172,7 @@ def f_connect(pre_id): pres = pres.flatten() posts = posts.flatten() else: - strides = bm.asarray(get_size_length(post_size)) + strides = bm.asarray(get_size_length(self.post_size)) pres = bm.sum(pres * strides, axis=1) posts = bm.sum(posts * strides, axis=1) return bm.asarray(pres, dtype=IDX_DTYPE), bm.asarray(posts, dtype=IDX_DTYPE) diff --git a/brainpy/connect/tests/test_random_conn.py b/brainpy/connect/tests/test_random_conn.py index 3df5185ea..8744be01b 100644 --- a/brainpy/connect/tests/test_random_conn.py +++ b/brainpy/connect/tests/test_random_conn.py @@ -18,12 +18,13 @@ def test_size_consistent(self): def test_require_method(self): conn2 = bp.connect.FixedProb(prob=0.1, seed=123) conn2(pre_size=(10, 20), post_size=(10, 20)) - mat = conn2.require(100, 1000, bp.connect.CONN_MAT) - self.assertTrue(mat.shape == (100, 1000)) - mat = conn2.require(bp.connect.CONN_MAT) self.assertTrue(mat.shape == (200, 200)) + mat = conn2(100, 1000).require(bp.connect.CONN_MAT) + self.assertTrue(mat.shape == (100, 1000)) + + def test_random_fix_pre1(): for num in [0.4, 20]: @@ -34,8 +35,11 @@ def test_random_fix_pre1(): mat2 = conn2.require(bp.connect.CONN_MAT) print() + print(f'num = {num}') print('conn_mat 1\n', mat1) + print(mat1.sum()) print('conn_mat 2\n', mat2) + print(mat2.sum()) assert bp.math.array_equal(mat1, mat2) diff --git a/brainpy/connect/tests/test_regular_conn.py b/brainpy/connect/tests/test_regular_conn.py index b2a4fd41e..4fe2ab85d 100644 --- a/brainpy/connect/tests/test_regular_conn.py +++ b/brainpy/connect/tests/test_regular_conn.py @@ -65,9 +65,9 @@ def test_grid_four(self): for include_self in [True, False]: for size in (10, [10, 10], (4, 4, 5)): conn = bp.conn.GridFour(include_self=include_self, - periodic_boundary=periodic_boundary) - mat = conn.build_mat(size, size) - pre_ids, post_ids = conn.build_coo(size, size) + periodic_boundary=periodic_boundary)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) new_mat[pre_ids, post_ids] = True @@ -79,9 +79,9 @@ def test_grid_eight(self): for include_self in [True, False]: for size in (10, [10, 10], (4, 4, 5)): conn = bp.conn.GridEight(include_self=include_self, - periodic_boundary=periodic_boundary) - mat = conn.build_mat(size, size) - pre_ids, post_ids = conn.build_coo(size, size) + periodic_boundary=periodic_boundary)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) new_mat[pre_ids, post_ids] = True @@ -94,9 +94,9 @@ def test_grid_N(self): for size in (10, [10, 10], (4, 4, 5)): conn = bp.conn.GridN(include_self=include_self, periodic_boundary=periodic_boundary, - N=2) - mat = conn.build_mat(size, size) - pre_ids, post_ids = conn.build_coo(size, size) + N=2)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) new_mat[pre_ids, post_ids] = True diff --git a/brainpy/connect/utils.py b/brainpy/connect/utils.py deleted file mode 100644 index ca7a8e8e9..000000000 --- a/brainpy/connect/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- - - -from brainpy.errors import ConnectorError -from brainpy.tools import size2num - -__all__ = [ - 'get_pre_num', 'get_post_num', - 'get_pre_size', 'get_post_size', -] - - -def get_pre_size(obj, pre_size=None): - if pre_size is None: - if obj.pre_size is None: - raise ConnectorError('Please provide "pre_size" and "post_size"') - else: - return obj.pre_size - else: - return (pre_size, ) if isinstance(pre_size, int) else pre_size - - -def get_pre_num(obj, pre_size=None): - return size2num(get_pre_size(obj, pre_size)) - - -def get_post_size(obj, post_size=None): - if post_size is None: - if obj.post_size is None: - raise ConnectorError('Please provide "pre_size" and "post_size"') - else: - return obj.post_size - else: - return (post_size,) if isinstance(post_size, int) else post_size - - -def get_post_num(obj, post_size=None): - return size2num(get_post_size(obj, post_size)) diff --git a/brainpy/tools/others/numba_util.py b/brainpy/tools/others/numba_util.py index db01c27f2..ce2fe370c 100644 --- a/brainpy/tools/others/numba_util.py +++ b/brainpy/tools/others/numba_util.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- - - +import numba import numpy as np try: from numba import njit @@ -11,6 +10,7 @@ __all__ = [ 'numba_jit', 'numba_seed', + 'numba_range', 'SUPPORT_NUMBA', ] @@ -38,3 +38,4 @@ def numba_seed(seed): _seed(seed) +numba_range = numba.prange if SUPPORT_NUMBA else range From cf5050ecf08e87468070c7273f0623d95e955d56 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 17 Oct 2022 23:34:26 +0800 Subject: [PATCH 4/7] upgrade op register --- brainpy/math/operators/op_register.py | 4 +- extensions/brainpylib/custom_op/cpu.py | 28 +++--- extensions/brainpylib/custom_op/cuda.py | 84 ---------------- extensions/brainpylib/custom_op/gpu.py | 104 ++++++++++++++++++-- extensions/brainpylib/custom_op/regis_op.py | 11 ++- 5 files changed, 122 insertions(+), 109 deletions(-) delete mode 100644 extensions/brainpylib/custom_op/cuda.py diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index 1ac3b2eb0..838772b93 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -128,9 +128,9 @@ def register_op( out_shapes=eval_shape, apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) - def fixed_op(*inputs): + def fixed_op(*inputs, **info): inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs]) - res = f.bind(*inputs) + res = f.bind(*inputs, **info) return res[0] if len(res) == 1 else res return fixed_op diff --git a/extensions/brainpylib/custom_op/cpu.py b/extensions/brainpylib/custom_op/cpu.py index 7a002499d..3eee81207 100644 --- a/extensions/brainpylib/custom_op/cpu.py +++ b/extensions/brainpylib/custom_op/cpu.py @@ -2,10 +2,11 @@ import ctypes -import numba +import numpy as np from jax.abstract_arrays import ShapedArray from jax.lib import xla_client -from numba import types +from jax import dtypes +from numba import types, carray, cfunc _lambda_no = 0 ctypes.pythonapi.PyCapsule_New.argtypes = [ @@ -17,14 +18,14 @@ def _compile_cpu_signature(func, input_dtypes, input_shapes, - output_dtypes, output_shapes): + output_dtypes, output_shapes, debug=True): code_scope = dict( func_to_call=func, input_shapes=input_shapes, input_dtypes=input_dtypes, output_shapes=output_shapes, output_dtypes=output_dtypes, - carray=numba.carray, + carray=carray, ) args_in = [ @@ -47,13 +48,12 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): func_to_call(args_out, args_in) '''.format(args_in=",\n ".join(args_in), args_out=",\n ".join(args_out)) - # print(code_string) + if debug: print(code_string) exec(compile(code_string.strip(), '', 'exec'), code_scope) new_f = code_scope['xla_cpu_custom_call_target'] - wrapper = numba.cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr))) - xla_c_rule = wrapper(new_f) + xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), + types.CPointer(types.voidptr)))(new_f) target_name = xla_c_rule.native_name.encode("ascii") capsule = ctypes.pythonapi.PyCapsule_New( xla_c_rule.address, # A CFFI pointer to a function @@ -64,12 +64,16 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): return target_name -def func_cpu_translation(func, abs_eval_fn, c, *inputs): +def func_cpu_translation(func, abs_eval_fn, c, *inputs, **info): input_shapes = [c.get_shape(arg) for arg in inputs] + for v in info.values(): + if not isinstance(v, (int, float)): + raise TypeError + input_shapes.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) + input_shapes = tuple(input_shapes) input_dtypes = tuple(shape.element_type() for shape in input_shapes) input_dimensions = tuple(shape.dimensions() for shape in input_shapes) - output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_shapes)) + output_abstract_arrays = abs_eval_fn(*input_shapes[:len(inputs)], **info) output_shapes = tuple(array.shape for array in output_abstract_arrays) output_dtypes = tuple(array.dtype for array in output_abstract_arrays) output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) @@ -83,7 +87,7 @@ def func_cpu_translation(func, abs_eval_fn, c, *inputs): return xla_client.ops.CustomCallWithLayout( c, target_name, - operands=inputs, + operands=inputs + tuple(xla_client.ops.ConstantLiteral(c, i) for i in info.values()), operand_shapes_with_layout=input_shapes, shape_with_layout=xla_output_shape, ) diff --git a/extensions/brainpylib/custom_op/cuda.py b/extensions/brainpylib/custom_op/cuda.py deleted file mode 100644 index 4b66349aa..000000000 --- a/extensions/brainpylib/custom_op/cuda.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- - -import ctypes -import ctypes.util -import sys - -from cffi import FFI -from numba import cuda -from numba import types - - -class Dl_info(ctypes.Structure): - """ - Structure of the Dl_info returned by the CFFI of dl.dladdr - """ - - _fields_ = ( - ("dli_fname", ctypes.c_char_p), - ("dli_fbase", ctypes.c_void_p), - ("dli_sname", ctypes.c_char_p), - ("dli_saddr", ctypes.c_void_p), - ) - - -# Find the dynamic linker library path. Only works on unix-like os -libdl_path = ctypes.util.find_library("dl") -if libdl_path: - # Load the dynamic linker dynamically - libdl = ctypes.CDLL(libdl_path) - - # Define dladdr to get the pointer to a symbol in a shared - # library already loaded. - # https://man7.org/linux/man-pages/man3/dladdr.3.html - libdl.dladdr.argtypes = (ctypes.c_void_p, ctypes.POINTER(Dl_info)) - # restype is None as it returns by reference -else: - # On Windows it is nontrivial to have libdl, so we disable everything about - # it and use other ways to find paths of libraries - libdl = None - - -def find_path_of_symbol_in_library(symbol): - if libdl is None: - raise ValueError("libdl not found.") - - info = Dl_info() - result = libdl.dladdr(symbol, ctypes.byref(info)) - if result and info.dli_fname: - return info.dli_fname.decode(sys.getfilesystemencoding()) - else: - raise ValueError("Cannot determine path of Library.") - - -try: - _libcuda = cuda.driver.find_driver() - if sys.platform == "win32": - libcuda_path = ctypes.util.find_library(_libcuda._name) - else: - libcuda_path = find_path_of_symbol_in_library(_libcuda.cuMemcpy) - numba_cffi_loaded = True -except Exception: - numba_cffi_loaded = False - - -if numba_cffi_loaded: - # functions needed - ffi = FFI() - ffi.cdef("int cuMemcpy(void* dst, void* src, unsigned int len, int type);") - ffi.cdef("int cuMemcpyAsync(void* dst, void* src, unsigned int len, int type, void* stream);") - ffi.cdef("int cuStreamSynchronize(void* stream);") - ffi.cdef("int cudaMallocHost(void** ptr, size_t size);") - ffi.cdef("int cudaFreeHost(void* ptr);") - - # load libraray - # could ncuda.driver.find_library() - libcuda = ffi.dlopen(libcuda_path) - cuMemcpy = libcuda.cuMemcpy - cuMemcpyAsync = libcuda.cuMemcpyAsync - cuStreamSynchronize = libcuda.cuStreamSynchronize - - memcpyHostToHost = types.int32(0) - memcpyHostToDevice = types.int32(1) - memcpyDeviceToHost = types.int32(2) - memcpyDeviceToDevice = types.int32(3) diff --git a/extensions/brainpylib/custom_op/gpu.py b/extensions/brainpylib/custom_op/gpu.py index fd833d985..150272dce 100644 --- a/extensions/brainpylib/custom_op/gpu.py +++ b/extensions/brainpylib/custom_op/gpu.py @@ -1,11 +1,91 @@ # -*- coding: utf-8 -*- -import numba +import ctypes +import ctypes.util +import sys + import numpy as np +from cffi import FFI from jax.abstract_arrays import ShapedArray from jax.lib import xla_client +from jax import dtypes +from numba import cuda, cfunc, types + + +class Dl_info(ctypes.Structure): + """ + Structure of the Dl_info returned by the CFFI of dl.dladdr + """ + + _fields_ = ( + ("dli_fname", ctypes.c_char_p), + ("dli_fbase", ctypes.c_void_p), + ("dli_sname", ctypes.c_char_p), + ("dli_saddr", ctypes.c_void_p), + ) + + +# Find the dynamic linker library path. Only works on unix-like os +libdl_path = ctypes.util.find_library("dl") +if libdl_path: + # Load the dynamic linker dynamically + libdl = ctypes.CDLL(libdl_path) + + # Define dladdr to get the pointer to a symbol in a shared + # library already loaded. + # https://man7.org/linux/man-pages/man3/dladdr.3.html + libdl.dladdr.argtypes = (ctypes.c_void_p, ctypes.POINTER(Dl_info)) + # restype is None as it returns by reference +else: + # On Windows it is nontrivial to have libdl, so we disable everything about + # it and use other ways to find paths of libraries + libdl = None + + +def find_path_of_symbol_in_library(symbol): + if libdl is None: + raise ValueError("libdl not found.") + + info = Dl_info() + result = libdl.dladdr(symbol, ctypes.byref(info)) + if result and info.dli_fname: + return info.dli_fname.decode(sys.getfilesystemencoding()) + else: + raise ValueError("Cannot determine path of Library.") + + +try: + _libcuda = cuda.driver.find_driver() + if sys.platform == "win32": + libcuda_path = ctypes.util.find_library(_libcuda._name) + else: + libcuda_path = find_path_of_symbol_in_library(_libcuda.cuMemcpy) + numba_cffi_loaded = True +except Exception: + numba_cffi_loaded = False + + +if numba_cffi_loaded: + # functions needed + ffi = FFI() + ffi.cdef("int cuMemcpy(void* dst, void* src, unsigned int len, int type);") + ffi.cdef("int cuMemcpyAsync(void* dst, void* src, unsigned int len, int type, void* stream);") + ffi.cdef("int cuStreamSynchronize(void* stream);") + ffi.cdef("int cudaMallocHost(void** ptr, size_t size);") + ffi.cdef("int cudaFreeHost(void* ptr);") + + # load libraray + # could ncuda.driver.find_library() + libcuda = ffi.dlopen(libcuda_path) + cuMemcpy = libcuda.cuMemcpy + cuMemcpyAsync = libcuda.cuMemcpyAsync + cuStreamSynchronize = libcuda.cuStreamSynchronize + + memcpyHostToHost = types.int32(0) + memcpyHostToDevice = types.int32(1) + memcpyDeviceToHost = types.int32(2) + memcpyDeviceToDevice = types.int32(3) -from .cuda import * _lambda_no = 0 ctypes.pythonapi.PyCapsule_New.argtypes = [ @@ -83,7 +163,7 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len): exec(compile(code_string.strip(), '', 'exec'), code_scope) new_f = code_scope['xla_gpu_custom_call_target'] - wrapper = numba.cfunc(types.void( + wrapper = cfunc(types.void( types.voidptr, types.CPointer(types.voidptr), types.voidptr, types.uint64)) @@ -98,15 +178,24 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len): return target_name -def func_gpu_translation(func, abs_eval_fn, c, *inputs): +def func_gpu_translation(func, abs_eval_fn, c, *inputs, **info): if not numba_cffi_loaded: raise RuntimeError("Numba cffi could not be loaded.") input_shapes = [c.get_shape(arg) for arg in inputs] + for v in info.values(): + input_shapes.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) + # if isinstance(v, int): + # input_shapes.append(xla_client.Shape.array_shape(np.dtype(np.int_), (), ())) + # elif isinstance(v, float): + # input_shapes.append(xla_client.Shape.array_shape(np.dtype(np.float_), (), ())) + # else: + # raise TypeError input_dtypes = tuple(shape.element_type() for shape in input_shapes) input_dimensions = tuple(shape.dimensions() for shape in input_shapes) output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_shapes)) + for shape in input_shapes[:len(inputs)]), + **info) output_shapes = tuple(array.shape for array in output_abstract_arrays) output_dtypes = tuple(array.dtype for array in output_abstract_arrays) output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) @@ -120,7 +209,10 @@ def func_gpu_translation(func, abs_eval_fn, c, *inputs): return xla_client.ops.CustomCallWithLayout( c, target_name, - operands=inputs, + operands=inputs + tuple(xla_client.ops.ConstantLiteral(c, i) for i in info.values()), operand_shapes_with_layout=input_shapes, shape_with_layout=xla_output_shape, ) + + + diff --git a/extensions/brainpylib/custom_op/regis_op.py b/extensions/brainpylib/custom_op/regis_op.py index cfc09ca6e..4acc5ce01 100644 --- a/extensions/brainpylib/custom_op/regis_op.py +++ b/extensions/brainpylib/custom_op/regis_op.py @@ -71,9 +71,9 @@ def register_op( cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) # output shape evaluation function - def abs_eval_rule(*input_shapes): + def abs_eval_rule(*input_shapes, **info): if callable(out_shapes): - shapes = out_shapes(*input_shapes) + shapes = out_shapes(*input_shapes, **info) elif isinstance(out_shapes, ShapedArray): shapes = [out_shapes] elif isinstance(out_shapes, (tuple, list)): @@ -95,17 +95,18 @@ def abs_eval_rule(*input_shapes): return shapes # output evaluation function - def eval_rule(*inputs): + def eval_rule(*inputs, **info): # compute the output shapes - output_shapes = abs_eval_rule(*inputs) + output_shapes = abs_eval_rule(*inputs, **info) # Preallocate the outputs outputs = tuple(np.zeros(shape.shape, dtype=shape.dtype) for shape in output_shapes) # convert inputs to a tuple inputs = tuple(np.asarray(arg) for arg in inputs) + inputs += tuple(info.values()) # call the kernel cpu_func(outputs, inputs) # Return the outputs - return tuple(outputs) + return outputs[0] if len(outputs) == 1 else tuple(outputs) # cpu function prim.def_abstract_eval(abs_eval_rule) From 07533d9987977912b426069fb989bd448a645be3 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 17 Oct 2022 23:49:43 +0800 Subject: [PATCH 5/7] upgrade op register --- extensions/brainpylib/custom_op/cpu.py | 14 +++--- extensions/brainpylib/custom_op/gpu.py | 14 ++---- extensions/brainpylib/custom_op/tests/a.py | 51 ++++++++++++++++++++++ 3 files changed, 63 insertions(+), 16 deletions(-) create mode 100644 extensions/brainpylib/custom_op/tests/a.py diff --git a/extensions/brainpylib/custom_op/cpu.py b/extensions/brainpylib/custom_op/cpu.py index 3eee81207..0a397b8df 100644 --- a/extensions/brainpylib/custom_op/cpu.py +++ b/extensions/brainpylib/custom_op/cpu.py @@ -18,7 +18,7 @@ def _compile_cpu_signature(func, input_dtypes, input_shapes, - output_dtypes, output_shapes, debug=True): + output_dtypes, output_shapes, debug=False): code_scope = dict( func_to_call=func, input_shapes=input_shapes, @@ -29,11 +29,11 @@ def _compile_cpu_signature(func, input_dtypes, input_shapes, ) args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' for i in range(len(input_shapes)) ] args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' for i in range(len(output_shapes)) ] @@ -46,8 +46,8 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): {args_in} ) func_to_call(args_out, args_in) - '''.format(args_in=",\n ".join(args_in), - args_out=",\n ".join(args_out)) + '''.format(args_in="\n ".join(args_in), + args_out="\n ".join(args_out)) if debug: print(code_string) exec(compile(code_string.strip(), '', 'exec'), code_scope) @@ -73,7 +73,9 @@ def func_cpu_translation(func, abs_eval_fn, c, *inputs, **info): input_shapes = tuple(input_shapes) input_dtypes = tuple(shape.element_type() for shape in input_shapes) input_dimensions = tuple(shape.dimensions() for shape in input_shapes) - output_abstract_arrays = abs_eval_fn(*input_shapes[:len(inputs)], **info) + output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) + for shape in input_shapes[:len(inputs)]), + **info) output_shapes = tuple(array.shape for array in output_abstract_arrays) output_dtypes = tuple(array.dtype for array in output_abstract_arrays) output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) diff --git a/extensions/brainpylib/custom_op/gpu.py b/extensions/brainpylib/custom_op/gpu.py index 150272dce..dabaff90e 100644 --- a/extensions/brainpylib/custom_op/gpu.py +++ b/extensions/brainpylib/custom_op/gpu.py @@ -126,7 +126,7 @@ def _compile_gpu_signature(func, input_dtypes, input_shapes, ) args_in = [ - f'empty(input_shapes[{i}], dtype=input_dtypes[{i}])' + f'empty(input_shapes[{i}], dtype=input_dtypes[{i}]),' for i in range(len(input_shapes)) ] cuMemcpyAsync_in = [ @@ -134,7 +134,7 @@ def _compile_gpu_signature(func, input_dtypes, input_shapes, for i in range(len(input_shapes)) ] args_out = [ - f'empty(output_shapes[{i}], dtype=output_dtypes[{i}])' + f'empty(output_shapes[{i}], dtype=output_dtypes[{i}]),' for i in range(len(output_shapes)) ] cuMemcpyAsync_out = [ @@ -155,8 +155,8 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len): cuStreamSynchronize(stream) func_to_call(args_out, args_in) {cuMemcpyAsync_out} - '''.format(args_in=",\n ".join(args_in), - args_out=",\n ".join(args_out), + '''.format(args_in="\n ".join(args_in), + args_out="\n ".join(args_out), cuMemcpyAsync_in="\n ".join(cuMemcpyAsync_in), cuMemcpyAsync_out="\n ".join(cuMemcpyAsync_out)) # print(code_string) @@ -185,12 +185,6 @@ def func_gpu_translation(func, abs_eval_fn, c, *inputs, **info): input_shapes = [c.get_shape(arg) for arg in inputs] for v in info.values(): input_shapes.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) - # if isinstance(v, int): - # input_shapes.append(xla_client.Shape.array_shape(np.dtype(np.int_), (), ())) - # elif isinstance(v, float): - # input_shapes.append(xla_client.Shape.array_shape(np.dtype(np.float_), (), ())) - # else: - # raise TypeError input_dtypes = tuple(shape.element_type() for shape in input_shapes) input_dimensions = tuple(shape.dimensions() for shape in input_shapes) output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) diff --git a/extensions/brainpylib/custom_op/tests/a.py b/extensions/brainpylib/custom_op/tests/a.py new file mode 100644 index 000000000..a0ac11da2 --- /dev/null +++ b/extensions/brainpylib/custom_op/tests/a.py @@ -0,0 +1,51 @@ +import brainpy.math as bm +import brainpy as bp +from jax.abstract_arrays import ShapedArray + + +def try1(): + def abs_eval(events, indices, indptr, *, weight, post_num): + return ShapedArray((post_num,), bm.float64) + + def con_compute(outs, ins): + post_val, = outs + post_val.fill(0) + events, indices, indptr, weight, _ = ins + weight = weight[()] + for i in range(events.size): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + index = indices[j] + post_val[index] += weight + + event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) + + events = bm.random.rand(10) < 0.2 + indices, indptr = bp.conn.FixedProb(0.1)(10, 20).require('pre2post') + print(bm.jit(event_sum, static_argnames=('weight', 'post_num'))(events, indices, indptr, weight=1., post_num=20)) + + +def try2(): + def abs_eval(events, indices, indptr, post_val, weight): + return post_val + + def con_compute(outs, ins): + post_val, = outs + events, indices, indptr, _, weight = ins + weight = weight[()] + for i in range(events.size): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + index = indices[j] + post_val[index] += weight + + event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) + + events = bm.random.rand(10) < 0.2 + indices, indptr = bp.conn.FixedProb(0.1)(10, 20).require('pre2post') + print(bm.jit(event_sum)(events, indices, indptr, bm.zeros(20), 1.)) + + +if __name__ == '__main__': + try1() + # try2() From a2c399ba91a9dd3f1efee0772ef18ad546bc8012 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 18 Oct 2022 13:46:15 +0800 Subject: [PATCH 6/7] customizing operator using numba and hand-written ops --- brainpy/connect/regular_conn.py | 12 +-- extensions/brainpylib/custom_op/cpu.py | 21 +++-- extensions/brainpylib/custom_op/gpu.py | 23 +++--- extensions/brainpylib/custom_op/tests/a.py | 23 +++--- .../brainpylib/custom_op/tests/ei_net.py | 80 +++++++++++++++++++ 5 files changed, 127 insertions(+), 32 deletions(-) create mode 100644 extensions/brainpylib/custom_op/tests/ei_net.py diff --git a/brainpy/connect/regular_conn.py b/brainpy/connect/regular_conn.py index c40758c08..253f00998 100644 --- a/brainpy/connect/regular_conn.py +++ b/brainpy/connect/regular_conn.py @@ -38,21 +38,21 @@ def build_coo(self): if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' f'same size, but {self.pre_num} != {self.post_num}.') - return bm.arange(self.pre_num, dtype=IDX_DTYPE), bm.arange(self.post_num, dtype=IDX_DTYPE), + return np.arange(self.pre_num, dtype=IDX_DTYPE), np.arange(self.post_num, dtype=IDX_DTYPE), def build_csr(self): if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' f'same size, but {self.pre_num} != {self.post_num}.') - ind = bm.arange(self.pre_num) + ind = np.arange(self.pre_num) indptr = np.arange(self.pre_num + 1) - return bm.asarray(ind, dtype=IDX_DTYPE), bm.arange(indptr, dtype=IDX_DTYPE), + return np.asarray(ind, dtype=IDX_DTYPE), np.arange(indptr, dtype=IDX_DTYPE), def build_mat(self, pre_size=None, post_size=None): if self.pre_num != self.post_num: raise ConnectorError(f'One2One connection must be defined in two groups with the ' f'same size, but {self.pre_num} != {self.post_num}.') - return bm.fill_diagonal(bm.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE), True) + return np.fill_diagonal(np.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE), True) one2one = One2One() @@ -72,9 +72,9 @@ def __repr__(self): return f'{self.__class__.__name__}(include_self={self.include_self})' def build_mat(self): - mat = bm.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE) + mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE) if not self.include_self: - bm.fill_diagonal(mat, False) + np.fill_diagonal(mat, False) return mat diff --git a/extensions/brainpylib/custom_op/cpu.py b/extensions/brainpylib/custom_op/cpu.py index 0a397b8df..a32f1b0a4 100644 --- a/extensions/brainpylib/custom_op/cpu.py +++ b/extensions/brainpylib/custom_op/cpu.py @@ -2,10 +2,9 @@ import ctypes -import numpy as np +from jax import dtypes from jax.abstract_arrays import ShapedArray from jax.lib import xla_client -from jax import dtypes from numba import types, carray, cfunc _lambda_no = 0 @@ -17,8 +16,14 @@ ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object -def _compile_cpu_signature(func, input_dtypes, input_shapes, - output_dtypes, output_shapes, debug=False): +def _compile_cpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + debug=False +): code_scope = dict( func_to_call=func, input_shapes=input_shapes, @@ -53,7 +58,7 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): new_f = code_scope['xla_cpu_custom_call_target'] xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr)))(new_f) + types.CPointer(types.voidptr)))(new_f) target_name = xla_c_rule.native_name.encode("ascii") capsule = ctypes.pythonapi.PyCapsule_New( xla_c_rule.address, # A CFFI pointer to a function @@ -83,8 +88,10 @@ def func_cpu_translation(func, abs_eval_fn, c, *inputs, **info): for arg in zip(output_dtypes, output_shapes, output_layouts)] xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes) target_name = _compile_cpu_signature(func, - input_dtypes, input_dimensions, - output_dtypes, output_shapes) + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes) return xla_client.ops.CustomCallWithLayout( c, diff --git a/extensions/brainpylib/custom_op/gpu.py b/extensions/brainpylib/custom_op/gpu.py index dabaff90e..83839ead9 100644 --- a/extensions/brainpylib/custom_op/gpu.py +++ b/extensions/brainpylib/custom_op/gpu.py @@ -64,7 +64,6 @@ def find_path_of_symbol_in_library(symbol): except Exception: numba_cffi_loaded = False - if numba_cffi_loaded: # functions needed ffi = FFI() @@ -86,7 +85,6 @@ def find_path_of_symbol_in_library(symbol): memcpyDeviceToHost = types.int32(2) memcpyDeviceToDevice = types.int32(3) - _lambda_no = 0 ctypes.pythonapi.PyCapsule_New.argtypes = [ ctypes.c_void_p, # void* pointer @@ -96,8 +94,14 @@ def find_path_of_symbol_in_library(symbol): ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object -def _compile_gpu_signature(func, input_dtypes, input_shapes, - output_dtypes, output_shapes): +def _compile_gpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + debug=False +): input_byte_size = tuple( np.prod(shape) * dtype.itemsize for (shape, dtype) in zip(input_shapes, input_dtypes) @@ -159,7 +163,7 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len): args_out="\n ".join(args_out), cuMemcpyAsync_in="\n ".join(cuMemcpyAsync_in), cuMemcpyAsync_out="\n ".join(cuMemcpyAsync_out)) - # print(code_string) + if debug: print(code_string) exec(compile(code_string.strip(), '', 'exec'), code_scope) new_f = code_scope['xla_gpu_custom_call_target'] @@ -197,8 +201,10 @@ def func_gpu_translation(func, abs_eval_fn, c, *inputs, **info): for arg in zip(output_dtypes, output_shapes, output_layouts)] xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes) target_name = _compile_gpu_signature(func, - input_dtypes, input_dimensions, - output_dtypes, output_shapes) + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes) return xla_client.ops.CustomCallWithLayout( c, @@ -207,6 +213,3 @@ def func_gpu_translation(func, abs_eval_fn, c, *inputs, **info): operand_shapes_with_layout=input_shapes, shape_with_layout=xla_output_shape, ) - - - diff --git a/extensions/brainpylib/custom_op/tests/a.py b/extensions/brainpylib/custom_op/tests/a.py index a0ac11da2..77a66bb97 100644 --- a/extensions/brainpylib/custom_op/tests/a.py +++ b/extensions/brainpylib/custom_op/tests/a.py @@ -1,7 +1,10 @@ + import brainpy.math as bm import brainpy as bp from jax.abstract_arrays import ShapedArray +import numba +bm.set_platform('cpu') def try1(): def abs_eval(events, indices, indptr, *, weight, post_num): @@ -11,18 +14,20 @@ def con_compute(outs, ins): post_val, = outs post_val.fill(0) events, indices, indptr, weight, _ = ins - weight = weight[()] + # weight = weight[()] + weight = weight + print(weight) for i in range(events.size): if events[i]: - for j in range(indptr[i], indptr[i + 1]): - index = indices[j] - post_val[index] += weight + for j in numba.prange(indptr[i], indptr[i + 1]): + post_val[indices[j]] += weight event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) - events = bm.random.rand(10) < 0.2 - indices, indptr = bp.conn.FixedProb(0.1)(10, 20).require('pre2post') - print(bm.jit(event_sum, static_argnames=('weight', 'post_num'))(events, indices, indptr, weight=1., post_num=20)) + events = bm.random.RandomState(123).rand(10) < 0.2 + indices, indptr = bp.conn.FixedProb(0.1, seed=123)(10, 20).require('pre2post') + # print(bm.jit(, static_argnames=('weight', 'post_num'))(events, indices, indptr, weight=1., post_num=20)) + print(event_sum(events, indices, indptr, weight=1., post_num=20)) def try2(): @@ -41,8 +46,8 @@ def con_compute(outs, ins): event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) - events = bm.random.rand(10) < 0.2 - indices, indptr = bp.conn.FixedProb(0.1)(10, 20).require('pre2post') + events = bm.random.RandomState(123).rand(10) < 0.2 + indices, indptr = bp.conn.FixedProb(0.1, seed=123)(10, 20).require('pre2post') print(bm.jit(event_sum)(events, indices, indptr, bm.zeros(20), 1.)) diff --git a/extensions/brainpylib/custom_op/tests/ei_net.py b/extensions/brainpylib/custom_op/tests/ei_net.py new file mode 100644 index 000000000..2a913d69a --- /dev/null +++ b/extensions/brainpylib/custom_op/tests/ei_net.py @@ -0,0 +1,80 @@ +import brainpy.math as bm +import brainpy as bp +from jax.abstract_arrays import ShapedArray + +bm.set_platform('cpu') + + +def abs_eval(events, indices, indptr, *, weight, post_num): + return ShapedArray((post_num,), bm.float32) + + +def con_compute(outs, ins): + post_val, = outs + post_val.fill(0) + events, indices, indptr, weight, _ = ins + weight = weight[()] + for i in range(events.size): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + index = indices[j] + post_val[index] += weight + + +event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute, apply_cpu_func_to_gpu=True) + + +class ExponentialV2(bp.dyn.TwoEndConn): + """Exponential synapse model using customized operator written in C++.""" + + def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): + super(ExponentialV2, self).__init__(pre=pre, post=post, conn=conn) + self.check_pre_attrs('spike') + self.check_post_attrs('input', 'V') + + # parameters + self.E = E + self.tau = tau + self.delay = delay + self.g_max = g_max + self.pre2post = self.conn.require('pre2post') + + # variables + self.g = bm.Variable(bm.zeros(self.post.num)) + + # function + self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto') + + def update(self, tdi): + self.g.value = self.integral(self.g, tdi.t, tdi.dt) + self.g += event_sum(self.pre.spike, + self.pre2post[0], + self.pre2post[1], + weight=self.g_max, + post_num=self.post.num) + self.post.input += self.g * (self.E - self.post.V) + + +class EINet(bp.dyn.Network): + def __init__(self, scale): + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto') + I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto') + + # synapses + E2E = ExponentialV2(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + E2I = ExponentialV2(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) + I2E = ExponentialV2(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) + I2I = ExponentialV2(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) + + super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) + + +net2 = EINet(scale=10.) +runner2 = bp.dyn.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) +t, _ = runner2.predict(100., eval_time=True) +print(t) + + From 71e13088643dd387c491f98d24e6f912e67d67f8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 18 Oct 2022 20:21:40 +0800 Subject: [PATCH 7/7] fix the difference of constants in jit and nonjit mode --- docs/others/{citing.rst => publications.rst} | 8 ++++++-- extensions/brainpylib/custom_op/regis_op.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) rename docs/others/{citing.rst => publications.rst} (79%) diff --git a/docs/others/citing.rst b/docs/others/publications.rst similarity index 79% rename from docs/others/citing.rst rename to docs/others/publications.rst index 34edadd9e..510a313e1 100644 --- a/docs/others/citing.rst +++ b/docs/others/publications.rst @@ -1,10 +1,14 @@ If BrainPy has been significant in your research, and you would like to acknowledge -the project in your academic publication, we suggest citing the following paper: +the project in your academic publication, we suggest citing the following papers: + + + +- Chaoming Wang, Xiaoyu Chen, Tianqiu Zhang, Si Wu. *BrainPy: a flexible, integrative, efficient, and extensible framework towards general-purpose brain dynamics programming*. In submission. + - Wang, C., Jiang, Y., Liu, X., Lin, X., Zou, X., Ji, Z., & Wu, S. (2021, December). *A Just-In-Time Compilation Approach for Neural Dynamics Simulation*. In International Conference on Neural Information Processing (pp. 15-26). Springer, Cham. -In BibTeX format: .. code-block:: diff --git a/extensions/brainpylib/custom_op/regis_op.py b/extensions/brainpylib/custom_op/regis_op.py index 4acc5ce01..683923d1a 100644 --- a/extensions/brainpylib/custom_op/regis_op.py +++ b/extensions/brainpylib/custom_op/regis_op.py @@ -102,7 +102,7 @@ def eval_rule(*inputs, **info): outputs = tuple(np.zeros(shape.shape, dtype=shape.dtype) for shape in output_shapes) # convert inputs to a tuple inputs = tuple(np.asarray(arg) for arg in inputs) - inputs += tuple(info.values()) + inputs += tuple(np.asarray(i) for i in info.values()) # call the kernel cpu_func(outputs, inputs) # Return the outputs