From d22c55f99d3d371dd9b39b636f471583740f63fd Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 16 Oct 2022 13:32:21 +0800 Subject: [PATCH 1/4] speedup connections in `One2One`, `All2All`, `GridFour`, `GridEight`, `GridN`, `FixedProb`, `FixedPreNum`, `FixedPostNum`, --- brainpy/connect/base.py | 33 +-- brainpy/connect/custom_conn.py | 108 ++++--- brainpy/connect/random_conn.py | 238 ++++++--------- brainpy/connect/regular_conn.py | 320 +++++++++++++-------- brainpy/connect/tests/test_regular_conn.py | 118 +++++--- brainpy/connect/utils.py | 38 +++ 6 files changed, 499 insertions(+), 356 deletions(-) create mode 100644 brainpy/connect/utils.py diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py index fbde670c7..7709e4d31 100644 --- a/brainpy/connect/base.py +++ b/brainpy/connect/base.py @@ -204,13 +204,15 @@ def _return_by_mat(self, structures, mat, all_data: dict): require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0 if require_other_structs: - pre_ids, post_ids = onp.where(mat > 0) - pre_ids = onp.ascontiguousarray(pre_ids, dtype=IDX_DTYPE) - post_ids = onp.ascontiguousarray(post_ids, dtype=IDX_DTYPE) + np = onp if isinstance(mat, onp.ndarray) else bm + pre_ids, post_ids = np.where(mat > 0) + pre_ids = np.asarray(pre_ids, dtype=IDX_DTYPE) + post_ids = np.asarray(post_ids, dtype=IDX_DTYPE) self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data) def _return_by_csr(self, structures, csr: tuple, all_data: dict): indices, indptr = csr + np = onp if isinstance(indices, onp.ndarray) else bm assert self.pre_num == indptr.size - 1 if (CONN_MAT in structures) and (CONN_MAT not in all_data): @@ -218,7 +220,7 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict): all_data[CONN_MAT] = bm.asarray(conn_mat, dtype=MAT_DTYPE) if (PRE_IDS in structures) and (PRE_IDS not in all_data): - pre_ids = onp.repeat(onp.arange(self.pre_num), onp.diff(indptr)) + pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr)) all_data[PRE_IDS] = bm.asarray(pre_ids, dtype=IDX_DTYPE) if (POST_IDS in structures) and (POST_IDS not in all_data): @@ -234,12 +236,12 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict): bm.asarray(indptrc, dtype=IDX_DTYPE)) if (PRE2SYN in structures) and (PRE2SYN not in all_data): - syn_seq = onp.arange(indices.size, dtype=IDX_DTYPE) + syn_seq = np.arange(indices.size, dtype=IDX_DTYPE) all_data[PRE2SYN] = (bm.asarray(syn_seq, dtype=IDX_DTYPE), bm.asarray(indptr, dtype=IDX_DTYPE)) if (POST2SYN in structures) and (POST2SYN not in all_data): - syn_seq = onp.arange(indices.size, dtype=IDX_DTYPE) + syn_seq = np.arange(indices.size, dtype=IDX_DTYPE) _, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq) all_data[POST2SYN] = (bm.asarray(syn_seqc, dtype=IDX_DTYPE), bm.asarray(indptrc, dtype=IDX_DTYPE)) @@ -297,7 +299,7 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None): # "mat" structure if mat is not None: - assert isinstance(mat, onp.ndarray) and onp.ndim(mat) == 2 + assert mat.ndim == 2 if (CONN_MAT in structures) and (CONN_MAT not in all_data): all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE) self._return_by_mat(structures, mat=mat, all_data=all_data) @@ -316,6 +318,7 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None): else: return tuple([all_data[n] for n in structures]) + @tools.not_customized def build_conn(self): """build connections with certain data type. @@ -326,7 +329,7 @@ def build_conn(self): Or a dict with three elements: csr, mat and ij. example: return dict(csr=(ind, indptr), mat=None, ij=None) """ - raise NotImplementedError + pass def require(self, *sizes_or_structures): sizes_or_structures = list(sizes_or_structures) @@ -355,8 +358,6 @@ def require(self, *sizes_or_structures): self.check(structures) if self.is_version2_style: - if (pre_size is None) or (post_size is None): - raise ConnectorError('Please provide both "pre_size" and "post_size".') if len(structures) == 1: if PRE2POST in structures: return self.build_csr(pre_size, post_size) @@ -371,10 +372,10 @@ def require(self, *sizes_or_structures): return self.build_coo(pre_size, post_size) conn_data = dict(csr=None, ij=None, mat=None) - if not hasattr(self.build_csr, 'not_customized'): - conn_data['csr'] = self.build_csr(pre_size, post_size) - elif not hasattr(self.build_coo, 'not_customized'): + if not hasattr(self.build_coo, 'not_customized'): conn_data['ij'] = self.build_coo(pre_size, post_size) + elif not hasattr(self.build_csr, 'not_customized'): + conn_data['csr'] = self.build_csr(pre_size, post_size) elif not hasattr(self.build_mat, 'not_customized'): conn_data['mat'] = self.build_mat(pre_size, post_size) else: @@ -385,15 +386,15 @@ def requires(self, *sizes_or_structures): return self.require(*sizes_or_structures) @tools.not_customized - def build_mat(self, pre_size, post_size): + def build_mat(self, pre_size=None, post_size=None): pass @tools.not_customized - def build_csr(self, pre_size, post_size): + def build_csr(self, pre_size=None, post_size=None): pass @tools.not_customized - def build_coo(self, pre_size, post_size): + def build_coo(self, pre_size=None, post_size=None): pass diff --git a/brainpy/connect/custom_conn.py b/brainpy/connect/custom_conn.py index 879cf60ea..0486c2e8e 100644 --- a/brainpy/connect/custom_conn.py +++ b/brainpy/connect/custom_conn.py @@ -3,14 +3,16 @@ import jax.numpy as jnp import numpy as np +from brainpy import math as bm from brainpy import tools from brainpy.errors import ConnectorError -from brainpy.math.jaxarray import JaxArray from .base import * +from .utils import * __all__ = [ 'MatConn', 'IJConn', + 'CSRConn', 'SparseMatConn' ] @@ -21,19 +23,23 @@ class MatConn(TwoEndConnector): def __init__(self, conn_mat): super(MatConn, self).__init__() - assert isinstance(conn_mat, (np.ndarray, JaxArray, jnp.ndarray)) and conn_mat.ndim == 2 + assert isinstance(conn_mat, (np.ndarray, bm.JaxArray, jnp.ndarray)) and conn_mat.ndim == 2 self.pre_num, self.post_num = conn_mat.shape self.pre_size, self.post_size = (self.pre_num,), (self.post_num,) - - self.conn_mat = np.asarray(conn_mat).astype(MAT_DTYPE) - + + self.conn_mat = bm.asarray(conn_mat).astype(MAT_DTYPE) + def __call__(self, pre_size, post_size): assert self.pre_num == tools.size2num(pre_size) assert self.post_num == tools.size2num(post_size) return self - def build_conn(self): - return 'mat', self.conn_mat + def build_mat(self, pre_size=None, post_size=None): + pre_num = get_pre_num(pre_size) + post_num = get_post_num(post_size) + assert self.conn_mat.shape[0] == pre_num + assert self.conn_mat.shape[1] == post_num + return self.conn_mat class IJConn(TwoEndConnector): @@ -42,37 +48,75 @@ class IJConn(TwoEndConnector): def __init__(self, i, j): super(IJConn, self).__init__() - assert isinstance(i, (np.ndarray, JaxArray, jnp.ndarray)) and i.ndim == 1 - assert isinstance(j, (np.ndarray, JaxArray, jnp.ndarray)) and j.ndim == 1 + assert isinstance(i, (np.ndarray, bm.JaxArray, jnp.ndarray)) and i.ndim == 1 + assert isinstance(j, (np.ndarray, bm.JaxArray, jnp.ndarray)) and j.ndim == 1 assert i.size == j.size # initialize the class via "pre_ids" and "post_ids" - self.pre_ids = np.asarray(i).astype(IDX_DTYPE) - self.post_ids = np.asarray(j).astype(IDX_DTYPE) + self.pre_ids = bm.asarray(i).astype(IDX_DTYPE) + self.post_ids = bm.asarray(j).astype(IDX_DTYPE) + self.max_pre = bm.max(self.pre_ids) + self.max_post = bm.max(self.post_ids) def __call__(self, pre_size, post_size): super(IJConn, self).__call__(pre_size, post_size) - - max_pre = np.max(self.pre_ids) - max_post = np.max(self.post_ids) - if max_pre >= self.pre_num: + if self.max_pre >= self.pre_num: raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than ' - f'the maximum id ({max_pre}) of self.pre_ids.') - if max_post >= self.post_num: + f'the maximum id ({self.max_pre}) of self.pre_ids.') + if self.max_post >= self.post_num: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' + f'the maximum id ({self.max_post}) of self.post_ids.') + return self + + def build_coo(self, pre_size=None, post_size=None): + pre_num = get_pre_num(pre_size) + post_num = get_post_num(post_size) + if pre_num <= self.max_pre: + raise ConnectorError(f'pre_num ({pre_num}) should be greater than ' + f'the maximum id ({self.max_pre}) of self.pre_ids.') + if post_num <= self.max_post: + raise ConnectorError(f'post_num ({post_num}) should be greater than ' + f'the maximum id ({self.max_post}) of self.post_ids.') + return self.pre_ids, self.post_ids + + +class CSRConn(TwoEndConnector): + """Connector built from the CSR sparse connection matrix.""" + + def __init__(self, indices, inptr): + super(CSRConn, self).__init__() + + self.indices = bm.asarray(indices).astype(IDX_DTYPE) + self.inptr = bm.asarray(inptr).astype(IDX_DTYPE) + self.pre_num = self.inptr.size - 1 + self.max_post = bm.max(self.indices) + + def __call__(self, pre_size, post_size): + if self.pre_num != tools.size2num(pre_size): + raise ConnectorError(f'(pre_size, post_size) is inconsistent with the shape of the sparse matrix.') + self.post_num = np.prod(post_size) + if self.post_num <= self.max_post: raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' - f'the maximum id ({max_post}) of self.post_ids.') + f'the maximum id ({self.max_post}) of self.post_ids.') + assert self.post_num == tools.size2num(post_size) return self - def build_conn(self): - return 'ij', (self.pre_ids, self.post_ids) + def build_csr(self, pre_size=None, post_size=None): + pre_num = get_pre_num(pre_size) + post_num = get_post_num(post_size) + if pre_num != tools.size2num(pre_size): + raise ConnectorError(f'(pre_size, post_size) is inconsistent with ' + f'the shape of the sparse matrix.') + if post_num <= self.max_post: + raise ConnectorError(f'post_num ({post_num}) should be greater than ' + f'the maximum id ({self.max_post}) of self.post_ids.') + return self.indices, self.inptr -class SparseMatConn(TwoEndConnector): +class SparseMatConn(CSRConn): """Connector built from the sparse connection matrix""" def __init__(self, csr_mat): - super(SparseMatConn, self).__init__() - try: from scipy.sparse import csr_matrix except (ModuleNotFoundError, ImportError): @@ -80,20 +124,6 @@ def __init__(self, csr_mat): f'Please run "pip install scipy" to install scipy.') assert isinstance(csr_mat, csr_matrix) - csr_mat.data = np.asarray(csr_mat.data).astype(MAT_DTYPE) self.csr_mat = csr_mat - self.pre_num, self.post_num = csr_mat.shape - - def __call__(self, pre_size, post_size): - try: - assert self.pre_num == tools.size2num(pre_size) - assert self.post_num == tools.size2num(post_size) - except AssertionError: - raise ConnectorError(f'(pre_size, post_size) is inconsistent with the shape of the sparse matrix.') - - super(SparseMatConn, self).__call__(pre_size, post_size) - return self - - def build_conn(self): - ind, indptr = self.csr_mat.indices, self.csr_mat.indptr - return 'csr', (ind, indptr) + super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE), + inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE)) diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py index 25841f3cb..d9ddae25f 100644 --- a/brainpy/connect/random_conn.py +++ b/brainpy/connect/random_conn.py @@ -1,11 +1,14 @@ # -*- coding: utf-8 -*- + import jax import numpy as np +from typing import Optional from brainpy import math as bm from brainpy.errors import ConnectorError from brainpy.tools.others import numba_seed, numba_jit, SUPPORT_NUMBA, format_seed from .base import * +from .utils import * __all__ = [ 'FixedProb', @@ -24,9 +27,11 @@ class FixedProb(TwoEndConnector): """Connect the post-synaptic neurons with fixed probability. + .. versionchanged:: 2.2.3.2 + Parameters ---------- - prob : float + prob: float The conn probability. pre_ratio: float The ratio of pre-synaptic neurons to connect. @@ -39,6 +44,7 @@ class FixedProb(TwoEndConnector): def __init__(self, prob, pre_ratio=1., include_self=True, seed=None): super(FixedProb, self).__init__() assert 0. <= prob <= 1. + assert 0. <= pre_ratio <= 1. self.prob = prob self.pre_ratio = pre_ratio self.include_self = include_self @@ -49,56 +55,28 @@ def __repr__(self): return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' f'include_self={self.include_self}, seed={self.seed})') - def build_conn(self): - if SUPPORT_NUMBA: - numba_seed(self.seed) - rng = np.random - else: - rng = np.random.RandomState(self.seed) - - include_self = self.include_self - pre_ratio = self.pre_ratio - prob = self.prob - - @numba_jit - def f_connect(pre_i, num_post): - if rng.random() < pre_ratio: - p = rng.random(num_post) <= prob - if (not include_self) and pre_i < num_post: - p[pre_i] = False - return np.where(p)[0] - - # make connections - ind = [] - count = np.zeros(self.pre_num, dtype=IDX_DTYPE) - for i in range(self.pre_num): - posts = f_connect(pre_i=i, num_post=self.post_num) - if posts is not None: - ind.append(posts) - count[i] = len(posts) - ind = np.concatenate(ind) if len(ind) > 0 else np.asarray([], dtype=IDX_DTYPE) - indptr = np.concatenate(([0], count)).cumsum() - - return 'csr', (ind, indptr) - - def build_mat(self, pre_size, post_size): - pre_num = np.prod(pre_size) - post_num = np.prod(post_size) + def build_mat(self, pre_size=None, post_size=None): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) pre_state = self.rng.rand(pre_num, 1) < self.pre_ratio mat = (self.rng.rand(pre_num, post_num) < self.prob) * pre_state if not self.include_self: bm.fill_diagonal(mat, False) return mat.astype(MAT_DTYPE) - def build_coo(self, pre_size, post_size): - pre_num = np.prod(pre_size) - post_num = np.prod(post_size) + def _iii(self, pre_size, post_size): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) + if (not self.include_self) and (pre_num != post_num): + raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + f'But `include_self` is set to True.') post_num_to_select = int(post_num * self.prob) post_ids = bm.arange(post_num) if self.pre_ratio < 1.: pre_num_to_select = int(pre_num * self.pre_ratio) pre_ids = self.rng.choice(pre_num, size=pre_num_to_select, replace=False) else: + pre_num_to_select = pre_num pre_ids = bm.arange(pre_num) @jax.vmap @@ -106,46 +84,21 @@ def f(i, key): posts = bm.delete(post_ids, i) if not self.include_self else post_ids return self.rng.permutation(posts, key=key)[:post_num_to_select] - selected_pre_ids = bm.repeat(pre_ids, post_num_to_select) selected_post_ids = f(pre_ids, self.rng.split_keys(pre_ids.size)).flatten() - return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) - - def build_csr(self, pre_size, post_size): - pre_num = np.prod(pre_size) - post_num = np.prod(post_size) - post_num_to_select = int(post_num * self.prob) - post_ids = bm.arange(post_num) - if self.pre_ratio < 1.: - pre_num_to_select = int(pre_num * self.pre_ratio) - pre_ids = self.rng.choice(pre_num, size=pre_num_to_select, replace=False) - else: - pre_num_to_select = pre_num - pre_ids = bm.arange(pre_num) + return pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids - @jax.vmap - def f(i, key): - posts = bm.delete(post_ids, i) if not self.include_self else post_ids - return self.rng.permutation(posts, key=key)[:post_num_to_select] + def build_coo(self, pre_size=None, post_size=None): + _, post_num_to_select, selected_post_ids, pre_ids = self._iii(pre_size, post_size) + selected_pre_ids = bm.repeat(pre_ids, post_num_to_select) + return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) - selected_post_ids = f(pre_ids, self.rng.split_keys(pre_ids.size)).flatten() + def build_csr(self, pre_size=None, post_size=None): + pre_num_to_select, post_num_to_select, selected_post_ids, _ = self._iii(pre_size, post_size) selected_pre_inptr = bm.cumsum(bm.concatenate([bm.zeros(1), bm.ones(pre_num_to_select) * post_num_to_select])) return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) class FixedNum(TwoEndConnector): - """Connect with fixed number for each pre- or post-synaptic neuron. - - Parameters - ---------- - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - """ - def __init__(self, num, include_self=True, seed=None): super(FixedNum, self).__init__() if isinstance(num, int): @@ -157,68 +110,51 @@ def __init__(self, num, include_self=True, seed=None): self.num = num self.seed = format_seed(seed) self.include_self = include_self - self.rng = np.random.RandomState(seed=self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _fixed_num_prob(num_need, num_total, i=0): - prob = rng.random(num_total) - if not include_self and i <= num_total: - prob[i] = 1. - neu_idx = np.argsort(prob)[:num_need] - return np.asarray(neu_idx, dtype=IDX_DTYPE) - - self._connect = numba_jit(_fixed_num_prob) + self.rng = bm.random.RandomState(seed=self.seed) def __repr__(self): - return (f'{self.__class__.__name__}(num={self.num}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') + return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' class FixedPreNum(FixedNum): - """Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron. + """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. Parameters ---------- num : float, int - The connection probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). include_self : bool - Whether create (i, i) conn ? + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. """ - def __repr__(self): - return (f'{self.__class__.__name__}(num={self.num}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_conn(self): - # check - if isinstance(self.num, int): - assert 0 <= self.num <= self.pre_num, f'"num" must be smaller than "self.pre_num", ' \ - f'but got {self.num} > {self.pre_num}' - num = self.num - else: - assert 0. <= self.num <= 1., f'"num" must be in [0., 1.), but got {self.num}' - num = int(self.pre_num * self.num) - - # seed - self.seed = self.rng.randint(1, int(1e7)) - numba_seed(self.seed) + def build_coo(self, pre_size=None, post_size=None): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) + if isinstance(self.num, int) and self.num > pre_num: + raise ConnectorError(f'"num" must be smaller than "pre_num", ' + f'but got {self.num} > {pre_num}') + if (not self.include_self) and (pre_num != post_num): + raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + f'But `include_self` is set to True.') + pre_num_to_select = int(pre_num * self.num) if isinstance(self.num, float) else self.num + pre_ids = bm.arange(pre_num) + post_ids = bm.arange(post_num) - # make connections - pre_ids = [] - for i in range(self.post_num): - pres = self._connect(num_need=num, num_total=self.pre_num, i=i) - pre_ids.append(pres) - pre_ids = np.concatenate(pre_ids) if len(pre_ids) > 0 else np.asarray([], dtype=IDX_DTYPE) - post_ids = np.repeat(np.arange(self.post_num), num) + @jax.vmap + def f(post_i, key): + pres = bm.delete(pre_ids, post_i) if not self.include_self else pre_ids + return self.rng.permutation(pres, key=key)[:pre_num_to_select] - return 'ij', (pre_ids, post_ids) + selected_pre_ids = f(post_ids, self.rng.split_keys(pre_num)).flatten() + selected_post_ids = bm.repeat(post_ids, pre_num_to_select) + return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) class FixedPostNum(FixedNum): - """Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron. + """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. Parameters ---------- @@ -231,35 +167,36 @@ class FixedPostNum(FixedNum): Seed the random generator. """ - def __repr__(self): - return (f'{self.__class__.__name__}(num={self.num}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') + def _ii(self, pre_size, post_size): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) + if isinstance(self.num, int) and self.num > post_num: + raise ConnectorError(f'"num" must be smaller than "post_num", ' + f'but got {self.num} > {post_num}') + if (not self.include_self) and (pre_num != post_num): + raise ConnectorError(f'We found pre_num != post_num ({pre_num} != {post_num}). ' + f'But `include_self` is set to True.') + post_num_to_select = int(post_num * self.num) if isinstance(self.num, float) else self.num + pre_ids = bm.arange(pre_num) + post_ids = bm.arange(post_num) - def build_conn(self): - # check - if isinstance(self.num, int): - assert 0 <= self.num <= self.post_num, f'"num" must be smaller than "self.post_num", ' \ - f'but got {self.num} > {self.post_num}' - num = self.num - else: - assert 0. <= self.num <= 1., f'"num" must be in [0., 1.), but got {self.num}' - num = int(self.post_num * self.num) + @jax.vmap + def f(pre_i, key): + posts = bm.delete(post_ids, pre_i) if not self.include_self else post_ids + return self.rng.permutation(posts, key=key)[:post_num_to_select] - # seed - self.seed = self.rng.randint(1, int(1e7)) - numba_seed(self.seed) + selected_post_ids = f(pre_ids, self.rng.split_keys(pre_num)).flatten() + return pre_num, post_num_to_select, selected_post_ids - # make connections - post_ids = [] # i.e. post_ids - for i in range(self.pre_num): - posts = self._connect(num_need=num, num_total=self.post_num, i=i) - post_ids.append(posts) - post_ids = np.concatenate(post_ids) - count = np.ones(self.pre_num, dtype=IDX_DTYPE) * num - indptr = np.concatenate(([0], count)).cumsum() + def build_csr(self, pre_size=None, post_size=None): + pre_num, post_num, selected_post_ids = self._ii(pre_size, post_size) + selected_pre_inptr = bm.cumsum(bm.concatenate([bm.zeros(1), bm.ones(pre_num) * post_num])) + return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) - return 'csr', (post_ids, indptr) + def build_coo(self, pre_size=None, post_size=None): + pre_num, post_num, selected_post_ids = self._ii(pre_size, post_size) + selected_pre_ids = bm.repeat(bm.arange(pre_num), post_num) + return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) class GaussianProb(OneEndConnector): @@ -295,7 +232,7 @@ class GaussianProb(OneEndConnector): normalize : bool Whether normalize the connection probability . include_self : bool - Whether create the conn at the same position. + Whether create the connection at the same position. seed : int The random seed. """ @@ -303,7 +240,7 @@ class GaussianProb(OneEndConnector): def __init__( self, sigma: float, - encoding_values=None, + encoding_values: Optional[np.ndarray] = None, normalize: bool = True, include_self: bool = True, periodic_boundary: bool = False, @@ -316,7 +253,7 @@ def __init__( self.include_self = include_self self.periodic_boundary = periodic_boundary self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) + self.rng = bm.random.RandomState(self.seed) def __repr__(self): return (f'{self.__class__.__name__}(sigma={self.sigma}, ' @@ -325,7 +262,7 @@ def __repr__(self): f'include_self={self.include_self}, ' f'seed={self.seed})') - def build_conn(self): + def build_mat(self, pre_size=None, post_size=None): # value range to encode if self.encoding_values is None: value_ranges = tuple([(0, s) for s in self.pre_size]) @@ -352,6 +289,7 @@ def build_conn(self): # values values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] + # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) value_sizes = np.array([v[1] - v[0] for v in value_ranges]) if value_sizes.ndim < post_values.ndim: @@ -380,12 +318,10 @@ def build_conn(self): prob_mat /= prob_mat.max() # connectivity - conn_mat = prob_mat >= self.rng.random(prob_mat.shape) - + conn_mat = bm.asarray(prob_mat) >= self.rng.random(prob_mat.shape) if not self.include_self: - np.fill_diagonal(conn_mat, False) - - return 'mat', conn_mat + bm.fill_diagonal(conn_mat, False) + return conn_mat class SmallWorld(TwoEndConnector): @@ -985,7 +921,7 @@ def build_conn(self): post_size = np.asarray(self.post_size) connected_pres = [] connected_posts = [] - pre_ids = np.meshgrid(*(np.arange(p) for p in self.pre_size)) + pre_ids = np.meshgrid(*(np.arange(p) for p in self.pre_size), indexing='ij') pre_ids = tuple([(np.moveaxis(p, 0, 1).flatten()) if p.ndim > 1 else p.flatten() for p in pre_ids]) size = np.prod(pre_size) for i in range(size): diff --git a/brainpy/connect/regular_conn.py b/brainpy/connect/regular_conn.py index 0378d2d23..0d3286bb9 100644 --- a/brainpy/connect/regular_conn.py +++ b/brainpy/connect/regular_conn.py @@ -1,15 +1,14 @@ # -*- coding: utf-8 -*- +from typing import Union, Tuple, List -import logging - +import jax import numpy as np +from brainpy import math as bm from brainpy.errors import ConnectorError -from brainpy.tools.others import numba_jit - +from brainpy.tools import size2num from .base import * - -logger = logging.getLogger('brainpy.building.connect') +from .utils import * __all__ = [ 'One2One', 'one2one', @@ -37,11 +36,31 @@ def __call__(self, pre_size, post_size): f'same size, but {self.pre_num} != {self.post_num}.') return self - def build_conn(self): - ind = np.arange(self.pre_num) - indptr = np.arange(self.pre_num + 1) + def build_coo(self, pre_size=None, post_size=None): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) + if pre_num != post_num: + raise ConnectorError(f'One2One connection must be defined in two groups with the ' + f'same size, but {pre_num} != {post_num}.') + return bm.arange(pre_num, dtype=IDX_DTYPE), bm.arange(post_num, dtype=IDX_DTYPE), - return dict(csr=(ind, indptr), mat=None, ij=None) + def build_csr(self, pre_size=None, post_size=None): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) + if pre_num != post_num: + raise ConnectorError(f'One2One connection must be defined in two groups with the ' + f'same size, but {pre_num} != {post_num}.') + ind = bm.arange(pre_num) + indptr = np.arange(pre_num + 1) + return bm.asarray(ind, dtype=IDX_DTYPE), bm.arange(indptr, dtype=IDX_DTYPE), + + def build_mat(self, pre_size=None, post_size=None): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) + if pre_num != post_num: + raise ConnectorError(f'One2One connection must be defined in two groups with the ' + f'same size, but {pre_num} != {post_num}.') + return bm.fill_diagonal(bm.zeros((pre_num, post_num), dtype=MAT_DTYPE), True) one2one = One2One() @@ -58,100 +77,163 @@ def __init__(self, include_self=True): super(All2All, self).__init__() def __repr__(self): - return (f'{self.__class__.__name__}(include_self={self.include_self})') + return f'{self.__class__.__name__}(include_self={self.include_self})' - def build_conn(self): - mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE) + def build_mat(self, pre_size=None, post_size=None): + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) + mat = bm.ones((pre_num, post_num), dtype=MAT_DTYPE) if not self.include_self: - np.fill_diagonal(mat, False) - - return dict(csr=None, mat=mat, ij=None) + bm.fill_diagonal(mat, False) + return mat all2all = All2All(include_self=True) -@numba_jit -def _grid_four(height, width, row, include_self): - conn_i = [] - conn_j = [] - - for col in range(width): - i_index = (row * width) + col - if 0 <= row - 1 < height: - j_index = ((row - 1) * width) + col - conn_i.append(i_index) - conn_j.append(j_index) - if 0 <= row + 1 < height: - j_index = ((row + 1) * width) + col - conn_i.append(i_index) - conn_j.append(j_index) - if 0 <= col - 1 < width: - j_index = (row * width) + col - 1 - conn_i.append(i_index) - conn_j.append(j_index) - if 0 <= col + 1 < width: - j_index = (row * width) + col + 1 - conn_i.append(i_index) - conn_j.append(j_index) - if include_self: - conn_i.append(i_index) - conn_j.append(i_index) - return conn_i, conn_j - - -class GridFour(OneEndConnector): - """The nearest four neighbors conn method.""" - - def __init__(self, include_self=False): - super(GridFour, self).__init__() +def get_size_length(sizes: Union[Tuple, List]): + if not isinstance(sizes, (tuple, list)): + raise TypeError + lengths = [] + a = 1 + for s in reversed(sizes): + lengths.insert(0, a) + a *= s + return np.asarray(lengths) + + +class GridConn(OneEndConnector): + def __init__( + self, + strides, + include_self: bool = False, + periodic_boundary: bool = False, + ): + super(GridConn, self).__init__() + self.strides = strides self.include_self = include_self + self.periodic_boundary = periodic_boundary def __repr__(self): - return (f'{self.__class__.__name__}(include_self={self.include_self})') - - def build_conn(self): - # only the 1- or 2-D structure is supported - if len(self.pre_size) == 1: - height, width = self.pre_size[0], 1 - elif len(self.pre_size) == 2: - height, width = self.pre_size + return f'{self.__class__.__name__}(include_self={self.include_self}, periodic_boundary={self.periodic_boundary})' + + def _format(self, pre_size, post_size): + pre_size = get_pre_size(self, pre_size) + post_size = get_post_size(self, post_size) + pre_num = size2num(pre_size) + post_num = size2num(post_size) + dim = len(post_size) + if pre_num != post_num: + raise ConnectorError(f'{self.__class__.__name__} is used to for connection within ' + f'a same population. But we detect pre_num != post_num ' + f'({pre_num} != {post_num}).') + # point indices + indices = bm.meshgrid(*(bm.arange(size) for size in post_size), indexing='ij') + indices = bm.asarray(indices) + indices = indices.reshape(dim, post_num).T + lengths = bm.asarray(post_size) + return lengths, post_size, dim, indices + + def _get_strides(self, dim): + # increments + increments = np.asarray(np.meshgrid(*(self.strides for _ in range(dim)))).reshape(dim, -1).T + select_ids = self._select_stride(increments) + increments = bm.asarray(increments[select_ids]) + return increments + + def _select_stride(self, stride: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def _select_dist(self, dist: bm.ndarray) -> bm.ndarray: + raise NotImplementedError + + def build_mat(self, pre_size=None, post_size=None): + sizes, post_size, _, indices = self._format(pre_size, post_size) + + @jax.vmap + def f_connect(pre_id): + # pre_id: R^(num_dim) + dist = bm.abs(pre_id - indices) + if self.periodic_boundary: + dist = bm.where(dist > sizes / 2, sizes - dist, dist) + return self._select_dist(dist) + + return bm.asarray(f_connect(indices), dtype=MAT_DTYPE) + + def build_coo(self, pre_size=None, post_size=None): + sizes, post_size, dim, indices = self._format(pre_size, post_size) + strides = self._get_strides(dim) + + @jax.vmap + def f_connect(pre_id): + # pre_id: R^(num_dim) + post_ids = pre_id + strides + if self.periodic_boundary: + post_ids = post_ids % sizes + else: + post_ids = bm.where(post_ids < sizes, post_ids, -1) + size = len(post_ids) + pre_ids = bm.repeat(pre_id, size).reshape(dim, size).T + return pre_ids, post_ids + + pres, posts = f_connect(indices) + pres = pres.reshape(-1, dim) + posts = posts.reshape(-1, dim) + idx = bm.nonzero(bm.all(posts >= 0, axis=1))[0] + pres = pres[idx] + posts = posts[idx] + if dim == 1: + pres = pres.flatten() + posts = posts.flatten() else: - raise ConnectorError(f'Currently, GridFour only supports the two-dimensional geometry.') + strides = bm.asarray(get_size_length(post_size)) + pres = bm.sum(pres * strides, axis=1) + posts = bm.sum(posts * strides, axis=1) + return bm.asarray(pres, dtype=IDX_DTYPE), bm.asarray(posts, dtype=IDX_DTYPE) + + +class GridFour(GridConn): + """The nearest four neighbors connection method. + + Parameters + ---------- + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. + .. versionadded:: 2.2.3.2 + + include_self : bool + Whether create connection at the same position. + """ + + def __init__( + self, + include_self: bool = False, + periodic_boundary: bool = False + ): + super(GridFour, self).__init__(strides=np.asarray([-1, 0, 1]), + include_self=include_self, + periodic_boundary=periodic_boundary) + self.include_self = include_self + self.periodic_boundary = periodic_boundary - conn_i = [] - conn_j = [] - for row in range(height): - a = _grid_four(height, width, row, include_self=self.include_self) - conn_i.extend(a[0]) - conn_j.extend(a[1]) - pre_ids = np.asarray(conn_i, dtype=IDX_DTYPE) - post_ids = np.asarray(conn_j, dtype=IDX_DTYPE) + def _select_stride(self, stride: np.ndarray) -> np.ndarray: + temp = abs(stride).sum(axis=1) + return (temp <= 1) if self.include_self else (temp == 1) - return 'ij', (pre_ids, post_ids) + def _select_dist(self, dist: bm.ndarray) -> bm.ndarray: + dist = bm.linalg.norm(dist, axis=1) + return dist <= 1 if self.include_self else dist == 1 + # dist = bm.abs(dist) + # if self.include_self: + # return bm.prod(dist <= 1, axis=1) + # else: + # return bm.prod(dist == 1, axis=1) grid_four = GridFour() -@numba_jit -def _grid_n(height, width, row, n, include_self): - conn_i = [] - conn_j = [] - for col in range(width): - i_index = (row * width) + col - for row_diff in range(-n, n + 1): - for col_diff in range(-n, n + 1): - if (not include_self) and (row_diff == col_diff == 0): - continue - if 0 <= row + row_diff < height and 0 <= col + col_diff < width: - j_index = ((row + row_diff) * width) + col + col_diff - conn_i.append(i_index) - conn_j.append(j_index) - return conn_i, conn_j - - -class GridN(OneEndConnector): +class GridN(GridConn): """The nearest (2*N+1) * (2*N+1) neighbors conn method. Parameters @@ -169,43 +251,55 @@ class GridN(OneEndConnector): [x x x x x] [x x x x x] include_self : bool - Whether create (i, i) conn ? + Whether create (i, i) conn ? + periodic_boundary: bool + Whether the neuron encode the value space with the periodic boundary. + .. versionadded:: 2.2.3.2 """ - def __init__(self, N=1, include_self=False): - super(GridN, self).__init__() + def __init__( + self, + N: int = 1, + include_self: bool = False, + periodic_boundary: bool = False + ): + super(GridN, self).__init__(strides=np.arange(-N, N + 1, 1), + include_self=include_self, + periodic_boundary=periodic_boundary) self.N = N - self.include_self = include_self def __repr__(self): - return (f'{self.__class__.__name__}(N={self.N}, include_self={self.include_self})') - - def build_conn(self): - if len(self.pre_size) == 1: - height, width = self.pre_size[0], 1 - elif len(self.pre_size) == 2: - height, width = self.pre_size + return (f'{self.__class__.__name__}(N={self.N}, ' + f'include_self={self.include_self}, ' + f'periodic_boundary={self.periodic_boundary})') + + def _select_stride(self, stride: np.ndarray) -> np.ndarray: + return (np.ones(len(stride), dtype=bool) + if self.include_self else + (np.sum(np.abs(stride), axis=1) > 0)) + + def _select_dist(self, dist: bm.ndarray) -> bm.ndarray: + if self.include_self: + return bm.all(dist <= self.N, axis=1) else: - raise ConnectorError(f'Currently, GridN only supports the two-dimensional geometry.') - - conn_i = [] - conn_j = [] - for row in range(height): - res = _grid_n(height=height, width=width, row=row, - n=self.N, include_self=self.include_self) - conn_i.extend(res[0]) - conn_j.extend(res[1]) - pre_ids = np.asarray(conn_i, dtype=IDX_DTYPE) - post_ids = np.asarray(conn_j, dtype=IDX_DTYPE) - - return 'ij', (pre_ids, post_ids) + return bm.logical_and(bm.all(dist <= self.N, axis=1), + bm.logical_not(bm.all(dist == 0, axis=1))) class GridEight(GridN): - """The nearest eight neighbors conn method.""" + """The nearest eight neighbors conn method. + + Parameters + ---------- + include_self : bool + Whether create (i, i) conn ? + periodic_boundary: bool + Whether the neurons encode the value space with the periodic boundary. + .. versionadded:: 2.2.3.2 + """ - def __init__(self, include_self=False): - super(GridEight, self).__init__(N=1, include_self=include_self) + def __init__(self, include_self=False, periodic_boundary: bool = False): + super(GridEight, self).__init__(N=1, include_self=include_self, periodic_boundary=periodic_boundary) grid_eight = GridEight() diff --git a/brainpy/connect/tests/test_regular_conn.py b/brainpy/connect/tests/test_regular_conn.py index f6d9e79a7..b2a4fd41e 100644 --- a/brainpy/connect/tests/test_regular_conn.py +++ b/brainpy/connect/tests/test_regular_conn.py @@ -1,52 +1,29 @@ # -*- coding: utf-8 -*- +import numpy as np import brainpy as bp from brainpy import connect +import unittest -def test_one2one(): - for size in [100, (3, 4), (4, 5, 6)]: - conn = connect.One2One()(pre_size=size, post_size=size) - conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \ - conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn') - - num = bp.tools.size2num(size) - - actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_) - bp.math.fill_diagonal(actual_mat, True) - - assert bp.math.array_equal(actual_mat, conn_mat) - assert bp.math.array_equal(pre_ids, bp.math.arange(num)) - assert bp.math.array_equal(post_ids, bp.math.arange(num)) - - print() - print('conn_mat', conn_mat) - print('pre_ids', pre_ids) - print('post_ids', post_ids) - print('pre2post', pre2post) - print('post2pre', post2pre) - print('pre2syn', pre2syn) - print('post2syn', post2syn) - - -def test_all2all(): - for has_self in [True, False]: +class TestOne2One(unittest.TestCase): + def test_one2one(self): for size in [100, (3, 4), (4, 5, 6)]: - conn = connect.All2All(include_self=has_self)(pre_size=size, post_size=size) - mat = conn.require(connect.CONN_MAT) + conn = connect.One2One()(pre_size=size, post_size=size) + conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \ conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn') + num = bp.tools.size2num(size) - print(mat) - actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_) - if not has_self: - bp.math.fill_diagonal(actual_mat, False) + actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_) + bp.math.fill_diagonal(actual_mat, True) - assert bp.math.array_equal(actual_mat, mat) + assert bp.math.array_equal(actual_mat, conn_mat) + assert bp.math.array_equal(pre_ids, bp.math.arange(num)) + assert bp.math.array_equal(post_ids, bp.math.arange(num)) - print() print('conn_mat', conn_mat) print('pre_ids', pre_ids) print('post_ids', post_ids) @@ -56,5 +33,72 @@ def test_all2all(): print('post2syn', post2syn) -def test_grid_four(): - pass +class TestAll2All(unittest.TestCase): + def test_all2all(self): + for has_self in [True, False]: + for size in [100, (3, 4), (4, 5, 6)]: + conn = connect.All2All(include_self=has_self)(pre_size=size, post_size=size) + mat = conn.require(connect.CONN_MAT) + conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \ + conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn') + num = bp.tools.size2num(size) + + print(mat) + actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_) + if not has_self: + bp.math.fill_diagonal(actual_mat, False) + assert bp.math.array_equal(actual_mat, mat) + + print() + print('conn_mat', conn_mat) + print('pre_ids', pre_ids) + print('post_ids', post_ids) + print('pre2post', pre2post) + print('post2pre', post2pre) + print('pre2syn', pre2syn) + print('post2syn', post2syn) + + +class TestGridConn(unittest.TestCase): + def test_grid_four(self): + for periodic_boundary in [True, False]: + for include_self in [True, False]: + for size in (10, [10, 10], (4, 4, 5)): + conn = bp.conn.GridFour(include_self=include_self, + periodic_boundary=periodic_boundary) + mat = conn.build_mat(size, size) + pre_ids, post_ids = conn.build_coo(size, size) + new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) + new_mat[pre_ids, post_ids] = True + + print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') + self.assertTrue(bp.math.allclose(mat, new_mat)) + + def test_grid_eight(self): + for periodic_boundary in [True, False]: + for include_self in [True, False]: + for size in (10, [10, 10], (4, 4, 5)): + conn = bp.conn.GridEight(include_self=include_self, + periodic_boundary=periodic_boundary) + mat = conn.build_mat(size, size) + pre_ids, post_ids = conn.build_coo(size, size) + new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) + new_mat[pre_ids, post_ids] = True + + print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') + self.assertTrue(bp.math.allclose(mat, new_mat)) + + def test_grid_N(self): + for periodic_boundary in [True, False]: + for include_self in [True, False]: + for size in (10, [10, 10], (4, 4, 5)): + conn = bp.conn.GridN(include_self=include_self, + periodic_boundary=periodic_boundary, + N=2) + mat = conn.build_mat(size, size) + pre_ids, post_ids = conn.build_coo(size, size) + new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) + new_mat[pre_ids, post_ids] = True + + print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') + self.assertTrue(bp.math.allclose(mat, new_mat)) diff --git a/brainpy/connect/utils.py b/brainpy/connect/utils.py new file mode 100644 index 000000000..ca7a8e8e9 --- /dev/null +++ b/brainpy/connect/utils.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + + +from brainpy.errors import ConnectorError +from brainpy.tools import size2num + +__all__ = [ + 'get_pre_num', 'get_post_num', + 'get_pre_size', 'get_post_size', +] + + +def get_pre_size(obj, pre_size=None): + if pre_size is None: + if obj.pre_size is None: + raise ConnectorError('Please provide "pre_size" and "post_size"') + else: + return obj.pre_size + else: + return (pre_size, ) if isinstance(pre_size, int) else pre_size + + +def get_pre_num(obj, pre_size=None): + return size2num(get_pre_size(obj, pre_size)) + + +def get_post_size(obj, post_size=None): + if post_size is None: + if obj.post_size is None: + raise ConnectorError('Please provide "pre_size" and "post_size"') + else: + return obj.post_size + else: + return (post_size,) if isinstance(post_size, int) else post_size + + +def get_post_num(obj, post_size=None): + return size2num(get_post_size(obj, post_size)) From d2adfc8b7ccade1c941ac2ec843b753ce2402a84 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 16 Oct 2022 14:03:38 +0800 Subject: [PATCH 2/4] fix connection bugs --- brainpy/connect/base.py | 15 +++++++------ brainpy/connect/custom_conn.py | 26 ++++++++--------------- brainpy/connect/random_conn.py | 2 +- brainpy/connect/tests/test_random_conn.py | 4 ++-- 4 files changed, 21 insertions(+), 26 deletions(-) diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py index 7709e4d31..904e96ff2 100644 --- a/brainpy/connect/base.py +++ b/brainpy/connect/base.py @@ -3,6 +3,7 @@ import abc from typing import Union, List, Tuple, Any +import jax.numpy as jnp import numpy as onp from brainpy import tools, math as bm @@ -359,16 +360,16 @@ def require(self, *sizes_or_structures): self.check(structures) if self.is_version2_style: if len(structures) == 1: - if PRE2POST in structures: + if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'): return self.build_csr(pre_size, post_size) - elif CONN_MAT in structures: + elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'): return self.build_mat(pre_size, post_size) - elif PRE_IDS in structures: + elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'): return self.build_coo(pre_size, post_size)[0] - elif POST_IDS in structures: + elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): return self.build_coo(pre_size, post_size)[1] elif len(structures) == 2: - if PRE_IDS in structures and POST_IDS in structures: + if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): return self.build_coo(pre_size, post_size) conn_data = dict(csr=None, ij=None, mat=None) @@ -443,6 +444,8 @@ def csr2csc(csr, post_num, data=None): pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr)) sort_ids = np.argsort(indices, kind=kind) # to maintain the original order of the elements with the same value + if isinstance(sort_ids, bm.JaxArray): + sort_ids = sort_ids.value pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE) unique_post_ids, count = np.unique(indices, return_counts=True) @@ -504,7 +507,7 @@ def ij2csr(pre_ids, post_ids, num_pre): # sorting sort_ids = np.argsort(pre_ids, kind=kind) - post_ids = post_ids[sort_ids] + post_ids = post_ids[sort_ids.value if isinstance(sort_ids, bm.JaxArray) else sort_ids] indices = post_ids unique_pre_ids, pre_count = np.unique(pre_ids, return_counts=True) diff --git a/brainpy/connect/custom_conn.py b/brainpy/connect/custom_conn.py index 0486c2e8e..7cac4cd72 100644 --- a/brainpy/connect/custom_conn.py +++ b/brainpy/connect/custom_conn.py @@ -35,8 +35,8 @@ def __call__(self, pre_size, post_size): return self def build_mat(self, pre_size=None, post_size=None): - pre_num = get_pre_num(pre_size) - post_num = get_post_num(post_size) + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) assert self.conn_mat.shape[0] == pre_num assert self.conn_mat.shape[1] == post_num return self.conn_mat @@ -69,8 +69,8 @@ def __call__(self, pre_size, post_size): return self def build_coo(self, pre_size=None, post_size=None): - pre_num = get_pre_num(pre_size) - post_num = get_post_num(post_size) + pre_num = get_pre_num(self, pre_size) + post_num = get_post_num(self, post_size) if pre_num <= self.max_pre: raise ConnectorError(f'pre_num ({pre_num}) should be greater than ' f'the maximum id ({self.max_pre}) of self.pre_ids.') @@ -91,20 +91,12 @@ def __init__(self, indices, inptr): self.pre_num = self.inptr.size - 1 self.max_post = bm.max(self.indices) - def __call__(self, pre_size, post_size): - if self.pre_num != tools.size2num(pre_size): - raise ConnectorError(f'(pre_size, post_size) is inconsistent with the shape of the sparse matrix.') - self.post_num = np.prod(post_size) - if self.post_num <= self.max_post: - raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' - f'the maximum id ({self.max_post}) of self.post_ids.') - assert self.post_num == tools.size2num(post_size) - return self - def build_csr(self, pre_size=None, post_size=None): - pre_num = get_pre_num(pre_size) - post_num = get_post_num(post_size) - if pre_num != tools.size2num(pre_size): + pre_size = get_pre_size(self, pre_size) + post_size = get_post_size(self, post_size) + pre_num = np.prod(pre_size) + post_num = np.prod(post_size) + if pre_num != self.pre_num: raise ConnectorError(f'(pre_size, post_size) is inconsistent with ' f'the shape of the sparse matrix.') if post_num <= self.max_post: diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py index d9ddae25f..4daf894f8 100644 --- a/brainpy/connect/random_conn.py +++ b/brainpy/connect/random_conn.py @@ -148,7 +148,7 @@ def f(post_i, key): pres = bm.delete(pre_ids, post_i) if not self.include_self else pre_ids return self.rng.permutation(pres, key=key)[:pre_num_to_select] - selected_pre_ids = f(post_ids, self.rng.split_keys(pre_num)).flatten() + selected_pre_ids = f(post_ids, self.rng.split_keys(post_num)).flatten() selected_post_ids = bm.repeat(post_ids, pre_num_to_select) return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) diff --git a/brainpy/connect/tests/test_random_conn.py b/brainpy/connect/tests/test_random_conn.py index ac76468a3..3df5185ea 100644 --- a/brainpy/connect/tests/test_random_conn.py +++ b/brainpy/connect/tests/test_random_conn.py @@ -49,7 +49,7 @@ def test_random_fix_pre2(): def test_random_fix_pre3(): - with pytest.raises(AssertionError): + with pytest.raises(bp.errors.ConnectorError): conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4) conn1.require(bp.connect.CONN_MAT) @@ -77,7 +77,7 @@ def test_random_fix_post2(): def test_random_fix_post3(): - with pytest.raises(AssertionError): + with pytest.raises(bp.errors.ConnectorError): conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) conn1.require(bp.connect.CONN_MAT) From b725631cb80e3794190162ffaa054b943c68d146 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 16 Oct 2022 14:20:50 +0800 Subject: [PATCH 3/4] consistent brainpy ops with brainpylib --- brainpy/math/operators/pre_syn_post.py | 54 ++++--------------- brainpy/math/operators/utils.py | 4 +- brainpy/measure/tests/test_correlation.py | 10 ++-- extensions/brainpylib/atomic_prod.py | 4 +- extensions/brainpylib/atomic_sum.py | 4 +- extensions/brainpylib/event_prod.py | 4 +- extensions/brainpylib/event_sum.py | 4 +- .../brainpylib/tests/test_atomic_prod_cpu.py | 5 +- .../brainpylib/tests/test_atomic_prod_gpu.py | 4 +- .../brainpylib/tests/test_atomic_sum_cpu.py | 4 +- .../brainpylib/tests/test_atomic_sum_gpu.py | 4 +- .../tests/test_coo_event_sum_gpu.py | 4 +- .../tests/test_csr_event_prod_cpu.py | 4 +- .../tests/test_csr_event_sum_cpu.py | 12 ++--- .../tests/test_csr_event_sum_gpu.py | 12 ++--- extensions/setup.py | 2 +- 16 files changed, 52 insertions(+), 83 deletions(-) diff --git a/brainpy/math/operators/pre_syn_post.py b/brainpy/math/operators/pre_syn_post.py index 27411acdb..8b37321a5 100644 --- a/brainpy/math/operators/pre_syn_post.py +++ b/brainpy/math/operators/pre_syn_post.py @@ -104,7 +104,7 @@ def pre2post_csr_event_sum(events: Array, indices = as_device_array(indices) idnptr = as_device_array(idnptr) values = as_device_array(values) - return brainpylib.event_sum(events, (indices, idnptr), post_num, values) + return brainpylib.csr_event_sum(events, (indices, idnptr), post_num, values) pre2post_event_sum = pre2post_csr_event_sum @@ -124,7 +124,7 @@ def pre2post_coo_event_sum(events: Array, pre_ids: Array Pre-synaptic ids. post_ids: Array - Post-synaptic idsd. + Post-synaptic ids. post_num: int The number of post-synaptic group. values: float, Array @@ -140,7 +140,7 @@ def pre2post_coo_event_sum(events: Array, post_ids = as_device_array(post_ids) pre_ids = as_device_array(pre_ids) values = as_device_array(values) - return brainpylib.event_sum2(events, pre_ids, post_ids, post_num, values) + return brainpylib.coo_event_sum(events, pre_ids, post_ids, post_num, values) def pre2post_csr_event_prod(events, pre2post, post_num, values=1.): @@ -195,7 +195,7 @@ def pre2post_csr_event_prod(events, pre2post, post_num, values=1.): indices = as_device_array(indices) idnptr = as_device_array(idnptr) values = as_device_array(values) - return brainpylib.event_prod(events, (indices, idnptr), post_num, values) + return brainpylib.csr_event_prod(events, (indices, idnptr), post_num, values) pre2post_event_prod = pre2post_csr_event_prod @@ -385,40 +385,6 @@ def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): return syn2post_mean(pre_values, post_ids, post_num) -def pre2post_matmul(event, conn): - event = event.value if isinstance(event, JaxArray) else event - Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] - Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] - if jnp.ndim(event) != 1: - raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') - if jnp.ndim(Cl) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') - if jnp.ndim(Cr) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') - - f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum(), in_axes=(0, None)) - ii = jnp.arange(Cl.shape[0]) - f1 = vmap(lambda j: f0(ii, j).sum(), in_axes=(None, 0)) - return f1(jnp.arange(Cr.shape[1])) - - -def pre2post_matmul2(event, conn): - event = event.value if isinstance(event, JaxArray) else event - Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] - Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] - if jnp.ndim(event) != 1: - raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') - if jnp.ndim(Cl) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') - if jnp.ndim(Cr) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') - f1 = vmap(lambda j: (event * (Cl * Cr[:, j]).sum(1)).sum()) - return f1(jnp.arange(Cr.shape[1])) - - -_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None)) - - def pre2syn(pre_values, pre_ids): """The pre-to-syn computation. @@ -459,7 +425,7 @@ def pre2syn(pre_values, pre_ids): _jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3)) -def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True): +def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post summation computation. This function is equivalent to: @@ -495,7 +461,7 @@ def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True): syn2post = syn2post_sum -def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=True): +def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post product computation. This function is equivalent to: @@ -532,7 +498,7 @@ def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=True): return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) -def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=True): +def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post maximum computation. This function is equivalent to: @@ -569,7 +535,7 @@ def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=True): return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) -def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True): +def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post minimization computation. This function is equivalent to: @@ -606,7 +572,7 @@ def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True): return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) -def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=True): +def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post mean computation. Parameters @@ -636,7 +602,7 @@ def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=True): return jnp.nan_to_num(nominator / denominator) -def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True): +def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post softmax computation. Parameters diff --git a/brainpy/math/operators/utils.py b/brainpy/math/operators/utils.py index fb8c1cfa4..a6143a437 100644 --- a/brainpy/math/operators/utils.py +++ b/brainpy/math/operators/utils.py @@ -8,7 +8,7 @@ brainpylib = None -_BRAINPYLIB_MINIMAL_VERSION = '0.0.6' +_BRAINPYLIB_MINIMAL_VERSION = '0.0.7' def _check_brainpylib(ops_name): @@ -17,7 +17,7 @@ def _check_brainpylib(ops_name): raise PackageMissingError( f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n' f'Please install it through:\n\n' - f'>>> pip install brainpylib=={_BRAINPYLIB_MINIMAL_VERSION}\n' + f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}\n' f'>>> # or \n' f'>>> pip install brainpylib -U' ) diff --git a/brainpy/measure/tests/test_correlation.py b/brainpy/measure/tests/test_correlation.py index aa963fe34..ac992a63b 100644 --- a/brainpy/measure/tests/test_correlation.py +++ b/brainpy/measure/tests/test_correlation.py @@ -54,19 +54,21 @@ def test_cc5(self): class TestVoltageFluctuation(unittest.TestCase): def test_vf1(self): rng = bp.math.random.RandomState(122) - voltages = rng.normal(0, 10, size=(1000, 100)) + voltages = rng.normal(0, 10, size=(1000, 100)).value print(bp.measure.voltage_fluctuation(voltages)) - voltages = bp.math.ones((1000, 100)) + bm.enable_x64() + voltages = bp.math.ones((1000, 100)).value r1 = bp.measure.voltage_fluctuation(voltages) jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False)) + jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False)) r2 = jit_f(voltages) - print(r1, r2) # TODO: JIT results are different? - # self.assertTrue(r1 == r2) + bm.disable_x64() + class TestFunctionalConnectivity(unittest.TestCase): def test_cf1(self): diff --git a/extensions/brainpylib/atomic_prod.py b/extensions/brainpylib/atomic_prod.py index 3b6973c00..1b7f9bd11 100644 --- a/extensions/brainpylib/atomic_prod.py +++ b/extensions/brainpylib/atomic_prod.py @@ -40,7 +40,8 @@ def coo_atomic_prod(values, post_ids, post_num, pre_ids=None): raise ValueError(f'The dtype of post_ids must be uint32 or uint64, while we got {post_ids.dtype}') # output value - values = jnp.asarray([values]) + if np.ndim(values) == 0: + values = jnp.asarray([values]) if values.dtype not in [jnp.float32, jnp.float64]: raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.') # if values.size not in [1, pre_ids.size]: @@ -49,7 +50,6 @@ def coo_atomic_prod(values, post_ids, post_num, pre_ids=None): if values.size != 1 and values.size <= pre_ids.max(): raise ValueError(f'The size of "values" must be 1 (a scalar) or longer than pre_size (a vector), ' f'while we got {values.size} != 1 <= {pre_ids.max()}') - values = values.flatten() # bind operator return coo_atomic_prod_p1.bind(values, pre_ids, post_ids, post_num=post_num) diff --git a/extensions/brainpylib/atomic_sum.py b/extensions/brainpylib/atomic_sum.py index 457831249..a6ca92277 100644 --- a/extensions/brainpylib/atomic_sum.py +++ b/extensions/brainpylib/atomic_sum.py @@ -40,7 +40,8 @@ def coo_atomic_sum(values, post_ids, post_num, pre_ids=None): raise ValueError(f'The dtype of post_ids must be uint32 or uint64, while we got {post_ids.dtype}') # output value - values = jnp.asarray([values]) + if np.ndim(values) == 0: + values = jnp.asarray([values]) if values.dtype not in [jnp.float32, jnp.float64]: raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.') # if values.size not in [1, pre_ids.size]: @@ -49,7 +50,6 @@ def coo_atomic_sum(values, post_ids, post_num, pre_ids=None): if values.size != 1 and values.size <= pre_ids.max(): raise ValueError(f'The size of "values" must be 1 (a scalar) or longer than pre_size (a vector), ' f'while we got {values.size} != 1 <= {pre_ids.max()}') - values = values.flatten() # bind operator return coo_atomic_sum_p1.bind(values, pre_ids, post_ids, post_num=post_num) diff --git a/extensions/brainpylib/event_prod.py b/extensions/brainpylib/event_prod.py index ad549a5bd..752002523 100644 --- a/extensions/brainpylib/event_prod.py +++ b/extensions/brainpylib/event_prod.py @@ -41,13 +41,13 @@ def csr_event_prod(events, pre2post, post_num, values): raise ValueError(f'The dtype of pre2post must be uint32 or uint64, while we got {indices.dtype}') # output value - values = jnp.asarray([values]) + if np.ndim(values) == 0: + values = jnp.asarray([values]) if values.dtype not in [jnp.float32, jnp.float64]: raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.') if values.size not in [1, indices.size]: raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre2post[0]) (a vector), ' f'while we got {values.size} != 1 != {indices.size}') - values = values.flatten() # bind operator return csr_event_prod_p1.bind(events, indices, indptr, values, post_num=post_num) diff --git a/extensions/brainpylib/event_sum.py b/extensions/brainpylib/event_sum.py index e7bebb656..8b8467c51 100644 --- a/extensions/brainpylib/event_sum.py +++ b/extensions/brainpylib/event_sum.py @@ -49,7 +49,7 @@ def csr_event_sum(events: jnp.ndarray, raise ValueError(f'The dtype of pre2post must be integer, while we got {indices.dtype}') # output value - if not isinstance(values, jnp.ndarray): + if np.ndim(values) == 0: values = jnp.asarray([values]) dtype = values.dtype if dtype not in [jnp.float32, jnp.float64]: @@ -178,7 +178,7 @@ def coo_event_sum(events, pre_ids, post_ids, post_num, values): f'while we got {pre_ids.dtype}') # output value - if not isinstance(values, jnp.ndarray): + if np.ndim(values) == 0: values = jnp.asarray([values]) if values.dtype not in [jnp.float32, jnp.float64]: raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.') diff --git a/extensions/brainpylib/tests/test_atomic_prod_cpu.py b/extensions/brainpylib/tests/test_atomic_prod_cpu.py index 5d83b300e..328508f57 100644 --- a/extensions/brainpylib/tests/test_atomic_prod_cpu.py +++ b/extensions/brainpylib/tests/test_atomic_prod_cpu.py @@ -7,6 +7,7 @@ from brainpylib import coo_atomic_prod import brainpy as bp +import brainpy.math as bm bp.math.set_platform('cpu') @@ -37,7 +38,7 @@ def test_homo_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) post_ids = conn.require('post_ids') - a = coo_atomic_prod(value, post_ids.value, size) + a = coo_atomic_prod(value, bm.as_jax(post_ids), size) print(a) def test_heter_fixedpro(self): @@ -46,5 +47,5 @@ def test_heter_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) pre_ids, post_ids = conn.require('pre_ids', 'post_ids') - a = coo_atomic_prod(value, post_ids.value, size, pre_ids.value) + a = coo_atomic_prod(value, bm.as_jax(post_ids), size, bm.as_jax(pre_ids)) print(a) diff --git a/extensions/brainpylib/tests/test_atomic_prod_gpu.py b/extensions/brainpylib/tests/test_atomic_prod_gpu.py index 4e7296e70..766bda8ae 100644 --- a/extensions/brainpylib/tests/test_atomic_prod_gpu.py +++ b/extensions/brainpylib/tests/test_atomic_prod_gpu.py @@ -37,7 +37,7 @@ def test_homo_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) post_ids = conn.require('post_ids') - a = coo_atomic_prod(value, post_ids.value, size) + a = coo_atomic_prod(value, bp.math.as_jax(post_ids), size) print(a) def test_heter_fixedpro(self): @@ -46,5 +46,5 @@ def test_heter_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) pre_ids, post_ids = conn.require('pre_ids', 'post_ids') - a = coo_atomic_prod(value, post_ids.value, size, pre_ids.value) + a = coo_atomic_prod(value, bp.math.as_jax(post_ids), size, bp.math.as_jax(pre_ids)) print(a) diff --git a/extensions/brainpylib/tests/test_atomic_sum_cpu.py b/extensions/brainpylib/tests/test_atomic_sum_cpu.py index 9e1b7db11..3bd46229d 100644 --- a/extensions/brainpylib/tests/test_atomic_sum_cpu.py +++ b/extensions/brainpylib/tests/test_atomic_sum_cpu.py @@ -37,7 +37,7 @@ def test_homo_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) post_ids = conn.require('post_ids') - a = coo_atomic_sum(value, post_ids.value, size) + a = coo_atomic_sum(value, bp.math.as_jax(post_ids), size) print(a) def test_heter_fixedpro(self): @@ -46,5 +46,5 @@ def test_heter_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) pre_ids, post_ids = conn.require('pre_ids', 'post_ids') - a = coo_atomic_sum(value, post_ids.value, size, pre_ids.value) + a = coo_atomic_sum(value, bp.math.as_jax(post_ids), size, bp.math.as_jax(pre_ids)) print(a) diff --git a/extensions/brainpylib/tests/test_atomic_sum_gpu.py b/extensions/brainpylib/tests/test_atomic_sum_gpu.py index 9e1b7db11..3bd46229d 100644 --- a/extensions/brainpylib/tests/test_atomic_sum_gpu.py +++ b/extensions/brainpylib/tests/test_atomic_sum_gpu.py @@ -37,7 +37,7 @@ def test_homo_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) post_ids = conn.require('post_ids') - a = coo_atomic_sum(value, post_ids.value, size) + a = coo_atomic_sum(value, bp.math.as_jax(post_ids), size) print(a) def test_heter_fixedpro(self): @@ -46,5 +46,5 @@ def test_heter_fixedpro(self): conn = bp.conn.FixedProb(prob=1, seed=123) conn(pre_size=size, post_size=size) pre_ids, post_ids = conn.require('pre_ids', 'post_ids') - a = coo_atomic_sum(value, post_ids.value, size, pre_ids.value) + a = coo_atomic_sum(value, bp.math.as_jax(post_ids), size, bp.math.as_jax(pre_ids)) print(a) diff --git a/extensions/brainpylib/tests/test_coo_event_sum_gpu.py b/extensions/brainpylib/tests/test_coo_event_sum_gpu.py index 716ddd27f..e9ed47010 100644 --- a/extensions/brainpylib/tests/test_coo_event_sum_gpu.py +++ b/extensions/brainpylib/tests/test_coo_event_sum_gpu.py @@ -26,7 +26,7 @@ def test_homo_values(self): sps = bm.random.random(size).value < 0.5 # print(sps) value = 3.0233 - a = coo_event_sum(sps, pre_ids.value, post_ids.value, size, value) + a = coo_event_sum(sps, bp.math.as_jax(pre_ids), bp.math.as_jax(post_ids), size, value) print(a) def test_heter_value(self): @@ -40,7 +40,7 @@ def test_heter_value(self): sps = bm.random.random(size).value < 0.5 values = bm.random.rand(post_ids.size) # values = bm.ones(post_ids.size) - a = coo_event_sum(sps, pre_ids.value, post_ids.value , size, values.value) + a = coo_event_sum(sps, bp.math.as_jax(pre_ids), bp.math.as_jax(post_ids) , size, values.value) print(a) # diff --git a/extensions/brainpylib/tests/test_csr_event_prod_cpu.py b/extensions/brainpylib/tests/test_csr_event_prod_cpu.py index 4caeef2ca..7aa9e73ad 100644 --- a/extensions/brainpylib/tests/test_csr_event_prod_cpu.py +++ b/extensions/brainpylib/tests/test_csr_event_prod_cpu.py @@ -21,7 +21,7 @@ def test_homo_values(self): sps = bm.random.random(size).value < 0.5 # print(sps) value = 1.0233 - a = csr_event_prod(sps, (post_ids.value, indptr.value), size, value) + a = csr_event_prod(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, value) print(a) def test_heter_value(self): @@ -35,6 +35,6 @@ def test_heter_value(self): sps = bm.random.random(size).value < 0.5 values = bm.random.rand(post_ids.size) # values = bm.ones(post_ids.size) - a = csr_event_prod(sps, (post_ids.value, indptr.value), size, values.value) + a = csr_event_prod(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, values.value) print(a) diff --git a/extensions/brainpylib/tests/test_csr_event_sum_cpu.py b/extensions/brainpylib/tests/test_csr_event_sum_cpu.py index c6e120718..6f4ebca0a 100644 --- a/extensions/brainpylib/tests/test_csr_event_sum_cpu.py +++ b/extensions/brainpylib/tests/test_csr_event_sum_cpu.py @@ -27,7 +27,7 @@ def test_homo_values(self): sps = bm.random.random(size).value < 0.5 # print(sps) value = 3.0233 - a = csr_event_sum(sps, (post_ids.value, indptr.value), size, value) + a = csr_event_sum(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, value) print(a) def test_homo_values_batching(self): @@ -40,11 +40,11 @@ def test_homo_values_batching(self): sps = bm.random.random((10, size)).value < 0.5 value = 3.0233 f = vmap(csr_event_sum, in_axes=(0, None, None, None)) - a1 = f(sps, (post_ids.value, indptr.value), size, value) + a1 = f(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, value) print(a1) - f = vmap(lambda events: csr_event_sum(events, (post_ids.value, indptr.value), size, value)) + f = vmap(lambda events: csr_event_sum(events, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, value)) a2 = f(sps) print(a2) @@ -61,7 +61,7 @@ def test_heter_value(self): sps = bm.random.random(size).value < 0.5 values = bm.random.rand(post_ids.size) # values = bm.ones(post_ids.size) - a = csr_event_sum(sps, (post_ids.value, indptr.value), size, values.value) + a = csr_event_sum(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, values.value) print(a) def test_heter_values_batching(self): @@ -74,9 +74,9 @@ def test_heter_values_batching(self): sps = bm.random.random((10, size)).value < 0.5 values = bm.random.rand(post_ids.size).value f = vmap(csr_event_sum, in_axes=(0, None, None, None)) - a1 = f(sps, (post_ids.value, indptr.value), size, values) + a1 = f(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, values) - f = vmap(lambda events: csr_event_sum(events, (post_ids.value, indptr.value), size, values)) + f = vmap(lambda events: csr_event_sum(events, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, values)) a2 = f(sps) print(a1, a2) diff --git a/extensions/brainpylib/tests/test_csr_event_sum_gpu.py b/extensions/brainpylib/tests/test_csr_event_sum_gpu.py index 44a882476..226b7e75b 100644 --- a/extensions/brainpylib/tests/test_csr_event_sum_gpu.py +++ b/extensions/brainpylib/tests/test_csr_event_sum_gpu.py @@ -27,7 +27,7 @@ def test_homo_values(self): sps = bm.random.random(size).value < 0.5 # print(sps) value = 3.0233 - a = csr_event_sum(sps, (post_ids.value, indptr.value), size, value) + a = csr_event_sum(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, value) print(a) def test_homo_values_batching(self): @@ -40,11 +40,11 @@ def test_homo_values_batching(self): sps = bm.random.random((10, size)).value < 0.5 value = 3.0233 f = vmap(csr_event_sum, in_axes=(0, None, None, None)) - a1 = f(sps, (post_ids.value, indptr.value), size, value) + a1 = f(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, value) print(a1) - f = vmap(lambda events: csr_event_sum(events, (post_ids.value, indptr.value), size, value)) + f = vmap(lambda events: csr_event_sum(events, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, value)) a2 = f(sps) print(a2) @@ -61,7 +61,7 @@ def test_heter_value(self): sps = bm.random.random(size).value < 0.5 values = bm.random.rand(post_ids.size) # values = bm.ones(post_ids.size) - a = csr_event_sum(sps, (post_ids.value, indptr.value), size, values.value) + a = csr_event_sum(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, values.value) print(a) def test_heter_values_batching(self): @@ -74,9 +74,9 @@ def test_heter_values_batching(self): sps = bm.random.random((10, size)).value < 0.5 values = bm.random.rand(post_ids.size).value f = vmap(csr_event_sum, in_axes=(0, None, None, None)) - a1 = f(sps, (post_ids.value, indptr.value), size, values) + a1 = f(sps, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, values) - f = vmap(lambda events: csr_event_sum(events, (post_ids.value, indptr.value), size, values)) + f = vmap(lambda events: csr_event_sum(events, (bp.math.as_jax(post_ids), bp.math.as_jax(indptr)), size, values)) a2 = f(sps) print(a1, a2) diff --git a/extensions/setup.py b/extensions/setup.py index a5b770b75..4620dd1b0 100644 --- a/extensions/setup.py +++ b/extensions/setup.py @@ -19,7 +19,7 @@ # extension modules ext_modules = [ Pybind11Extension("brainpylib/cpu_ops", - sources=["lib/cpu_ops.cc"] + glob.glob("lib/*_cpu.cc"), + sources=["lib/cpu_ops.cc"] + glob.glob("lib/cpu_*.cc"), cxx_std=11, define_macros=[('VERSION_INFO', __version__)]), ] From fa40e8b75def3fb6fdbc1f1e1f1261791a8a99db Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 16 Oct 2022 14:28:08 +0800 Subject: [PATCH 4/4] add compatible apis in brainpylib --- extensions/brainpylib/atomic_prod.py | 5 ++++- extensions/brainpylib/atomic_sum.py | 5 ++++- extensions/brainpylib/event_prod.py | 4 +++- extensions/brainpylib/event_sum.py | 4 +++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/extensions/brainpylib/atomic_prod.py b/extensions/brainpylib/atomic_prod.py index 1b7f9bd11..5eaeb52b1 100644 --- a/extensions/brainpylib/atomic_prod.py +++ b/extensions/brainpylib/atomic_prod.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- __all__ = [ - 'coo_atomic_prod', + 'coo_atomic_prod', 'atomic_prod' ] from functools import partial @@ -55,6 +55,9 @@ def coo_atomic_prod(values, post_ids, post_num, pre_ids=None): return coo_atomic_prod_p1.bind(values, pre_ids, post_ids, post_num=post_num) +atomic_prod = coo_atomic_prod + + def _atomic_prod_abstract(values, pre_ids, post_ids, *, post_num): return ShapedArray(shape=(post_num, ), dtype=values.dtype) diff --git a/extensions/brainpylib/atomic_sum.py b/extensions/brainpylib/atomic_sum.py index a6ca92277..528d5b000 100644 --- a/extensions/brainpylib/atomic_sum.py +++ b/extensions/brainpylib/atomic_sum.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- __all__ = [ - 'coo_atomic_sum', + 'coo_atomic_sum', 'atomic_sum' ] from functools import partial @@ -55,6 +55,9 @@ def coo_atomic_sum(values, post_ids, post_num, pre_ids=None): return coo_atomic_sum_p1.bind(values, pre_ids, post_ids, post_num=post_num) +atomic_sum = coo_atomic_sum + + def _atomic_sum_abstract(values, pre_ids, post_ids, *, post_num): return ShapedArray(dtype=values.dtype, shape=(post_num,)) diff --git a/extensions/brainpylib/event_prod.py b/extensions/brainpylib/event_prod.py index 752002523..cc8237a5b 100644 --- a/extensions/brainpylib/event_prod.py +++ b/extensions/brainpylib/event_prod.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- __all__ = [ - 'csr_event_prod', + 'csr_event_prod', 'event_prod', ] from functools import partial @@ -51,6 +51,8 @@ def csr_event_prod(events, pre2post, post_num, values): # bind operator return csr_event_prod_p1.bind(events, indices, indptr, values, post_num=post_num) +event_prod = csr_event_prod + def _event_prod_abstract(events, indices, indptr, values, *, post_num): return ShapedArray(dtype=values.dtype, shape=(post_num,)) diff --git a/extensions/brainpylib/event_sum.py b/extensions/brainpylib/event_sum.py index 8b8467c51..48fa5008b 100644 --- a/extensions/brainpylib/event_sum.py +++ b/extensions/brainpylib/event_sum.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- __all__ = [ - 'csr_event_sum', + 'csr_event_sum', 'event_sum', 'coo_event_sum', ] @@ -61,6 +61,8 @@ def csr_event_sum(events: jnp.ndarray, return csr_event_sum_p1.bind(events, indices, indptr, values, post_num=post_num) +event_sum = csr_event_sum + def _event_sum_abstract(events, indices, indptr, values, *, post_num): return ShapedArray(dtype=values.dtype, shape=(post_num,))