Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions brainpy/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
from typing import Union, List, Tuple, Any

import jax.numpy as jnp
import numpy as onp

from brainpy import tools, math as bm
Expand Down Expand Up @@ -204,21 +205,23 @@ def _return_by_mat(self, structures, mat, all_data: dict):

require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0
if require_other_structs:
pre_ids, post_ids = onp.where(mat > 0)
pre_ids = onp.ascontiguousarray(pre_ids, dtype=IDX_DTYPE)
post_ids = onp.ascontiguousarray(post_ids, dtype=IDX_DTYPE)
np = onp if isinstance(mat, onp.ndarray) else bm
pre_ids, post_ids = np.where(mat > 0)
pre_ids = np.asarray(pre_ids, dtype=IDX_DTYPE)
post_ids = np.asarray(post_ids, dtype=IDX_DTYPE)
self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data)

def _return_by_csr(self, structures, csr: tuple, all_data: dict):
indices, indptr = csr
np = onp if isinstance(indices, onp.ndarray) else bm
assert self.pre_num == indptr.size - 1

if (CONN_MAT in structures) and (CONN_MAT not in all_data):
conn_mat = csr2mat((indices, indptr), self.pre_num, self.post_num)
all_data[CONN_MAT] = bm.asarray(conn_mat, dtype=MAT_DTYPE)

if (PRE_IDS in structures) and (PRE_IDS not in all_data):
pre_ids = onp.repeat(onp.arange(self.pre_num), onp.diff(indptr))
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[PRE_IDS] = bm.asarray(pre_ids, dtype=IDX_DTYPE)

if (POST_IDS in structures) and (POST_IDS not in all_data):
Expand All @@ -234,12 +237,12 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict):
bm.asarray(indptrc, dtype=IDX_DTYPE))

if (PRE2SYN in structures) and (PRE2SYN not in all_data):
syn_seq = onp.arange(indices.size, dtype=IDX_DTYPE)
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
all_data[PRE2SYN] = (bm.asarray(syn_seq, dtype=IDX_DTYPE),
bm.asarray(indptr, dtype=IDX_DTYPE))

if (POST2SYN in structures) and (POST2SYN not in all_data):
syn_seq = onp.arange(indices.size, dtype=IDX_DTYPE)
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
_, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq)
all_data[POST2SYN] = (bm.asarray(syn_seqc, dtype=IDX_DTYPE),
bm.asarray(indptrc, dtype=IDX_DTYPE))
Expand Down Expand Up @@ -297,7 +300,7 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):

# "mat" structure
if mat is not None:
assert isinstance(mat, onp.ndarray) and onp.ndim(mat) == 2
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE)
self._return_by_mat(structures, mat=mat, all_data=all_data)
Expand All @@ -316,6 +319,7 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
else:
return tuple([all_data[n] for n in structures])

@tools.not_customized
def build_conn(self):
"""build connections with certain data type.

Expand All @@ -326,7 +330,7 @@ def build_conn(self):
Or a dict with three elements: csr, mat and ij.
example: return dict(csr=(ind, indptr), mat=None, ij=None)
"""
raise NotImplementedError
pass

def require(self, *sizes_or_structures):
sizes_or_structures = list(sizes_or_structures)
Expand Down Expand Up @@ -355,26 +359,24 @@ def require(self, *sizes_or_structures):

self.check(structures)
if self.is_version2_style:
if (pre_size is None) or (post_size is None):
raise ConnectorError('Please provide both "pre_size" and "post_size".')
if len(structures) == 1:
if PRE2POST in structures:
if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'):
return self.build_csr(pre_size, post_size)
elif CONN_MAT in structures:
elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'):
return self.build_mat(pre_size, post_size)
elif PRE_IDS in structures:
elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)[0]
elif POST_IDS in structures:
elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)[1]
elif len(structures) == 2:
if PRE_IDS in structures and POST_IDS in structures:
if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'):
return self.build_coo(pre_size, post_size)

conn_data = dict(csr=None, ij=None, mat=None)
if not hasattr(self.build_csr, 'not_customized'):
conn_data['csr'] = self.build_csr(pre_size, post_size)
elif not hasattr(self.build_coo, 'not_customized'):
if not hasattr(self.build_coo, 'not_customized'):
conn_data['ij'] = self.build_coo(pre_size, post_size)
elif not hasattr(self.build_csr, 'not_customized'):
conn_data['csr'] = self.build_csr(pre_size, post_size)
elif not hasattr(self.build_mat, 'not_customized'):
conn_data['mat'] = self.build_mat(pre_size, post_size)
else:
Expand All @@ -385,15 +387,15 @@ def requires(self, *sizes_or_structures):
return self.require(*sizes_or_structures)

@tools.not_customized
def build_mat(self, pre_size, post_size):
def build_mat(self, pre_size=None, post_size=None):
pass

@tools.not_customized
def build_csr(self, pre_size, post_size):
def build_csr(self, pre_size=None, post_size=None):
pass

@tools.not_customized
def build_coo(self, pre_size, post_size):
def build_coo(self, pre_size=None, post_size=None):
pass


Expand Down Expand Up @@ -442,6 +444,8 @@ def csr2csc(csr, post_num, data=None):
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))

sort_ids = np.argsort(indices, kind=kind) # to maintain the original order of the elements with the same value
if isinstance(sort_ids, bm.JaxArray):
sort_ids = sort_ids.value
pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)

unique_post_ids, count = np.unique(indices, return_counts=True)
Expand Down Expand Up @@ -503,7 +507,7 @@ def ij2csr(pre_ids, post_ids, num_pre):

# sorting
sort_ids = np.argsort(pre_ids, kind=kind)
post_ids = post_ids[sort_ids]
post_ids = post_ids[sort_ids.value if isinstance(sort_ids, bm.JaxArray) else sort_ids]

indices = post_ids
unique_pre_ids, pre_count = np.unique(pre_ids, return_counts=True)
Expand Down
104 changes: 63 additions & 41 deletions brainpy/connect/custom_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import jax.numpy as jnp
import numpy as np

from brainpy import math as bm
from brainpy import tools
from brainpy.errors import ConnectorError
from brainpy.math.jaxarray import JaxArray
from .base import *
from .utils import *

__all__ = [
'MatConn',
'IJConn',
'CSRConn',
'SparseMatConn'
]

Expand All @@ -21,19 +23,23 @@ class MatConn(TwoEndConnector):
def __init__(self, conn_mat):
super(MatConn, self).__init__()

assert isinstance(conn_mat, (np.ndarray, JaxArray, jnp.ndarray)) and conn_mat.ndim == 2
assert isinstance(conn_mat, (np.ndarray, bm.JaxArray, jnp.ndarray)) and conn_mat.ndim == 2
self.pre_num, self.post_num = conn_mat.shape
self.pre_size, self.post_size = (self.pre_num,), (self.post_num,)
self.conn_mat = np.asarray(conn_mat).astype(MAT_DTYPE)

self.conn_mat = bm.asarray(conn_mat).astype(MAT_DTYPE)

def __call__(self, pre_size, post_size):
assert self.pre_num == tools.size2num(pre_size)
assert self.post_num == tools.size2num(post_size)
return self

def build_conn(self):
return 'mat', self.conn_mat
def build_mat(self, pre_size=None, post_size=None):
pre_num = get_pre_num(self, pre_size)
post_num = get_post_num(self, post_size)
assert self.conn_mat.shape[0] == pre_num
assert self.conn_mat.shape[1] == post_num
return self.conn_mat


class IJConn(TwoEndConnector):
Expand All @@ -42,58 +48,74 @@ class IJConn(TwoEndConnector):
def __init__(self, i, j):
super(IJConn, self).__init__()

assert isinstance(i, (np.ndarray, JaxArray, jnp.ndarray)) and i.ndim == 1
assert isinstance(j, (np.ndarray, JaxArray, jnp.ndarray)) and j.ndim == 1
assert isinstance(i, (np.ndarray, bm.JaxArray, jnp.ndarray)) and i.ndim == 1
assert isinstance(j, (np.ndarray, bm.JaxArray, jnp.ndarray)) and j.ndim == 1
assert i.size == j.size

# initialize the class via "pre_ids" and "post_ids"
self.pre_ids = np.asarray(i).astype(IDX_DTYPE)
self.post_ids = np.asarray(j).astype(IDX_DTYPE)
self.pre_ids = bm.asarray(i).astype(IDX_DTYPE)
self.post_ids = bm.asarray(j).astype(IDX_DTYPE)
self.max_pre = bm.max(self.pre_ids)
self.max_post = bm.max(self.post_ids)

def __call__(self, pre_size, post_size):
super(IJConn, self).__call__(pre_size, post_size)

max_pre = np.max(self.pre_ids)
max_post = np.max(self.post_ids)
if max_pre >= self.pre_num:
if self.max_pre >= self.pre_num:
raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than '
f'the maximum id ({max_pre}) of self.pre_ids.')
if max_post >= self.post_num:
f'the maximum id ({self.max_pre}) of self.pre_ids.')
if self.max_post >= self.post_num:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({max_post}) of self.post_ids.')
f'the maximum id ({self.max_post}) of self.post_ids.')
return self

def build_conn(self):
return 'ij', (self.pre_ids, self.post_ids)


class SparseMatConn(TwoEndConnector):
def build_coo(self, pre_size=None, post_size=None):
pre_num = get_pre_num(self, pre_size)
post_num = get_post_num(self, post_size)
if pre_num <= self.max_pre:
raise ConnectorError(f'pre_num ({pre_num}) should be greater than '
f'the maximum id ({self.max_pre}) of self.pre_ids.')
if post_num <= self.max_post:
raise ConnectorError(f'post_num ({post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.pre_ids, self.post_ids


class CSRConn(TwoEndConnector):
"""Connector built from the CSR sparse connection matrix."""

def __init__(self, indices, inptr):
super(CSRConn, self).__init__()

self.indices = bm.asarray(indices).astype(IDX_DTYPE)
self.inptr = bm.asarray(inptr).astype(IDX_DTYPE)
self.pre_num = self.inptr.size - 1
self.max_post = bm.max(self.indices)

def build_csr(self, pre_size=None, post_size=None):
pre_size = get_pre_size(self, pre_size)
post_size = get_post_size(self, post_size)
pre_num = np.prod(pre_size)
post_num = np.prod(post_size)
if pre_num != self.pre_num:
raise ConnectorError(f'(pre_size, post_size) is inconsistent with '
f'the shape of the sparse matrix.')
if post_num <= self.max_post:
raise ConnectorError(f'post_num ({post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.indices, self.inptr


class SparseMatConn(CSRConn):
"""Connector built from the sparse connection matrix"""

def __init__(self, csr_mat):
super(SparseMatConn, self).__init__()

try:
from scipy.sparse import csr_matrix
except (ModuleNotFoundError, ImportError):
raise ConnectorError(f'Using SparseMatConn requires the scipy package. '
f'Please run "pip install scipy" to install scipy.')

assert isinstance(csr_mat, csr_matrix)
csr_mat.data = np.asarray(csr_mat.data).astype(MAT_DTYPE)
self.csr_mat = csr_mat
self.pre_num, self.post_num = csr_mat.shape

def __call__(self, pre_size, post_size):
try:
assert self.pre_num == tools.size2num(pre_size)
assert self.post_num == tools.size2num(post_size)
except AssertionError:
raise ConnectorError(f'(pre_size, post_size) is inconsistent with the shape of the sparse matrix.')

super(SparseMatConn, self).__call__(pre_size, post_size)
return self

def build_conn(self):
ind, indptr = self.csr_mat.indices, self.csr_mat.indptr
return 'csr', (ind, indptr)
super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE),
inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE))
Loading