Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
<!--- For example, markdown files should pass markdownlint locally according to the rules -->
<!--- See how your change affects other areas of the code, etc. -->

## Screenshots(optional)
<!--- If Screenshots is not necessary or not available in this pull request, you can delete this section -->
<!--- Changes including html and css are required to have screenshots -->

## Types of changes
<!--- What types of changes does your code introduce? -->
<!--- Only left the line that best describes this pull request -->
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/MacOS_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
2 changes: 1 addition & 1 deletion .github/workflows/Windows_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
5 changes: 0 additions & 5 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,11 @@
synouts, # synaptic output
synplast, # synaptic plasticity
)
from .dyn.base import *
from .dyn.runners import *


# dynamics training
from . import train
from .train.base import *
from .train.online import *
from .train.offline import *
from .train.back_propagation import *


# automatic dynamics analysis
Expand Down
47 changes: 35 additions & 12 deletions brainpy/algorithms/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,35 @@ class OfflineAlgorithm(Base):
def __init__(self, name=None):
super(OfflineAlgorithm, self).__init__(name=name)

def __call__(self, targets, inputs, outputs) -> Tensor:
def __call__(self, identifier, target, input, output):
"""The training procedure.

Parameters
----------
identifier: str
The variable name.
target: JaxArray, ndarray
The 2d target data with the shape of `(num_batch, num_output)`.
input: JaxArray, ndarray
The 2d input data with the shape of `(num_batch, num_input)`.
output: JaxArray, ndarray
The 2d output data with the shape of `(num_batch, num_output)`.

Returns
-------
weight: JaxArray
The weights after fit.
"""
return self.call(identifier, target, input, output)

def call(self, identifier, targets, inputs, outputs) -> Tensor:
"""The training procedure.

Parameters
----------
identifier: str
The identifier.

inputs: JaxArray, jax.numpy.ndarray, numpy.ndarray
The 3d input data with the shape of `(num_batch, num_time, num_input)`,
or, the 2d input data with the shape of `(num_time, num_input)`.
Expand All @@ -67,8 +91,7 @@ def __repr__(self):
return self.__class__.__name__

def initialize(self, identifier, *args, **kwargs):
raise NotImplementedError('Must implement the initialize() '
'function by the subclass itself.')
pass


def _check_data_2d_atls(x):
Expand Down Expand Up @@ -166,7 +189,7 @@ def __init__(
regularizer=Regularization(0.))
self.gradient_descent = gradient_descent

def __call__(self, targets, inputs, outputs=None):
def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
Expand Down Expand Up @@ -225,7 +248,7 @@ def __init__(
regularizer=L2Regularization(alpha=alpha))
self.gradient_descent = gradient_descent

def __call__(self, targets, inputs, outputs=None):
def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
Expand Down Expand Up @@ -284,7 +307,7 @@ def __init__(
assert gradient_descent
self.degree = degree

def __call__(self, targets, inputs, outputs=None):
def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
Expand Down Expand Up @@ -332,7 +355,7 @@ def __init__(
self.gradient_descent = gradient_descent
self.sigmoid = Sigmoid()

def __call__(self, targets, inputs, outputs=None) -> Tensor:
def call(self, identifier, targets, inputs, outputs=None) -> Tensor:
# prepare data
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
Expand Down Expand Up @@ -395,11 +418,11 @@ def __init__(
self.degree = degree
self.add_bias = add_bias

def __call__(self, targets, inputs, outputs=None):
def call(self, identifier, targets, inputs, outputs=None):
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)
return super(PolynomialRegression, self).__call__(targets, inputs)
return super(PolynomialRegression, self).call(identifier, targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
Expand Down Expand Up @@ -431,12 +454,12 @@ def __init__(
self.degree = degree
self.add_bias = add_bias

def __call__(self, targets, inputs, outputs=None):
def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)
return super(PolynomialRidgeRegression, self).__call__(targets, inputs)
return super(PolynomialRidgeRegression, self).call(identifier, targets, inputs)

def predict(self, W, X):
X = _check_data_2d_atls(bm.asarray(X))
Expand Down Expand Up @@ -489,7 +512,7 @@ def __init__(
self.gradient_descent = gradient_descent
assert gradient_descent

def __call__(self, targets, inputs, outputs=None):
def call(self, identifier, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.asarray(inputs))
targets = _check_data_2d_atls(bm.asarray(targets))
Expand Down
19 changes: 10 additions & 9 deletions brainpy/algorithms/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import brainpy.math as bm
from brainpy.base import Base
from jax import vmap
import jax.numpy as jnp

__all__ = [
# base class
Expand All @@ -25,12 +27,12 @@ class OnlineAlgorithm(Base):
def __init__(self, name=None):
super(OnlineAlgorithm, self).__init__(name=name)

def __call__(self, name, target, input, output):
def __call__(self, identifier, target, input, output):
"""The training procedure.

Parameters
----------
name: str
identifier: str
The variable name.
target: JaxArray, ndarray
The 2d target data with the shape of `(num_batch, num_output)`.
Expand All @@ -44,11 +46,10 @@ def __call__(self, name, target, input, output):
weight: JaxArray
The weights after fit.
"""
return self.call(name, target, input, output)
return self.call(identifier, target, input, output)

def initialize(self, identifier, *args, **kwargs):
raise NotImplementedError('Must implement the initialize() '
'function by the subclass itself.')
pass

def call(self, identifier, target, input, output):
"""The training procedure.
Expand Down Expand Up @@ -146,11 +147,11 @@ def __init__(self, alpha=0.1, name=None):
super(LMS, self).__init__(name=name)
self.alpha = alpha

def initialize(self, identifier, *args, **kwargs):
pass

def call(self, identifier, target, input, output):
return -self.alpha * bm.dot(output - target, output)
assert target.shape[0] == input.shape[0] == output.shape[0], 'Batch size should be consistent.'
error = bm.as_jax(output - target)
input = bm.as_jax(input)
return -self.alpha * bm.sum(vmap(jnp.outer)(input, error), axis=0)


name2func['lms'] = LMS
Expand Down
21 changes: 17 additions & 4 deletions brainpy/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def find_fps_with_gd_method(
f_eval_loss = self._get_f_eval_loss()

def f_loss():
return f_eval_loss(tree_map(lambda a: a.value,
return f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.JaxArray))).mean()

Expand Down Expand Up @@ -386,9 +386,11 @@ def batch_train(start_i, n_batch):
f'is below tolerance {tolerance:0.10f}.')

self._opt_losses = bm.concatenate(opt_losses)
self._losses = f_eval_loss(tree_map(lambda a: a.value, fixed_points,
self._losses = f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.JaxArray)))
self._fixed_points = tree_map(lambda a: a.value, fixed_points,
self._fixed_points = tree_map(lambda a: bm.as_device_array(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.JaxArray))
self._selected_ids = jnp.arange(num_candidate)

Expand Down Expand Up @@ -425,7 +427,7 @@ def find_fps_with_opt_solver(
print(f"Optimizing with {opt_solver} to find fixed points:")

# optimizing
res = f_opt(tree_map(lambda a: a.value,
res = f_opt(tree_map(lambda a: bm.as_device_array(a),
candidates,
is_leaf=lambda a: isinstance(a, bm.JaxArray)))

Expand Down Expand Up @@ -720,16 +722,27 @@ def _generate_ds_cell_function(
shared = DotDict(t=t, dt=dt, i=0)

def f_cell(h: Dict):
target.clear_input()

# update target variables
for k, v in self.target_vars.items():
v.value = (bm.asarray(h[k], dtype=v.dtype)
if v.batch_axis is None else
bm.asarray(bm.expand_dims(h[k], axis=v.batch_axis), dtype=v.dtype))

# update excluded variables
for k, v in self.excluded_vars.items():
v.value = self.excluded_data[k]

# add inputs
if f_input is not None:
f_input(shared)

# call update functions
args = (shared,) + self.args
target.update(*args)

# get new states
new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis))
for k, v in self.target_vars.items()}
return new_h
Expand Down
2 changes: 1 addition & 1 deletion brainpy/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def vars(self, method='absolute', level=-1, include_self=True):
for node_path, node in nodes.items():
for k in dir(node):
v = getattr(node, k)
if isinstance(v, math.Variable):
if isinstance(v, math.Variable) and not k.startswith('_') and not k.endswith('_'):
gather[f'{node_path}.{k}' if node_path else k] = v
gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
return gather
Expand Down
13 changes: 10 additions & 3 deletions brainpy/base/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ def replace(self, key, new_value):
self[key] = new_value

def update(self, other, **kwargs):
assert isinstance(other, dict)
for key, value in other.items():
self[key] = value
assert isinstance(other, (dict, list, tuple))
if isinstance(other, dict):
for key, value in other.items():
self[key] = value
elif isinstance(other, (tuple, list)):
num = len(self)
for i, value in enumerate(other):
self[f'_var{i+num}'] = value
else:
raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}')
for key, value in kwargs.items():
self[key] = value

Expand Down
Loading