Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/Linux_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
4 changes: 1 addition & 3 deletions .github/workflows/MacOS_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,14 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install jax==0.3.14
python -m pip install jaxlib==0.3.14
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
2 changes: 1 addition & 1 deletion .github/workflows/Windows_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
69 changes: 26 additions & 43 deletions brainpy/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def build_conn(self):

import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_mat(self, pre_size, post_size):
def build_mat(self, ):
return conn_matrix

def build_csr(self, pre_size, post_size):
def build_csr(self, ):
return post_ids, inptr

def build_coo(self, pre_size, post_size):
def build_coo(self, ):
return pre_ids, post_ids

"""
Expand Down Expand Up @@ -196,8 +196,6 @@ def check(self, structures: Union[Tuple, List, str]):
raise ConnectorError(f'Unknown synapse structure "{n}". '
f'Only {SUPPORTED_SYN_STRUCTURE} is supported.')



def _return_by_mat(self, structures, mat, all_data: dict):
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
Expand Down Expand Up @@ -332,70 +330,56 @@ def build_conn(self):
"""
pass

def require(self, *sizes_or_structures):
sizes_or_structures = list(sizes_or_structures)
pre_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
post_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
structures = sizes_or_structures
if isinstance(post_size, str):
structures.insert(0, post_size)
post_size = None
if isinstance(pre_size, str):
structures.insert(0, pre_size)
pre_size = None

version2_style = (pre_size is not None) and (post_size is not None)
if not version2_style:
try:
assert self.pre_num is not None and self.post_num is not None
except AssertionError:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use self.__call__(pre_size, post_size) '
f'before requiring connection data.')
if pre_size is None:
pre_size = self.pre_size
if post_size is None:
post_size = self.post_size
def require(self, *structures):
try:
assert self.pre_num is not None and self.post_num is not None
except AssertionError:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use self.__call__() '
f'before requiring connection data.')

self.check(structures)
if self.is_version2_style:
if len(structures) == 1:
if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'):
return self.build_csr(pre_size, post_size)
r = self.build_csr()
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)
elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'):
return self.build_mat(pre_size, post_size)
return bm.asarray(self.build_mat(), dtype=MAT_DTYPE)
elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)[0]
return bm.asarray(self.build_coo()[0], dtype=IDX_DTYPE)
elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)[1]
return bm.asarray(self.build_coo()[1], dtype=IDX_DTYPE)
elif len(structures) == 2:
if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)
r = self.build_coo()
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)

conn_data = dict(csr=None, ij=None, mat=None)
if not hasattr(self.build_coo, 'not_customized'):
conn_data['ij'] = self.build_coo(pre_size, post_size)
conn_data['ij'] = self.build_coo()
elif not hasattr(self.build_csr, 'not_customized'):
conn_data['csr'] = self.build_csr(pre_size, post_size)
conn_data['csr'] = self.build_csr()
elif not hasattr(self.build_mat, 'not_customized'):
conn_data['mat'] = self.build_mat(pre_size, post_size)
conn_data['mat'] = self.build_mat()

else:
conn_data = self.build_conn()
return self.make_returns(structures, conn_data)

def requires(self, *sizes_or_structures):
return self.require(*sizes_or_structures)
def requires(self, *structures):
return self.require(*structures)

@tools.not_customized
def build_mat(self, pre_size=None, post_size=None):
def build_mat(self):
pass

@tools.not_customized
def build_csr(self, pre_size=None, post_size=None):
def build_csr(self):
pass

@tools.not_customized
def build_coo(self, pre_size=None, post_size=None):
def build_coo(self):
pass


Expand Down Expand Up @@ -425,7 +409,6 @@ def __call__(self, pre_size, post_size=None):
else:
post_size = tuple(post_size)
self.pre_size, self.post_size = pre_size, post_size

self.pre_num = tools.size2num(self.pre_size)
self.post_num = tools.size2num(self.post_size)
return self
Expand Down
33 changes: 12 additions & 21 deletions brainpy/connect/custom_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from brainpy import tools
from brainpy.errors import ConnectorError
from .base import *
from .utils import *

__all__ = [
'MatConn',
Expand All @@ -34,11 +33,9 @@ def __call__(self, pre_size, post_size):
assert self.post_num == tools.size2num(post_size)
return self

def build_mat(self, pre_size=None, post_size=None):
pre_num = get_pre_num(self, pre_size)
post_num = get_post_num(self, post_size)
assert self.conn_mat.shape[0] == pre_num
assert self.conn_mat.shape[1] == post_num
def build_mat(self):
assert self.conn_mat.shape[0] == self.pre_num
assert self.conn_mat.shape[1] == self.post_num
return self.conn_mat


Expand Down Expand Up @@ -68,14 +65,12 @@ def __call__(self, pre_size, post_size):
f'the maximum id ({self.max_post}) of self.post_ids.')
return self

def build_coo(self, pre_size=None, post_size=None):
pre_num = get_pre_num(self, pre_size)
post_num = get_post_num(self, post_size)
if pre_num <= self.max_pre:
raise ConnectorError(f'pre_num ({pre_num}) should be greater than '
def build_coo(self):
if self.pre_num <= self.max_pre:
raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than '
f'the maximum id ({self.max_pre}) of self.pre_ids.')
if post_num <= self.max_post:
raise ConnectorError(f'post_num ({post_num}) should be greater than '
if self.post_num <= self.max_post:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.pre_ids, self.post_ids

Expand All @@ -91,16 +86,12 @@ def __init__(self, indices, inptr):
self.pre_num = self.inptr.size - 1
self.max_post = bm.max(self.indices)

def build_csr(self, pre_size=None, post_size=None):
pre_size = get_pre_size(self, pre_size)
post_size = get_post_size(self, post_size)
pre_num = np.prod(pre_size)
post_num = np.prod(post_size)
if pre_num != self.pre_num:
def build_csr(self):
if self.pre_num != self.pre_num:
raise ConnectorError(f'(pre_size, post_size) is inconsistent with '
f'the shape of the sparse matrix.')
if post_num <= self.max_post:
raise ConnectorError(f'post_num ({post_num}) should be greater than '
if self.post_num <= self.max_post:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.indices, self.inptr

Expand Down
Loading