diff --git a/.github/workflows/Sync_branches.yml b/.github/workflows/Sync_branches.yml index 753301052..00ff74b68 100644 --- a/.github/workflows/Sync_branches.yml +++ b/.github/workflows/Sync_branches.yml @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@master - - name: Merge master -> brainpy-2.x + - name: Merge master -> brainpy-2.2.x uses: devmasx/merge-branch@master with: type: now from_branch: master - target_branch: brainpy-2.x + target_branch: brainpy-2.2.x github_token: ${{ github.token }} \ No newline at end of file diff --git a/brainpy/base/base.py b/brainpy/base/base.py index c7ca6f525..f55e25ec1 100644 --- a/brainpy/base/base.py +++ b/brainpy/base/base.py @@ -29,6 +29,8 @@ class Base(object): """ + _excluded_vars = () + def __init__(self, name=None): # check whether the object has a unique name. self._name = None @@ -120,8 +122,10 @@ def vars(self, method='absolute', level=-1, include_self=True): for node_path, node in nodes.items(): for k in dir(node): v = getattr(node, k) - if isinstance(v, math.Variable) and not k.startswith('_') and not k.endswith('_'): - gather[f'{node_path}.{k}' if node_path else k] = v + if isinstance(v, math.Variable): + if k not in node._excluded_vars: + # if not k.startswith('_') and not k.endswith('_'): + gather[f'{node_path}.{k}' if node_path else k] = v gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()}) return gather diff --git a/brainpy/dyn/neurons/fractional_models.py b/brainpy/dyn/neurons/fractional_models.py index 3d0277bf3..9a643b563 100644 --- a/brainpy/dyn/neurons/fractional_models.py +++ b/brainpy/dyn/neurons/fractional_models.py @@ -226,7 +226,7 @@ def __init__( self, size: Shape, alpha: Union[float, Sequence[float]], - num_step: int, + num_memory: int, a: Union[float, Tensor, Initializer, Callable] = 0.02, b: Union[float, Tensor, Initializer, Callable] = 0.20, c: Union[float, Tensor, Initializer, Callable] = -65., @@ -272,10 +272,10 @@ def __init__( self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool)) # functions - check_integer(num_step, 'num_step', allow_none=False) + check_integer(num_memory, 'num_step', allow_none=False) self.integral = CaputoL1Schema(f=self.derivative, alpha=alpha, - num_memory=num_step, + num_memory=num_memory, inits=[self.V, self.u]) def reset_state(self, batch_size=None): diff --git a/brainpy/inputs/currents.py b/brainpy/inputs/currents.py index 55031cd15..b6ada2712 100644 --- a/brainpy/inputs/currents.py +++ b/brainpy/inputs/currents.py @@ -113,12 +113,12 @@ def constant_input(I_and_duration, dt=None): # get the current start = 0 - I_current = jnp.zeros((int(np.ceil(I_duration / dt)),) + I_shape) + I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape) for c_size, duration in I_and_duration: length = int(duration / dt) - I_current = I_current.at[start: start + length].set(c_size) + I_current[start: start + length] = c_size start += length - return I_current, I_duration + return I_current.value, I_duration def constant_current(*args, **kwargs): @@ -172,12 +172,12 @@ def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None): if isinstance(sp_sizes, (float, int)): sp_sizes = [sp_sizes] * len(sp_times) - current = jnp.zeros(int(np.ceil(duration / dt))) + current = bm.zeros(int(np.ceil(duration / dt))) for time, dur, size in zip(sp_times, sp_lens, sp_sizes): pp = int(time / dt) p_len = int(dur / dt) - current = current.at[pp: pp + p_len].set(size) - return current + current[pp: pp + p_len] = size + return current.value def spike_current(*args, **kwargs): @@ -218,12 +218,12 @@ def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None): dt = bm.get_dt() if dt is None else dt t_end = duration if t_end is None else t_end - current = jnp.zeros(int(np.ceil(duration / dt))) + current = bm.zeros(int(np.ceil(duration / dt))) p1 = int(np.ceil(t_start / dt)) p2 = int(np.ceil(t_end / dt)) cc = jnp.array(jnp.linspace(c_start, c_end, p2 - p1)) - current = current.at[p1: p2].set(cc) - return current + current[p1: p2] = cc + return current.value def ramp_current(*args, **kwargs): @@ -265,9 +265,9 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None): i_start = int(t_start / dt) i_end = int(t_end / dt) noises = rng.standard_normal((i_end - i_start, n)) * jnp.sqrt(dt) - currents = jnp.zeros((int(duration / dt), n)) - currents = currents.at[i_start: i_end].set(bm.as_device_array(noises)) - return currents + currents = bm.zeros((int(duration / dt), n)) + currents[i_start: i_end] = noises + return currents.value def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, seed=None): diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py index 4e32a61d5..fd36e69a4 100644 --- a/brainpy/integrators/fde/Caputo.py +++ b/brainpy/integrators/fde/Caputo.py @@ -154,7 +154,7 @@ def __init__( def _check_step(self, args): dt, t = args - raise ValueError(f'The maximum number of step is {self.num_step}, ' + raise ValueError(f'The maximum number of step is {self.num_memory}, ' f'however, the current time {t} require a time ' f'step number {t / dt}.') @@ -164,7 +164,7 @@ def _integral_func(self, *args, **kwargs): t = all_args['t'] dt = all_args.pop(DT, self.dt) if check.is_checking(): - check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t)) + check_error_in_jit(self.num_memory * dt < t, self._check_step, (dt, t)) # derivative values devs = self.f(**all_args) @@ -185,11 +185,11 @@ def _integral_func(self, *args, **kwargs): # integral results integrals = [] - idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step + idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory for i, key in enumerate(self.variables): integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key] integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i])) - self.idx.value = (self.idx + 1) % self.num_step + self.idx.value = (self.idx + 1) % self.num_memory # return integrals if len(self.variables) == 1: @@ -344,19 +344,19 @@ def __init__( dtype=self.inits[v].dtype)) for v in self.variables} self.register_implicit_vars(self.diff_states) - self.idx = bm.Variable(bm.asarray([self.num_step - 1])) + self.idx = bm.Variable(bm.asarray([self.num_memory - 1])) # integral function self.set_integral(self._integral_func) def reset(self, inits): """Reset function.""" - self.idx.value = bm.asarray([self.num_step - 1]) + self.idx.value = bm.asarray([self.num_memory - 1]) inits = check_inits(inits, self.variables) for key, value in inits.items(): self.inits[key].value = value for key, val in inits.items(): - self.diff_states[key + "_diff"].value = bm.zeros((self.num_step,) + val.shape, dtype=val.dtype) + self.diff_states[key + "_diff"].value = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) def hists(self, var=None, numpy=True): """Get the recorded history values.""" @@ -378,7 +378,7 @@ def hists(self, var=None, numpy=True): def _check_step(self, args): dt, t = args - raise ValueError(f'The maximum number of step is {self.num_step}, ' + raise ValueError(f'The maximum number of step is {self.num_memory}, ' f'however, the current time {t} require a time ' f'step number {t / dt}.') @@ -388,7 +388,7 @@ def _integral_func(self, *args, **kwargs): t = all_args['t'] dt = all_args.pop(DT, self.dt) if check.is_checking(): - check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t)) + check_error_in_jit(self.num_memory * dt < t, self._check_step, (dt, t)) # derivative values devs = self.f(**all_args) @@ -405,7 +405,7 @@ def _integral_func(self, *args, **kwargs): # integral results integrals = [] - idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step + idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory for i, key in enumerate(self.variables): self.diff_states[key + '_diff'][self.idx[0]] = all_args[key] - self.inits[key] self.inits[key].value = all_args[key] @@ -413,7 +413,7 @@ def _integral_func(self, *args, **kwargs): memory_trace = self.coef[idx, i] @ self.diff_states[key + '_diff'] integral = markov_term - memory_trace integrals.append(integral) - self.idx.value = (self.idx + 1) % self.num_step + self.idx.value = (self.idx + 1) % self.num_memory # return integrals if len(self.variables) == 1: diff --git a/brainpy/integrators/fde/GL.py b/brainpy/integrators/fde/GL.py index 3518f6243..714de23ec 100644 --- a/brainpy/integrators/fde/GL.py +++ b/brainpy/integrators/fde/GL.py @@ -4,7 +4,7 @@ This module provides numerical solvers for Grünwald–Letnikov derivative FDEs. """ -from typing import Dict, Union, Callable +from typing import Dict, Union, Callable, Any import jax.numpy as jnp @@ -127,8 +127,8 @@ class GLShortMemory(FDEIntegrator): def __init__( self, f: Callable, - alpha, - inits, + alpha: Any, + inits: Any, num_memory: int, dt: float = None, name: str = None, @@ -152,9 +152,9 @@ def __init__( # delays self.delays = {} for key, val in inits.items(): - delay = bm.Variable(bm.zeros((self.num_step,) + val.shape, dtype=val.dtype)) + delay = bm.Variable(bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype)) delay[0] = val - self.delays[key] = delay + self.delays[key+'_delay'] = delay self._idx = bm.Variable(bm.asarray([1])) self.register_implicit_vars(self.delays) @@ -171,7 +171,7 @@ def reset(self, inits): self._idx.value = bm.asarray([1]) inits = check_inits(inits, self.variables) for key, val in inits.items(): - delay = bm.zeros((self.num_step,) + val.shape, dtype=val.dtype) + delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) delay[0] = val self.delays[key].value = delay @@ -199,13 +199,14 @@ def _integral_func(self, *args, **kwargs): # integral results integrals = [] - idx = (self._idx + bm.arange(self.num_step)) % self.num_step + idx = (self._idx + bm.arange(self.num_memory)) % self.num_memory for i, var in enumerate(self.variables): - summation = self._binomial_coef[:, i] @ self.delays[var][idx] + delay_var = var + '_delay' + summation = self._binomial_coef[:, i] @ self.delays[delay_var][idx] integral = (dt ** self.alpha[i]) * devs[var] - summation - self.delays[var][self._idx[0]] = integral + self.delays[delay_var][self._idx[0]] = integral integrals.append(integral) - self._idx.value = (self._idx + 1) % self.num_step + self._idx.value = (self._idx + 1) % self.num_memory # return integrals if len(self.variables) == 1: diff --git a/brainpy/integrators/fde/base.py b/brainpy/integrators/fde/base.py index 1820711ba..091f7de39 100644 --- a/brainpy/integrators/fde/base.py +++ b/brainpy/integrators/fde/base.py @@ -55,8 +55,8 @@ def __init__( arguments = parses[2] # function arguments # memory length - check_integer(num_memory, 'num_step', allow_none=False, min_bound=1) - self.num_step = num_memory + check_integer(num_memory, 'num_memory', allow_none=False, min_bound=1) + self.num_memory = num_memory # super initialization super(FDEIntegrator, self).__init__(name=name, diff --git a/brainpy/integrators/fde/tests/test_Caputo.py b/brainpy/integrators/fde/tests/test_Caputo.py index 5599bbffe..4948fe770 100644 --- a/brainpy/integrators/fde/tests/test_Caputo.py +++ b/brainpy/integrators/fde/tests/test_Caputo.py @@ -25,7 +25,7 @@ def test1(self): intg.idx[0] = N - 1 intg.diff_states['a_diff'][:N - 1] = bp.math.asarray(diff) - idx = ((intg.num_step - intg.idx) + np.arange(intg.num_step)) % intg.num_step + idx = ((intg.num_memory - intg.idx) + np.arange(intg.num_memory)) % intg.num_memory memory_trace2 = intg.coef[idx, 0] @ intg.diff_states['a_diff'] print() diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index 6e8a95d96..ea79b2d65 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -134,10 +134,11 @@ def __init__( numpy_mon_after_run: bool """ - # initialize variables if not isinstance(target, Integrator): raise TypeError(f'Target must be instance of {Integrator.__name__}, ' f'but we got {type(target)}') + + # get maximum size and initial variables if inits is not None: if isinstance(inits, (list, tuple, bm.JaxArray, jnp.ndarray)): assert len(target.variables) == len(inits) @@ -148,6 +149,8 @@ def __init__( else: max_size = 1 inits = dict() + + # initialize variables self.variables = TensorCollector({v: bm.Variable(bm.zeros(max_size)) for v in target.variables}) for k in inits.keys(): @@ -207,7 +210,6 @@ def __init__( self.dyn_vars.update(self.target.vars().unique()) # Variables - self.dyn_vars.update(self.variables) if len(self._dyn_args) > 0: self.idx = bm.Variable(bm.zeros(1, dtype=jnp.int_)) @@ -240,11 +242,6 @@ def _loop_func(times): return out_vars, returns self.step_func = _loop_func - def _post(self, times, returns: dict): # monitor - self.mon.ts = times + self.dt - for key in returns.keys(): - self.mon[key] = bm.asarray(returns[key]) - def _step(self, t): # arguments kwargs = dict() @@ -254,10 +251,12 @@ def _step(self, t): if len(self._dyn_args) > 0: kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()}) self.idx += 1 + # return of function monitors returns = dict() for key, func in self.fun_monitors.items(): returns[key] = func(t, self.dt) + # call integrator function update_values = self.target(**kwargs) if len(self.target.variables) == 1: @@ -265,6 +264,8 @@ def _step(self, t): else: for i, v in enumerate(self.target.variables): self.variables[v].update(update_values[i]) + + # progress bar if self.progress_bar: id_tap(lambda *args: self._pbar.update(), ()) return returns diff --git a/brainpy/integrators/sde/normal.py b/brainpy/integrators/sde/normal.py index c9975fcc8..90197e248 100644 --- a/brainpy/integrators/sde/normal.py +++ b/brainpy/integrators/sde/normal.py @@ -114,9 +114,9 @@ def step(self, *args, **kwargs): # diffusion values diffusions = self.g(**all_args) if len(self.variables) == 1: - if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): - raise ValueError('Diffusion values must be a tensor when there ' - 'is only one variable in the equation.') + # if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): + # raise ValueError('Diffusion values must be a tensor when there ' + # 'is only one variable in the equation.') diffusions = {self.variables[0]: diffusions} else: if not isinstance(diffusions, (tuple, list)): diff --git a/brainpy/math/numpy_ops.py b/brainpy/math/numpy_ops.py index 21a267985..f933236b6 100644 --- a/brainpy/math/numpy_ops.py +++ b/brainpy/math/numpy_ops.py @@ -1840,11 +1840,15 @@ def asarray(a, dtype=None, order=None): @wraps(jnp.arange) def arange(*args, **kwargs): + args = [_remove_jaxarray(a) for a in args] + kwargs = {k: _remove_jaxarray(v) for k, v in kwargs.items()} return JaxArray(jnp.arange(*args, **kwargs)) @wraps(jnp.linspace) def linspace(*args, **kwargs): + args = [_remove_jaxarray(a) for a in args] + kwargs = {k: _remove_jaxarray(v) for k, v in kwargs.items()} res = jnp.linspace(*args, **kwargs) if isinstance(res, tuple): return JaxArray(res[0]), res[1] @@ -1854,6 +1858,8 @@ def linspace(*args, **kwargs): @wraps(jnp.logspace) def logspace(*args, **kwargs): + args = [_remove_jaxarray(a) for a in args] + kwargs = {k: _remove_jaxarray(v) for k, v in kwargs.items()} return JaxArray(jnp.logspace(*args, **kwargs)) diff --git a/docs/tutorial_toolbox/illustration_joint_equations.py b/docs/tutorial_toolbox/illustration_joint_equations.py index f73f6cebb..4e5a1f863 100644 --- a/docs/tutorial_toolbox/illustration_joint_equations.py +++ b/docs/tutorial_toolbox/illustration_joint_equations.py @@ -24,8 +24,8 @@ def __init__(self, size): self.integral = bp.odeint(self.derivative, method='rk2') - def update(self, t, dt): - V, u = self.integral(self.V, self.u, t, self.input, dt=dt) + def update(self, tdi): + V, u = self.integral(self.V, self.u, tdi.t, self.input, tdi.dt) spike = V >= 0. self.V.value = bm.where(spike, -65., V) self.u.value = bm.where(spike, u + 8., u) @@ -49,9 +49,9 @@ def __init__(self, size): self.int_V = bp.odeint(self.dV, method='rk2') self.int_u = bp.odeint(self.du, method='rk2') - def update(self, t, dt): - V = self.int_V(self.V, t, self.u, self.input, dt=dt) - u = self.int_u(self.u, t, self.V, dt=dt) + def update(self, tdi): + V = self.int_V(self.V, tdi.t, self.u, self.input, tdi.dt) + u = self.int_u(self.u, tdi.t, self.V, tdi.dt) spike = V >= 0. self.V.value = bm.where(spike, -65., V) self.u.value = bm.where(spike, u + 8., u) @@ -59,11 +59,11 @@ def update(self, t, dt): neu1 = IzhiJoint(1) -runner = bp.StructRunner(neu1, monitors=['V'], inputs=('input', 20.), dt=0.2) +runner = bp.dyn.DSRunner(neu1, monitors=['V'], inputs=('input', 20.), dt=0.2) runner(800) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, alpha=0.6, legend='V - joint', show=False) neu2 = IzhiSeparate(1) -runner = bp.StructRunner(neu2, monitors=['V'], inputs=('input', 20.), dt=0.2) +runner = bp.dyn.DSRunner(neu2, monitors=['V'], inputs=('input', 20.), dt=0.2) runner(800) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, alpha=0.6, legend='V - separate', show=True) diff --git a/examples/analysis/highdim_CANN.py b/examples/analysis/highdim_CANN.py index ca3490e26..5a121e0ba 100644 --- a/examples/analysis/highdim_CANN.py +++ b/examples/analysis/highdim_CANN.py @@ -84,13 +84,15 @@ def find_fixed_points(pars=None, verbose=False, opt_method='gd', cand_method='ra finder = bp.analysis.SlowPointFinder(f_cell=cann, target_vars={'u': cann.u}, dt=1.) if opt_method == 'gd': finder.find_fps_with_gd_method( - candidates={'u': candidates}, tolerance=tolerance, num_batch=200, + candidates={'u': candidates}, + tolerance=tolerance, + num_batch=200, optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.2, 1, 0.999)), ) elif opt_method == 'BFGS': finder.find_fps_with_opt_solver({'u': candidates}) else: - raise ValueError + raise ValueError() finder.filter_loss(tolerance) finder.keep_unique(5e-3) diff --git a/examples/simulation/Wu_2008_CANN_2D.py b/examples/simulation/Wu_2008_CANN_2D.py index 885564be0..886812cca 100644 --- a/examples/simulation/Wu_2008_CANN_2D.py +++ b/examples/simulation/Wu_2008_CANN_2D.py @@ -53,11 +53,16 @@ def dist(self, d): def make_conn(self): x1, x2 = bm.meshgrid(self.x, self.x) value = bm.stack([x1.flatten(), x2.flatten()]).T - d = self.dist(bm.abs(value[0] - value)) - d = bm.linalg.norm(d, axis=1) - d = d.reshape((self.length, self.length)) - Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) - return Jxx + + @jax.vmap + def get_J(v): + d = self.dist(bm.abs(v - value)) + d = bm.linalg.norm(d, axis=1) + # d = d.reshape((self.length, self.length)) + Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) + return Jxx + + return get_J(value) def get_stimulus_by_pos(self, pos): assert bm.size(pos) == 2 @@ -72,20 +77,19 @@ def update(self, tdi): r1 = bm.square(self.u) r2 = 1.0 + self.k * bm.sum(r1) self.r.value = r1 / r2 - r = bm.fft.fft2(self.r) - jjft = bm.fft.fft2(self.conn_mat) - interaction = bm.real(bm.fft.ifft2(r * jjft)) + interaction = (self.r.flatten() @ self.conn_mat).reshape((self.length, self.length)) self.u.value = self.u + (-self.u + self.input + interaction) / self.tau * tdi.dt self.input[:] = 0. -cann = CANN2D(length=512, k=0.1) +cann = CANN2D(length=100, k=0.1) cann.show_conn() # encoding Iext, length = bp.inputs.section_input( values=[cann.get_stimulus_by_pos([0., 0.]), 0.], - durations=[10., 20.], return_length=True + durations=[10., 20.], + return_length=True ) runner = bp.dyn.DSRunner(cann, inputs=['input', Iext, 'iter'], @@ -93,7 +97,8 @@ def update(self, tdi): dyn_vars=cann.vars()) runner.run(length) -bp.visualize.animate_2D(values=runner.mon.r, net_size=(cann.length, cann.length)) +bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)), + net_size=(cann.length, cann.length)) # tracking length = 20 @@ -102,8 +107,8 @@ def update(self, tdi): Iext = jax.vmap(cann.get_stimulus_by_pos)(positions) runner = bp.dyn.DSRunner(cann, inputs=['input', Iext, 'iter'], - monitors=['r'], - dyn_vars=cann.vars()) + monitors=['r']) runner.run(length) -bp.visualize.animate_2D(values=runner.mon.r, net_size=(cann.length, cann.length)) +bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)), + net_size=(cann.length, cann.length))