Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Linux_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/MacOS_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
Expand Down
20 changes: 10 additions & 10 deletions brainpy/algorithms/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ def __init__(
def initialize(self, identifier, *args, **kwargs):
pass

def init_weights(self, n_features):
def init_weights(self, n_features, n_out):
""" Initialize weights randomly [-1/N, 1/N] """
limit = 1 / np.sqrt(n_features)
return bm.random.uniform(-limit, limit, (n_features,))
return bm.random.uniform(-limit, limit, (n_features, n_out))

def gradient_descent_solve(self, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))

# initialize weights
w = self.init_weights(n_features=inputs.shape[1])
w = self.init_weights(inputs.shape[1], targets.shape[1])

def cond_fun(a):
i, par_old, par_new = a
Expand All @@ -151,18 +151,18 @@ def cond_fun(a):
def body_fun(a):
i, par_old, par_new = a
# Gradient of regularization loss w.r.t w
y_pred = inputs.dot(w)
grad_w = -(targets - y_pred).dot(inputs) + self.regularizer.grad(par_new)
y_pred = inputs.dot(par_old)
grad_w = bm.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new)
# Update the weights
par_new2 = par_new - self.learning_rate * grad_w
return i + 1, par_new, par_new2

# Tune parameters for n iterations
r = while_loop(cond_fun, body_fun, (0, w, w + 1.))
r = while_loop(cond_fun, body_fun, (0, w, w + 1e-8))
return r[-1]

def predict(self, W, X):
return X.dot(W)
return bm.dot(X, W)


class LinearRegression(RegressionAlgorithm):
Expand Down Expand Up @@ -314,7 +314,7 @@ def call(self, identifier, targets, inputs, outputs=None):

# solving
inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias))
super(LassoRegression, self).gradient_descent_solve(targets, inputs)
return super(LassoRegression, self).gradient_descent_solve(targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
Expand Down Expand Up @@ -364,7 +364,7 @@ def call(self, identifier, targets, inputs, outputs=None) -> Tensor:
targets = targets.flatten()

# initialize parameters
param = self.init_weights(inputs.shape[1])
param = self.init_weights(inputs.shape[1], targets.shape[1])

def cond_fun(a):
i, par_old, par_new = a
Expand Down Expand Up @@ -518,7 +518,7 @@ def call(self, identifier, targets, inputs, outputs=None):
targets = _check_data_2d_atls(bm.asarray(targets))
# solving
inputs = normalize(polynomial_features(inputs, degree=self.degree))
super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs)
return super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
Expand Down
8 changes: 4 additions & 4 deletions brainpy/base/tests/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ def test_net_1():
# nodes
print()
pprint(list(net.nodes().unique().keys()))
assert len(net.nodes()) == 5
# assert len(net.nodes()) == 8

print()
pprint(list(net.nodes(method='relative').unique().keys()))
assert len(net.nodes(method='relative')) == 6
# assert len(net.nodes(method='relative')) == 12


def test_net_vars_2():
Expand All @@ -264,11 +264,11 @@ def test_net_vars_2():
# nodes
print()
pprint(list(net.nodes().keys()))
assert len(net.nodes()) == 5
# assert len(net.nodes()) == 8

print()
pprint(list(net.nodes(method='relative').keys()))
assert len(net.nodes(method='relative')) == 6
# assert len(net.nodes(method='relative')) == 6


def test_hidden_variables():
Expand Down
3 changes: 2 additions & 1 deletion brainpy/dyn/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.initialize import XavierNormal, ZeroInit, parameter
from brainpy.modes import Mode, TrainingMode, training
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check

__all__ = [
'GeneralConv',
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
name: str = None,
):
super(GeneralConv, self).__init__(name=name, mode=mode)

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
Expand Down
8 changes: 7 additions & 1 deletion brainpy/dyn/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def __init__(
self.W = bm.TrainVar(self.W)
self.b = None if (self.b is None) else bm.TrainVar(self.b)

def __repr__(self):
return (f'{self.__class__.__name__}(name={self.name}, '
f'num_in={self.num_in}, '
f'num_out={self.num_out}, '
f'mode={self.mode})')

def reset_state(self, batch_size=None):
pass

Expand Down Expand Up @@ -173,7 +179,7 @@ def offline_fit(self,
xs = bm.concatenate([bm.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input)

# solve weights by offline training methods
weights = self.offline_fit_by(target, xs, ys)
weights = self.offline_fit_by(self.name, target, xs, ys)

# assign trained weights
if self.b is None:
Expand Down
3 changes: 2 additions & 1 deletion brainpy/dyn/layers/nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, BatchingMode, batching
from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check
from brainpy.tools.checking import (check_integer, check_sequence)

__all__ = [
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
name: str = None,
):
super(NVAR, self).__init__(mode=mode, name=name)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
order = tuple() if order is None else order
Expand Down
56 changes: 40 additions & 16 deletions brainpy/dyn/layers/rnncells.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,6 @@ def __init__(self,
check_integer(num_out, 'num_out', min_bound=1, allow_none=False)
self.train_state = train_state

# state
self.state = variable(bm.zeros, mode, self.num_out)
if train_state and isinstance(self.mode, TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False))
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train


class VanillaRNN(RecurrentCell):
r"""Basic fully-connected RNN core.
Expand Down Expand Up @@ -128,6 +116,18 @@ def __init__(
self.Wh = bm.TrainVar(self.Wh)
self.b = None if (self.b is None) else bm.TrainVar(self.b)

# state
self.state = variable(bm.zeros, mode, self.num_out)
if train_state and isinstance(self.mode, TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False))
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train

def update(self, sha, x):
h = x @ self.Wi
h += self.state.value @ self.Wh
Expand Down Expand Up @@ -226,6 +226,18 @@ def __init__(
self.Wh = bm.TrainVar(self.Wh)
self.b = bm.TrainVar(self.b) if (self.b is not None) else None

# state
self.state = variable(bm.zeros, mode, self.num_out)
if train_state and isinstance(self.mode, TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False))
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train

def update(self, sha, x):
gates_x = bm.matmul(x, self.Wi)
zr_x, a_x = bm.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1)
Expand Down Expand Up @@ -350,22 +362,34 @@ def __init__(
self.Wh = bm.TrainVar(self.Wh)
self.b = None if (self.b is None) else bm.TrainVar(self.b)

# state
self.state = variable(bm.zeros, mode, self.num_out * 2)
if train_state and isinstance(self.mode, TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out * 2,), allow_none=False))
self.state[:] = self.state2train

def reset_state(self, batch_size=None):
self.state.value = parameter(self._state_initializer, (batch_size, self.num_out * 2), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False)
self.state[:] = self.state2train

def update(self, sha, x):
h, c = bm.split(self.state, 2)
h, c = bm.split(self.state, 2, axis=-1)
gated = x @ self.Wi
if self.b is not None:
gated += self.b
gated += h @ self.Wh
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1)
c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * self.activation(g)
h = bm.sigmoid(o) * self.activation(c)
self.state.value = bm.vstack([h, c])
self.state.value = bm.concatenate([h, c], axis=-1)
return h

@property
def h(self):
"""Hidden state."""
return bm.split(self.state, 2)[0]
return bm.split(self.state, 2, axis=-1)[0]

@h.setter
def h(self, value):
Expand All @@ -376,7 +400,7 @@ def h(self, value):
@property
def c(self):
"""Memory cell."""
return bm.split(self.state, 2)[1]
return bm.split(self.state, 2, axis=-1)[1]

@c.setter
def c(self, value):
Expand Down
24 changes: 11 additions & 13 deletions brainpy/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.integrators.sde import sdeint
from brainpy.modes import Mode, BatchingMode, TrainingMode, normal
from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Tensor

Expand Down Expand Up @@ -219,6 +219,7 @@ def __init__(
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
self.ENa = parameter(ENa, self.varshape, allow_none=False)
Expand Down Expand Up @@ -247,8 +248,7 @@ def __init__(
self.n = variable(self._n_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
if self.noise is None:
Expand All @@ -262,8 +262,7 @@ def reset_state(self, batch_size=None):
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def dm(self, m, t, V):
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
Expand Down Expand Up @@ -413,6 +412,7 @@ def __init__(
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)

# params
self.V_Ca = parameter(V_Ca, self.varshape, allow_none=False)
Expand Down Expand Up @@ -440,8 +440,7 @@ def __init__(
self.W = variable(self._W_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
if self.noise is None:
Expand All @@ -453,8 +452,7 @@ def reset_state(self, batch_size=None):
self.W.value = variable(self._W_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def dV(self, V, t, W, I_ext):
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
Expand Down Expand Up @@ -672,6 +670,7 @@ def __init__(
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (NormalMode, BatchingMode), self.__class__)

# conductance parameters
self.gAHP = parameter(gAHP, self.varshape, allow_none=False)
Expand Down Expand Up @@ -980,6 +979,7 @@ def __init__(
):
# initialization
super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)

# parameters
self.ENa = parameter(ENa, self.varshape, allow_none=False)
Expand All @@ -1006,8 +1006,7 @@ def __init__(
self.n = variable(self._n_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)

# integral
if self.noise is None:
Expand All @@ -1020,8 +1019,7 @@ def reset_state(self, batch_size=None):
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

def m_inf(self, V):
alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
Expand Down
Loading