Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/Sync_branches.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:
steps:
- uses: actions/checkout@master

- name: Merge master -> brainpy-2.x
- name: Merge master -> brainpy-2.2.x
uses: devmasx/merge-branch@master
with:
type: now
from_branch: master
target_branch: brainpy-2.x
target_branch: brainpy-2.2.x
github_token: ${{ github.token }}
8 changes: 6 additions & 2 deletions brainpy/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Base(object):

"""

_excluded_vars = ()

def __init__(self, name=None):
# check whether the object has a unique name.
self._name = None
Expand Down Expand Up @@ -120,8 +122,10 @@ def vars(self, method='absolute', level=-1, include_self=True):
for node_path, node in nodes.items():
for k in dir(node):
v = getattr(node, k)
if isinstance(v, math.Variable) and not k.startswith('_') and not k.endswith('_'):
gather[f'{node_path}.{k}' if node_path else k] = v
if isinstance(v, math.Variable):
if k not in node._excluded_vars:
# if not k.startswith('_') and not k.endswith('_'):
gather[f'{node_path}.{k}' if node_path else k] = v
gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
return gather

Expand Down
6 changes: 3 additions & 3 deletions brainpy/dyn/neurons/fractional_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self,
size: Shape,
alpha: Union[float, Sequence[float]],
num_step: int,
num_memory: int,
a: Union[float, Tensor, Initializer, Callable] = 0.02,
b: Union[float, Tensor, Initializer, Callable] = 0.20,
c: Union[float, Tensor, Initializer, Callable] = -65.,
Expand Down Expand Up @@ -272,10 +272,10 @@ def __init__(
self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool))

# functions
check_integer(num_step, 'num_step', allow_none=False)
check_integer(num_memory, 'num_step', allow_none=False)
self.integral = CaputoL1Schema(f=self.derivative,
alpha=alpha,
num_memory=num_step,
num_memory=num_memory,
inits=[self.V, self.u])

def reset_state(self, batch_size=None):
Expand Down
24 changes: 12 additions & 12 deletions brainpy/inputs/currents.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def constant_input(I_and_duration, dt=None):

# get the current
start = 0
I_current = jnp.zeros((int(np.ceil(I_duration / dt)),) + I_shape)
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape)
for c_size, duration in I_and_duration:
length = int(duration / dt)
I_current = I_current.at[start: start + length].set(c_size)
I_current[start: start + length] = c_size
start += length
return I_current, I_duration
return I_current.value, I_duration


def constant_current(*args, **kwargs):
Expand Down Expand Up @@ -172,12 +172,12 @@ def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None):
if isinstance(sp_sizes, (float, int)):
sp_sizes = [sp_sizes] * len(sp_times)

current = jnp.zeros(int(np.ceil(duration / dt)))
current = bm.zeros(int(np.ceil(duration / dt)))
for time, dur, size in zip(sp_times, sp_lens, sp_sizes):
pp = int(time / dt)
p_len = int(dur / dt)
current = current.at[pp: pp + p_len].set(size)
return current
current[pp: pp + p_len] = size
return current.value


def spike_current(*args, **kwargs):
Expand Down Expand Up @@ -218,12 +218,12 @@ def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
dt = bm.get_dt() if dt is None else dt
t_end = duration if t_end is None else t_end

current = jnp.zeros(int(np.ceil(duration / dt)))
current = bm.zeros(int(np.ceil(duration / dt)))
p1 = int(np.ceil(t_start / dt))
p2 = int(np.ceil(t_end / dt))
cc = jnp.array(jnp.linspace(c_start, c_end, p2 - p1))
current = current.at[p1: p2].set(cc)
return current
current[p1: p2] = cc
return current.value


def ramp_current(*args, **kwargs):
Expand Down Expand Up @@ -265,9 +265,9 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
i_start = int(t_start / dt)
i_end = int(t_end / dt)
noises = rng.standard_normal((i_end - i_start, n)) * jnp.sqrt(dt)
currents = jnp.zeros((int(duration / dt), n))
currents = currents.at[i_start: i_end].set(bm.as_device_array(noises))
return currents
currents = bm.zeros((int(duration / dt), n))
currents[i_start: i_end] = noises
return currents.value


def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
Expand Down
22 changes: 11 additions & 11 deletions brainpy/integrators/fde/Caputo.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(

def _check_step(self, args):
dt, t = args
raise ValueError(f'The maximum number of step is {self.num_step}, '
raise ValueError(f'The maximum number of step is {self.num_memory}, '
f'however, the current time {t} require a time '
f'step number {t / dt}.')

Expand All @@ -164,7 +164,7 @@ def _integral_func(self, *args, **kwargs):
t = all_args['t']
dt = all_args.pop(DT, self.dt)
if check.is_checking():
check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t))
check_error_in_jit(self.num_memory * dt < t, self._check_step, (dt, t))

# derivative values
devs = self.f(**all_args)
Expand All @@ -185,11 +185,11 @@ def _integral_func(self, *args, **kwargs):

# integral results
integrals = []
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step
idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory
for i, key in enumerate(self.variables):
integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key]
integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i]))
self.idx.value = (self.idx + 1) % self.num_step
self.idx.value = (self.idx + 1) % self.num_memory

# return integrals
if len(self.variables) == 1:
Expand Down Expand Up @@ -344,19 +344,19 @@ def __init__(
dtype=self.inits[v].dtype))
for v in self.variables}
self.register_implicit_vars(self.diff_states)
self.idx = bm.Variable(bm.asarray([self.num_step - 1]))
self.idx = bm.Variable(bm.asarray([self.num_memory - 1]))

# integral function
self.set_integral(self._integral_func)

def reset(self, inits):
"""Reset function."""
self.idx.value = bm.asarray([self.num_step - 1])
self.idx.value = bm.asarray([self.num_memory - 1])
inits = check_inits(inits, self.variables)
for key, value in inits.items():
self.inits[key].value = value
for key, val in inits.items():
self.diff_states[key + "_diff"].value = bm.zeros((self.num_step,) + val.shape, dtype=val.dtype)
self.diff_states[key + "_diff"].value = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype)

def hists(self, var=None, numpy=True):
"""Get the recorded history values."""
Expand All @@ -378,7 +378,7 @@ def hists(self, var=None, numpy=True):

def _check_step(self, args):
dt, t = args
raise ValueError(f'The maximum number of step is {self.num_step}, '
raise ValueError(f'The maximum number of step is {self.num_memory}, '
f'however, the current time {t} require a time '
f'step number {t / dt}.')

Expand All @@ -388,7 +388,7 @@ def _integral_func(self, *args, **kwargs):
t = all_args['t']
dt = all_args.pop(DT, self.dt)
if check.is_checking():
check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t))
check_error_in_jit(self.num_memory * dt < t, self._check_step, (dt, t))

# derivative values
devs = self.f(**all_args)
Expand All @@ -405,15 +405,15 @@ def _integral_func(self, *args, **kwargs):

# integral results
integrals = []
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step
idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory
for i, key in enumerate(self.variables):
self.diff_states[key + '_diff'][self.idx[0]] = all_args[key] - self.inits[key]
self.inits[key].value = all_args[key]
markov_term = dt ** self.alpha[i] * self.gamma_alpha[i] * devs[key] + all_args[key]
memory_trace = self.coef[idx, i] @ self.diff_states[key + '_diff']
integral = markov_term - memory_trace
integrals.append(integral)
self.idx.value = (self.idx + 1) % self.num_step
self.idx.value = (self.idx + 1) % self.num_memory

# return integrals
if len(self.variables) == 1:
Expand Down
21 changes: 11 additions & 10 deletions brainpy/integrators/fde/GL.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
This module provides numerical solvers for Grünwald–Letnikov derivative FDEs.
"""

from typing import Dict, Union, Callable
from typing import Dict, Union, Callable, Any

import jax.numpy as jnp

Expand Down Expand Up @@ -127,8 +127,8 @@ class GLShortMemory(FDEIntegrator):
def __init__(
self,
f: Callable,
alpha,
inits,
alpha: Any,
inits: Any,
num_memory: int,
dt: float = None,
name: str = None,
Expand All @@ -152,9 +152,9 @@ def __init__(
# delays
self.delays = {}
for key, val in inits.items():
delay = bm.Variable(bm.zeros((self.num_step,) + val.shape, dtype=val.dtype))
delay = bm.Variable(bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype))
delay[0] = val
self.delays[key] = delay
self.delays[key+'_delay'] = delay
self._idx = bm.Variable(bm.asarray([1]))
self.register_implicit_vars(self.delays)

Expand All @@ -171,7 +171,7 @@ def reset(self, inits):
self._idx.value = bm.asarray([1])
inits = check_inits(inits, self.variables)
for key, val in inits.items():
delay = bm.zeros((self.num_step,) + val.shape, dtype=val.dtype)
delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype)
delay[0] = val
self.delays[key].value = delay

Expand Down Expand Up @@ -199,13 +199,14 @@ def _integral_func(self, *args, **kwargs):

# integral results
integrals = []
idx = (self._idx + bm.arange(self.num_step)) % self.num_step
idx = (self._idx + bm.arange(self.num_memory)) % self.num_memory
for i, var in enumerate(self.variables):
summation = self._binomial_coef[:, i] @ self.delays[var][idx]
delay_var = var + '_delay'
summation = self._binomial_coef[:, i] @ self.delays[delay_var][idx]
integral = (dt ** self.alpha[i]) * devs[var] - summation
self.delays[var][self._idx[0]] = integral
self.delays[delay_var][self._idx[0]] = integral
integrals.append(integral)
self._idx.value = (self._idx + 1) % self.num_step
self._idx.value = (self._idx + 1) % self.num_memory

# return integrals
if len(self.variables) == 1:
Expand Down
4 changes: 2 additions & 2 deletions brainpy/integrators/fde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __init__(
arguments = parses[2] # function arguments

# memory length
check_integer(num_memory, 'num_step', allow_none=False, min_bound=1)
self.num_step = num_memory
check_integer(num_memory, 'num_memory', allow_none=False, min_bound=1)
self.num_memory = num_memory

# super initialization
super(FDEIntegrator, self).__init__(name=name,
Expand Down
2 changes: 1 addition & 1 deletion brainpy/integrators/fde/tests/test_Caputo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test1(self):

intg.idx[0] = N - 1
intg.diff_states['a_diff'][:N - 1] = bp.math.asarray(diff)
idx = ((intg.num_step - intg.idx) + np.arange(intg.num_step)) % intg.num_step
idx = ((intg.num_memory - intg.idx) + np.arange(intg.num_memory)) % intg.num_memory
memory_trace2 = intg.coef[idx, 0] @ intg.diff_states['a_diff']

print()
Expand Down
15 changes: 8 additions & 7 deletions brainpy/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,11 @@ def __init__(
numpy_mon_after_run: bool
"""

# initialize variables
if not isinstance(target, Integrator):
raise TypeError(f'Target must be instance of {Integrator.__name__}, '
f'but we got {type(target)}')

# get maximum size and initial variables
if inits is not None:
if isinstance(inits, (list, tuple, bm.JaxArray, jnp.ndarray)):
assert len(target.variables) == len(inits)
Expand All @@ -148,6 +149,8 @@ def __init__(
else:
max_size = 1
inits = dict()

# initialize variables
self.variables = TensorCollector({v: bm.Variable(bm.zeros(max_size))
for v in target.variables})
for k in inits.keys():
Expand Down Expand Up @@ -207,7 +210,6 @@ def __init__(
self.dyn_vars.update(self.target.vars().unique())

# Variables

self.dyn_vars.update(self.variables)
if len(self._dyn_args) > 0:
self.idx = bm.Variable(bm.zeros(1, dtype=jnp.int_))
Expand Down Expand Up @@ -240,11 +242,6 @@ def _loop_func(times):
return out_vars, returns
self.step_func = _loop_func

def _post(self, times, returns: dict): # monitor
self.mon.ts = times + self.dt
for key in returns.keys():
self.mon[key] = bm.asarray(returns[key])

def _step(self, t):
# arguments
kwargs = dict()
Expand All @@ -254,17 +251,21 @@ def _step(self, t):
if len(self._dyn_args) > 0:
kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()})
self.idx += 1

# return of function monitors
returns = dict()
for key, func in self.fun_monitors.items():
returns[key] = func(t, self.dt)

# call integrator function
update_values = self.target(**kwargs)
if len(self.target.variables) == 1:
self.variables[self.target.variables[0]].update(update_values)
else:
for i, v in enumerate(self.target.variables):
self.variables[v].update(update_values[i])

# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
return returns
Expand Down
6 changes: 3 additions & 3 deletions brainpy/integrators/sde/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def step(self, *args, **kwargs):
# diffusion values
diffusions = self.g(**all_args)
if len(self.variables) == 1:
if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)):
raise ValueError('Diffusion values must be a tensor when there '
'is only one variable in the equation.')
# if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)):
# raise ValueError('Diffusion values must be a tensor when there '
# 'is only one variable in the equation.')
diffusions = {self.variables[0]: diffusions}
else:
if not isinstance(diffusions, (tuple, list)):
Expand Down
6 changes: 6 additions & 0 deletions brainpy/math/numpy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,11 +1840,15 @@ def asarray(a, dtype=None, order=None):

@wraps(jnp.arange)
def arange(*args, **kwargs):
args = [_remove_jaxarray(a) for a in args]
kwargs = {k: _remove_jaxarray(v) for k, v in kwargs.items()}
return JaxArray(jnp.arange(*args, **kwargs))


@wraps(jnp.linspace)
def linspace(*args, **kwargs):
args = [_remove_jaxarray(a) for a in args]
kwargs = {k: _remove_jaxarray(v) for k, v in kwargs.items()}
res = jnp.linspace(*args, **kwargs)
if isinstance(res, tuple):
return JaxArray(res[0]), res[1]
Expand All @@ -1854,6 +1858,8 @@ def linspace(*args, **kwargs):

@wraps(jnp.logspace)
def logspace(*args, **kwargs):
args = [_remove_jaxarray(a) for a in args]
kwargs = {k: _remove_jaxarray(v) for k, v in kwargs.items()}
return JaxArray(jnp.logspace(*args, **kwargs))


Expand Down
Loading