Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
losses, # loss functions
measure, # methods for data analysis
datasets, # methods for generating data
inputs) # methods for generating input currents
inputs, # methods for generating input currents
algorithms, # online or offline training algorithms
)


# numerical integrators
Expand All @@ -58,7 +60,8 @@
rates, # rate models
synapses, # synaptic dynamics
synouts, # synaptic output
synplast) # synaptic plasticity
synplast, # synaptic plasticity
)


# dynamics training
Expand Down
3 changes: 2 additions & 1 deletion brainpy/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(

# parameters for `f_cell` is DynamicalSystem instance
inputs: Sequence = None,
fun_inputs: Callable = None,
t: float = None,
dt: float = None,
included_vars: Dict[str, bm.Variable] = None,
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(
# input function
if inputs is not None:
inputs = check_and_format_inputs(host=self.target, inputs=inputs)
_input_step, _has_iter = build_inputs(inputs)
_input_step, _has_iter = build_inputs(inputs, fun_inputs)
if _has_iter:
raise UnsupportedError(f'Do not support iterable inputs when using fixed point finder.')
else:
Expand Down
25 changes: 14 additions & 11 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None):
elif isinstance(nodes, dict):
nodes = nodes.values()
for node in nodes:
for name in node.local_delay_vars.keys():
for name in node.local_delay_vars:
delay = self.global_delay_data[name][0]
target = self.global_delay_data[name][1]
delay.update(target.value)
Expand All @@ -250,7 +250,7 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None):
elif isinstance(nodes, dict):
nodes = nodes.values()
for node in nodes:
for name in node.local_delay_vars.keys():
for name in node.local_delay_vars:
delay = self.global_delay_data[name][0]
target = self.global_delay_data[name][1]
delay.reset(target.value)
Expand All @@ -260,15 +260,18 @@ def __del__(self):

This function is used to pop out the variables which registered in global delay data.
"""
for key in tuple(self.local_delay_vars.keys()):
val = self.global_delay_data.pop(key)
del val
val = self.local_delay_vars.pop(key)
del val
for key in tuple(self.implicit_nodes.keys()):
del self.implicit_nodes[key]
for key in tuple(self.implicit_vars.keys()):
del self.implicit_vars[key]
if hasattr(self, 'local_delay_vars'):
for key in tuple(self.local_delay_vars.keys()):
val = self.global_delay_data.pop(key)
del val
val = self.local_delay_vars.pop(key)
del val
if hasattr(self, 'implicit_nodes'):
for key in tuple(self.implicit_nodes.keys()):
del self.implicit_nodes[key]
if hasattr(self, 'implicit_vars'):
for key in tuple(self.implicit_vars.keys()):
del self.implicit_vars[key]
for key in tuple(self.__dict__.keys()):
del self.__dict__[key]
gc.collect()
Expand Down
71 changes: 63 additions & 8 deletions brainpy/dyn/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import time
from collections.abc import Iterable
from typing import Dict, Union, Sequence
from typing import Dict, Union, Sequence, Callable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -33,9 +33,9 @@ def check_and_format_inputs(host, inputs):
Parameters
----------
host : DynamicalSystem
The host which contains all data.
The host which contains all data.
inputs : tuple, list
The inputs of the population.
The inputs of the population.

Returns
-------
Expand Down Expand Up @@ -161,12 +161,30 @@ def check_and_format_inputs(host, inputs):
return formatted_inputs


def build_inputs(inputs):
def build_inputs(inputs, fun_inputs):
"""Build input function.

Parameters
----------
inputs : tuple, list
The inputs of the population.
fun_inputs: optional, callable
The input function customized by users.

Returns
-------
func: callable
The input function.
"""

fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}
array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []}

if not (fun_inputs is None or callable(fun_inputs)):
raise ValueError

_has_iter_array = False
for variable, value, type_, op in inputs:
# variable
Expand Down Expand Up @@ -202,6 +220,8 @@ def _f_ops(ops, var, data):
raise ValueError(f'Unknown input operation: {ops}')

def func(tdi):
if fun_inputs is not None:
fun_inputs(tdi)
for ops, values in fix_inputs.items():
for var, data in values:
_f_ops(ops, var, data)
Expand All @@ -225,6 +245,7 @@ class DSRunner(Runner):
----------
target : DynamicalSystem
The target model to run.

inputs : list, tuple
The inputs for the target DynamicalSystem. It should be the format
of `[(target, value, [type, operation])]`, where `target` is the
Expand All @@ -239,14 +260,50 @@ class DSRunner(Runner):
- ``operation``: should be a string, support `+`, `-`, `*`, `/`, `=`.
- Also, if you want to specify multiple inputs, just give multiple ``(target, value, [type, operation])``,
for example ``[(target1, value1), (target2, value2)]``.

fun_inputs: callable
The functional inputs. Manually specify the inputs for the target variables.
This input function should receive one argument `shared` which contains the shared arguments like
time `t`, time step `dt`, and index `i`.

monitors: None, sequence of str, dict, Monitor
Variables to monitor.

- A list of string. Like `monitors=['a', 'b', 'c']`
- A list of string with index specification. Like `monitors=[('a', 1), ('b', [1,3,5]), 'c']`
- A dict with the explicit monitor target, like: `monitors={'a': model.spike, 'b': model.V}`
- A dict with the index specification, like: `monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}`

fun_monitors: dict
Monitoring variables by callable functions. Should be a dict.
The `key` should be a string for the later retrieval by `runner.mon[key]`.
The `value` should be a callable function which receives two arguments: `t` and `dt`.

jit: bool, dict
The JIT settings.

progress_bar: bool
Use progress bar to report the running progress or not?

dyn_vars: Optional, dict
The dynamically changed variables. Instance of :py:class:`~.Variable`.

numpy_mon_after_run : bool
When finishing the network running, transform the JAX arrays into numpy ndarray or not?

"""

target: DynamicalSystem

def __init__(
self,
target: DynamicalSystem,

# inputs for target variables
inputs: Sequence = (),
fun_inputs: Callable = None,

# extra info
dt: float = None,
t0: Union[float, int] = 0.,
**kwargs
Expand All @@ -269,11 +326,10 @@ def __init__(

# Build the monitor function
self._mon_info = self.format_monitors()
# self._monitor_step = self.build_monitors(*self.format_monitors())

# Build input function
inputs = check_and_format_inputs(host=target, inputs=inputs)
self._input_step, _ = build_inputs(inputs)
self._input_step, _ = build_inputs(check_and_format_inputs(host=target, inputs=inputs),
fun_inputs=fun_inputs)

# run function
self._f_predict_compiled = dict()
Expand Down Expand Up @@ -581,4 +637,3 @@ def __del__(self):
for key in tuple(self._f_predict_compiled.keys()):
del self._f_predict_compiled[key]
super(DSRunner, self).__del__()

49 changes: 18 additions & 31 deletions brainpy/math/index_tricks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc

from jax import core
Expand Down Expand Up @@ -61,17 +47,18 @@ class _Mgrid(_IndexGrid):
Examples:
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:

>>> jnp.mgrid[0:4:1]
>>> import brainpy.math as bm
>>> bm.mgrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)

Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:

>>> jnp.mgrid[0:1:4j]
>>> bm.mgrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)

Multiple slices can be used to create broadcasted grids of indices:

>>> jnp.mgrid[:2, :3]
>>> bm.mgrid[:2, :3]
DeviceArray([[[0, 0, 0],
[1, 1, 1]],
[[0, 1, 2],
Expand All @@ -96,17 +83,17 @@ class _Ogrid(_IndexGrid):
Examples:
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:

>>> jnp.ogrid[0:4:1]
>>> bm.ogrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)

Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:

>>> jnp.ogrid[0:1:4j]
>>> bm.ogrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)

Multiple slices can be used to create sparse grids of indices:

>>> jnp.ogrid[:2, :3]
>>> bm.ogrid[:2, :3]
[DeviceArray([[0],
[1]], dtype=int32),
DeviceArray([[0, 1, 2]], dtype=int32)]
Expand Down Expand Up @@ -200,13 +187,13 @@ class RClass(_AxisConcat):
Examples:
Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects:

>>> jnp.r_[-1:5:1, 0, 0, jnp.array([1,2,3])]
>>> bm.r_[-1:5:1, 0, 0, bm.array([1,2,3])]
DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)

An imaginary value for ``step`` will create a ``jnp.linspace`` object instead,
which includes the right endpoint:

>>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])]
>>> bm.r_[-1:1:6j, 0, bm.array([1,2,3])]
DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005,
0.6 , 1. , 0. , 1. ,
2. , 3. ], dtype=float32)
Expand All @@ -215,11 +202,11 @@ class RClass(_AxisConcat):
specify concatenation axis, minimum number of dimensions, and the position of the
upgraded array's original dimensions in the resulting array's shape tuple:

>>> jnp.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output
>>> bm.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output
DeviceArray([[1, 2, 3],
[4, 5, 6]], dtype=int32)

>>> jnp.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front
>>> bm.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front
DeviceArray([[1],
[2],
[3],
Expand All @@ -230,7 +217,7 @@ class RClass(_AxisConcat):
Negative values for ``trans1d`` offset the last axis towards the start
of the shape tuple:

>>> jnp.r_['0,2,-2', [1,2,3], [4,5,6]]
>>> bm.r_['0,2,-2', [1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
Expand All @@ -241,10 +228,10 @@ class RClass(_AxisConcat):
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
to create an array with an extra row or column axis, respectively:

>>> jnp.r_['r',[1,2,3], [4,5,6]]
>>> bm.r_['r',[1,2,3], [4,5,6]]
DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32)

>>> jnp.r_['c',[1,2,3], [4,5,6]]
>>> bm.r_['c',[1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
Expand Down Expand Up @@ -274,24 +261,24 @@ class CClass(_AxisConcat):

Examples:

>>> a = jnp.arange(6).reshape((2,3))
>>> jnp.c_[a,a]
>>> a = bm.arange(6).reshape((2,3))
>>> bm.c_[a,a]
DeviceArray([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]], dtype=int32)

Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify
concatenation axis, minimum number of dimensions, and the position of the upgraded array's
original dimensions in the resulting array's shape tuple:

>>> jnp.c_['0,2', [1,2,3], [4,5,6]]
>>> bm.c_['0,2', [1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)

>>> jnp.c_['0,2,-1', [1,2,3], [4,5,6]]
>>> bm.c_['0,2,-1', [1,2,3], [4,5,6]]
DeviceArray([[1, 2, 3],
[4, 5, 6]], dtype=int32)

Expand Down
Loading