diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 2b3c621dc..9dc05d28e 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.3.5" +__version__ = "2.3.6" # fundamental supporting modules @@ -75,8 +75,7 @@ TwoEndConn as TwoEndConn, CondNeuGroup as CondNeuGroup, Channel as Channel) -from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations - LoopOverTime as LoopOverTime,) +from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,) from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner from brainpy._src.dyn.context import share, Delay @@ -207,7 +206,6 @@ dyn.__dict__['TwoEndConn'] = TwoEndConn dyn.__dict__['CondNeuGroup'] = CondNeuGroup dyn.__dict__['Channel'] = Channel -dyn.__dict__['NoSharedArg'] = NoSharedArg dyn.__dict__['LoopOverTime'] = LoopOverTime dyn.__dict__['DSRunner'] = DSRunner diff --git a/brainpy/_src/dyn/transform.py b/brainpy/_src/dyn/transform.py index 0c4f95225..7b825762d 100644 --- a/brainpy/_src/dyn/transform.py +++ b/brainpy/_src/dyn/transform.py @@ -14,7 +14,6 @@ __all__ = [ 'LoopOverTime', - 'NoSharedArg', ] @@ -207,12 +206,13 @@ def __call__( if isinstance(duration_or_xs, float): shared = tools.DotDict() if self.t0 is not None: - shared['t'] = jnp.arange(self.t0.value, duration_or_xs, self.dt) + shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value if self.i0 is not None: - shared['i'] = jnp.arange(self.i0.value, shared['t'].shape[0]) + shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value xs = None if self.no_state: raise ValueError('Under the `no_state=True` setting, input cannot be a duration.') + length = shared['t'].shape else: inp_err_msg = ('\n' @@ -278,8 +278,8 @@ def __call__( else: shared = tools.DotDict() - shared['t'] = jnp.arange(self.t0.value, self.dt * length[0], self.dt) - shared['i'] = jnp.arange(self.i0.value, length[0]) + shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value + shared['i'] = jnp.arange(0, length[0]) + self.i0.value assert not self.no_state results = bm.for_loop(functools.partial(self._run, self.shared_arg), @@ -295,6 +295,10 @@ def __call__( def reset_state(self, batch_size=None): self.target.reset_state(batch_size) + if self.i0 is not None: + self.i0.value = jnp.asarray(0) + if self.t0 is not None: + self.t0.value = jnp.asarray(0.) def _run(self, static_sh, dyn_sh, x): share.save(**static_sh, **dyn_sh) @@ -304,50 +308,3 @@ def _run(self, static_sh, dyn_sh, x): self.target.clear_input() return outs - -class NoSharedArg(DynSysToBPObj): - """Transform an instance of :py:class:`~.DynamicalSystem` into a callable - :py:class:`~.BrainPyObject` :math:`y=f(x)`. - - .. note:: - - This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`. - - If some children nodes need shared arguments, like :py:class:`~.Dropout` or - :py:class:`~.LIF` models, using ``NoSharedArg`` will cause errors. - - Examples - -------- - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> l = bp.Sequential(bp.layers.Dense(100, 10), - >>> bm.relu, - >>> bp.layers.Dense(10, 2)) - >>> l = bp.NoSharedArg(l) - >>> l(bm.random.random(256, 100)) - - Parameters - ---------- - target: DynamicalSystem - The target to transform. - name: str - The transformed object name. - """ - - def __init__(self, target: DynamicalSystem, name: str = None): - super().__init__(target=target, name=name) - if isinstance(target, Sequential) and target.no_shared_arg: - raise ValueError(f'It is a {Sequential.__name__} object with `no_shared_arg=True`, ' - f'which has already able to be called with `f(x)`. ') - - def __call__(self, *args, **kwargs): - return self.target(tools.DotDict(), *args, **kwargs) - - def reset(self, batch_size=None): - """Reset function which reset the whole variables in the model. - """ - self.target.reset(batch_size) - - def reset_state(self, batch_size=None): - self.target.reset_state(batch_size) diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 16eac59a0..a51a5c35a 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -324,7 +324,7 @@ def clone(self): return self.__class__() -def set_environment( +def set( mode: modes.Mode = None, dt: float = None, x64: bool = None, @@ -381,6 +381,9 @@ def set_environment( set_complex(complex_) +set_environment = set + + class environment(_DecoratorContextManager): r"""Context-manager that sets a computing environment for brain dynamics computation. diff --git a/brainpy/_src/math/surrogate/one_input.py b/brainpy/_src/math/surrogate/one_input.py index b58844df1..9d5b85b78 100644 --- a/brainpy/_src/math/surrogate/one_input.py +++ b/brainpy/_src/math/surrogate/one_input.py @@ -34,6 +34,14 @@ ] +class Sigmoid: + def __init__(self, alpha=4., orgin=False): + self.alpha = alpha + self.orgin = orgin + + def __call__(self, x: Union[jax.Array, Array]): + return sigmoid(x, alpha=self.alpha, origin=self.origin) + @vjp_custom(['x'], dict(alpha=4., origin=False), dict(origin=[True, False])) def sigmoid( @@ -105,6 +113,15 @@ def grad(dz): return z, grad +class PiecewiseQuadratic: + def __init__(self, alpha=1., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return piecewise_quadratic(x, alpha=self.alpha, origin=self.origin) + + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def piecewise_quadratic( x: Union[jax.Array, Array], @@ -195,6 +212,15 @@ def grad(dz): return z, grad +class PiecewiseExp: + def __init__(self, alpha=1., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return piecewise_exp(x, alpha=self.alpha, origin=self.origin) + + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def piecewise_exp( x: Union[jax.Array, Array], @@ -271,6 +297,15 @@ def grad(dz): return z, grad +class SoftSign: + def __init__(self, alpha=1., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return soft_sign(x, alpha=self.alpha, origin=self.origin) + + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def soft_sign( x: Union[jax.Array, Array], @@ -342,6 +377,15 @@ def grad(dz): return z, grad +class Arctan: + def __init__(self, alpha=1., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return arctan(x, alpha=self.alpha, origin=self.origin) + + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def arctan( x: Union[jax.Array, Array], @@ -412,6 +456,15 @@ def grad(dz): return z, grad +class NonzeroSignLog: + def __init__(self, alpha=1., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return nonzero_sign_log(x, alpha=self.alpha, origin=self.origin) + + @vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]}) def nonzero_sign_log( x: Union[jax.Array, Array], @@ -495,6 +548,15 @@ def grad(dz): return z, grad +class ERF: + def __init__(self, alpha=1., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return erf(x, alpha=self.alpha, origin=self.origin) + + @vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]}) def erf( x: Union[jax.Array, Array], @@ -569,12 +631,22 @@ def erf( z = jnp.asarray(x >= 0, dtype=x.dtype) def grad(dz): - dx = (alpha / math.sqrt(math.pi)) * jnp.exp(-math.pow(alpha, 2) * x * x) + dx = (alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(alpha, 2) * x * x) return dx * as_jax(dz), None return z, grad +class PiecewiseLeakyRelu: + def __init__(self, c=0.01, w=1., origin=False): + self.c = c + self.w = w + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return piecewise_leaky_relu(x, c=self.c, w=self.w, origin=self.origin) + + @vjp_custom(['x'], dict(c=0.01, w=1., origin=False), statics={'origin': [True, False]}) def piecewise_leaky_relu( x: Union[jax.Array, Array], @@ -673,6 +745,16 @@ def grad(dz): return z, grad +class SquarewaveFourierSeries: + def __init__(self, n=2, t_period=8., origin=False): + self.n = n + self.t_period = t_period + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return squarewave_fourier_series(x, self.n, self.t_period, self.origin) + + @vjp_custom(['x'], dict(n=2, t_period=8., origin=False), statics={'origin': [True, False]}) def squarewave_fourier_series( x: Union[jax.Array, Array], @@ -732,13 +814,13 @@ def squarewave_fourier_series( The spiking state. """ - w = math.pi * 2. / t_period + w = jnp.pi * 2. / t_period if origin: ret = jnp.sin(w * x) for i in range(2, n): c = (2 * i - 1.) ret += jnp.sin(c * w * x) / c - z = 0.5 + 2. / math.pi * ret + z = 0.5 + 2. / jnp.pi * ret else: z = jnp.asarray(x >= 0, dtype=x.dtype) @@ -752,6 +834,17 @@ def grad(dz): return z, grad +class S2NN: + def __init__(self, alpha=4., beta=1., epsilon=1e-8, origin=False): + self.alpha = alpha + self.beta = beta + self.epsilon = epsilon + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array], ): + return s2nn(x, self.alpha, self.beta, self.epsilon, self.origin) + + @vjp_custom(['x'], defaults=dict(alpha=4., beta=1., epsilon=1e-8, origin=False), statics={'origin': [True, False]}) @@ -844,6 +937,15 @@ def grad(dz): return z, grad +class QPseudoSpike: + def __init__(self, alpha=2., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return q_pseudo_spike(x, self.alpha, self.origin) + + @vjp_custom(['x'], dict(alpha=2., origin=False), statics={'origin': [True, False]}) @@ -925,6 +1027,16 @@ def grad(dz): return z, grad +class LeakyRelu: + def __init__(self, alpha=0.1, beta=1., origin=False): + self.alpha = alpha + self.beta = beta + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return leaky_relu(x, self.alpha, self.beta, self.origin) + + @vjp_custom(['x'], dict(alpha=0.1, beta=1., origin=False), statics={'origin': [True, False]}) @@ -1006,6 +1118,15 @@ def grad(dz): return z, grad +class LogTailedRelu: + def __init__(self, alpha=0., origin=False): + self.alpha = alpha + self.origin = origin + + def __call__(self, x: Union[jax.Array, Array]): + return log_tailed_relu(x, self.alpha, self.origin) + + @vjp_custom(['x'], dict(alpha=0., origin=False), statics={'origin': [True, False]}) @@ -1098,6 +1219,15 @@ def grad(dz): return z, grad +class ReluGrad: + def __init__(self, alpha=0.3, width=1.): + self.alpha = alpha + self.width = width + + def __call__(self, x: Union[jax.Array, Array]): + return relu_grad(x, self.alpha, self.width) + + @vjp_custom(['x'], dict(alpha=0.3, width=1.)) def relu_grad( x: Union[jax.Array, Array], @@ -1163,6 +1293,15 @@ def grad(dz): return z, grad +class GaussianGrad: + def __init__(self, sigma=0.5, alpha=0.5): + self.sigma = sigma + self.alpha = alpha + + def __call__(self, x: Union[jax.Array, Array]): + return gaussian_grad(x, self.sigma, self.alpha) + + @vjp_custom(['x'], dict(sigma=0.5, alpha=0.5)) def gaussian_grad( x: Union[jax.Array, Array], @@ -1221,12 +1360,23 @@ def gaussian_grad( z = jnp.asarray(x >= 0, dtype=x.dtype) def grad(dz): - dx = jnp.exp(-(x ** 2) / 2 * math.pow(sigma, 2)) / (math.sqrt(2 * math.pi) * sigma) + dx = jnp.exp(-(x ** 2) / 2 * jnp.power(sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * sigma) return alpha * dx * as_jax(dz), None, None return z, grad +class MultiGaussianGrad: + def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): + self.h = h + self.s = s + self.sigma = sigma + self.scale = scale + + def __call__(self, x: Union[jax.Array, Array]): + return multi_gaussian_grad(x, self.h, self.s, self.sigma, self.scale) + + @vjp_custom(['x'], dict(h=0.15, s=6.0, sigma=0.5, scale=0.5)) def multi_gaussian_grad( x: Union[jax.Array, Array], @@ -1294,15 +1444,23 @@ def multi_gaussian_grad( z = jnp.asarray(x >= 0, dtype=x.dtype) def grad(dz): - g1 = jnp.exp(-x ** 2 / (2 * math.pow(sigma, 2))) / (math.sqrt(2 * math.pi) * sigma) - g2 = jnp.exp(-(x - sigma) ** 2 / (2 * math.pow(s * sigma, 2))) / (math.sqrt(2 * math.pi) * s * sigma) - g3 = jnp.exp(-(x + sigma) ** 2 / (2 * math.pow(s * sigma, 2))) / (math.sqrt(2 * math.pi) * s * sigma) + g1 = jnp.exp(-x ** 2 / (2 * jnp.power(sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * sigma) + g2 = jnp.exp(-(x - sigma) ** 2 / (2 * jnp.power(s * sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * s * sigma) + g3 = jnp.exp(-(x + sigma) ** 2 / (2 * jnp.power(s * sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * s * sigma) dx = g1 * (1. + h) - g2 * h - g3 * h return scale * dx * as_jax(dz), None, None, None, None return z, grad +class InvSquareGrad: + def __init__(self, alpha=100.): + self.alpha = alpha + + def __call__(self, x: Union[jax.Array, Array]): + return inv_square_grad(x, self.alpha) + + @vjp_custom(['x'], dict(alpha=100.)) def inv_square_grad( x: Union[jax.Array, Array], @@ -1360,6 +1518,14 @@ def grad(dz): return z, grad +class SlayerGrad: + def __init__(self, alpha=1.): + self.alpha = alpha + + def __call__(self, x: Union[jax.Array, Array]): + return slayer_grad(x, self.alpha) + + @vjp_custom(['x'], dict(alpha=1.)) def slayer_grad( x: Union[jax.Array, Array], diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index 3c0730d72..22cada96f 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -16,6 +16,7 @@ environment as environment, batching_environment as batching_environment, training_environment as training_environment, + set as set, set_environment as set_environment, enable_x64 as enable_x64, disable_x64 as disable_x64, diff --git a/brainpy/math/surrogate.py b/brainpy/math/surrogate.py index 9892f3f90..88c9a368c 100644 --- a/brainpy/math/surrogate.py +++ b/brainpy/math/surrogate.py @@ -5,23 +5,58 @@ # vjp_custom as vjp_custom # ) from brainpy._src.math.surrogate.one_input import ( + Sigmoid, sigmoid as sigmoid, + + PiecewiseQuadratic, piecewise_quadratic as piecewise_quadratic, + + PiecewiseExp, piecewise_exp as piecewise_exp, + + SoftSign, soft_sign as soft_sign, + + Arctan, arctan as arctan, + + NonzeroSignLog, nonzero_sign_log as nonzero_sign_log, + + ERF, erf as erf, + + PiecewiseLeakyRelu, piecewise_leaky_relu as piecewise_leaky_relu, + + SquarewaveFourierSeries, squarewave_fourier_series as squarewave_fourier_series, + + S2NN, s2nn as s2nn, + + QPseudoSpike, q_pseudo_spike as q_pseudo_spike, + + LeakyRelu, leaky_relu as leaky_relu, + + LogTailedRelu, log_tailed_relu as log_tailed_relu, + + ReluGrad, relu_grad as relu_grad, + + GaussianGrad, gaussian_grad as gaussian_grad, + + InvSquareGrad, inv_square_grad as inv_square_grad, + + MultiGaussianGrad, multi_gaussian_grad as multi_gaussian_grad, + + SlayerGrad, slayer_grad as slayer_grad, ) from brainpy._src.math.surrogate.two_inputs import (