Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def filter_loss(self, tolerance: float = 1e-5):
else:
num_fps = self._fixed_points.shape[0]
ids = self._losses < tolerance
keep_ids = bm.as_jax(jnp.where(ids)[0])
keep_ids = bm.as_jax(bm.where(ids)[0])
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
self._losses = self._losses[keep_ids]
self._selected_ids = self._selected_ids[keep_ids]
Expand Down
41 changes: 20 additions & 21 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gc
from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any

import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -18,8 +19,6 @@
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape



__all__ = [
# general class
'DynamicalSystem',
Expand Down Expand Up @@ -170,14 +169,14 @@ def register_delay(
raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
f'integer, array of integers, callable function, brainpy.init.Initializer.')
if delay_type == 'heter':
if delay_step.dtype not in [jnp.int32, jnp.int64]:
if delay_step.dtype not in [bm.int32, bm.int64]:
raise ValueError('Only support delay steps of int32, int64. If your '
'provide delay time length, please divide the "dt" '
'then provide us the number of delay steps.')
if delay_target.shape[0] != delay_step.shape[0]:
raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
if delay_type != 'none':
max_delay_step = int(jnp.max(delay_step))
max_delay_step = int(bm.max(delay_step))

# delay target
if delay_type != 'none':
Expand Down Expand Up @@ -207,8 +206,8 @@ def register_delay(
def get_delay_data(
self,
identifier: str,
delay_step: Optional[Union[int, bm.Array, jnp.DeviceArray]],
*indices: Union[int, slice, bm.Array, jnp.DeviceArray],
delay_step: Optional[Union[int, bm.Array, jax.Array]],
*indices: Union[int, slice, bm.Array, jax.Array],
):
"""Get delay data according to the provided delay steps.

Expand All @@ -230,19 +229,19 @@ def get_delay_data(
return self.global_delay_data[identifier][1].value

if identifier in self.global_delay_data:
if jnp.ndim(delay_step) == 0:
if bm.ndim(delay_step) == 0:
return self.global_delay_data[identifier][0](delay_step, *indices)
else:
if len(indices) == 0:
indices = (jnp.arange(delay_step.size),)
indices = (bm.arange(delay_step.size),)
return self.global_delay_data[identifier][0](delay_step, *indices)

elif identifier in self.local_delay_vars:
if jnp.ndim(delay_step) == 0:
if bm.ndim(delay_step) == 0:
return self.local_delay_vars[identifier](delay_step)
else:
if len(indices) == 0:
indices = (jnp.arange(delay_step.size),)
indices = (bm.arange(delay_step.size),)
return self.local_delay_vars[identifier](delay_step, *indices)

else:
Expand Down Expand Up @@ -878,7 +877,7 @@ def __init__(
# ------------
if isinstance(conn, TwoEndConnector):
self.conn = conn(pre.size, post.size)
elif isinstance(conn, (bm.ndarray, np.ndarray, jnp.ndarray)):
elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)):
if (pre.num, post.num) != conn.shape:
raise ValueError(f'"conn" is provided as a matrix, and it is expected '
f'to be an array with shape of (pre.num, post.num) = '
Expand Down Expand Up @@ -1157,11 +1156,11 @@ def _init_weights(
return weight, conn_mask

def _syn2post_with_all2all(self, syn_value, syn_weight):
if jnp.ndim(syn_weight) == 0:
if bm.ndim(syn_weight) == 0:
if isinstance(self.mode, bm.BatchingMode):
post_vs = jnp.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
else:
post_vs = jnp.sum(syn_value)
post_vs = bm.sum(syn_value)
if not self.conn.include_self:
post_vs = post_vs - syn_value
post_vs = syn_weight * post_vs
Expand All @@ -1173,7 +1172,7 @@ def _syn2post_with_one2one(self, syn_value, syn_weight):
return syn_value * syn_weight

def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
if jnp.ndim(syn_weight) == 0:
if bm.ndim(syn_weight) == 0:
post_vs = (syn_weight * syn_value) @ conn_mat
else:
post_vs = syn_value @ (syn_weight * conn_mat)
Expand Down Expand Up @@ -1253,8 +1252,8 @@ def __init__(

# variables
self.V = variable(V_initializer, self.mode, self.varshape)
self.input = variable(jnp.zeros, self.mode, self.varshape)
self.spike = variable(lambda s: jnp.zeros(s, dtype=bool), self.mode, self.varshape)
self.input = variable(bm.zeros, self.mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), self.mode, self.varshape)

# function
if self.noise is None:
Expand All @@ -1271,8 +1270,8 @@ def derivative(self, V, t):

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.spike.value = variable(lambda s: jnp.zeros(s, dtype=bool), batch_size, self.varshape)
self.input.value = variable(jnp.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
for channel in self.nodes(level=1, include_self=False).subset(Channel).unique().values():
channel.reset_state(self.V.value, batch_size=batch_size)

Expand All @@ -1286,7 +1285,7 @@ def update(self, tdi, *args, **kwargs):
# update variables
for node in channels.values():
node.update(tdi, self.V.value)
self.spike.value = jnp.logical_and(V >= self.V_th, self.V < self.V_th)
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
self.V.value = V

def register_implicit_nodes(self, *channels, **named_channels):
Expand All @@ -1295,7 +1294,7 @@ def register_implicit_nodes(self, *channels, **named_channels):

def clear_input(self):
"""Useful for monitoring inputs. """
self.input.value = jnp.zeros_like(self.input.value)
self.input.value = bm.zeros_like(self.input.value)


class Channel(DynamicalSystem):
Expand Down
21 changes: 10 additions & 11 deletions brainpy/_src/dyn/channels/IH.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from typing import Union, Callable

import jax.numpy as jnp
import brainpy.math as bm
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
Expand Down Expand Up @@ -76,7 +75,7 @@ def __init__(
self.E = parameter(E, self.varshape, allow_none=False)

# variable
self.p = variable(jnp.zeros, self.mode, self.varshape)
self.p = variable(bm.zeros, self.mode, self.varshape)

# function
self.integral = odeint(self.derivative, method=method)
Expand All @@ -96,10 +95,10 @@ def current(self, V):
return self.g_max * self.p * (self.E - V)

def f_p_inf(self, V):
return 1. / (1. + jnp.exp((V + 75.) / 5.5))
return 1. / (1. + bm.exp((V + 75.) / 5.5))

def f_p_tau(self, V):
return 1. / (jnp.exp(-0.086 * V - 14.59) + jnp.exp(0.0701 * V - 1.87))
return 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87))


class Ih_De1996(IhChannel, CalciumChannel):
Expand Down Expand Up @@ -200,9 +199,9 @@ def __init__(
self.g_inc = parameter(g_inc, self.varshape, allow_none=False)

# variable
self.O = variable(jnp.zeros, self.mode, self.varshape)
self.OL = variable(jnp.zeros, self.mode, self.varshape)
self.P1 = variable(jnp.zeros, self.mode, self.varshape)
self.O = variable(bm.zeros, self.mode, self.varshape)
self.OL = variable(bm.zeros, self.mode, self.varshape)
self.P1 = variable(bm.zeros, self.mode, self.varshape)

# function
self.integral = odeint(JointEq(self.dO, self.dOL, self.dP1), method=method)
Expand All @@ -229,7 +228,7 @@ def current(self, V, C_Ca, E_Ca):

def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
varshape = self.varshape if (batch_size is None) else ((batch_size,) + self.varshape)
self.P1.value = jnp.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape)
self.P1.value = bm.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape)
inf = self.f_inf(V)
tau = self.f_tau(V)
alpha = inf / tau
Expand All @@ -242,8 +241,8 @@ def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
assert self.OL.shape[0] == batch_size

def f_inf(self, V):
return 1 / (1 + jnp.exp((V + 75 - self.V_sh) / 5.5))
return 1 / (1 + bm.exp((V + 75 - self.V_sh) / 5.5))

def f_tau(self, V):
return (20. + 1000 / (jnp.exp((V + 71.5 - self.V_sh) / 14.2) +
jnp.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi
return (20. + 1000 / (bm.exp((V + 71.5 - self.V_sh) / 14.2) +
bm.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi
Loading