diff --git a/brainpy/__init__.py b/brainpy/__init__.py index e2ac8336b..e0d2f0444 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -36,11 +36,9 @@ from . import integrators from .integrators import ode from .integrators import sde -from .integrators import dde from .integrators import fde from .integrators.ode import odeint from .integrators.sde import sdeint -from .integrators.dde import ddeint from .integrators.fde import fdeint from .integrators.joint_eq import JointEq diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index dc051148e..de4f4dddc 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -2,19 +2,16 @@ from typing import Union, Callable -import numpy as np -from jax.experimental.host_callback import id_tap - import brainpy.math as bm from brainpy import check from brainpy.dyn.base import NeuGroup +from brainpy.dyn.others.noises import OUProcess from brainpy.initialize import Initializer, Uniform, init_param, ZeroInit -from brainpy.integrators.dde import ddeint from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.tools.checking import check_float, check_initializer +from brainpy.tools.errors import check_error_in_jit from brainpy.types import Shape, Tensor -from brainpy.dyn.others.noises import OUProcess __all__ = [ 'Population', @@ -307,7 +304,7 @@ def __init__( method=sde_method) # integral - self.integral = ddeint(method=method, + self.integral = odeint(method=method, f=JointEq([self.dx, self.dy]), state_delays={'V': self.x_delay}) @@ -327,15 +324,14 @@ def dx(self, x, t, y, x_ext): def dy(self, y, t, x, y_ext): return (x + self.a - self.b * y + y_ext) / self.tau - def _check_dt(self, dt, *args): - if np.absolute(dt - self.dt) > 1e-6: - raise ValueError(f'The "dt" {dt} used in model running is ' - f'not consistent with the "dt" {self.dt} ' - f'used in model definition.') + def _check_dt(self, dt): + raise ValueError(f'The "dt" {dt} used in model running is ' + f'not consistent with the "dt" {self.dt} ' + f'used in model definition.') def update(self, t, dt): if check.is_checking(): - id_tap(self._check_dt, dt) + check_error_in_jit(not bm.isclose(dt, self.dt), self._check_dt, dt) if self.x_ou is not None: self.input += self.x_ou.x self.x_ou.update(t, dt) @@ -882,5 +878,3 @@ def update(self, t, dt): self.i.value = bm.maximum(self.i + di * dt, 0.) self.Ie[:] = 0. self.Ii[:] = 0. - - diff --git a/brainpy/integrators/__init__.py b/brainpy/integrators/__init__.py index e8f3436a3..a8ecacef8 100644 --- a/brainpy/integrators/__init__.py +++ b/brainpy/integrators/__init__.py @@ -42,10 +42,5 @@ set_default_fdeint, register_fde_integrator) -# DDE tools -from . import dde -from .dde import ddeint - - # PDE tools from . import pde diff --git a/brainpy/integrators/dde/__init__.py b/brainpy/integrators/dde/__init__.py deleted file mode 100644 index 52dc2cb8d..000000000 --- a/brainpy/integrators/dde/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Numerical methods for delay differential equations (DDEs). -""" - -from .base import * -from .generic import * -from .explicit_rk import * diff --git a/brainpy/integrators/dde/base.py b/brainpy/integrators/dde/base.py deleted file mode 100644 index 3b5e8eb6c..000000000 --- a/brainpy/integrators/dde/base.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union, Callable, Dict - -import brainpy.math as bm -from brainpy.errors import DiffEqError -from brainpy.integrators.base import Integrator -from brainpy.integrators.constants import F, DT, unique_name -from brainpy.integrators.utils import get_args -from brainpy.tools.checking import check_dict_data - -__all__ = [ - 'DDEIntegrator', -] - - -class DDEIntegrator(Integrator): - """Basic numerical integrator for delay differential equations (DDEs). - """ - - def __init__( - self, - f: Callable, - var_type: str = None, - dt: Union[float, int] = None, - name: str = None, - show_code: bool = False, - state_delays: Dict[str, bm.TimeDelay] = None, - neutral_delays: Dict[str, bm.NeuTimeDelay] = None, - ): - dt = bm.get_dt() if dt is None else dt - parses = get_args(f) - variables = parses[0] # variable names, (before 't') - parameters = parses[1] # parameter names, (after 't') - arguments = parses[2] # function arguments - - # super initialization - super(DDEIntegrator, self).__init__(name=name, - variables=variables, - parameters=parameters, - arguments=arguments, - dt=dt, - state_delays=state_delays) - - # other settings - self.var_type = var_type - self.show_code = show_code - - # derivative function - self.derivative = {F: f} - self.f = f - - # code scope - self.code_scope = {F: f} - - # code lines - self.func_name = _f_names(f) - self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):'] - - # delays - self._neutral_delays = dict() - if neutral_delays is not None: - check_dict_data(neutral_delays, - key_type=str, - val_type=bm.NeuTimeDelay) - for key, delay in neutral_delays.items(): - if key not in self.variables: - raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') - self._neutral_delays[key] = delay - self.register_implicit_nodes(self._neutral_delays) - if (len(self.neutral_delays) + len(self.state_delays)) == 0: - raise DiffEqError('There is no delay variable, it should not be ' - 'a delay differential equation, please use "brainpy.odeint()". ' - 'Or, if you forget add delay variables, please set them with ' - '"state_delays" and "neutral_delays" arguments.') - - @property - def neutral_delays(self): - return self._neutral_delays - - @neutral_delays.setter - def neutral_delays(self, value): - raise ValueError('Cannot set "neutral_delays" by users.') - - def __call__(self, *args, **kwargs): - assert self.integral is not None, 'Please build the integrator first.' - # check arguments - for i, arg in enumerate(args): - kwargs[self.arg_names[i]] = arg - - # integral - new_vars = self.integral(**kwargs) - if len(self.variables) == 1: - dict_vars = {self.variables[0]: new_vars} - else: - dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} - - dt = kwargs.pop(DT, self.dt) - # update neutral delay variables - if len(self.neutral_delays): - kwargs.update(dict_vars) - new_dvars = self.f(**kwargs) - if len(self.variables) == 1: - new_dvars = {self.variables[0]: new_dvars} - else: - new_dvars = {k: new_dvars[i] for i, k in enumerate(self.variables)} - for key, delay in self.neutral_delays.items(): - if isinstance(delay, bm.LengthDelay): - delay.update(new_dvars[key]) - elif isinstance(delay, bm.TimeDelay): - delay.update(kwargs['t'] + dt, new_dvars[key]) - else: - raise ValueError('Unknown delay variable. We only supports ' - 'brainpy.math.LengthDelay, brainpy.math.TimeDelay, ' - 'brainpy.math.NeutralDelay. ' - f'While we got {delay}') - - # update state delay variables - for key, delay in self.state_delays.items(): - if isinstance(delay, bm.LengthDelay): - delay.update(dict_vars[key]) - elif isinstance(delay, bm.TimeDelay): - delay.update(kwargs['t'] + dt, dict_vars[key]) - else: - raise ValueError('Unknown delay variable. We only supports ' - 'brainpy.math.LengthDelay, brainpy.math.TimeDelay, ' - 'brainpy.math.NeutralDelay. ' - f'While we got {delay}') - - return new_vars - - -def _f_names(f): - func_name = unique_name('dde') - if f.__name__.isidentifier(): - func_name += '_' + f.__name__ - return func_name diff --git a/brainpy/integrators/dde/explicit_rk.py b/brainpy/integrators/dde/explicit_rk.py deleted file mode 100644 index d6c22b36a..000000000 --- a/brainpy/integrators/dde/explicit_rk.py +++ /dev/null @@ -1,176 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict -import brainpy.math as bm - -from brainpy.integrators.constants import F, DT -from brainpy.integrators.dde.base import DDEIntegrator -from brainpy.integrators.ode import common -from brainpy.integrators.utils import compile_code, check_kws -from brainpy.integrators.dde.generic import register_dde_integrator - -__all__ = [ - 'ExplicitRKIntegrator', - 'Euler', - 'MidPoint', - 'Heun2', - 'Ralston2', - 'RK2', - 'RK3', - 'Heun3', - 'Ralston3', - 'SSPRK3', - 'RK4', - 'Ralston4', - 'RK4Rule38', -] - - -class ExplicitRKIntegrator(DDEIntegrator): - A = [] # The A matrix in the Butcher tableau. - B = [] # The B vector in the Butcher tableau. - C = [] # The C vector in the Butcher tableau. - - def __init__(self, f, **kwargs): - super(ExplicitRKIntegrator, self).__init__(f=f, **kwargs) - - # integrator keywords - keywords = { - F: 'the derivative function', - # DT: 'the precision of numerical integration' - } - for v in self.variables: - keywords[f'{v}_new'] = 'the intermediate value' - for i in range(1, len(self.A) + 1): - keywords[f'd{v}_k{i}'] = 'the intermediate value' - for i in range(2, len(self.A) + 1): - keywords[f'k{i}_{v}_arg'] = 'the intermediate value' - keywords[f'k{i}_t_arg'] = 'the intermediate value' - check_kws(self.arg_names, keywords) - self.build() - - def build(self): - # step stage - common.step(self.variables, DT, self.A, self.C, self.code_lines, self.parameters) - # variable update - return_args = common.update(self.variables, DT, self.B, self.code_lines) - # returns - self.code_lines.append(f' return {", ".join(return_args)}') - # compile - self.integral = compile_code(code_scope={k: v for k, v in self.code_scope.items()}, - code_lines=self.code_lines, - show_code=self.show_code, - func_name=self.func_name) - - -class Euler(ExplicitRKIntegrator): - A = [(), ] - B = [1] - C = [0] - - -register_dde_integrator('euler', Euler) - - -class MidPoint(ExplicitRKIntegrator): - A = [(), (0.5,)] - B = [0, 1] - C = [0, 0.5] - - -register_dde_integrator('midpoint', MidPoint) - - -class Heun2(ExplicitRKIntegrator): - A = [(), (1,)] - B = [0.5, 0.5] - C = [0, 1] - - -register_dde_integrator('heun2', Heun2) - - -class Ralston2(ExplicitRKIntegrator): - A = [(), ('2/3',)] - B = [0.25, 0.75] - C = [0, '2/3'] - - -register_dde_integrator('ralston2', Ralston2) - - -class RK2(ExplicitRKIntegrator): - def __init__(self, f, beta=2 / 3, var_type=None, dt=None, name=None, - state_delays: Dict[str, bm.TimeDelay] = None, - neutral_delays: Dict[str, bm.NeuTimeDelay] = None): - self.A = [(), (beta,)] - self.B = [1 - 1 / (2 * beta), 1 / (2 * beta)] - self.C = [0, beta] - super(RK2, self).__init__(f=f, var_type=var_type, dt=dt, name=name, - state_delays=state_delays, neutral_delays=neutral_delays) - - -register_dde_integrator('rk2', RK2) - - -class RK3(ExplicitRKIntegrator): - A = [(), (0.5,), (-1, 2)] - B = ['1/6', '2/3', '1/6'] - C = [0, 0.5, 1] - - -register_dde_integrator('rk3', RK3) - - -class Heun3(ExplicitRKIntegrator): - A = [(), ('1/3',), (0, '2/3')] - B = [0.25, 0, 0.75] - C = [0, '1/3', '2/3'] - - -register_dde_integrator('heun3', Heun3) - - -class Ralston3(ExplicitRKIntegrator): - A = [(), (0.5,), (0, 0.75)] - B = ['2/9', '1/3', '4/9'] - C = [0, 0.5, 0.75] - - -register_dde_integrator('ralston3', Ralston3) - - -class SSPRK3(ExplicitRKIntegrator): - A = [(), (1,), (0.25, 0.25)] - B = ['1/6', '1/6', '2/3'] - C = [0, 1, 0.5] - - -register_dde_integrator('ssprk3', SSPRK3) - - -class RK4(ExplicitRKIntegrator): - A = [(), (0.5,), (0., 0.5), (0., 0., 1)] - B = ['1/6', '1/3', '1/3', '1/6'] - C = [0, 0.5, 0.5, 1] - - -register_dde_integrator('rk4', RK4) - - -class Ralston4(ExplicitRKIntegrator): - A = [(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)] - B = [.17476028, -.55148066, 1.20553560, .17118478] - C = [0, .4, .45573725, 1] - - -register_dde_integrator('ralston4', Ralston4) - - -class RK4Rule38(ExplicitRKIntegrator): - A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)] - B = [0.125, 0.375, 0.375, 0.125] - C = [0, '1/3', '2/3', 1] - - -register_dde_integrator('rk4_38rule', RK4Rule38) diff --git a/brainpy/integrators/dde/generic.py b/brainpy/integrators/dde/generic.py deleted file mode 100644 index 693c5b99a..000000000 --- a/brainpy/integrators/dde/generic.py +++ /dev/null @@ -1,135 +0,0 @@ -# -*- coding: utf-8 -*- - - -import warnings -from typing import Union, Dict - -import brainpy.math as bm -from .base import DDEIntegrator - -__all__ = [ - 'ddeint', - 'set_default_ddeint', - 'get_default_ddeint', - 'register_dde_integrator', - 'get_supported_methods', -] - -name2method = {} - -_DEFAULT_DDE_METHOD = 'euler' - - -def ddeint( - f=None, - method='euler', - var_type: str = None, - dt: Union[float, int] = None, - name: str = None, - show_code: bool = False, - state_delays: Dict[str, bm.TimeDelay] = None, - neutral_delays: Dict[str, bm.NeuTimeDelay] = None, - **kwargs -): - """Numerical integration for ODEs. - - .. deprecated:: 2.1.11 - Please use :py:func:`~.odeint` instead. This module will be removed since version 2.2.0. - - Parameters - ---------- - f : callable, function - The derivative function. - method : str - The shortcut name of the numerical integrator. - var_type: str - Variable type in the defined function. - dt: float, int - The time precision for integration. - name: str - The name. - show_code: bool - Whether show the formatted codes. - state_delays: dict - The state delay variables. - neutral_delays: dict - The neutral delay variable. - - Returns - ------- - integral : DDEIntegrator - The numerical solver of `f`. - """ - warnings.warn('Please use "brainpy.dde.ddeint" instead. ' - '"brainpy.dde.ddeint" is deprecated since ' - 'version 2.1.11. ', DeprecationWarning) - - method = _DEFAULT_DDE_METHOD if method is None else method - if method not in name2method: - raise ValueError(f'Unknown DDE numerical method "{method}". Currently ' - f'BrainPy only support: {list(name2method.keys())}') - - if f is None: - return lambda f: name2method[method](f, - var_type=var_type, - dt=dt, - name=name, - state_delays=state_delays, - neutral_delays=neutral_delays, - **kwargs) - else: - return name2method[method](f, - var_type=var_type, - dt=dt, - name=name, - state_delays=state_delays, - neutral_delays=neutral_delays, - **kwargs) - - -def set_default_ddeint(method): - """Set the default ODE numerical integrator method for differential equations. - - Parameters - ---------- - method : str, callable - Numerical integrator method. - """ - if not isinstance(method, str): - raise ValueError(f'Only support string, not {type(method)}.') - if method not in name2method: - raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') - - global _DEFAULT_DDE_METHOD - _DEFAULT_DDE_METHOD = method - - -def get_default_ddeint(): - """Get the default ODE numerical integrator method. - - Returns - ------- - method : str - The default numerical integrator method. - """ - return _DEFAULT_DDE_METHOD - - -def register_dde_integrator(name, integrator): - """Register a new DDE integrator. - - Parameters - ---------- - name: ste - integrator: type - """ - if name in name2method: - raise ValueError(f'"{name}" has been registered in DDE integrators.') - if not issubclass(integrator, DDEIntegrator): - raise ValueError(f'"integrator" must be an instance of {DDEIntegrator.__name__}') - name2method[name] = integrator - - -def get_supported_methods(): - """Get all supported numerical methods for DDEs.""" - return list(name2method.keys()) diff --git a/brainpy/integrators/dde/tests/test_explicit_rk.py b/brainpy/integrators/dde/tests/test_explicit_rk.py deleted file mode 100644 index 082acf116..000000000 --- a/brainpy/integrators/dde/tests/test_explicit_rk.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- coding: utf-8 -*- - - -import unittest - -import brainpy as bp -import brainpy.math as bm - - -class TestExplicitRKStateDelay(unittest.TestCase): - def test_euler(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='euler', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_midpoint(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='midpoint', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_heun2(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='heun2', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_ralston2(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='ralston2', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_rk2(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='rk2', - state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_rk3(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='rk3', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_heun3(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='heun3', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_ralston3(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='ralston3', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_ssprk3(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='ssprk3', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_rk4(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='rk4', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_ralston4(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='ralston4', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - - def test_rk4_38rule(self): - xdelay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - - @bp.ddeint(method='rk4_38rule', state_delays={'x': xdelay}) - def equation(x, t, ): - return -xdelay(t - 1) - - runner = bp.integrators.IntegratorRunner(equation, monitors=['x']) - runner.run(20.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon['x'], show=True) - diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py index 3476a418c..68bf5bfd0 100644 --- a/brainpy/integrators/fde/Caputo.py +++ b/brainpy/integrators/fde/Caputo.py @@ -8,13 +8,13 @@ from typing import Union, Dict import jax.numpy as jnp -from jax.experimental.host_callback import id_tap import brainpy.math as bm from brainpy import check from brainpy.errors import UnsupportedError from brainpy.integrators.constants import DT from brainpy.integrators.utils import check_inits, format_args +from brainpy.tools.errors import check_error_in_jit from .base import FDEIntegrator from .generic import register_fde_integrator @@ -151,19 +151,19 @@ def __init__( self.set_integral(self._integral_func) - def _check_step(self, args, transform): + def _check_step(self, args): dt, t = args - if self.num_step * dt < t: - raise ValueError(f'The maximum number of step is {self.num_step}, ' - f'however, the current time {t} require a time ' - f'step number {t / dt}.') + raise ValueError(f'The maximum number of step is {self.num_step}, ' + f'however, the current time {t} require a time ' + f'step number {t / dt}.') def _integral_func(self, *args, **kwargs): # format arguments all_args = format_args(args, kwargs, self.arg_names) + t = all_args['t'] dt = all_args.pop(DT, self.dt) if check.is_checking(): - id_tap(self._check_step, (dt, all_args['t'])) + check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t)) # derivative values devs = self.f(**all_args) @@ -375,19 +375,19 @@ def hists(self, var=None, numpy=True): hists_ = hists_.numpy() return hists_ - def _check_step(self, args, transform): + def _check_step(self, args): dt, t = args - if self.num_step * dt < t: - raise ValueError(f'The maximum number of step is {self.num_step}, ' - f'however, the current time {t} require a time ' - f'step number {t / dt}.') + raise ValueError(f'The maximum number of step is {self.num_step}, ' + f'however, the current time {t} require a time ' + f'step number {t / dt}.') def _integral_func(self, *args, **kwargs): # format arguments all_args = format_args(args, kwargs, self.arg_names) + t = all_args['t'] dt = all_args.pop(DT, self.dt) if check.is_checking(): - id_tap(self._check_step, (dt, all_args['t'])) + check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t)) # derivative values devs = self.f(**all_args) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index fcb6592cf..70166868e 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -3,9 +3,7 @@ from typing import Union, Callable, Tuple import jax.numpy as jnp -import numpy as np from jax import vmap -from jax.experimental.host_callback import id_tap from jax.lax import cond from brainpy import check @@ -15,6 +13,7 @@ from brainpy.math.jaxarray import ndarray, Variable, JaxArray from brainpy.math.setting import get_dt from brainpy.tools.checking import check_float, check_integer +from brainpy.tools.errors import check_error_in_jit __all__ = [ 'AbstractDelay', @@ -195,26 +194,24 @@ def reset(self, self.data[:-1] = before_t0 self._before_type = _DATA_BEFORE - def _check_time(self, times, transforms): + def _check_time1(self, times): prev_time, current_time = times - current_time = current_time[0] - if prev_time > current_time + 1e-6: - raise ValueError(f'\n' - f'!!! Error in {self.__class__.__name__}: \n' - f'The request time should be less than the ' - f'current time {current_time}. But we ' - f'got {prev_time} > {current_time}') - lower_time = current_time - self.delay_len - if prev_time < lower_time - self.dt: - raise ValueError(f'\n' - f'!!! Error in {self.__class__.__name__}: \n' - f'The request time of the variable should be in ' - f'[{lower_time}, {current_time}], but we got {prev_time}') + raise ValueError(f'The request time should be less than the ' + f'current time {current_time}. But we ' + f'got {prev_time} > {current_time}') + + def _check_time2(self, times): + prev_time, current_time = times + raise ValueError(f'The request time of the variable should be in ' + f'[{current_time - self.delay_len}, {current_time}], ' + f'but we got {prev_time}') def __call__(self, time, indices=None): # check if check.is_checking(): - id_tap(self._check_time, (time, self.current_time)) + current_time = self.current_time[0] + check_error_in_jit(time > current_time + 1e-6, self._check_time1, (time, current_time)) + check_error_in_jit(time < current_time - self.delay_len - self.dt, self._check_time2, (time, current_time)) if self._before_type == _FUNC_BEFORE: res = cond(time < self.t0, self._before_t0, @@ -338,20 +335,14 @@ def reset( else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') - def _check_delay(self, delay_len, transforms): - if isinstance(delay_len, ndarray): - delay_len = delay_len.value - if np.any(delay_len >= self.num_delay_step): - raise ValueError(f'\n' - f'!!! Error in {self.__class__.__name__}: \n' - f'The request delay length should be less than the ' - f'maximum delay {self.num_delay_step}. But we ' - f'got {delay_len}') + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.num_delay_step}. But we got {delay_len}') def __call__(self, delay_len, *indices): # check if check.is_checking(): - id_tap(self._check_delay, delay_len) + check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len) # the delay length delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step if not jnp.issubdtype(delay_idx.dtype, jnp.integer): diff --git a/brainpy/math/random.py b/brainpy/math/random.py index b5ac71338..49a462925 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -2,17 +2,17 @@ from collections import namedtuple from functools import partial - from operator import index + import jax import numpy as np from jax import lax, jit, vmap, numpy as jnp, random as jr, core from jax._src import dtypes from jax.experimental.host_callback import call -from jax.experimental import checkify from jax.tree_util import register_pytree_node from brainpy.math.jaxarray import JaxArray, Variable +from brainpy.tools.errors import check_error_in_jit from .utils import wraps __all__ = [ @@ -648,10 +648,12 @@ def truncated_normal(self, lower, upper, size, scale=None): else: return JaxArray(rands * scale) + def _check_p(self, p): + raise ValueError(f'Parameter p should be within [0, 1], but we got {p}') + def bernoulli(self, p, size=None): - p = _remove_jax_array(p) - p = _check_py_seq(p) - checkify.check(jnp.all(jnp.logical_and(p >= 0, p <= 1)), 'Bernoulli parameter p should be within [0, 1]') + p = _check_py_seq(_remove_jax_array(p)) + check_error_in_jit(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) if size is None: size = jnp.shape(p) return JaxArray(jr.bernoulli(self.split_key(), p=p, shape=_size2shape(size))) @@ -668,11 +670,9 @@ def lognormal(self, mean=None, sigma=None, size=None): return JaxArray(samples) def binomial(self, n, p, size=None): - n = n.value if isinstance(n, JaxArray) else n - p = p.value if isinstance(p, JaxArray) else p - n = _check_py_seq(n) - p = _check_py_seq(p) - checkify.check(jnp.all(jnp.logical_and(p >= 0, p <= 1)), '"p" must be in [0, 1].') + n = _check_py_seq(n.value if isinstance(n, JaxArray) else n) + p = _check_py_seq(p.value if isinstance(p, JaxArray) else p) + check_error_in_jit(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) if size is None: size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)) return JaxArray(_binomial(self.split_key(), p, n, shape=_size2shape(size))) @@ -703,10 +703,13 @@ def geometric(self, p, size=None): r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p)) return JaxArray(r) + def _check_p2(self, p): + raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}') + def multinomial(self, n, pvals, size=None): n = _check_py_seq(_remove_jax_array(n)) pvals = _check_py_seq(_remove_jax_array(pvals)) - checkify.check(jnp.sum(pvals[:-1]) <= 1., 'We require `sum(pvals[:-1]) <= 1`.') + check_error_in_jit(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals) if isinstance(n, jax.core.Tracer): raise ValueError("The total count parameter `n` should not be a jax abstract array.") size = _size2shape(size) diff --git a/brainpy/math/tests/test_numpy_indexing.py b/brainpy/math/tests/test_numpy_indexing.py index e9c8dedd4..05071745c 100644 --- a/brainpy/math/tests/test_numpy_indexing.py +++ b/brainpy/math/tests/test_numpy_indexing.py @@ -404,20 +404,19 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): MODES = ["clip", "drop", "promise_in_bounds"] - class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @parameterized.named_parameters( - jtu.cases_from_list({"testcase_name": - "{}_inshape={}_indexer={}".format(name, jtu.format_shape_dtype_string(shape, dtype), - indexer), - "shape": shape, - "dtype": dtype, - "indexer": indexer} - for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer, _ in index_specs - for dtype in all_dtypes)) + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_indexer={}".format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), + "shape": shape, + "dtype": dtype, + "indexer": indexer} + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer, _ in index_specs + for dtype in all_dtypes) + ) def testStaticIndexing(self, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -430,11 +429,15 @@ def testStaticIndexing(self, shape, dtype, indexer): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - @parameterized.named_parameters(jtu.cases_from_list({ - "testcase_name": f"_{funcname}", "funcname": funcname} - for funcname in - ["negative", "sin", "cos", "square", "sqrt", "log", "exp"])) + @parameterized.named_parameters( + jtu.cases_from_list({"testcase_name": f"_{funcname}", "funcname": funcname} + for funcname in + ["negative", "sin", "cos", "square", "sqrt", "log", "exp"]) + ) def testIndexApply(self, funcname, size=10, dtype='float32'): + if not hasattr(jnp.zeros(1).at[0], 'apply'): + self.skipTest('Has not apply() function') + rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), -size, size) np_func = getattr(np, funcname) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index 485449e2a..0ed341e58 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -319,13 +319,6 @@ def test_multinominal2(self): self.assertTupleEqual(a.shape, (3,)) self.assertTrue(a.sum() == 100) - def test_multinominal3(self): - with self.assertRaises(ValueError): - a = bm.random.multinomial(100, (0.5, 0.6, 0.3)) - with self.assertRaises(ValueError): - f = jax.jit(bm.random.multinomial, static_argnums=2) - a = f(100, (0.5, 0.6, 0.3), 2) - def test_multivariate_normal1(self): # self.skipTest('Windows jaxlib error') a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) diff --git a/brainpy/tools/__init__.py b/brainpy/tools/__init__.py index eb0f9b809..628663331 100644 --- a/brainpy/tools/__init__.py +++ b/brainpy/tools/__init__.py @@ -2,3 +2,4 @@ from .codes import * from .others import * +from .errors import * diff --git a/brainpy/tools/errors.py b/brainpy/tools/errors.py new file mode 100644 index 000000000..b23189d2b --- /dev/null +++ b/brainpy/tools/errors.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + + +from jax.lax import cond +from jax.experimental.host_callback import id_tap + +__all__ = [ + 'check_error_in_jit' +] + + +def _make_err_func(f): + f2 = lambda arg, transforms: f(arg) + + def err_f(x): + id_tap(f2, x) + return + return err_f + + +def check_error_in_jit(pred, err_f, err_arg=None): + """Check errors in a jit function. + + Parameters + ---------- + pred: bool + The boolean prediction. + err_f: callable + The error function, which raise errors. + err_arg: any + The arguments which passed into `err_f`. + """ + cond(pred, _make_err_func(err_f), lambda _: None, err_arg) + + diff --git a/docs/apis/integrators.rst b/docs/apis/integrators.rst index cf6fb02bd..dc6904410 100644 --- a/docs/apis/integrators.rst +++ b/docs/apis/integrators.rst @@ -11,6 +11,5 @@ auto/integrators/joint_eq integrators/ODE integrators/SDE - integrators/DDE integrators/FDE diff --git a/docs/apis/integrators/DDE.rst b/docs/apis/integrators/DDE.rst deleted file mode 100644 index 50a9811e1..000000000 --- a/docs/apis/integrators/DDE.rst +++ /dev/null @@ -1,14 +0,0 @@ -Numerical Methods for DDEs -========================== - -.. currentmodule:: brainpy.integrators.dde -.. automodule:: brainpy.integrators.dde - - -.. toctree:: - :maxdepth: 2 - - ../auto/integrators/dde_base - ../auto/integrators/dde_generic - ../auto/integrators/dde_explicit_rk - diff --git a/docs/apis/tools.rst b/docs/apis/tools.rst index cd3d44094..20c232eef 100644 --- a/docs/apis/tools.rst +++ b/docs/apis/tools.rst @@ -10,4 +10,5 @@ auto/tools/checking auto/tools/codes + auto/tools/errors auto/tools/others diff --git a/docs/auto_generater.py b/docs/auto_generater.py index c1d6090e1..4ad8ea76c 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -358,17 +358,6 @@ def generate_integrators_doc(path='apis/auto/integrators/'): filename=os.path.join(path, 'sde_srk_scalar.rst'), header='SRK methods for scalar Wiener process') - # DDE - write_module(module_name='brainpy.integrators.dde.base', - filename=os.path.join(path, 'dde_base.rst'), - header='Base Integrator') - write_module(module_name='brainpy.integrators.dde.generic', - filename=os.path.join(path, 'dde_generic.rst'), - header='Generic Functions') - write_module(module_name='brainpy.integrators.dde.explicit_rk', - filename=os.path.join(path, 'dde_explicit_rk.rst'), - header='Explicit Runge-Kutta Methods') - # FDE write_module(module_name='brainpy.integrators.fde.base', filename=os.path.join(path, 'fde_base.rst'), @@ -571,6 +560,9 @@ def generate_tools_docs(path='apis/auto/tools/'): write_module(module_name='brainpy.tools.others', filename=os.path.join(path, 'others.rst'), header='Other Tools') + write_module(module_name='brainpy.tools.errors', + filename=os.path.join(path, 'errors.rst'), + header='Error Tools') def generate_compact_docs(path='apis/auto/compat/'): diff --git a/examples/simulation/Wang_2002_decision_making_spiking.py b/examples/simulation/Wang_2002_decision_making_spiking.py index ecc0cde81..1d2e60f68 100644 --- a/examples/simulation/Wang_2002_decision_making_spiking.py +++ b/examples/simulation/Wang_2002_decision_making_spiking.py @@ -3,7 +3,6 @@ import brainpy as bp import brainpy.math as bm -bp.check.turn_off() bm.set_platform('cpu') import matplotlib.pyplot as plt