Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.2.3.1"
__version__ = "2.2.3.2"


try:
Expand Down Expand Up @@ -28,9 +28,9 @@

>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14", "jaxlib=0.3.14".
Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14" and "jaxlib=0.3.14".

More detail installation instruction, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
For more detail installation instructions, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax

''') from None

Expand Down
84 changes: 42 additions & 42 deletions brainpy/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable
from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.integrators.sde import sdeint
from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check_mode
from brainpy.modes import Mode, BatchingMode, NormalMode, normal, check_mode
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Array

Expand Down Expand Up @@ -243,18 +243,18 @@ def __init__(
self._V_initializer = V_initializer

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.m = (bm.Variable(self.m_inf(self.V.value))
if m_initializer is None else
variable(self._m_initializer, mode, self.varshape))
variable_(self._m_initializer, self.varshape, mode))
self.h = (bm.Variable(self.h_inf(self.V.value))
if h_initializer is None else
variable(self._h_initializer, mode, self.varshape))
variable_(self._h_initializer, self.varshape, mode))
self.n = (bm.Variable(self.n_inf(self.V.value))
if n_initializer is None else
variable(self._n_initializer, mode, self.varshape))
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
variable_(self._n_initializer, self.varshape, mode))
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)

# integral
if self.noise is None:
Expand All @@ -281,21 +281,21 @@ def __init__(
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
if self._m_initializer is None:
self.m.value = self.m_inf(self.V.value)
else:
self.m.value = variable(self._m_initializer, batch_size, self.varshape)
self.m.value = variable_(self._m_initializer, self.varshape, batch_size)
if self._h_initializer is None:
self.h.value = self.h_inf(self.V.value)
else:
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
self.h.value = variable_(self._h_initializer, self.varshape, batch_size)
if self._n_initializer is None:
self.n.value = self.n_inf(self.V.value)
else:
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.n.value = variable_(self._n_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def dV(self, V, t, m, h, n, I_ext):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
Expand Down Expand Up @@ -452,10 +452,10 @@ def __init__(
self._V_initializer = V_initializer

# variables
self.W = variable(self._W_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.W = variable_(self._W_initializer, self.varshape, mode)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
Expand All @@ -464,10 +464,10 @@ def __init__(
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.W.value = variable(self._W_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.W.value = variable_(self._W_initializer, self.varshape, batch_size)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def dV(self, V, t, W, I_ext):
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
Expand Down Expand Up @@ -718,16 +718,16 @@ def __init__(
self._Ca_initializer = Ca_initializer

# variables
self.Vs = variable(self._Vs_initializer, mode, self.varshape)
self.Vd = variable(self._Vd_initializer, mode, self.varshape)
self.Ca = variable(self._Ca_initializer, mode, self.varshape)
self.Vs = variable_(self._Vs_initializer, self.varshape, mode)
self.Vd = variable_(self._Vd_initializer, self.varshape, mode)
self.Ca = variable_(self._Ca_initializer, self.varshape, mode)
self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.Id = variable(bm.zeros, mode, self.varshape) # input to soma
self.Is = variable(bm.zeros, mode, self.varshape) # input to dendrite
self.Id = variable_(bm.zeros, self.varshape, mode) # input to soma
self.Is = variable_(bm.zeros, self.varshape, mode) # input to dendrite
# self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool))

# integral
Expand All @@ -737,17 +737,17 @@ def __init__(
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.Vd.value = variable(self._Vd_initializer, batch_size, self.varshape)
self.Vs.value = variable(self._Vs_initializer, batch_size, self.varshape)
self.Ca.value = variable(self._Ca_initializer, batch_size, self.varshape)
self.Vd.value = variable_(self._Vd_initializer, self.varshape, batch_size)
self.Vs.value = variable_(self._Vs_initializer, self.varshape, batch_size)
self.Ca.value = variable_(self._Ca_initializer, self.varshape, batch_size)
batch_axis = 0 if isinstance(self.mode, BatchingMode) else None
self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis)
self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis)
self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis)
self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis)
self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis)
self.Id.value = variable(bm.zeros, batch_size, self.varshape)
self.Is.value = variable(bm.zeros, batch_size, self.varshape)
self.Id.value = variable_(bm.zeros, self.varshape, batch_size)
self.Is.value = variable_(bm.zeros, self.varshape, batch_size)
# self.spike[:] = False

def dCa(self, Ca, t, s, Vd):
Expand Down Expand Up @@ -1017,11 +1017,11 @@ def __init__(
self._V_initializer = V_initializer

# variables
self.h = variable(self._h_initializer, mode, self.varshape)
self.n = variable(self._n_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.h = variable_(self._h_initializer, self.varshape, mode)
self.n = variable_(self._n_initializer, self.varshape, mode)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
Expand All @@ -1030,11 +1030,11 @@ def __init__(
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.h.value = variable_(self._h_initializer, self.varshape, batch_size)
self.n.value = variable_(self._n_initializer, self.varshape, batch_size)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def m_inf(self, V):
alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
Expand Down
10 changes: 5 additions & 5 deletions brainpy/dyn/neurons/input_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.errors import ModelBuildError
from brainpy.initialize import Initializer, parameter, variable
from brainpy.initialize import Initializer, parameter, variable_
from brainpy.modes import Mode, BatchingMode, normal
from brainpy.types import Shape, Array

Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(

# variables
self.i = bm.Variable(bm.zeros(1, dtype=bm.ditype()))
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)
if need_sort:
sort_idx = bm.argsort(self.times)
self.indices.value = self.indices[sort_idx]
Expand All @@ -162,7 +162,7 @@ def body_fun(t):

def reset_state(self, batch_size=None):
self.i[0] = 1
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def update(self, tdi, x=None):
self.spike[:] = False
Expand Down Expand Up @@ -193,7 +193,7 @@ def __init__(
self.freqs = parameter(freqs, self.num, allow_none=False)

# variables
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)
self.rng = bm.random.RandomState(seed=seed)

def update(self, tdi, x=None):
Expand All @@ -205,5 +205,5 @@ def reset(self, batch_size=None):
self.reset_state(batch_size)

def reset_state(self, batch_size=None):
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

4 changes: 2 additions & 2 deletions brainpy/dyn/neurons/noise_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def __init__(
self.tau = init.parameter(tau, self.varshape, allow_none=False)

# variables
self.x = init.variable(lambda s: bm.ones(s) * self.mean, mode, self.varshape)
self.x = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, mode)

# integral functions
self.integral = sdeint(f=self.df, g=self.dg, method=method)

def reset_state(self, batch_size=None):
self.x.value = init.variable(lambda s: bm.ones(s) * self.mean, batch_size, self.varshape)
self.x.value = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, batch_size)

def df(self, x, t):
return (self.mean - x) / self.tau
Expand Down
Loading