Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 188 additions & 123 deletions brainpy/connect/base.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions brainpy/connect/custom_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class CSRConn(TwoEndConnector):
def __init__(self, indices, inptr):
super(CSRConn, self).__init__()

self.indices = bm.asarray(indices).astype(IDX_DTYPE)
self.inptr = bm.asarray(inptr).astype(IDX_DTYPE)
self.indices = bm.asarray(indices, dtype=IDX_DTYPE)
self.inptr = bm.asarray(inptr, dtype=IDX_DTYPE)
self.pre_num = self.inptr.size - 1
self.max_post = bm.max(self.indices)

Expand Down Expand Up @@ -110,3 +110,5 @@ def __init__(self, csr_mat):
self.csr_mat = csr_mat
super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE),
inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE))
self.pre_num = csr_mat.shape[0]
self.post_num = csr_mat.shape[1]
2 changes: 1 addition & 1 deletion brainpy/connect/regular_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_csr(self):
f'same size, but {self.pre_num} != {self.post_num}.')
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return np.asarray(ind, dtype=IDX_DTYPE), np.arange(indptr, dtype=IDX_DTYPE),
return (np.asarray(ind, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE))

def build_mat(self, pre_size=None, post_size=None):
if self.pre_num != self.post_num:
Expand Down
14 changes: 7 additions & 7 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,15 +988,15 @@ def __init__(
ltp.register_master(master=self)
self.ltp: SynLTP = ltp

def init_weights(
def _init_weights(
self,
weight: Union[float, Array, Initializer, Callable],
comp_method: str,
sparse_data: str = 'csr'
) -> Union[float, Array]:
if comp_method not in ['sparse', 'dense']:
raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
if sparse_data not in ['csr', 'ij']:
if sparse_data not in ['csr', 'ij', 'coo']:
raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
Expand All @@ -1014,11 +1014,11 @@ def init_weights(
if comp_method == 'sparse':
if sparse_data == 'csr':
conn_mask = self.conn.require('pre2post')
elif sparse_data == 'ij':
elif sparse_data in ['ij', 'coo']:
conn_mask = self.conn.require('post_ids', 'pre_ids')
else:
ValueError(f'Unknown sparse data type: {sparse_data}')
weight = parameter(weight, conn_mask[1].shape, allow_none=False)
weight = parameter(weight, conn_mask[0].shape, allow_none=False)
elif comp_method == 'dense':
weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
conn_mask = self.conn.require('conn_mat')
Expand All @@ -1030,7 +1030,7 @@ def init_weights(
weight = bm.TrainVar(weight)
return weight, conn_mask

def syn2post_with_all2all(self, syn_value, syn_weight):
def _syn2post_with_all2all(self, syn_value, syn_weight):
if bm.ndim(syn_weight) == 0:
if isinstance(self.mode, BatchingMode):
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
Expand All @@ -1043,10 +1043,10 @@ def syn2post_with_all2all(self, syn_value, syn_weight):
post_vs = syn_value @ syn_weight
return post_vs

def syn2post_with_one2one(self, syn_value, syn_weight):
def _syn2post_with_one2one(self, syn_value, syn_weight):
return syn_value * syn_weight

def syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
if bm.ndim(syn_weight) == 0:
post_vs = (syn_weight * syn_value) @ conn_mat
else:
Expand Down
24 changes: 13 additions & 11 deletions brainpy/dyn/layers/activate.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
from brainpy.dyn.base import DynamicalSystem
from typing import Optional
from brainpy.modes import Mode
from typing import Callable
from typing import Optional

from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, training


class Activation(DynamicalSystem):
r"""Applies a activation to the inputs
r"""Applies an activation function to the inputs

Parameters:
----------
activate_fun: Callable
activate_fun: Callable, function
The function of Activation
name: str, Optional
The name of the object
mode: Mode
Enable training this node or not. (default True).
"""

def __init__(self,
activate_fun: Callable,
name: Optional[str] = None,
mode: Optional[Mode] = None,
**kwargs,
):
def __init__(
self,
activate_fun: Callable,
name: Optional[str] = None,
mode: Mode = training,
**kwargs,
):
super().__init__(name, mode)
self.activate_fun = activate_fun
self.kwargs = kwargs
Expand Down
18 changes: 10 additions & 8 deletions brainpy/dyn/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from brainpy.dyn.base import DynamicalSystem
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
from brainpy.tools.checking import check_initializer
from brainpy.types import Array

Expand Down Expand Up @@ -201,17 +201,19 @@ class Flatten(DynamicalSystem):
mode: Mode
Enable training this node or not. (default True)
"""
def __init__(self,
name: Optional[str] = None,
mode: Optional[Mode] = batching,
):

def __init__(
self,
name: Optional[str] = None,
mode: Optional[Mode] = batching,
):
super().__init__(name, mode)

def update(self, shr, x):
if isinstance(self.mode, BatchingMode):
return x.reshape((x.shape[0], -1))
else:
return x.flatten()

def reset_state(self, batch_size=None):
pass
pass
32 changes: 16 additions & 16 deletions brainpy/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.comp_method = comp_method

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method=comp_method, sparse_data='csr')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
Expand All @@ -143,10 +143,10 @@ def update(self, tdi, pre_spike=None):
# synaptic values onto the post
if isinstance(self.conn, All2All):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
Expand All @@ -160,7 +160,7 @@ def update(self, tdi, pre_spike=None):
# post_vs *= f2(stp_value)
else:
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
if self.post_ref_key:
post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key))

Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')

# variables
self.g = variable_(bm.zeros, self.post.num, mode)
Expand Down Expand Up @@ -328,11 +328,11 @@ def update(self, tdi, pre_spike=None):
if isinstance(self.conn, All2All):
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
Expand All @@ -343,7 +343,7 @@ def update(self, tdi, pre_spike=None):
else:
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
# updates
self.g.value = self.integral(self.g.value, t, dt) + post_vs

Expand Down Expand Up @@ -487,7 +487,7 @@ def __init__(
f'But we got {self.tau_decay}')

# connections
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.h = variable_(bm.zeros, self.pre.num, mode)
Expand Down Expand Up @@ -531,16 +531,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
Expand Down Expand Up @@ -829,7 +829,7 @@ def __init__(
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable_(bm.zeros, self.pre.num, mode)
Expand Down Expand Up @@ -872,16 +872,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
Expand Down
16 changes: 8 additions & 8 deletions brainpy/dyn/synapses/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(
raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}')

# connection
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
Expand Down Expand Up @@ -226,16 +226,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
Expand Down Expand Up @@ -526,7 +526,7 @@ def __init__(
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
Expand Down Expand Up @@ -575,16 +575,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
16 changes: 14 additions & 2 deletions brainpy/math/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import os
import re

from jax import dtypes, config, numpy as jnp
from jax import dtypes, config, numpy as jnp, devices
from jax.lib import xla_bridge

__all__ = [
'enable_x64',
'disable_x64',
'set_platform',
'get_platform',
'set_host_device_count',

# device memory
Expand Down Expand Up @@ -92,7 +93,7 @@ def disable_x64():
config.update("jax_enable_x64", False)


def set_platform(platform):
def set_platform(platform: str):
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
Expand All @@ -101,6 +102,17 @@ def set_platform(platform):
config.update("jax_platform_name", platform)


def get_platform() -> str:
"""Get the computing platform.

Returns
-------
platform: str
Either 'cpu', 'gpu' or 'tpu'.
"""
return devices()[0].platform


def set_host_device_count(n):
"""
By default, XLA considers all CPU cores as one device. This utility tells XLA
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx-mathjax-offline',
# 'sphinx-mathjax-offline',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_autodoc_typehints',
Expand Down
Loading