Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 144 additions & 41 deletions brainpy/connect/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-

import abc
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Any

import numpy as np
import numpy as onp

from brainpy import tools, math as bm
from brainpy.errors import ConnectorError
Expand Down Expand Up @@ -42,8 +42,8 @@
PRE2SYN, POST2SYN,
PRE_SLICE, POST_SLICE]

MAT_DTYPE = np.bool_
IDX_DTYPE = np.uint32
MAT_DTYPE = onp.bool_
IDX_DTYPE = onp.uint32


def set_default_dtype(mat_dtype=None, idx_dtype=None):
Expand Down Expand Up @@ -92,14 +92,49 @@ class Connector(abc.ABC):


class TwoEndConnector(Connector):
"""Synaptic connector to build synapse connections between two neuron groups."""
"""Synaptic connector to build connections between two neuron groups.

If users want to customize their `Connector`, there are two ways:

1. Implementing ``build_conn(self)`` function, which returns one of
the connection data ``csr`` (CSR sparse data, a tuple of <post_ids, inptr>),
``ij`` (COO sparse data, a tuple of <pre_ids, post_ids>), and ``mat``
(a binary connection matrix). For instance,

.. code-block:: python

import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_conn(self):
return dict(csr=, mat=, ij=)

2. Implementing functions ``build_mat()``, ``build_csr()``, and
``build_coo()``. Users can provide all three functions, or one of them.

.. code-block:: python

import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_mat(self, pre_size, post_size):
return conn_matrix

def build_csr(self, pre_size, post_size):
return post_ids, inptr

def build_coo(self, pre_size, post_size):
return pre_ids, post_ids

"""

def __init__(self, ):
self.pre_size = None
self.post_size = None
self.pre_num = None
self.post_num = None

def __repr__(self):
return self.__class__.__name__

def __call__(self, pre_size, post_size):
"""Create the concrete connections between two end objects.

Expand Down Expand Up @@ -140,15 +175,16 @@ def _reset_conn(self, pre_size, post_size):
"""
self.__call__(pre_size, post_size)

def check(self, structures: Union[Tuple, List, str]):
# check "pre_num" and "post_num"
try:
assert self.pre_num is not None and self.post_num is not None
except AssertionError:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use self.__call__(pre_size, post_size) '
f'before requiring properties.')
@property
def is_version2_style(self):
if ((hasattr(self.build_coo, 'not_customized') and self.build_coo.not_customized) and
(hasattr(self.build_csr, 'not_customized') and self.build_csr.not_customized) and
(hasattr(self.build_mat, 'not_customized') and self.build_mat.not_customized)):
return False
else:
return True

def check(self, structures: Union[Tuple, List, str]):
# check synaptic structures
if isinstance(structures, str):
structures = [structures]
Expand All @@ -159,30 +195,30 @@ def check(self, structures: Union[Tuple, List, str]):
raise ConnectorError(f'Unknown synapse structure "{n}". '
f'Only {SUPPORTED_SYN_STRUCTURE} is supported.')



def _return_by_mat(self, structures, mat, all_data: dict):
assert isinstance(mat, np.ndarray) and np.ndim(mat) == 2
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE)

require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0
if require_other_structs:
pre_ids, post_ids = np.where(mat > 0)
pre_ids = np.ascontiguousarray(pre_ids, dtype=IDX_DTYPE)
post_ids = np.ascontiguousarray(post_ids, dtype=IDX_DTYPE)
pre_ids, post_ids = onp.where(mat > 0)
pre_ids = onp.ascontiguousarray(pre_ids, dtype=IDX_DTYPE)
post_ids = onp.ascontiguousarray(post_ids, dtype=IDX_DTYPE)
self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data)

def _return_by_csr(self, structures, csr: tuple, all_data: dict):
indices, indptr = csr
assert isinstance(indices, np.ndarray)
assert isinstance(indptr, np.ndarray)
assert self.pre_num == indptr.size - 1

if (CONN_MAT in structures) and (CONN_MAT not in all_data):
conn_mat = csr2mat((indices, indptr), self.pre_num, self.post_num)
all_data[CONN_MAT] = bm.asarray(conn_mat, dtype=MAT_DTYPE)

if (PRE_IDS in structures) and (PRE_IDS not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
pre_ids = onp.repeat(onp.arange(self.pre_num), onp.diff(indptr))
all_data[PRE_IDS] = bm.asarray(pre_ids, dtype=IDX_DTYPE)

if (POST_IDS in structures) and (POST_IDS not in all_data):
Expand All @@ -198,20 +234,18 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict):
bm.asarray(indptrc, dtype=IDX_DTYPE))

if (PRE2SYN in structures) and (PRE2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
syn_seq = onp.arange(indices.size, dtype=IDX_DTYPE)
all_data[PRE2SYN] = (bm.asarray(syn_seq, dtype=IDX_DTYPE),
bm.asarray(indptr, dtype=IDX_DTYPE))

if (POST2SYN in structures) and (POST2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
syn_seq = onp.arange(indices.size, dtype=IDX_DTYPE)
_, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq)
all_data[POST2SYN] = (bm.asarray(syn_seqc, dtype=IDX_DTYPE),
bm.asarray(indptrc, dtype=IDX_DTYPE))

def _return_by_ij(self, structures, ij: tuple, all_data: dict):
pre_ids, post_ids = ij
assert isinstance(pre_ids, np.ndarray)
assert isinstance(post_ids, np.ndarray)

if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.asarray(ij2mat(ij, self.pre_num, self.post_num), dtype=MAT_DTYPE)
Expand All @@ -232,9 +266,9 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
"""Make the desired synaptic structures and return them.
"""
if isinstance(conn_data, dict):
csr = conn_data['csr']
mat = conn_data['mat']
ij = conn_data['ij']
csr = conn_data.get('csr', None)
mat = conn_data.get('mat', None)
ij = conn_data.get('ij', None)
elif isinstance(conn_data, tuple):
if conn_data[0] == 'csr':
csr = conn_data[1]
Expand All @@ -244,6 +278,8 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
ij = conn_data[1]
else:
raise ConnectorError(f'Must provide one of "csr", "mat" or "ij". Got "{conn_data[0]}" instead.')
else:
raise ConnectorError

# checking
all_data = dict()
Expand All @@ -254,22 +290,20 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):

# "csr" structure
if csr is not None:
assert isinstance(csr[0], np.ndarray)
assert isinstance(csr[1], np.ndarray)
if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.asarray(csr[0], dtype=IDX_DTYPE),
bm.asarray(csr[1], dtype=IDX_DTYPE))
self._return_by_csr(structures, csr=csr, all_data=all_data)

# "mat" structure
if mat is not None:
assert isinstance(mat, np.ndarray) and np.ndim(mat) == 2
assert isinstance(mat, onp.ndarray) and onp.ndim(mat) == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE)
self._return_by_mat(structures, mat=mat, all_data=all_data)

# "ij" structure
if ij is not None:
assert isinstance(ij[0], np.ndarray)
assert isinstance(ij[1], np.ndarray)
if (PRE_IDS in structures) and (PRE_IDS not in structures):
all_data[PRE_IDS] = bm.asarray(ij[0], dtype=IDX_DTYPE)
if (POST_IDS in structures) and (POST_IDS not in structures):
Expand All @@ -294,13 +328,73 @@ def build_conn(self):
"""
raise NotImplementedError

def require(self, *structures):
def require(self, *sizes_or_structures):
sizes_or_structures = list(sizes_or_structures)
pre_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
post_size = sizes_or_structures.pop(0) if len(sizes_or_structures) >= 1 else None
structures = sizes_or_structures
if isinstance(post_size, str):
structures.insert(0, post_size)
post_size = None
if isinstance(pre_size, str):
structures.insert(0, pre_size)
pre_size = None

version2_style = (pre_size is not None) and (post_size is not None)
if not version2_style:
try:
assert self.pre_num is not None and self.post_num is not None
except AssertionError:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use self.__call__(pre_size, post_size) '
f'before requiring connection data.')
if pre_size is None:
pre_size = self.pre_size
if post_size is None:
post_size = self.post_size

self.check(structures)
conn_data = self.build_conn()
if self.is_version2_style:
if (pre_size is None) or (post_size is None):
raise ConnectorError('Please provide both "pre_size" and "post_size".')
if len(structures) == 1:
if PRE2POST in structures:
return self.build_csr(pre_size, post_size)
elif CONN_MAT in structures:
return self.build_mat(pre_size, post_size)
elif PRE_IDS in structures:
return self.build_coo(pre_size, post_size)[0]
elif POST_IDS in structures:
return self.build_coo(pre_size, post_size)[1]
elif len(structures) == 2:
if PRE_IDS in structures and POST_IDS in structures:
return self.build_coo(pre_size, post_size)

conn_data = dict(csr=None, ij=None, mat=None)
if not hasattr(self.build_csr, 'not_customized'):
conn_data['csr'] = self.build_csr(pre_size, post_size)
elif not hasattr(self.build_coo, 'not_customized'):
conn_data['ij'] = self.build_coo(pre_size, post_size)
elif not hasattr(self.build_mat, 'not_customized'):
conn_data['mat'] = self.build_mat(pre_size, post_size)
else:
conn_data = self.build_conn()
return self.make_returns(structures, conn_data)

def requires(self, *structures):
return self.require(*structures)
def requires(self, *sizes_or_structures):
return self.require(*sizes_or_structures)

@tools.not_customized
def build_mat(self, pre_size, post_size):
pass

@tools.not_customized
def build_csr(self, pre_size, post_size):
pass

@tools.not_customized
def build_coo(self, pre_size, post_size):
pass


class OneEndConnector(TwoEndConnector):
Expand Down Expand Up @@ -336,16 +430,18 @@ def __call__(self, pre_size, post_size=None):

def _reset_conn(self, pre_size, post_size=None):
self.__init__()

self.__call__(pre_size, post_size)


def csr2csc(csr, post_num, data=None):
"""Convert csr to csc."""
indices, indptr = csr
np = onp if isinstance(indices, onp.ndarray) else bm
kind = 'quicksort' if isinstance(indices, onp.ndarray) else 'stable'

pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))

sort_ids = np.argsort(indices, kind='mergesort') # to maintain the original order of the elements with the same value
sort_ids = np.argsort(indices, kind=kind) # to maintain the original order of the elements with the same value
pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)

unique_post_ids, count = np.unique(indices, return_counts=True)
Expand All @@ -365,8 +461,8 @@ def csr2csc(csr, post_num, data=None):

def mat2csr(dense):
"""convert a dense matrix to (indices, indptr)."""
if isinstance(dense, bm.ndarray):
dense = np.asarray(dense)
np = onp if isinstance(dense, onp.ndarray) else bm

pre_ids, post_ids = np.where(dense > 0)
pre_num = dense.shape[0]

Expand All @@ -382,6 +478,8 @@ def mat2csr(dense):
def csr2mat(csr, num_pre, num_post):
"""convert (indices, indptr) to a dense matrix."""
indices, indptr = csr
np = onp if isinstance(indices, onp.ndarray) else bm

d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
d[pre_ids, indices] = True
Expand All @@ -391,15 +489,20 @@ def csr2mat(csr, num_pre, num_post):
def ij2mat(ij, num_pre, num_post):
"""convert (indices, indptr) to a dense matrix."""
pre_ids, post_ids = ij
np = onp if isinstance(pre_ids, onp.ndarray) else bm

d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
d[pre_ids, post_ids] = True
return d


def ij2csr(pre_ids, post_ids, num_pre):
"""convert pre_ids, post_ids to (indices, indptr)."""
np = onp if isinstance(pre_ids, onp.ndarray) else bm
kind = 'quicksort' if isinstance(pre_ids, onp.ndarray) else 'stable'

# sorting
sort_ids = np.argsort(pre_ids, kind='mergesort')
sort_ids = np.argsort(pre_ids, kind=kind)
post_ids = post_ids[sort_ids]

indices = post_ids
Expand Down
Loading