Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ publishment.md
.vscode
io_test_tmp*

brainpy/base/tests/io_test_tmp*
brainpy/math/brainpy_object/tests/io_test_tmp*

development

Expand Down Expand Up @@ -217,3 +217,7 @@ cython_debug/
/docs/apis/simulation/generated/
!/brainpy/dyn/tests/data/
/examples/dynamics_simulation/data/
/examples/training_snn_models/logs/T100_b64_lr0.001/
/examples/training_snn_models/logs/
/examples/training_snn_models/data/
/docs/tutorial_advanced/data/
32 changes: 14 additions & 18 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,8 @@
# fundamental modules
from . import errors, check, tools

# "base" module
from . import base
from .base import (
# base class
Base,
BrainPyObject,

# collector
Collector,
ArrayCollector,
TensorCollector,
)

# math foundation
from . import math
from . import modes

# toolboxes
from . import (
Expand Down Expand Up @@ -69,7 +55,7 @@
synouts, # synaptic output
synplast, # synaptic plasticity

# base classes
# brainpy_object classes
DynamicalSystem,
Container,
Sequential,
Expand Down Expand Up @@ -113,9 +99,7 @@

# running
from . import running
from .running import (
Runner
)
from .running import (Runner)

# "visualization" module, will be removed soon
from .visualization import visualize
Expand All @@ -124,3 +108,15 @@
conn = connect
init = initialize
optim = optimizers

from . import experimental


# deprecated
from . import base
# use ``brainpy.math.*`` instead
from brainpy.math.object_transform.base_object import (Base, BrainPyObject,)
# use ``brainpy.math.*`` instead
from brainpy.math.object_transform.collector import (Collector, ArrayCollector, TensorCollector,)
# use ``brainpy.math.*`` instead
from . import modes
5 changes: 2 additions & 3 deletions brainpy/algorithms/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from jax.lax import while_loop

import brainpy.math as bm
from brainpy.base import BrainPyObject
from brainpy.types import ArrayType
from .utils import (Sigmoid,
Regularization, L1Regularization, L1L2Regularization, L2Regularization,
polynomial_features, normalize)

__all__ = [
# base class for offline training algorithm
# brainpy_object class for offline training algorithm
'OfflineAlgorithm',

# training methods
Expand All @@ -33,7 +32,7 @@
name2func = dict()


class OfflineAlgorithm(BrainPyObject):
class OfflineAlgorithm(bm.BrainPyObject):
"""Base class for offline training algorithm."""

def __init__(self, name=None):
Expand Down
5 changes: 2 additions & 3 deletions brainpy/algorithms/online.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-

import brainpy.math as bm
from brainpy.base import BrainPyObject
from jax import vmap
import jax.numpy as jnp

__all__ = [
# base class
# brainpy_object class
'OnlineAlgorithm',

# online learning algorithms
Expand All @@ -21,7 +20,7 @@
name2func = dict()


class OnlineAlgorithm(BrainPyObject):
class OnlineAlgorithm(bm.BrainPyObject):
"""Base class for online training algorithm."""

def __init__(self, name=None):
Expand Down
5 changes: 2 additions & 3 deletions brainpy/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import brainpy.math as bm
from brainpy import optimizers as optim, losses
from brainpy.analysis import utils, base, constants
from brainpy.base import ArrayCollector
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.runners import check_and_format_inputs, _f_ops
from brainpy.errors import AnalyzerError, UnsupportedError
Expand Down Expand Up @@ -133,11 +132,11 @@ def __init__(

# update function
if target_vars is None:
self.target_vars = ArrayCollector()
self.target_vars = bm.ArrayCollector()
else:
if not isinstance(target_vars, dict):
raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}')
self.target_vars = ArrayCollector(target_vars)
self.target_vars = bm.ArrayCollector(target_vars)
excluded_vars = () if excluded_vars is None else excluded_vars
if isinstance(excluded_vars, dict):
excluded_vars = tuple(excluded_vars.values())
Expand Down
9 changes: 4 additions & 5 deletions brainpy/analysis/lowdim/lowdim_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from brainpy import errors, tools
from brainpy.analysis import constants as C, utils
from brainpy.analysis.base import DSAnalyzer
from brainpy.base.collector import Collector

pyplot = None

Expand Down Expand Up @@ -91,7 +90,7 @@ def __init__(
if not isinstance(target_vars, dict):
raise errors.AnalyzerError('"target_vars" must be a dict, with the format of '
'{"var1": (var1_min, var1_max)}.')
self.target_vars = Collector(target_vars)
self.target_vars = bm.Collector(target_vars)
self.target_var_names = list(self.target_vars.keys()) # list of target vars
for key in self.target_vars.keys():
if key not in self.model.variables:
Expand All @@ -110,7 +109,7 @@ def __init__(
for key in fixed_vars.keys():
if key not in self.model.variables:
raise ValueError(f'{key} is not a dynamical variable in {self.model}.')
self.fixed_vars = Collector(fixed_vars)
self.fixed_vars = bm.Collector(fixed_vars)

# check duplicate
for key in self.fixed_vars.keys():
Expand All @@ -125,7 +124,7 @@ def __init__(
if not isinstance(pars_update, dict):
raise errors.AnalyzerError('"pars_update" must be a dict with the format '
'of {"par1": val1, "par2": val2}.')
pars_update = Collector(pars_update)
pars_update = bm.Collector(pars_update)
for key in pars_update.keys():
if key not in self.model.parameters:
raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.')
Expand All @@ -144,7 +143,7 @@ def __init__(
raise errors.AnalyzerError(
f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')

self.target_pars = Collector(target_pars)
self.target_pars = bm.Collector(target_pars)
self.target_par_names = list(self.target_pars.keys()) # list of target_pars

# check duplicate
Expand Down
30 changes: 8 additions & 22 deletions brainpy/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
# -*- coding: utf-8 -*-

"""
The ``base`` module for whole BrainPy ecosystem.

- This module provides the most fundamental class ``BrainPyObject``,
and its associated helper class ``Collector`` and ``ArrayCollector``.
- For each instance of "BrainPyObject" class, users can retrieve all
the variables (or trainable variables), integrators, and nodes.
- This module also provides a ``FunAsObject`` class to wrap user-defined
functions. In each function, maybe several nodes are used, and
users can initialize a ``FunAsObject`` by providing the nodes used
in the function. Unfortunately, ``FunAsObject`` class does not have
the ability to gather nodes automatically.
- This module provides ``io`` helper functions to help users save/load
model states, or share user's customized model with others.
- This module provides ``naming`` tools to guarantee the unique nameing
for each BrainPyObject object.

Details please see the following.
This module is deprecated since version 2.3.1.
Please use ``brainpy.math.*`` instead.
"""

from brainpy.base.base import *
from brainpy.base.collector import *
from brainpy.base.function import *
from brainpy.base.io import *
from brainpy.base.naming import *

from .base import *
from .collector import *
from .function import *
from .io import *
from .naming import *

Loading