From ddc2c17f12820ea4145c2adf35f84bb80af3d430 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 5 Jan 2023 22:39:05 +0800 Subject: [PATCH 01/13] updates --- brainpy/__init__.py | 46 +++++++++++++++++++++++++---- brainpy/_src/dyn/synapses/compat.py | 16 +++++----- brainpy/integrators/__init__.py | 1 - 3 files changed, 49 insertions(+), 14 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 7387552ed..f25604e70 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -40,6 +40,7 @@ # numerical integrators from brainpy import integrators from brainpy.integrators import ode, sde, fde +from brainpy._src.integrators.base import (Integrator as Integrator) from brainpy._src.integrators.joint_eq import (JointEq as JointEq) from brainpy._src.integrators.runner import (IntegratorRunner as IntegratorRunner) from brainpy._src.integrators.ode.generic import (odeint as odeint) @@ -107,6 +108,17 @@ # ---------------------- # +integrators.__dict__['Integrator'] = Integrator +integrators.__dict__['odeint'] = odeint +integrators.__dict__['sdeint'] = sdeint +integrators.__dict__['fdeint'] = fdeint +integrators.__dict__['IntegratorRunner'] = IntegratorRunner +integrators.__dict__['JointEq'] = JointEq +ode.__dict__['odeint'] = odeint +sde.__dict__['sdeint'] = sdeint +fde.__dict__['fdeint'] = fdeint + + # deprecated from brainpy._src.math.object_transform.base import (Base as Base, ArrayCollector as ArrayCollector, @@ -194,11 +206,35 @@ dyn.__dict__['NoSharedArg'] = NoSharedArg dyn.__dict__['LoopOverTime'] = LoopOverTime dyn.__dict__['DSRunner'] = DSRunner -integrators.__dict__['odeint'] = odeint -integrators.__dict__['sdeint'] = sdeint -integrators.__dict__['fdeint'] = fdeint -integrators.__dict__['IntegratorRunner'] = IntegratorRunner -integrators.__dict__['JointEq'] = JointEq + +dyn.__dict__['HH'] = neurons.HH +dyn.__dict__['MorrisLecar'] = neurons.MorrisLecar +dyn.__dict__['PinskyRinzelModel'] = neurons.PinskyRinzelModel +dyn.__dict__['FractionalFHR'] = neurons.FractionalFHR +dyn.__dict__['FractionalIzhikevich'] = neurons.FractionalIzhikevich +dyn.__dict__['LIF'] = neurons.LIF +dyn.__dict__['ExpIF'] = neurons.ExpIF +dyn.__dict__['AdExIF'] = neurons.AdExIF +dyn.__dict__['QuaIF'] = neurons.QuaIF +dyn.__dict__['AdQuaIF'] = neurons.AdQuaIF +dyn.__dict__['GIF'] = neurons.GIF +dyn.__dict__['Izhikevich'] = neurons.Izhikevich +dyn.__dict__['HindmarshRose'] = neurons.HindmarshRose +dyn.__dict__['FHN'] = neurons.FHN +dyn.__dict__['SpikeTimeGroup'] = neurons.SpikeTimeGroup +dyn.__dict__['PoissonGroup'] = neurons.PoissonGroup +dyn.__dict__['OUProcess'] = neurons.OUProcess + +from brainpy._src.dyn.synapses import compat +dyn.__dict__['DeltaSynapse'] = compat.DeltaSynapse +dyn.__dict__['ExpCUBA'] = compat.ExpCUBA +dyn.__dict__['ExpCOBA'] = compat.ExpCOBA +dyn.__dict__['DualExpCUBA'] = compat.DualExpCUBA +dyn.__dict__['DualExpCOBA'] = compat.DualExpCOBA +dyn.__dict__['AlphaCUBA'] = compat.AlphaCUBA +dyn.__dict__['AlphaCOBA'] = compat.AlphaCOBA +dyn.__dict__['NMDA'] = compat.NMDA +del compat import brainpy._src.math.arraycompatible as bm diff --git a/brainpy/_src/dyn/synapses/compat.py b/brainpy/_src/dyn/synapses/compat.py index d0c7b10de..1fb343ad5 100644 --- a/brainpy/_src/dyn/synapses/compat.py +++ b/brainpy/_src/dyn/synapses/compat.py @@ -26,7 +26,7 @@ class DeltaSynapse(Delta): """Delta synapse. .. deprecated:: 2.1.13 - Please use "brainpy.dyn.synapses.Delta" instead. + Please use "brainpy.synapses.Delta" instead. """ @@ -42,7 +42,7 @@ def __init__( post_has_ref: bool = False, name: str = None, ): - warnings.warn('Please use "brainpy.dyn.synapses.Delta" instead.', DeprecationWarning) + warnings.warn('Please use "brainpy.synapses.Delta" instead.', DeprecationWarning) super(DeltaSynapse, self).__init__(pre=pre, post=post, conn=conn, @@ -58,7 +58,7 @@ class ExpCUBA(Exponential): r"""Current-based exponential decay synapse model. .. deprecated:: 2.1.13 - Please use "brainpy.dyn.synapses.Exponential" instead. + Please use "brainpy.synapses.Exponential" instead. """ @@ -90,7 +90,7 @@ class ExpCOBA(Exponential): """Conductance-based exponential decay synapse model. .. deprecated:: 2.1.13 - Please use "brainpy.dyn.synapses.Exponential" instead. + Please use "brainpy.synapses.Exponential" instead. """ def __init__( @@ -127,7 +127,7 @@ class DualExpCUBA(DualExponential): r"""Current-based dual exponential synapse model. .. deprecated:: 2.1.13 - Please use "brainpy.dyn.synapses.DualExponential" instead. + Please use "brainpy.synapses.DualExponential" instead. """ @@ -162,7 +162,7 @@ class DualExpCOBA(DualExponential): .. deprecated:: 2.1.13 - Please use "brainpy.dyn.synapses.DualExponential" instead. + Please use "brainpy.synapses.DualExponential" instead. """ @@ -197,7 +197,7 @@ class AlphaCUBA(DualExpCUBA): r"""Current-based alpha synapse model. .. deprecated:: 2.1.13 - Please use "brainpy.dyn.synapses.Alpha" instead. + Please use "brainpy.synapses.Alpha" instead. """ @@ -229,7 +229,7 @@ class AlphaCOBA(DualExpCOBA): """Conductance-based alpha synapse model. .. deprecated:: 2.1.13 - Please use "brainpy.dyn.synapses.Alpha" instead. + Please use "brainpy.synapses.Alpha" instead. """ diff --git a/brainpy/integrators/__init__.py b/brainpy/integrators/__init__.py index 36ba2dbb2..176a71aec 100644 --- a/brainpy/integrators/__init__.py +++ b/brainpy/integrators/__init__.py @@ -3,5 +3,4 @@ from . import ode from . import sde from . import fde -from brainpy._src.integrators.base import (Integrator as Integrator) from brainpy._src.integrators.constants import * From f462dcf70b3c9ac5f31a2fb581ee5789f0b2befe Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:05:49 +0800 Subject: [PATCH 02/13] fix `brainpy.math.flatten()` --- brainpy/_src/math/arrayoperation.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/math/arrayoperation.py b/brainpy/_src/math/arrayoperation.py index 91d605d31..12f1b5d8f 100644 --- a/brainpy/_src/math/arrayoperation.py +++ b/brainpy/_src/math/arrayoperation.py @@ -50,10 +50,13 @@ def flatten(input: Union[jax.Array, Array], ndim = 1 if start_dim is None: start_dim = 0 + elif start_dim < 0: + start_dim = ndim + start_dim if end_dim is None: - end_dim = ndim - else: - end_dim += 1 + end_dim = ndim - 1 + elif end_dim < 0: + end_dim = ndim + end_dim + end_dim += 1 if start_dim < 0 or start_dim > ndim: raise ValueError(f'start_dim {start_dim} is out of size.') if end_dim < 0 or end_dim > ndim: @@ -62,15 +65,18 @@ def flatten(input: Union[jax.Array, Array], return jnp.reshape(input, new_shape) -def fill_diagonal(a, val): +def fill_diagonal(a, val, inplace=True): if a.ndim < 2: raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') + if not isinstance(a, Array) and inplace: + raise ValueError('``fill_diagonal()`` is used in in-place updating, therefore ' + 'it requires a brainpy Array. If you want to disable ' + 'inplace updating, use ``fill_diagonal(inplace=False)``.') val = val.value if isinstance(val, Array) else val i, j = jnp.diag_indices(min(a.shape[-2:])) r = as_jax(a).at[..., i, j].set(val) - if isinstance(a, Array): + if inplace: a.value = r - return a else: return r From 02a0761c834c396b9ce17226e5f19ac93ac820d8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:06:50 +0800 Subject: [PATCH 03/13] generalize online/offline algorithms so that they can be used in other cases --- brainpy/_src/dyn/base.py | 14 ++-- brainpy/_src/train/offline.py | 26 ++++---- brainpy/algorithms/offline.py | 119 +++++++++++++++++----------------- brainpy/algorithms/online.py | 64 ++++++++++++------ 4 files changed, 122 insertions(+), 101 deletions(-) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 5d862776e..cd79a38f3 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -344,10 +344,6 @@ def __del__(self): def online_init(self): raise NoImplementationError('Subclass must implement online_init() function when using OnlineTrainer.') - @tools.not_customized - def offline_init(self): - raise NoImplementationError('Subclass must implement offline_init() function when using OfflineTrainer.') - @tools.not_customized def online_fit(self, target: ArrayType, @@ -403,8 +399,8 @@ def update(self, *args, **kwargs): def clear_input(self): """Function for clearing input in the wrapped children dynamical system.""" - if isinstance(self.target, DynamicalSystem): - self.target.clear_input() + for child in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values(): + child.clear_input() def __repr__(self): name = self.__class__.__name__ @@ -1252,8 +1248,8 @@ def __init__( self.C = C self.A = A self.V_th = V_th - self._V_initializer = V_initializer self.noise = init_noise(noise, self.varshape, num_vars=3) + self._V_initializer = V_initializer # variables self.V = variable(V_initializer, self.mode, self.varshape) @@ -1277,6 +1273,8 @@ def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.spike.value = variable(lambda s: jnp.zeros(s, dtype=bool), batch_size, self.varshape) self.input.value = variable(jnp.zeros, batch_size, self.varshape) + for channel in self.nodes(level=1, include_self=False).subset(Channel).unique().values(): + channel.reset_state(self.V.value, batch_size=batch_size) def update(self, tdi, *args, **kwargs): V = self.integral(self.V.value, tdi['t'], tdi['dt']) @@ -1330,7 +1328,7 @@ def update(self, tdi, V): def current(self, V): raise NotImplementedError('Must be implemented by the subclass.') - def reset_state(self, batch_size=None): + def reset_state(self, V, batch_size=None): raise NotImplementedError('Must be implemented by the subclass.') diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 62f4f5d0e..a640407a0 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -254,19 +254,19 @@ def _check_interface(self): interface "offline_fit()" function. ''' ) - if hasattr(node.offline_init, 'not_customized'): - if node.offline_init.not_customized: - raise NoImplementationError( - f''' - The node - - {node} - - is set to be computing mode of {bm.training_mode} with {self.__class__.__name__}. - However, it does not implement the required training - interface "offline_init()" function. - ''' - ) + # if hasattr(node.offline_init, 'not_customized'): + # if node.offline_init.not_customized: + # raise NoImplementationError( + # f''' + # The node + # + # {node} + # + # is set to be computing mode of {bm.training_mode} with {self.__class__.__name__}. + # However, it does not implement the required training + # interface "offline_init()" function. + # ''' + # ) class RidgeTrainer(OfflineTrainer): diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py index 7081c23ec..e12b8ce13 100644 --- a/brainpy/algorithms/offline.py +++ b/brainpy/algorithms/offline.py @@ -3,22 +3,27 @@ import warnings import numpy as np +import jax.numpy as jnp from jax.lax import while_loop import brainpy.math as bm from brainpy._src.math.object_transform.base import BrainPyObject from brainpy.types import ArrayType from .utils import (Sigmoid, - Regularization, L1Regularization, L1L2Regularization, L2Regularization, - polynomial_features, normalize) + Regularization, + L1Regularization, + L1L2Regularization, + L2Regularization, + polynomial_features, + normalize) __all__ = [ # brainpy_object class for offline training algorithm 'OfflineAlgorithm', # training methods - 'LinearRegression', - 'RidgeRegression', + 'LinearRegression', 'linear_regression', + 'RidgeRegression', 'ridge_regression', 'LassoRegression', 'LogisticRegression', 'PolynomialRegression', @@ -39,18 +44,16 @@ class OfflineAlgorithm(BrainPyObject): def __init__(self, name=None): super(OfflineAlgorithm, self).__init__(name=name) - def __call__(self, identifier, target, input, output): + def __call__(self, targets, inputs, outputs=None): """The training procedure. Parameters ---------- - identifier: str - The variable name. - target: ArrayType + targets: ArrayType The 2d target data with the shape of `(num_batch, num_output)`. - input: ArrayType + inputs: ArrayType The 2d input data with the shape of `(num_batch, num_input)`. - output: ArrayType + outputs: ArrayType The 2d output data with the shape of `(num_batch, num_output)`. Returns @@ -58,16 +61,13 @@ def __call__(self, identifier, target, input, output): weight: ArrayType The weights after fit. """ - return self.call(identifier, target, input, output) + return self.call(targets, inputs, outputs) - def call(self, identifier, targets, inputs, outputs) -> ArrayType: + def call(self, targets, inputs, outputs=None) -> ArrayType: """The training procedure. Parameters ---------- - identifier: str - The identifier. - inputs: ArrayType The 3d input data with the shape of `(num_batch, num_time, num_input)`, or, the 2d input data with the shape of `(num_time, num_input)`. @@ -90,15 +90,12 @@ def call(self, identifier, targets, inputs, outputs) -> ArrayType: def __repr__(self): return self.__class__.__name__ - def initialize(self, identifier, *args, **kwargs): - pass - def _check_data_2d_atls(x): if x.ndim < 2: raise ValueError(f'Data must be a 2d tensor. But we got {x.ndim}d: {x.shape}.') if x.ndim != 2: - return x.reshape((-1, x.shape[-1])) + return bm.flatten(x, end_dim=-2) else: return x @@ -127,7 +124,7 @@ def __init__( self.learning_rate = learning_rate self.regularizer = regularizer - def initialize(self, identifier, *args, **kwargs): + def initialize(self, *args, **kwargs): pass def init_weights(self, n_features, n_out): @@ -137,22 +134,22 @@ def init_weights(self, n_features, n_out): def gradient_descent_solve(self, targets, inputs, outputs=None): # checking - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) # initialize weights w = self.init_weights(inputs.shape[1], targets.shape[1]) def cond_fun(a): i, par_old, par_new = a - return bm.logical_and(bm.logical_not(bm.allclose(par_old, par_new)), - i < self.max_iter).value + return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), + i < self.max_iter).value def body_fun(a): i, _, par_new = a # Gradient of regularization loss w.r.t w y_pred = inputs.dot(par_new) - grad_w = bm.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new) + grad_w = jnp.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new) # Update the weights par_new2 = par_new - self.learning_rate * grad_w return i + 1, par_new, par_new2 @@ -162,7 +159,7 @@ def body_fun(a): return r[-1] def predict(self, W, X): - return bm.dot(X, W) + return jnp.dot(X, W) class LinearRegression(RegressionAlgorithm): @@ -189,19 +186,21 @@ def __init__( regularizer=Regularization(0.)) self.gradient_descent = gradient_descent - def call(self, identifier, targets, inputs, outputs=None): + def call(self, targets, inputs, outputs=None): # checking - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) # solving if self.gradient_descent: return self.gradient_descent_solve(targets, inputs) else: - weights = bm.linalg.lstsq(inputs, targets) + weights = jnp.linalg.lstsq(inputs, targets) return weights[0] +linear_regression = LinearRegression() + name2func['linear'] = LinearRegression name2func['lstsq'] = LinearRegression @@ -248,10 +247,10 @@ def __init__( regularizer=L2Regularization(alpha=alpha)) self.gradient_descent = gradient_descent - def call(self, identifier, targets, inputs, outputs=None): + def call(self, targets, inputs, outputs=None): # checking - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) # solving if self.gradient_descent: @@ -259,14 +258,16 @@ def call(self, identifier, targets, inputs, outputs=None): else: temp = inputs.T @ inputs if self.regularizer.alpha > 0.: - temp += self.regularizer.alpha * bm.eye(inputs.shape[-1]) - weights = bm.linalg.pinv(temp) @ (inputs.T @ targets) + temp += self.regularizer.alpha * jnp.eye(inputs.shape[-1]) + weights = jnp.linalg.pinv(temp) @ (inputs.T @ targets) return weights def __repr__(self): return f'{self.__class__.__name__}(beta={self.regularizer.alpha})' +ridge_regression = RidgeRegression() + name2func['ridge'] = RidgeRegression @@ -307,17 +308,17 @@ def __init__( assert gradient_descent self.degree = degree - def call(self, identifier, targets, inputs, outputs=None): + def call(self, targets, inputs, outputs=None): # checking - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) # solving inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)) return super(LassoRegression, self).gradient_descent_solve(targets, inputs) def predict(self, W, X): - X = _check_data_2d_atls(bm.asarray(X)) + X = _check_data_2d_atls(bm.as_jax(X)) X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) return super(LassoRegression, self).predict(W, X) @@ -355,10 +356,10 @@ def __init__( self.gradient_descent = gradient_descent self.sigmoid = Sigmoid() - def call(self, identifier, targets, inputs, outputs=None) -> ArrayType: + def call(self, targets, inputs, outputs=None) -> ArrayType: # prepare data - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) if targets.shape[-1] != 1: raise ValueError(f'Target must be a scalar, but got multiple variables: {targets.shape}. ') targets = targets.flatten() @@ -368,7 +369,7 @@ def call(self, identifier, targets, inputs, outputs=None) -> ArrayType: def cond_fun(a): i, par_old, par_new = a - return bm.logical_and(bm.logical_not(bm.allclose(par_old, par_new)), + return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), i < self.max_iter).value def body_fun(a): @@ -384,12 +385,12 @@ def body_fun(a): diag_grad = bm.zeros((gradient.size, gradient.size)) diag = bm.arange(gradient.size) diag_grad[diag, diag] = gradient - par_new2 = bm.linalg.pinv(inputs.T.dot(diag_grad).dot(inputs)).dot(inputs.T).dot( + par_new2 = jnp.linalg.pinv(inputs.T.dot(diag_grad).dot(inputs)).dot(inputs.T).dot( diag_grad.dot(inputs).dot(par_new) + targets - y_pred) return i + 1, par_new, par_new2 # Tune parameters for n iterations - r = while_loop(cond_fun, body_fun, (0, param+1., param)) + r = while_loop(cond_fun, body_fun, (0, param + 1., param)) return r[-1] def predict(self, W, X): @@ -418,14 +419,14 @@ def __init__( self.degree = degree self.add_bias = add_bias - def call(self, identifier, targets, inputs, outputs=None): - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + def call(self, targets, inputs, outputs=None): + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRegression, self).call(identifier, targets, inputs) + return super(PolynomialRegression, self).call(targets, inputs) def predict(self, W, X): - X = _check_data_2d_atls(bm.asarray(X)) + X = _check_data_2d_atls(bm.as_jax(X)) X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) return super(PolynomialRegression, self).predict(W, X) @@ -454,15 +455,15 @@ def __init__( self.degree = degree self.add_bias = add_bias - def call(self, identifier, targets, inputs, outputs=None): + def call(self, targets, inputs, outputs=None): # checking - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRidgeRegression, self).call(identifier, targets, inputs) + return super(PolynomialRidgeRegression, self).call(targets, inputs) def predict(self, W, X): - X = _check_data_2d_atls(bm.asarray(X)) + X = _check_data_2d_atls(bm.as_jax(X)) X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) return super(PolynomialRidgeRegression, self).predict(W, X) @@ -512,16 +513,16 @@ def __init__( self.gradient_descent = gradient_descent assert gradient_descent - def call(self, identifier, targets, inputs, outputs=None): + def call(self, targets, inputs, outputs=None): # checking - inputs = _check_data_2d_atls(bm.asarray(inputs)) - targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) # solving inputs = normalize(polynomial_features(inputs, degree=self.degree)) return super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs) def predict(self, W, X): - X = _check_data_2d_atls(bm.asarray(X)) + X = _check_data_2d_atls(bm.as_jax(X)) X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) return super(ElasticNetRegression, self).predict(W, X) diff --git a/brainpy/algorithms/online.py b/brainpy/algorithms/online.py index a0a5b726a..f138ddbfb 100644 --- a/brainpy/algorithms/online.py +++ b/brainpy/algorithms/online.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- +import jax +import jax.numpy as jnp +from jax import vmap import brainpy.math as bm -from jax import vmap -import jax.numpy as jnp from brainpy._src.math.object_transform.base import BrainPyObject - __all__ = [ # brainpy_object class 'OnlineAlgorithm', @@ -28,7 +28,7 @@ class OnlineAlgorithm(BrainPyObject): def __init__(self, name=None): super(OnlineAlgorithm, self).__init__(name=name) - def __call__(self, identifier, target, input, output): + def __call__(self, *args, **kwargs): """The training procedure. Parameters @@ -47,12 +47,12 @@ def __call__(self, identifier, target, input, output): weight: ArrayType The weights after fit. """ - return self.call(identifier, target, input, output) + return self.call(*args, **kwargs) - def initialize(self, identifier, *args, **kwargs): + def register_target(self, *args, **kwargs): pass - def call(self, identifier, target, input, output): + def call(self, target, input, output, identifier: str=''): """The training procedure. Parameters @@ -105,21 +105,38 @@ def __init__(self, alpha=0.1, name=None): super(RLS, self).__init__(name=name) self.alpha = alpha - def initialize(self, identifier, feature_in, feature_out=None): + def register_target( + self, + feature_in: int, + identifier: str = '', + ): identifier = identifier + self.postfix - self.implicit_vars[identifier] = bm.Variable(bm.eye(feature_in) * self.alpha) - - def call(self, identifier, target, input, output): + self.implicit_vars[identifier] = bm.Variable(jnp.eye(feature_in) * self.alpha) + + def call( + self, + target: jax.Array, + input: jax.Array, + output: jax.Array, + identifier: str = '', + ): identifier = identifier + self.postfix P = self.implicit_vars[identifier] - # update the inverse correlation matrix - k = bm.dot(P, input.T) # (num_input, num_batch) - hPh = bm.dot(input, k) # (num_batch, num_batch) - c = bm.sum(1.0 / (1.0 + hPh)) # () - P -= c * bm.dot(k, k.T) # (num_input, num_input) - # update weights + input = bm.as_jax(input) + output = bm.as_jax(output) + target = bm.as_jax(target) + if input.ndim == 1: input = jnp.expand_dims(input, 0) + if target.ndim == 1: target = jnp.expand_dims(target, 0) + if output.ndim == 1: output = jnp.expand_dims(output, 0) + assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}' + assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}' + assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}' + k = jnp.dot(P.value, input.T) # (num_input, num_batch) + hPh = jnp.dot(input, k) # (num_batch, num_batch) + c = jnp.sum(1.0 / (1.0 + hPh)) # () + P -= c * jnp.dot(k, k.T) # (num_input, num_input) e = output - target # (num_batch, num_output) - dw = -c * bm.dot(k, e) # (num_input, num_output) + dw = -c * jnp.dot(k, e) # (num_input, num_output) return dw @@ -148,11 +165,16 @@ def __init__(self, alpha=0.1, name=None): super(LMS, self).__init__(name=name) self.alpha = alpha - def call(self, identifier, target, input, output): - assert target.shape[0] == input.shape[0] == output.shape[0], 'Batch size should be consistent.' + def call(self, target, input, output, identifier: str=''): + if input.ndim == 1: input = jnp.expand_dims(input, 0) + if target.ndim == 1: target = jnp.expand_dims(target, 0) + if output.ndim == 1: output = jnp.expand_dims(output, 0) + assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}' + assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}' + assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}' error = bm.as_jax(output - target) input = bm.as_jax(input) - return -self.alpha * bm.sum(vmap(jnp.outer)(input, error), axis=0) + return -self.alpha * jnp.sum(vmap(jnp.outer)(input, error), axis=0) name2func['lms'] = LMS From 124414e26e4861d47919bba3316520b5fc91809e Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:07:15 +0800 Subject: [PATCH 04/13] fix `MultiStepLR` scheduler bug --- brainpy/_src/optimizers/optimizer.py | 17 ++++++++++------- brainpy/_src/optimizers/scheduler.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/brainpy/_src/optimizers/optimizer.py b/brainpy/_src/optimizers/optimizer.py index bb33362ef..c251621ae 100644 --- a/brainpy/_src/optimizers/optimizer.py +++ b/brainpy/_src/optimizers/optimizer.py @@ -116,6 +116,9 @@ def __init__( weight_decay=weight_decay, name=name) + def __repr__(self): + return f'{self.__class__.__name__}(lr={self.lr})' + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): train_vars = dict() if train_vars is None else train_vars if not isinstance(train_vars, dict): @@ -129,7 +132,7 @@ def update(self, grads: dict): if self.weight_decay is None: p.value -= lr * grads[key] else: - p.value = (1 - self.weight_decay) * p + lr * grads[key] + p.value = (1 - self.weight_decay) * p - lr * grads[key] self.lr.step_call() @@ -178,6 +181,9 @@ def __init__( self.momentum = momentum + def __repr__(self): + return f'{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})' + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): train_vars = dict() if train_vars is None else train_vars if not isinstance(train_vars, dict): @@ -200,9 +206,6 @@ def update(self, grads: dict): p.value = (1 - self.weight_decay) * p + v self.lr.step_call() - def __repr__(self): - return f"{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})" - class MomentumNesterov(CommonOpt): r"""Nesterov accelerated gradient optimizer [2]_. @@ -242,6 +245,9 @@ def __init__( self.momentum = momentum + def __repr__(self): + return f'{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})' + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): train_vars = dict() if train_vars is None else train_vars if not isinstance(train_vars, dict): @@ -264,9 +270,6 @@ def update(self, grads: dict): p.value = (1 - self.weight_decay) * p + v self.lr.step_call() - def __repr__(self): - return f"{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})" - class Adagrad(CommonOpt): r"""Optimizer that implements the Adagrad algorithm. diff --git a/brainpy/_src/optimizers/scheduler.py b/brainpy/_src/optimizers/scheduler.py index 481545e35..47446f272 100644 --- a/brainpy/_src/optimizers/scheduler.py +++ b/brainpy/_src/optimizers/scheduler.py @@ -142,7 +142,7 @@ def __init__( def __call__(self, i=None): i = (self.last_epoch.value + 1) if i is None else i p = bm.ifelse([i < m for m in self.milestones], - list(range(1, len(self.milestones) + 1)) + [len(self.milestones) + 1]) + list(range(0, len(self.milestones))) + [len(self.milestones)]) return self.lr * self.gamma ** p def __repr__(self): From 8255b7c4fb7c002b3f575ee69e571d2b32db4203 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:07:40 +0800 Subject: [PATCH 05/13] fix `PoissonEncoder` --- brainpy/_src/encoding/stateless_encoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/encoding/stateless_encoding.py b/brainpy/_src/encoding/stateless_encoding.py index dbe50fe3b..c32ecdd52 100644 --- a/brainpy/_src/encoding/stateless_encoding.py +++ b/brainpy/_src/encoding/stateless_encoding.py @@ -66,5 +66,5 @@ def __call__(self, x: ArrayType, num_step: int = None): if not (self.min_val is None or self.max_val is None): x = (x - self.min_val) / (self.max_val - self.min_val) shape = x.shape if (num_step is None) else ((num_step,) + x.shape) - d = self.rng.rand(*shape).value < x + d = bm.as_jax(self.rng.rand(*shape)) < x return d.astype(x.dtype) From bc54908eb19d6e1e16522d476d49e409acb6480a Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:09:02 +0800 Subject: [PATCH 06/13] generalize numpy functions in `brainpy.math` --- brainpy/_src/math/arraycompatible.py | 2611 ++++--------------------- brainpy/_src/math/arrayinterporate.py | 1 - brainpy/_src/math/ndarray.py | 28 + brainpy/_src/math/random.py | 208 +- brainpy/math/arraycompatible.py | 359 ++++ brainpy/math/fft.py | 21 + brainpy/math/linalg.py | 22 + brainpy/math/others.py | 3 + 8 files changed, 957 insertions(+), 2296 deletions(-) create mode 100644 brainpy/math/arraycompatible.py diff --git a/brainpy/_src/math/arraycompatible.py b/brainpy/_src/math/arraycompatible.py index accad7670..598112f06 100644 --- a/brainpy/_src/math/arraycompatible.py +++ b/brainpy/_src/math/arraycompatible.py @@ -1,1948 +1,406 @@ # -*- coding: utf-8 -*- -from typing import Optional - -import jax.numpy as jnp -import numpy as np -from jax.tree_util import tree_map, tree_flatten, tree_unflatten - -from ._utils import wraps -from .arraycreation import * -from .arrayinterporate import * -from .ndarray import Array - -__all__ = [ - 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', - - # math funcs - 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', - 'add', 'reciprocal', 'negative', 'positive', 'multiply', 'divide', - 'power', 'subtract', 'true_divide', 'floor_divide', 'float_power', - 'fmod', 'mod', 'modf', 'divmod', 'remainder', 'abs', 'exp', 'exp2', - 'expm1', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', - 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', - 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', - 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round', - 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod', - 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum', - 'cumprod', 'cumsum', 'ediff1d', 'cross', 'trapz', 'isfinite', 'isinf', - 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve', - 'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside', - 'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle', - - # Elementwise bit operations - 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', - 'invert', 'left_shift', 'right_shift', - - # logic funcs - 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', - 'array_equal', 'isclose', 'allclose', 'logical_and', 'logical_not', - 'logical_or', 'logical_xor', 'all', 'any', "alltrue", 'sometrue', - - # array manipulation - 'shape', 'size', 'reshape', 'ravel', 'moveaxis', 'transpose', 'swapaxes', - 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', - 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', - 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', - 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', - 'argwhere', 'nonzero', 'flatnonzero', 'where', 'searchsorted', 'extract', - 'count_nonzero', 'max', 'min', 'amax', 'amin', - - # array creation - 'array_split', 'meshgrid', 'vander', - - # indexing funcs - 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', - 'triu_indices_from', 'take', 'select', - - # statistic funcs - 'nanmin', 'nanmax', 'ptp', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', - 'median', 'average', 'mean', 'std', 'var', 'nanmedian', 'nanmean', 'nanstd', 'nanvar', - 'corrcoef', 'correlate', 'cov', 'histogram', 'bincount', 'digitize', - - # window funcs - 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', - - # constants - 'e', 'pi', 'inf', - - # linear algebra - 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', - - # data types - 'dtype', 'finfo', 'iinfo', 'uint8', 'uint16', 'uint32', 'uint64', - 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', - 'float64', 'complex64', 'complex128', - - # more - 'product', 'row_stack', 'apply_over_axes', 'apply_along_axis', 'array_equiv', - 'array_repr', 'array_str', 'block', 'broadcast_arrays', 'broadcast_shapes', - 'broadcast_to', 'compress', 'cumproduct', 'diag_indices', 'diag_indices_from', - 'diagflat', 'diagonal', 'einsum', 'einsum_path', 'geomspace', 'gradient', - 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'i0', 'in1d', 'indices', - 'insert', 'intersect1d', 'iscomplex', 'isin', 'ix_', 'lexsort', 'load', - 'save', 'savez', 'mask_indices', 'msort', 'nan_to_num', 'nanargmax', 'setdiff1d', - 'nanargmin', 'pad', 'poly', 'polyadd', 'polyder', 'polyfit', 'polyint', - 'polymul', 'polysub', 'polyval', 'resize', 'rollaxis', 'roots', 'rot90', - 'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', - 'take_along_axis', 'can_cast', 'choose', 'copy', 'frombuffer', 'fromfile', - 'fromfunction', 'fromiter', 'fromstring', 'get_printoptions', 'iscomplexobj', - 'isneginf', 'isposinf', 'isrealobj', 'issubdtype', 'issubsctype', 'iterable', - 'packbits', 'piecewise', 'printoptions', 'set_printoptions', 'promote_types', - 'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete', - - # unique - 'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'array2string', 'asanyarray', - 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'disp', 'genfromtxt', - 'loadtxt', 'info', 'issubclass_', 'place', 'polydiv', 'put', 'putmask', 'safe_eval', - 'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat', -] - -_min = min -_max = max - - -# others -# ------ - - -def _as_jax_array_(obj): - return obj.value if isinstance(obj, Array) else obj - - -@wraps(jnp.full) -def full(shape, fill_value, dtype=None): - return Array(jnp.full(shape, fill_value, dtype=dtype)) - - -@wraps(jnp.full_like) -def full_like(a, fill_value, dtype=None, shape=None): - a = _as_jax_array_(a) - return Array(jnp.full_like(a, fill_value, dtype=dtype, shape=shape)) - - -@wraps(jnp.eye) -def eye(N, M=None, k=0, dtype=None): - return Array(jnp.eye(N, M=M, k=k, dtype=dtype)) - - -@wraps(jnp.identity) -def identity(n, dtype=None): - return Array(jnp.identity(n, dtype=dtype)) - - -@wraps(jnp.diag) -def diag(a, k=0): - a = _as_jax_array_(a) - return Array(jnp.diag(a, k)) - - -@wraps(jnp.tri) -def tri(N, M=None, k=0, dtype=None): - return Array(jnp.tri(N, M=M, k=k, dtype=dtype)) - - -@wraps(jnp.tril) -def tril(a, k=0): - a = _as_jax_array_(a) - return Array(jnp.tril(a, k)) - - -@wraps(jnp.triu) -def triu(a, k=0): - a = _as_jax_array_(a) - return Array(jnp.triu(a, k)) - - -@wraps(jnp.delete) -def delete(arr, obj, axis=None): - arr = _as_jax_array_(arr) - obj = _as_jax_array_(obj) - return Array(jnp.delete(arr, obj, axis=axis)) - - -@wraps(jnp.take_along_axis) -def take_along_axis(a, indices, axis, mode=None): - a = _as_jax_array_(a) - indices = _as_jax_array_(indices) - return Array(jnp.take_along_axis(a, indices, axis, mode)) - - -@wraps(jnp.block) -def block(arrays): - leaves, tree = tree_flatten(arrays, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_as_jax_array_(l) for l in leaves] - arrays = tree_unflatten(tree, leaves) - return Array(jnp.block(arrays)) - - -@wraps(jnp.broadcast_arrays) -def broadcast_arrays(*args): - args = [(_as_jax_array_(a)) for a in args] - return jnp.broadcast_arrays(args) - - -broadcast_shapes = wraps(jnp.broadcast_shapes)(jnp.broadcast_shapes) - - -@wraps(jnp.broadcast_to) -def broadcast_to(arr, shape): - arr = _as_jax_array_(arr) - return Array(jnp.broadcast_to(arr, shape)) - - -@wraps(jnp.compress) -def compress(condition, a, axis=None, out=None): - condition = _as_jax_array_(condition) - a = _as_jax_array_(a) - return Array(jnp.compress(condition, a, axis, out)) - - -@wraps(jnp.diag_indices) -def diag_indices(n, ndim=2): - res = jnp.diag_indices(n, ndim) - if isinstance(res, tuple): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.diag_indices_from) -def diag_indices_from(arr): - arr = _as_jax_array_(arr) - res = jnp.diag_indices_from(arr) - if isinstance(res, tuple): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.diagflat) -def diagflat(v, k=0): - v = _as_jax_array_(v) - return Array(jnp.diagflat(v, k)) - - -@wraps(jnp.diagonal) -def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1): - a = _as_jax_array_(a) - return Array(jnp.diagonal(a, offset, axis1, axis2)) - - -@wraps(jnp.einsum) -def einsum(*operands, out=None, optimize='optimal', precision=None, _use_xeinsum=False): - operands = tuple((_as_jax_array_(a)) for a in operands) - return Array(jnp.einsum(*operands, out=out, optimize=optimize, precision=precision, _use_xeinsum=_use_xeinsum)) - - -@wraps(jnp.einsum_path) -def einsum_path(subscripts, *operands, optimize='greedy'): - operands = tuple((_as_jax_array_(a)) for a in operands) - return jnp.einsum_path(subscripts, *operands, optimize=optimize) - - -@wraps(jnp.geomspace) -def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0): - return Array(jnp.geomspace(start, stop, num, endpoint, dtype, axis)) - - -@wraps(jnp.gradient) -def gradient(f, *varargs, axis=None, edge_order=None): - f = _as_jax_array_(f) - res = jnp.gradient(f, *varargs, axis=axis, edge_order=edge_order) - if isinstance(res, (list, tuple)): - return list(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.histogram2d) -def histogram2d(x, y, bins=10, range=None, weights=None, density=None): - x = _as_jax_array_(x) - y = _as_jax_array_(y) - H, xedges, yedges = jnp.histogram2d(x, y, bins, range, weights, density) - return Array(H), Array(xedges), Array(yedges) - - -@wraps(jnp.histogram_bin_edges) -def histogram_bin_edges(a, bins=10, range=None, weights=None): - a = _as_jax_array_(a) - return Array(jnp.histogram_bin_edges(a, bins, range, weights)) - - -@wraps(jnp.histogramdd) -def histogramdd(sample, bins=10, range=None, weights=None, density=None): - sample = _as_jax_array_(sample) - r = jnp.histogramdd(sample, bins, range, weights, density) - return Array(r[0]), r[1] - - -@wraps(jnp.i0) -def i0(x): - x = _as_jax_array_(x) - return Array(jnp.i0(x)) - - -@wraps(jnp.in1d) -def in1d(ar1, ar2, assume_unique=False, invert=False): - ar1 = _as_jax_array_(ar1) - ar2 = _as_jax_array_(ar2) - return Array(jnp.in1d(ar1, ar2, assume_unique, invert)) - - -@wraps(jnp.indices) -def indices(dimensions, dtype=None, sparse=False): - dtype = jnp.int32 if dtype is None else dtype - res = jnp.indices(dimensions, dtype, sparse) - if isinstance(res, (tuple, list)): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.insert) -def insert(arr, obj, values, axis=None): - arr = _as_jax_array_(arr) - values = _as_jax_array_(values) - return Array(jnp.insert(arr, obj, values, axis)) - - -@wraps(jnp.intersect1d) -def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): - ar1 = _as_jax_array_(ar1) - ar2 = _as_jax_array_(ar2) - res = jnp.intersect1d(ar1, ar2, assume_unique, return_indices) - if return_indices: - return tuple([Array(r) for r in res]) - else: - return Array(res) - - -@wraps(jnp.iscomplex) -def iscomplex(x): - x = _as_jax_array_(x) - return jnp.iscomplex(x) - - -@wraps(jnp.isin) -def isin(element, test_elements, assume_unique=False, invert=False): - element = _as_jax_array_(element) - test_elements = _as_jax_array_(test_elements) - return Array(jnp.isin(element, test_elements, assume_unique, invert)) - - -@wraps(jnp.ix_) -def ix_(*args): - args = [_as_jax_array_(a) for a in args] - return jnp.ix_(*args) - - -@wraps(jnp.lexsort) -def lexsort(keys, axis=-1): - leaves, tree = tree_flatten(keys, is_leaf=lambda x: isinstance(x, Array)) - leaves = [_as_jax_array_(l) for l in leaves] - keys = tree_unflatten(tree, leaves) - return Array(jnp.lexsort(keys, axis)) - - - -def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, - encoding='ASCII'): - return np.load(file, - mmap_mode=mmap_mode, - allow_pickle=allow_pickle, - fix_imports=fix_imports, - encoding=encoding) - - -@wraps(np.save) -def save(file, arr, allow_pickle=True, fix_imports=True): - arr = _as_jax_array_(arr) - np.save(file, arr, allow_pickle, fix_imports) - - -@wraps(np.savez) -def savez(file, *args, **kwds): - args = [_as_jax_array_(a) for a in args] - kwds = {k: _as_jax_array_(v) for k, v in kwds.items()} - np.savez(file, *args, **kwds) - - -mask_indices = wraps(jnp.mask_indices)(jnp.mask_indices) - - -def msort(a): - return Array(jnp.sort(_as_jax_array_(a), axis=0)) - - -@wraps(jnp.nan_to_num) -def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): - x = _as_jax_array_(x) - return Array(jnp.nan_to_num(x, copy, nan=nan, posinf=posinf, neginf=neginf)) - - -@wraps(jnp.nanargmax) -def nanargmax(a, axis=None, out=None, keepdims=None): - return Array(jnp.nanargmax(_as_jax_array_(a), axis=axis, out=out, keepdims=keepdims)) - - -@wraps(jnp.nanargmin) -def nanargmin(a, axis=None, out=None, keepdims=None): - return Array(jnp.nanargmin(_as_jax_array_(a), axis=axis, out=out, keepdims=keepdims)) - - -@wraps(jnp.pad) -def pad(array, pad_width, mode="constant", **kwargs): - array = _as_jax_array_(array) - pad_width = _as_jax_array_(pad_width) - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.pad(array, pad_width, mode, **kwargs)) - - -@wraps(jnp.poly) -def poly(seq_of_zeros): - seq_of_zeros = _as_jax_array_(seq_of_zeros) - return Array(jnp.poly(seq_of_zeros)) - - -@wraps(jnp.polyadd) -def polyadd(a1, a2): - a1 = _as_jax_array_(a1) - a2 = _as_jax_array_(a2) - return Array(jnp.polyadd(a1, a2)) - - -@wraps(jnp.polyder) -def polyder(p, m=1): - p = _as_jax_array_(p) - return Array(jnp.polyder(p, m)) - - -@wraps(jnp.polyfit) -def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): - x = _as_jax_array_(x) - y = _as_jax_array_(y) - res = jnp.polyfit(x, y, deg, rcond=rcond, full=full, w=w, cov=cov) - if isinstance(res, (tuple, list)): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.polyint) -def polyint(p, m=1, k=None): - p = _as_jax_array_(p) - return Array(jnp.polyint(p, m, k)) - - -@wraps(jnp.polymul) -def polymul(a1, a2, **kwargs): - a1 = _as_jax_array_(a1) - a2 = _as_jax_array_(a2) - return Array(jnp.polymul(a1, a2, **kwargs)) - - -@wraps(jnp.polysub) -def polysub(a1, a2): - a1 = _as_jax_array_(a1) - a2 = _as_jax_array_(a2) - return Array(jnp.polysub(a1, a2)) - - -@wraps(jnp.polyval) -def polyval(p, x): - p = _as_jax_array_(p) - x = _as_jax_array_(x) - return Array(jnp.polyval(p, x)) - - -@wraps(jnp.resize) -def resize(a, new_shape): - a = _as_jax_array_(a) - return Array(jnp.resize(a, new_shape)) - - -@wraps(jnp.rollaxis) -def rollaxis(a, axis: int, start=0): - a = _as_jax_array_(a) - return Array(jnp.rollaxis(a, axis, start)) - - -@wraps(jnp.roots) -def roots(p): - p = _as_jax_array_(p) - return Array(jnp.roots(p)) - - -@wraps(jnp.rot90) -def rot90(m, k=1, axes=(0, 1)): - m = _as_jax_array_(m) - return Array(jnp.rot90(m, k, axes)) - - -@wraps(jnp.setdiff1d) -def setdiff1d(ar1, ar2, assume_unique=False, **kwargs): - return Array(jnp.setdiff1d(_as_jax_array_(ar1), - _as_jax_array_(ar2), - assume_unique=assume_unique, - **kwargs)) - - -@wraps(jnp.setxor1d) -def setxor1d(ar1, ar2, assume_unique=False): - return Array(jnp.setxor1d(_as_jax_array_(ar1), - _as_jax_array_(ar2), - assume_unique=assume_unique)) - - -@wraps(jnp.tensordot) -def tensordot(a, b, axes=2, **kwargs): - a = _as_jax_array_(a) - b = _as_jax_array_(b) - return Array(jnp.tensordot(a, b, axes, **kwargs)) - - -@wraps(jnp.trim_zeros) -def trim_zeros(filt, trim='fb'): - return Array(jnp.trim_zeros(_as_jax_array_(filt), trim)) - - -@wraps(jnp.union1d) -def union1d(ar1, ar2, **kwargs): - ar1 = _as_jax_array_(ar1) - ar2 = _as_jax_array_(ar2) - return Array(jnp.union1d(ar1, ar2, **kwargs)) - - -@wraps(jnp.unravel_index) -def unravel_index(indices, shape): - indices = _as_jax_array_(indices) - shape = _as_jax_array_(shape) - return jnp.unravel_index(indices, shape) - - -@wraps(jnp.unwrap) -def unwrap(p, discont=jnp.pi, axis: int = -1, period: float = 2 * jnp.pi): - p = _as_jax_array_(p) - return Array(jnp.unwrap(p, discont, axis, period)) - - -# math funcs -# ---------- - -# 1. Basics -@wraps(jnp.isreal) -def isreal(x): - x = _as_jax_array_(x) - return jnp.isreal(x) - - -@wraps(jnp.isscalar) -def isscalar(x): - x = _as_jax_array_(x) - return jnp.isscalar(x) - - -@wraps(jnp.real) -def real(x): - return jnp.real(_as_jax_array_(x)) - - -@wraps(jnp.imag) -def imag(x): - return jnp.imag(_as_jax_array_(x)) - - -@wraps(jnp.conj) -def conj(x): - return jnp.conj(_as_jax_array_(x)) - - -@wraps(jnp.conjugate) -def conjugate(x): - return jnp.conjugate(_as_jax_array_(x)) - - -@wraps(jnp.ndim) -def ndim(x): - return jnp.ndim(_as_jax_array_(x)) - - -# 2. Arithmetic operations -@wraps(jnp.add) -def add(x, y): - return x + y - - -@wraps(jnp.reciprocal) -def reciprocal(x): - x = _as_jax_array_(x) - return Array(jnp.reciprocal(x)) - - -@wraps(jnp.negative) -def negative(x): - x = _as_jax_array_(x) - return Array(jnp.negative(x)) - - -@wraps(jnp.positive) -def positive(x): - x = _as_jax_array_(x) - return Array(jnp.positive(x)) - - -@wraps(jnp.multiply) -def multiply(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.multiply(x1, x2)) - - -@wraps(jnp.divide) -def divide(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.divide(x1, x2)) - - -@wraps(jnp.power) -def power(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.power(x1, x2)) - - -@wraps(jnp.subtract) -def subtract(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.subtract(x1, x2)) - - -@wraps(jnp.true_divide) -def true_divide(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.true_divide(x1, x2)) - - -@wraps(jnp.floor_divide) -def floor_divide(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.floor_divide(x1, x2)) - - -@wraps(jnp.float_power) -def float_power(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.float_power(x1, x2)) - - -@wraps(jnp.fmod) -def fmod(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.fmod(x1, x2)) - - -@wraps(jnp.mod) -def mod(x1, x2): - if isinstance(x1, Array): x1 = x1.value - x2 = _as_jax_array_(x2) - return Array(jnp.mod(x1, x2)) - - -@wraps(jnp.divmod) -def divmod(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - r = jnp.divmod(x1, x2) - return Array(r[0]), Array(r[1]) - - -@wraps(jnp.remainder) -def remainder(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.remainder(x1, x2)) - - -@wraps(jnp.modf) -def modf(x): - x = _as_jax_array_(x) - r = jnp.modf(x) - return Array(r[0]), Array(r[1]) - - -@wraps(jnp.abs) -def abs(x): - x = _as_jax_array_(x) - return Array(jnp.absolute(x)) - - -@wraps(jnp.absolute) -def absolute(x): - x = _as_jax_array_(x) - return Array(jnp.absolute(x)) - - -# 3. Exponents and logarithms -@wraps(jnp.exp) -def exp(x): - x = _as_jax_array_(x) - return Array(jnp.exp(x)) - - -@wraps(jnp.exp2) -def exp2(x): - x = _as_jax_array_(x) - return Array(jnp.exp2(x)) - - -@wraps(jnp.expm1) -def expm1(x): - x = _as_jax_array_(x) - return Array(jnp.expm1(x)) - - -@wraps(jnp.log) -def log(x): - x = _as_jax_array_(x) - return Array(jnp.log(x)) - - -@wraps(jnp.log10) -def log10(x): - x = _as_jax_array_(x) - return Array(jnp.log10(x)) - - -@wraps(jnp.log1p) -def log1p(x): - x = _as_jax_array_(x) - return Array(jnp.log1p(x)) - - -@wraps(jnp.log2) -def log2(x): - x = _as_jax_array_(x) - return Array(jnp.log2(x)) - - -@wraps(jnp.logaddexp) -def logaddexp(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.logaddexp(x1, x2)) - - -@wraps(jnp.logaddexp2) -def logaddexp2(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.logaddexp2(x1, x2)) - - -# 4. Rational routines -@wraps(jnp.lcm) -def lcm(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.lcm(x1, x2)) - - -@wraps(jnp.gcd) -def gcd(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.gcd(x1, x2)) - - -# 5. trigonometric functions -@wraps(jnp.arccos) -def arccos(x): - x = _as_jax_array_(x) - return Array(jnp.arccos(x)) - - -@wraps(jnp.arccosh) -def arccosh(x): - x = _as_jax_array_(x) - return Array(jnp.arccosh(x)) - - -@wraps(jnp.arcsin) -def arcsin(x): - x = _as_jax_array_(x) - return Array(jnp.arcsin(x)) - - -@wraps(jnp.arcsinh) -def arcsinh(x): - x = _as_jax_array_(x) - return Array(jnp.arcsinh(x)) - - -@wraps(jnp.arctan) -def arctan(x): - x = _as_jax_array_(x) - return Array(jnp.arctan(x)) - - -@wraps(jnp.arctan2) -def arctan2(x, y): - x = _as_jax_array_(x) - y = _as_jax_array_(y) - return Array(jnp.arctan2(x, y)) - - -@wraps(jnp.arctanh) -def arctanh(x): - x = _as_jax_array_(x) - return Array(jnp.arctanh(x)) - - -@wraps(jnp.cos) -def cos(x): - x = _as_jax_array_(x) - return Array(jnp.cos(x)) - - -@wraps(jnp.cosh) -def cosh(x): - x = _as_jax_array_(x) - return Array(jnp.cosh(x)) - - -@wraps(jnp.sin) -def sin(x): - x = _as_jax_array_(x) - return Array(jnp.sin(x)) - - -@wraps(jnp.sinc) -def sinc(x): - x = _as_jax_array_(x) - return Array(jnp.sinc(x)) - - -@wraps(jnp.sinh) -def sinh(x): - x = _as_jax_array_(x) - return Array(jnp.sinh(x)) - - -@wraps(jnp.tan) -def tan(x): - x = _as_jax_array_(x) - return Array(jnp.tan(x)) - - -@wraps(jnp.tanh) -def tanh(x): - x = _as_jax_array_(x) - return Array(jnp.tanh(x)) - - -@wraps(jnp.deg2rad) -def deg2rad(x): - x = _as_jax_array_(x) - return Array(jnp.deg2rad(x)) - - -@wraps(jnp.rad2deg) -def rad2deg(x): - x = _as_jax_array_(x) - return Array(jnp.rad2deg(x)) - - -@wraps(jnp.degrees) -def degrees(x): - x = _as_jax_array_(x) - return Array(jnp.degrees(x)) - - -@wraps(jnp.radians) -def radians(x): - x = _as_jax_array_(x) - return Array(jnp.radians(x)) - - -@wraps(jnp.hypot) -def hypot(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.hypot(x1, x2)) - - -# 6. Rounding -@wraps(jnp.round) -def round(a, decimals=0): - a = _as_jax_array_(a) - return Array(jnp.round(a, decimals=decimals)) - - -around = round -round_ = round - - -@wraps(jnp.rint) -def rint(x): - x = _as_jax_array_(x) - return Array(jnp.rint(x)) - - -@wraps(jnp.floor) -def floor(x): - x = _as_jax_array_(x) - return Array(jnp.floor(x)) - - -@wraps(jnp.ceil) -def ceil(x): - x = _as_jax_array_(x) - return Array(jnp.ceil(x)) - - -@wraps(jnp.trunc) -def trunc(x): - x = _as_jax_array_(x) - return Array(jnp.trunc(x)) - - -@wraps(jnp.fix) -def fix(x): - x = _as_jax_array_(x) - return Array(jnp.fix(x)) - - -# 7. Sums, products, differences, Reductions - - -@wraps(jnp.prod) -def prod(a, axis=None, dtype=None, keepdims=None, initial=None, where=None, **kwargs): - a = _as_jax_array_(a) - r = jnp.prod(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where, **kwargs) - return r if axis is None else Array(r) - - -product = prod - - -@wraps(jnp.sum) -def sum(a, axis=None, dtype=None, keepdims=None, initial=None, where=None, **kwargs): - a = _as_jax_array_(a) - r = jnp.sum(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where, **kwargs) - return r if axis is None else Array(r) - - -@wraps(jnp.diff) -def diff(a, n=1, axis: int = -1, prepend=None, append=None): - a = _as_jax_array_(a) - return Array(jnp.diff(a, n=n, axis=axis, prepend=prepend, append=append)) - - -@wraps(jnp.median) -def median(a, axis=None, keepdims=False, **kwargs): - a = _as_jax_array_(a) - r = jnp.median(a, axis=axis, keepdims=keepdims, **kwargs) - return r if axis is None else Array(r) - - -@wraps(jnp.nancumprod) -def nancumprod(a, axis=None, dtype=None): - a = _as_jax_array_(a) - return Array(jnp.nancumprod(a=a, axis=axis, dtype=dtype)) - - -@wraps(jnp.nancumsum) -def nancumsum(a, axis=None, dtype=None): - a = _as_jax_array_(a) - return Array(jnp.nancumsum(a=a, axis=axis, dtype=dtype)) - - -@wraps(jnp.cumprod) -def cumprod(a, axis=None, dtype=None): - a = _as_jax_array_(a) - return Array(jnp.cumprod(a=a, axis=axis, dtype=dtype)) - - -cumproduct = cumprod - - -@wraps(jnp.cumsum) -def cumsum(a, axis=None, dtype=None): - a = _as_jax_array_(a) - return Array(jnp.cumsum(a=a, axis=axis, dtype=dtype)) - - -@wraps(jnp.nanprod) -def nanprod(a, axis=None, dtype=None, keepdims=None, **kwargs): - a = _as_jax_array_(a) - r = jnp.nanprod(a=a, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) - return r if axis is None else Array(r) - - -@wraps(jnp.nansum) -def nansum(a, axis=None, dtype=None, keepdims=None, **kwargs): - a = _as_jax_array_(a) - r = jnp.nansum(a=a, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) - return r if axis is None else Array(r) - - -@wraps(jnp.ediff1d) -def ediff1d(a, to_end=None, to_begin=None): - a = _as_jax_array_(a) - to_end = _as_jax_array_(to_end) - to_begin = _as_jax_array_(to_begin) - return Array(jnp.ediff1d(a, to_end=to_end, to_begin=to_begin)) - - -@wraps(jnp.cross) -def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): - a = _as_jax_array_(a) - b = _as_jax_array_(b) - return Array(jnp.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis)) - - -@wraps(jnp.trapz) -def trapz(y, x=None, dx=1.0, axis: int = -1): - y = _as_jax_array_(y) - x = _as_jax_array_(x) - return jnp.trapz(y, x=x, dx=dx, axis=axis) - - -# 8. floating_functions -@wraps(jnp.isfinite) -def isfinite(x): - x = _as_jax_array_(x) - return Array(jnp.isfinite(x)) - - -@wraps(jnp.isinf) -def isinf(x): - x = _as_jax_array_(x) - return Array(jnp.isinf(x)) - - -@wraps(jnp.isnan) -def isnan(x): - x = _as_jax_array_(x) - return Array(jnp.isnan(x)) - - -@wraps(jnp.signbit) -def signbit(x): - x = _as_jax_array_(x) - return Array(jnp.signbit(x)) - - -@wraps(jnp.nextafter) -def nextafter(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.nextafter(x1, x2)) - - -@wraps(jnp.copysign) -def copysign(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.copysign(x1, x2)) - - -@wraps(jnp.ldexp) -def ldexp(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.ldexp(x1, x2)) - - -@wraps(jnp.frexp) -def frexp(x): - x = _as_jax_array_(x) - mantissa, exponent = jnp.frexp(x) - return Array(mantissa), Array(exponent) - - -# 9. Miscellaneous -@wraps(jnp.convolve) -def convolve(a, v, mode='full', **kwargs): - a = _as_jax_array_(a) - v = _as_jax_array_(v) - return Array(jnp.convolve(a, v, mode, **kwargs)) - - -@wraps(jnp.sqrt) -def sqrt(x): - x = _as_jax_array_(x) - return Array(jnp.sqrt(x)) - - -@wraps(jnp.cbrt) -def cbrt(x): - x = _as_jax_array_(x) - return Array(jnp.cbrt(x)) - - -@wraps(jnp.square) -def square(x): - x = _as_jax_array_(x) - return Array(jnp.square(x)) - - -@wraps(jnp.fabs) -def fabs(x): - x = _as_jax_array_(x) - return Array(jnp.fabs(x)) - - -@wraps(jnp.sign) -def sign(x): - x = _as_jax_array_(x) - return Array(jnp.sign(x)) - - -@wraps(jnp.heaviside) -def heaviside(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.heaviside(x1, x2)) - - -@wraps(jnp.maximum) -def maximum(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.maximum(x1, x2)) - - -@wraps(jnp.minimum) -def minimum(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.minimum(x1, x2)) - - -@wraps(jnp.fmax) -def fmax(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.fmax(x1, x2)) - - -@wraps(jnp.fmin) -def fmin(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.fmin(x1, x2)) - - -@wraps(jnp.interp) -def interp(x, xp, fp, left=None, right=None, period=None): - x = _as_jax_array_(x) - xp = _as_jax_array_(xp) - fp = _as_jax_array_(fp) - return Array(jnp.interp(x, xp, fp, left=left, right=right, period=period)) - - -@wraps(jnp.clip) -def clip(a, a_min=None, a_max=None): - a = _as_jax_array_(a) - a_min = _as_jax_array_(a_min) - a_max = _as_jax_array_(a_max) - return Array(jnp.clip(a, a_min, a_max)) - - -@wraps(jnp.angle) -def angle(z, deg=False): - z = _as_jax_array_(z) - a = jnp.angle(z) - if deg: - a *= 180 / pi - return Array(a) - - -# binary funcs -# ------------- - - -@wraps(jnp.bitwise_not) -def bitwise_not(x): - x = _as_jax_array_(x) - return Array(jnp.bitwise_not(x)) - - -@wraps(jnp.invert) -def invert(x): - x = _as_jax_array_(x) - return Array(jnp.invert(x)) - - -@wraps(jnp.bitwise_and) -def bitwise_and(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.bitwise_and(x1, x2)) - - -@wraps(jnp.bitwise_or) -def bitwise_or(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.bitwise_or(x1, x2)) - - -@wraps(jnp.bitwise_xor) -def bitwise_xor(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.bitwise_xor(x1, x2)) - - -@wraps(jnp.left_shift) -def left_shift(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.left_shift(x1, x2)) - - -@wraps(jnp.right_shift) -def right_shift(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.right_shift(x1, x2)) - - -# logic funcs -# ----------- - -# 1. Comparison -@wraps(jnp.equal) -def equal(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.equal(x1, x2)) - - -@wraps(jnp.not_equal) -def not_equal(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.not_equal(x1, x2)) - - -@wraps(jnp.greater) -def greater(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.greater(x1, x2)) - - -@wraps(jnp.greater_equal) -def greater_equal(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.greater_equal(x1, x2)) - - -@wraps(jnp.less) -def less(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.less(x1, x2)) - - -@wraps(jnp.less_equal) -def less_equal(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.less_equal(x1, x2)) - - -@wraps(jnp.array_equal) -def array_equal(a, b, equal_nan=False): - a = _as_jax_array_(a) - b = _as_jax_array_(b) - return jnp.array_equal(a, b, equal_nan=equal_nan) - - -@wraps(jnp.isclose) -def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - a = _as_jax_array_(a) - b = _as_jax_array_(b) - return Array(jnp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) - - -@wraps(jnp.allclose) -def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - a = _as_jax_array_(a) - b = _as_jax_array_(b) - return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - - -# 2. Logical operations -@wraps(jnp.logical_not) -def logical_not(x): - x = _as_jax_array_(x) - return Array(jnp.logical_not(x)) - - -@wraps(jnp.logical_and) -def logical_and(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.logical_and(x1, x2)) - - -@wraps(jnp.logical_or) -def logical_or(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.logical_or(x1, x2)) - - -@wraps(jnp.logical_xor) -def logical_xor(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.logical_xor(x1, x2)) - - -# 3. Truth value testing -@wraps(jnp.all) -def all(a, axis=None, keepdims=None, where=None): - a = _as_jax_array_(a) - r = jnp.all(a=a, axis=axis, keepdims=keepdims, where=where) - return r if axis is None else Array(r) - - -@wraps(jnp.any) -def any(a, axis=None, keepdims=None, where=None): - a = _as_jax_array_(a) - r = jnp.any(a=a, axis=axis, keepdims=keepdims, where=where) - return r if axis is None else Array(r) - - -alltrue = all -sometrue = any - - -# array manipulation -# ------------------ - - -@wraps(jnp.shape) -def shape(x): - x = _as_jax_array_(x) - return jnp.shape(x) - - -@wraps(jnp.size) -def size(x, axis=None): - x = _as_jax_array_(x) - r = jnp.size(x, axis=axis) - return r if axis is None else Array(r) - - -@wraps(jnp.reshape) -def reshape(x, newshape, order="C"): - x = _as_jax_array_(x) - return Array(jnp.reshape(x, newshape, order=order)) - - -@wraps(jnp.ravel) -def ravel(x, order="C"): - x = _as_jax_array_(x) - return Array(jnp.ravel(x, order=order)) - - -@wraps(jnp.moveaxis) -def moveaxis(x, source, destination): - x = _as_jax_array_(x) - return Array(jnp.moveaxis(x, source, destination)) - - -@wraps(jnp.transpose) -def transpose(x, axis=None): - x = _as_jax_array_(x) - return Array(jnp.transpose(x, axes=axis)) - - -@wraps(jnp.swapaxes) -def swapaxes(x, axis1, axis2): - x = _as_jax_array_(x) - return Array(jnp.swapaxes(x, axis1, axis2)) - - -@wraps(jnp.concatenate) -def concatenate(arrays, axis: int = 0): - arrays = [_as_jax_array_(a) for a in arrays] - return Array(jnp.concatenate(arrays, axis)) - - -@wraps(jnp.stack) -def stack(arrays, axis: int = 0): - arrays = [_as_jax_array_(a) for a in arrays] - return Array(jnp.stack(arrays, axis)) - - -@wraps(jnp.vstack) -def vstack(arrays): - arrays = [_as_jax_array_(a) for a in arrays] - return Array(jnp.vstack(arrays)) - - -row_stack = vstack - - -@wraps(jnp.hstack) -def hstack(arrays): - arrays = [_as_jax_array_(a) for a in arrays] - return Array(jnp.hstack(arrays)) - - -@wraps(jnp.dstack) -def dstack(arrays): - arrays = [_as_jax_array_(a) for a in arrays] - return Array(jnp.dstack(arrays)) - - -@wraps(jnp.column_stack) -def column_stack(arrays): - arrays = [_as_jax_array_(a) for a in arrays] - return Array(jnp.column_stack(arrays)) - - -@wraps(jnp.split) -def split(ary, indices_or_sections, axis=0): - if isinstance(ary, Array): ary = ary.value - if isinstance(indices_or_sections, Array): indices_or_sections = indices_or_sections.value - return [Array(a) for a in jnp.split(ary, indices_or_sections, axis=axis)] - - -@wraps(jnp.dsplit) -def dsplit(ary, indices_or_sections): - return split(ary, indices_or_sections, axis=2) - - -@wraps(jnp.hsplit) -def hsplit(ary, indices_or_sections): - return split(ary, indices_or_sections, axis=1) - - -@wraps(jnp.vsplit) -def vsplit(ary, indices_or_sections): - return split(ary, indices_or_sections, axis=0) - - -@wraps(jnp.tile) -def tile(A, reps): - A = _as_jax_array_(A) - return Array(jnp.tile(A, reps)) - - -@wraps(jnp.repeat) -def repeat(x, repeats, axis=None, **kwargs): - x = _as_jax_array_(x) - return Array(jnp.repeat(x, repeats=repeats, axis=axis, **kwargs)) - - -@wraps(jnp.unique) -def unique(x, return_index=False, return_inverse=False, - return_counts=False, axis=None, **kwargs): - x = _as_jax_array_(x) - res = jnp.unique(x, - return_index=return_index, - return_inverse=return_inverse, - return_counts=return_counts, - axis=axis, - **kwargs) - if isinstance(res, tuple): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.append) -def append(arr, values, axis=None): - arr = _as_jax_array_(arr) - values = _as_jax_array_(values) - return Array(jnp.append(arr, values, axis=axis)) - - -@wraps(jnp.flip) -def flip(x, axis=None): - x = _as_jax_array_(x) - return Array(jnp.flip(x, axis=axis)) - - -@wraps(jnp.fliplr) -def fliplr(x): - x = _as_jax_array_(x) - return Array(jnp.fliplr(x)) - - -@wraps(jnp.flipud) -def flipud(x): - x = _as_jax_array_(x) - return Array(jnp.flipud(x)) - - -@wraps(jnp.roll) -def roll(x, shift, axis=None): - x = _as_jax_array_(x) - return Array(jnp.roll(x, shift, axis=axis)) - - -@wraps(jnp.atleast_1d) -def atleast_1d(*arys): - return jnp.atleast_1d(*[_as_jax_array_(a) for a in arys]) - - -@wraps(jnp.atleast_2d) -def atleast_2d(*arys): - return jnp.atleast_2d(*[_as_jax_array_(a) for a in arys]) - - -@wraps(jnp.atleast_3d) -def atleast_3d(*arys): - return jnp.atleast_3d(*[_as_jax_array_(a) for a in arys]) - - -@wraps(jnp.expand_dims) -def expand_dims(x, axis): - x = _as_jax_array_(x) - return Array(jnp.expand_dims(x, axis=axis)) - - -@wraps(jnp.squeeze) -def squeeze(x, axis=None): - x = _as_jax_array_(x) - return Array(jnp.squeeze(x, axis=axis)) +import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_map +from ._utils import _compatible_with_brainpy_array +from .arraycreation import * +from .arrayinterporate import * +from .ndarray import Array -@wraps(jnp.sort) -def sort(x, axis=-1, kind='quicksort', order=None): - x = _as_jax_array_(x) - return Array(jnp.sort(x, axis=axis, kind=kind, order=order)) +__all__ = [ + 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + # math funcs + 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', + 'add', 'reciprocal', 'negative', 'positive', 'multiply', 'divide', + 'power', 'subtract', 'true_divide', 'floor_divide', 'float_power', + 'fmod', 'mod', 'modf', 'divmod', 'remainder', 'abs', 'exp', 'exp2', + 'expm1', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', + 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', + 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', + 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round', + 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod', + 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum', + 'cumprod', 'cumsum', 'ediff1d', 'cross', 'trapz', 'isfinite', 'isinf', + 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve', + 'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside', + 'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle', -@wraps(jnp.argsort) -def argsort(x, axis=-1, kind='stable', order=None): - x = _as_jax_array_(x) - return Array(jnp.argsort(x, axis=axis, kind=kind, order=order)) + # Elementwise bit operations + 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', + 'invert', 'left_shift', 'right_shift', + # logic funcs + 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', + 'array_equal', 'isclose', 'allclose', 'logical_and', 'logical_not', + 'logical_or', 'logical_xor', 'all', 'any', "alltrue", 'sometrue', -@wraps(jnp.argmax) -def argmax(x, axis=None, **kwargs): - x = _as_jax_array_(x) - r = jnp.argmax(x, axis=axis, **kwargs) - return r if axis is None else Array(r) + # array manipulation + 'shape', 'size', 'reshape', 'ravel', 'moveaxis', 'transpose', 'swapaxes', + 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', + 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', + 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', + 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', + 'argwhere', 'nonzero', 'flatnonzero', 'where', 'searchsorted', 'extract', + 'count_nonzero', 'max', 'min', 'amax', 'amin', + # array creation + 'array_split', 'meshgrid', 'vander', -@wraps(jnp.argmin) -def argmin(x, axis=None, **kwargs): - x = _as_jax_array_(x) - r = jnp.argmin(x, axis=axis, **kwargs) - return r if axis is None else Array(r) + # indexing funcs + 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', + 'triu_indices_from', 'take', 'select', + # statistic funcs + 'nanmin', 'nanmax', 'ptp', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + 'median', 'average', 'mean', 'std', 'var', 'nanmedian', 'nanmean', 'nanstd', 'nanvar', + 'corrcoef', 'correlate', 'cov', 'histogram', 'bincount', 'digitize', -@wraps(jnp.argwhere) -def argwhere(x, **kwargs): - x = _as_jax_array_(x) - return Array(jnp.argwhere(x, **kwargs)) + # window funcs + 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', + # constants + 'e', 'pi', 'inf', -@wraps(jnp.nonzero) -def nonzero(x, **kwargs): - x = _as_jax_array_(x) - res = jnp.nonzero(x, **kwargs) - return tuple([Array(r) for r in res]) if isinstance(res, tuple) else Array(res) + # linear algebra + 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', + # data types + 'dtype', 'finfo', 'iinfo', 'uint8', 'uint16', 'uint32', 'uint64', + 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', + 'float64', 'complex64', 'complex128', -@wraps(jnp.flatnonzero) -def flatnonzero(x, **kwargs): - x = _as_jax_array_(x) - return Array(jnp.flatnonzero(x, **kwargs)) + # more + 'product', 'row_stack', 'apply_over_axes', 'apply_along_axis', 'array_equiv', + 'array_repr', 'array_str', 'block', 'broadcast_arrays', 'broadcast_shapes', + 'broadcast_to', 'compress', 'cumproduct', 'diag_indices', 'diag_indices_from', + 'diagflat', 'diagonal', 'einsum', 'einsum_path', 'geomspace', 'gradient', + 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'i0', 'in1d', 'indices', + 'insert', 'intersect1d', 'iscomplex', 'isin', 'ix_', 'lexsort', 'load', + 'save', 'savez', 'mask_indices', 'msort', 'nan_to_num', 'nanargmax', 'setdiff1d', + 'nanargmin', 'pad', 'poly', 'polyadd', 'polyder', 'polyfit', 'polyint', + 'polymul', 'polysub', 'polyval', 'resize', 'rollaxis', 'roots', 'rot90', + 'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', + 'take_along_axis', 'can_cast', 'choose', 'copy', 'frombuffer', 'fromfile', + 'fromfunction', 'fromiter', 'fromstring', 'get_printoptions', 'iscomplexobj', + 'isneginf', 'isposinf', 'isrealobj', 'issubdtype', 'issubsctype', 'iterable', + 'packbits', 'piecewise', 'printoptions', 'set_printoptions', 'promote_types', + 'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete', + # unique + 'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'array2string', 'asanyarray', + 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'disp', 'genfromtxt', + 'loadtxt', 'info', 'issubclass_', 'place', 'polydiv', 'put', 'putmask', 'safe_eval', + 'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat', -@wraps(jnp.where) -def where(condition, x=None, y=None, **kwargs): - condition = _as_jax_array_(condition) - x = _as_jax_array_(x) - y = _as_jax_array_(y) - res = jnp.where(condition, x=x, y=y, **kwargs) - if isinstance(res, tuple): - return tuple(Array(r) for r in res) - else: - return Array(res) +] +_min = min +_max = max -@wraps(jnp.searchsorted) -def searchsorted(a, v, side='left', sorter=None): - a = _as_jax_array_(a) - v = _as_jax_array_(v) - return Array(jnp.searchsorted(a, v, side=side, sorter=sorter)) +def asanyarray(a, dtype=None, order=None): + return asarray(a, dtype=dtype, order=order) -@wraps(jnp.extract) -def extract(condition, arr): - condition = _as_jax_array_(condition) - arr = _as_jax_array_(arr) - return Array(jnp.extract(condition, arr)) +def ascontiguousarray(a, dtype=None, order=None): + return asarray(a, dtype=dtype, order=order) -@wraps(jnp.count_nonzero) -def count_nonzero(a, axis=None, keepdims=False): - a = _as_jax_array_(a) - return jnp.count_nonzero(a, axis=axis, keepdims=keepdims) +def asfarray(a, dtype=np.float_): + if not np.issubdtype(dtype, np.inexact): + dtype = np.float_ + return asarray(a, dtype=dtype) -@wraps(jnp.max) -def max(a, axis=None, out=None, keepdims=None, initial=None, where=None): - a = _as_jax_array_(a) - r = jnp.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - return r if axis is None else Array(r) +# Others +# ------ +meshgrid = _compatible_with_brainpy_array(jnp.meshgrid) +vander = _compatible_with_brainpy_array(jnp.vander) +full = _compatible_with_brainpy_array(jnp.full) +full_like = _compatible_with_brainpy_array(jnp.full_like) +eye = _compatible_with_brainpy_array(jnp.eye) +identity = _compatible_with_brainpy_array(jnp.identity) +diag = _compatible_with_brainpy_array(jnp.diag) +tri = _compatible_with_brainpy_array(jnp.tri) +tril = _compatible_with_brainpy_array(jnp.tril) +triu = _compatible_with_brainpy_array(jnp.triu) +delete = _compatible_with_brainpy_array(jnp.delete) +take_along_axis = _compatible_with_brainpy_array(jnp.take_along_axis) +block = _compatible_with_brainpy_array(jnp.block) +broadcast_arrays = _compatible_with_brainpy_array(jnp.broadcast_arrays) +broadcast_shapes = _compatible_with_brainpy_array(jnp.broadcast_shapes) +broadcast_to = _compatible_with_brainpy_array(jnp.broadcast_to) +compress = _compatible_with_brainpy_array(jnp.compress) +diag_indices = _compatible_with_brainpy_array(jnp.diag_indices) +diag_indices_from = _compatible_with_brainpy_array(jnp.diag_indices_from) +diagflat = _compatible_with_brainpy_array(jnp.diagflat) +diagonal = _compatible_with_brainpy_array(jnp.diagonal) +einsum = _compatible_with_brainpy_array(jnp.einsum) +einsum_path = _compatible_with_brainpy_array(jnp.einsum_path) +geomspace = _compatible_with_brainpy_array(jnp.geomspace) +gradient = _compatible_with_brainpy_array(jnp.gradient) +histogram2d = _compatible_with_brainpy_array(jnp.histogram2d) +histogram_bin_edges = _compatible_with_brainpy_array(jnp.histogram_bin_edges) +histogramdd = _compatible_with_brainpy_array(jnp.histogramdd) +i0 = _compatible_with_brainpy_array(jnp.i0) +in1d = _compatible_with_brainpy_array(jnp.in1d) +indices = _compatible_with_brainpy_array(jnp.indices) +insert = _compatible_with_brainpy_array(jnp.insert) +intersect1d = _compatible_with_brainpy_array(jnp.intersect1d) +iscomplex = _compatible_with_brainpy_array(jnp.iscomplex) +isin = _compatible_with_brainpy_array(jnp.isin) +ix_ = _compatible_with_brainpy_array(jnp.ix_) +lexsort = _compatible_with_brainpy_array(jnp.lexsort) +load = _compatible_with_brainpy_array(jnp.load) +save = _compatible_with_brainpy_array(jnp.save) +savez = _compatible_with_brainpy_array(jnp.savez) +mask_indices = _compatible_with_brainpy_array(jnp.mask_indices) +msort = _compatible_with_brainpy_array(jnp.msort) +nan_to_num = _compatible_with_brainpy_array(jnp.nan_to_num) +nanargmax = _compatible_with_brainpy_array(jnp.nanargmax) +nanargmin = _compatible_with_brainpy_array(jnp.nanargmin) +pad = _compatible_with_brainpy_array(jnp.pad) +poly = _compatible_with_brainpy_array(jnp.poly) +polyadd = _compatible_with_brainpy_array(jnp.polyadd) +polyder = _compatible_with_brainpy_array(jnp.polyder) +polyfit = _compatible_with_brainpy_array(jnp.polyfit) +polyint = _compatible_with_brainpy_array(jnp.polyint) +polymul = _compatible_with_brainpy_array(jnp.polymul) +polysub = _compatible_with_brainpy_array(jnp.polysub) +polyval = _compatible_with_brainpy_array(jnp.polyval) +resize = _compatible_with_brainpy_array(jnp.resize) +rollaxis = _compatible_with_brainpy_array(jnp.rollaxis) +roots = _compatible_with_brainpy_array(jnp.roots) +rot90 = _compatible_with_brainpy_array(jnp.rot90) +setdiff1d = _compatible_with_brainpy_array(jnp.setdiff1d) +setxor1d = _compatible_with_brainpy_array(jnp.setxor1d) +tensordot = _compatible_with_brainpy_array(jnp.tensordot) +trim_zeros = _compatible_with_brainpy_array(jnp.trim_zeros) +union1d = _compatible_with_brainpy_array(jnp.union1d) +unravel_index = _compatible_with_brainpy_array(jnp.unravel_index) +unwrap = _compatible_with_brainpy_array(jnp.unwrap) -@wraps(jnp.min) -def min(a, axis=None, out=None, keepdims=None, initial=None, where=None): - a = _as_jax_array_(a) - r = jnp.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - return r if axis is None else Array(r) +# math funcs +# ---------- +isreal = _compatible_with_brainpy_array(jnp.isreal) +isscalar = _compatible_with_brainpy_array(jnp.isscalar) +real = _compatible_with_brainpy_array(jnp.real) +imag = _compatible_with_brainpy_array(jnp.imag) +conj = _compatible_with_brainpy_array(jnp.conj) +conjugate = _compatible_with_brainpy_array(jnp.conjugate) +ndim = _compatible_with_brainpy_array(jnp.ndim) +add = _compatible_with_brainpy_array(jnp.add) +reciprocal = _compatible_with_brainpy_array(jnp.reciprocal) +negative = _compatible_with_brainpy_array(jnp.negative) +positive = _compatible_with_brainpy_array(jnp.positive) +multiply = _compatible_with_brainpy_array(jnp.multiply) +divide = _compatible_with_brainpy_array(jnp.divide) +power = _compatible_with_brainpy_array(jnp.power) +subtract = _compatible_with_brainpy_array(jnp.subtract) +true_divide = _compatible_with_brainpy_array(jnp.true_divide) +floor_divide = _compatible_with_brainpy_array(jnp.floor_divide) +float_power = _compatible_with_brainpy_array(jnp.float_power) +fmod = _compatible_with_brainpy_array(jnp.fmod) +mod = _compatible_with_brainpy_array(jnp.mod) +divmod = _compatible_with_brainpy_array(jnp.divmod) +remainder = _compatible_with_brainpy_array(jnp.remainder) +modf = _compatible_with_brainpy_array(jnp.modf) +abs = _compatible_with_brainpy_array(jnp.abs) +absolute = _compatible_with_brainpy_array(jnp.absolute) +exp = _compatible_with_brainpy_array(jnp.exp) +exp2 = _compatible_with_brainpy_array(jnp.exp2) +expm1 = _compatible_with_brainpy_array(jnp.expm1) +log = _compatible_with_brainpy_array(jnp.log) +log10 = _compatible_with_brainpy_array(jnp.log10) +log1p = _compatible_with_brainpy_array(jnp.log1p) +log2 = _compatible_with_brainpy_array(jnp.log2) +logaddexp = _compatible_with_brainpy_array(jnp.logaddexp) +logaddexp2 = _compatible_with_brainpy_array(jnp.logaddexp2) +lcm = _compatible_with_brainpy_array(jnp.lcm) +gcd = _compatible_with_brainpy_array(jnp.gcd) +arccos = _compatible_with_brainpy_array(jnp.arccos) +arccosh = _compatible_with_brainpy_array(jnp.arccosh) +arcsin = _compatible_with_brainpy_array(jnp.arcsin) +arcsinh = _compatible_with_brainpy_array(jnp.arcsinh) +arctan = _compatible_with_brainpy_array(jnp.arctan) +arctan2 = _compatible_with_brainpy_array(jnp.arctan2) +arctanh = _compatible_with_brainpy_array(jnp.arctanh) +cos = _compatible_with_brainpy_array(jnp.cos) +cosh = _compatible_with_brainpy_array(jnp.cosh) +sin = _compatible_with_brainpy_array(jnp.sin) +sinc = _compatible_with_brainpy_array(jnp.sinc) +sinh = _compatible_with_brainpy_array(jnp.sinh) +tan = _compatible_with_brainpy_array(jnp.tan) +tanh = _compatible_with_brainpy_array(jnp.tanh) +deg2rad = _compatible_with_brainpy_array(jnp.deg2rad) +rad2deg = _compatible_with_brainpy_array(jnp.rad2deg) +degrees = _compatible_with_brainpy_array(jnp.degrees) +radians = _compatible_with_brainpy_array(jnp.radians) +hypot = _compatible_with_brainpy_array(jnp.hypot) +round = _compatible_with_brainpy_array(jnp.round) +around = round +round_ = round +rint = _compatible_with_brainpy_array(jnp.rint) +floor = _compatible_with_brainpy_array(jnp.floor) +ceil = _compatible_with_brainpy_array(jnp.ceil) +trunc = _compatible_with_brainpy_array(jnp.trunc) +fix = _compatible_with_brainpy_array(jnp.fix) +prod = _compatible_with_brainpy_array(jnp.prod) +sum = _compatible_with_brainpy_array(jnp.sum) +diff = _compatible_with_brainpy_array(jnp.diff) +median = _compatible_with_brainpy_array(jnp.median) +nancumprod = _compatible_with_brainpy_array(jnp.nancumprod) +nancumsum = _compatible_with_brainpy_array(jnp.nancumsum) +cumprod = _compatible_with_brainpy_array(jnp.cumprod) +cumproduct = cumprod +cumsum = _compatible_with_brainpy_array(jnp.cumsum) +nanprod = _compatible_with_brainpy_array(jnp.nanprod) +nansum = _compatible_with_brainpy_array(jnp.nansum) +ediff1d = _compatible_with_brainpy_array(jnp.ediff1d) +cross = _compatible_with_brainpy_array(jnp.cross) +trapz = _compatible_with_brainpy_array(jnp.trapz) +isfinite = _compatible_with_brainpy_array(jnp.isfinite) +isinf = _compatible_with_brainpy_array(jnp.isinf) +isnan = _compatible_with_brainpy_array(jnp.isnan) +signbit = _compatible_with_brainpy_array(jnp.signbit) +nextafter = _compatible_with_brainpy_array(jnp.nextafter) +copysign = _compatible_with_brainpy_array(jnp.copysign) +ldexp = _compatible_with_brainpy_array(jnp.ldexp) +frexp = _compatible_with_brainpy_array(jnp.frexp) +convolve = _compatible_with_brainpy_array(jnp.convolve) +sqrt = _compatible_with_brainpy_array(jnp.sqrt) +cbrt = _compatible_with_brainpy_array(jnp.cbrt) +square = _compatible_with_brainpy_array(jnp.square) +fabs = _compatible_with_brainpy_array(jnp.fabs) +sign = _compatible_with_brainpy_array(jnp.sign) +heaviside = _compatible_with_brainpy_array(jnp.heaviside) +maximum = _compatible_with_brainpy_array(jnp.maximum) +minimum = _compatible_with_brainpy_array(jnp.minimum) +fmax = _compatible_with_brainpy_array(jnp.fmax) +fmin = _compatible_with_brainpy_array(jnp.fmin) +interp = _compatible_with_brainpy_array(jnp.interp) +clip = _compatible_with_brainpy_array(jnp.clip) +angle = _compatible_with_brainpy_array(jnp.angle) +bitwise_not = _compatible_with_brainpy_array(jnp.bitwise_not) +invert = _compatible_with_brainpy_array(jnp.invert) +bitwise_and = _compatible_with_brainpy_array(jnp.bitwise_and) +bitwise_or = _compatible_with_brainpy_array(jnp.bitwise_or) +bitwise_xor = _compatible_with_brainpy_array(jnp.bitwise_xor) +left_shift = _compatible_with_brainpy_array(jnp.left_shift) +right_shift = _compatible_with_brainpy_array(jnp.right_shift) +equal = _compatible_with_brainpy_array(jnp.equal) +not_equal = _compatible_with_brainpy_array(jnp.not_equal) +greater = _compatible_with_brainpy_array(jnp.greater) +greater_equal = _compatible_with_brainpy_array(jnp.greater_equal) +less = _compatible_with_brainpy_array(jnp.less) +less_equal = _compatible_with_brainpy_array(jnp.less_equal) +array_equal = _compatible_with_brainpy_array(jnp.array_equal) +isclose = _compatible_with_brainpy_array(jnp.isclose) +allclose = _compatible_with_brainpy_array(jnp.allclose) +logical_not = _compatible_with_brainpy_array(jnp.logical_not) +logical_and = _compatible_with_brainpy_array(jnp.logical_and) +logical_or = _compatible_with_brainpy_array(jnp.logical_or) +logical_xor = _compatible_with_brainpy_array(jnp.logical_xor) +all = _compatible_with_brainpy_array(jnp.all) +any = _compatible_with_brainpy_array(jnp.any) +alltrue = all +sometrue = any +# array manipulation +# ------------------ +shape = _compatible_with_brainpy_array(jnp.shape) +size = _compatible_with_brainpy_array(jnp.size) +reshape = _compatible_with_brainpy_array(jnp.reshape) +ravel = _compatible_with_brainpy_array(jnp.ravel) +moveaxis = _compatible_with_brainpy_array(jnp.moveaxis) +transpose = _compatible_with_brainpy_array(jnp.transpose) +swapaxes = _compatible_with_brainpy_array(jnp.swapaxes) +concatenate = _compatible_with_brainpy_array(jnp.concatenate) +stack = _compatible_with_brainpy_array(jnp.stack) +vstack = _compatible_with_brainpy_array(jnp.vstack) +product = prod +row_stack = vstack +hstack = _compatible_with_brainpy_array(jnp.hstack) +dstack = _compatible_with_brainpy_array(jnp.dstack) +column_stack = _compatible_with_brainpy_array(jnp.column_stack) +split = _compatible_with_brainpy_array(jnp.split) +dsplit = _compatible_with_brainpy_array(jnp.dsplit) +hsplit = _compatible_with_brainpy_array(jnp.hsplit) +vsplit = _compatible_with_brainpy_array(jnp.vsplit) +tile = _compatible_with_brainpy_array(jnp.tile) +repeat = _compatible_with_brainpy_array(jnp.repeat) +unique = _compatible_with_brainpy_array(jnp.unique) +append = _compatible_with_brainpy_array(jnp.append) +flip = _compatible_with_brainpy_array(jnp.flip) +fliplr = _compatible_with_brainpy_array(jnp.fliplr) +flipud = _compatible_with_brainpy_array(jnp.flipud) +roll = _compatible_with_brainpy_array(jnp.roll) +atleast_1d = _compatible_with_brainpy_array(jnp.atleast_1d) +atleast_2d = _compatible_with_brainpy_array(jnp.atleast_2d) +atleast_3d = _compatible_with_brainpy_array(jnp.atleast_3d) +expand_dims = _compatible_with_brainpy_array(jnp.expand_dims) +squeeze = _compatible_with_brainpy_array(jnp.squeeze) +sort = _compatible_with_brainpy_array(jnp.sort) +argsort = _compatible_with_brainpy_array(jnp.argsort) +argmax = _compatible_with_brainpy_array(jnp.argmax) +argmin = _compatible_with_brainpy_array(jnp.argmin) +argwhere = _compatible_with_brainpy_array(jnp.argwhere) +nonzero = _compatible_with_brainpy_array(jnp.nonzero) +flatnonzero = _compatible_with_brainpy_array(jnp.flatnonzero) +where = _compatible_with_brainpy_array(jnp.where) +searchsorted = _compatible_with_brainpy_array(jnp.searchsorted) +extract = _compatible_with_brainpy_array(jnp.extract) +count_nonzero = _compatible_with_brainpy_array(jnp.count_nonzero) +max = _compatible_with_brainpy_array(jnp.max) +min = _compatible_with_brainpy_array(jnp.min) amax = max amin = min - - -@wraps(jnp.apply_along_axis) -def apply_along_axis(func1d, axis: int, arr, *args, **kwargs): - arr = _as_jax_array_(arr) - return jnp.apply_along_axis(func1d, axis, arr, *args, **kwargs) - - -@wraps(jnp.apply_over_axes) -def apply_over_axes(func, a, axes): - a = _as_jax_array_(a) - return jnp.apply_over_axes(func, a, axes) - - -@wraps(jnp.array_equiv) -def array_equiv(a1, a2): - try: - a1, a2 = asarray(a1), asarray(a2) - except Exception: - return False - try: - eq = equal(a1, a2) - except ValueError: - # shapes are not broadcastable - return False - return all(eq) - - -@wraps(jnp.array_repr) -def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): - arr = _as_jax_array_(arr) - return jnp.array_repr(arr, max_line_width=max_line_width, precision=precision, suppress_small=suppress_small) - - -@wraps(jnp.array_str) -def array_str(a, max_line_width=None, precision=None, suppress_small=None): - a = _as_jax_array_(a) - return jnp.array_str(a, max_line_width=max_line_width, precision=precision, suppress_small=suppress_small) - - -@wraps(jnp.array_split) -def array_split(ary, indices_or_sections, axis: int = 0): - ary = _as_jax_array_(ary) - if isinstance(indices_or_sections, Array): - indices_or_sections = indices_or_sections.value - elif isinstance(indices_or_sections, (tuple, list)): - indices_or_sections = [_as_jax_array_(i) for i in indices_or_sections] - return tuple([Array(a) for a in jnp.array_split(ary, indices_or_sections, axis)]) - - -@wraps(jnp.meshgrid) -def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): - xi = [_as_jax_array_(x) for x in xi] - rr = jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) - return list(Array(r) for r in rr) - - -@wraps(jnp.vander) -def vander(x, N=None, increasing=False): - x = _as_jax_array_(x) - return Array(jnp.vander(x, N=N, increasing=increasing)) - - +apply_along_axis = _compatible_with_brainpy_array(jnp.apply_along_axis) +apply_over_axes = _compatible_with_brainpy_array(jnp.apply_over_axes) +array_equiv = _compatible_with_brainpy_array(jnp.array_equiv) +array_repr = _compatible_with_brainpy_array(jnp.array_repr) +array_str = _compatible_with_brainpy_array(jnp.array_str) +array_split = _compatible_with_brainpy_array(jnp.array_split) # indexing funcs # -------------- tril_indices = jnp.tril_indices triu_indices = jnp.triu_indices - - -@wraps(jnp.tril_indices_from) -def tril_indices_from(x, k=0): - x = _as_jax_array_(x) - res = jnp.tril_indices_from(x, k=k) - if isinstance(res, tuple): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.triu_indices_from) -def triu_indices_from(x, k=0): - x = _as_jax_array_(x) - res = jnp.triu_indices_from(x, k=k) - if isinstance(res, tuple): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -@wraps(jnp.take) -def take(x, indices, axis=None, mode=None): - x = _as_jax_array_(x) - indices = _as_jax_array_(indices) - return Array(jnp.take(x, indices=indices, axis=axis, mode=mode)) - - -@wraps(jnp.select) -def select(condlist, choicelist, default=0): - condlist = [_as_jax_array_(c) for c in condlist] - choicelist = [_as_jax_array_(c) for c in choicelist] - return Array(jnp.select(condlist, choicelist, default=default)) - - -# statistic funcs -# --------------- -@wraps(jnp.nanmin) -def nanmin(x, axis=None, keepdims=None, **kwargs): - x = _as_jax_array_(x) - r = jnp.nanmin(x, axis=axis, keepdims=keepdims, **kwargs) - return r if axis is None else Array(r) - - -@wraps(jnp.nanmax) -def nanmax(x, axis=None, keepdims=None, **kwargs): - x = _as_jax_array_(x) - r = jnp.nanmax(x, axis=axis, keepdims=keepdims, **kwargs) - return r if axis is None else Array(r) - - -@wraps(jnp.ptp) -def ptp(x, axis=None, keepdims=None): - x = _as_jax_array_(x) - r = jnp.ptp(x, axis=axis, keepdims=keepdims) - return r if axis is None else Array(r) - - -@wraps(jnp.percentile) -def percentile(a, - q, - axis=None, - out=None, - overwrite_input: bool = False, - method: str = "linear", - keepdims: bool = False, - interpolation=None): - a = _as_jax_array_(a) - q = _as_jax_array_(q) - r = jnp.percentile(a=a, - q=q, - axis=axis, - out=out, - overwrite_input=overwrite_input, - method=method, - keepdims=keepdims, - interpolation=interpolation) - return r if axis is None else Array(r) - - -@wraps(jnp.nanpercentile) -def nanpercentile(a, - q, - axis=None, - out=None, - overwrite_input: bool = False, - method: str = "linear", - keepdims: bool = False, - interpolation=None): - a = _as_jax_array_(a) - q = _as_jax_array_(q) - r = jnp.nanpercentile(a=a, - q=q, - axis=axis, - out=out, - overwrite_input=overwrite_input, - method=method, - keepdims=keepdims, - interpolation=interpolation) - return r if axis is None else Array(r) - - -@wraps(jnp.quantile) -def quantile(a, q, axis=None, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, - interpolation=None): - a = _as_jax_array_(a) - q = _as_jax_array_(q) - r = jnp.quantile(a=a, q=q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims, - interpolation=interpolation) - return r if axis is None else Array(r) - - -@wraps(jnp.nanquantile) -def nanquantile(a, q, axis=None, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, - interpolation=None): - a = _as_jax_array_(a) - q = _as_jax_array_(q) - r = jnp.nanquantile(a=a, q=q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims, - interpolation=interpolation) - return r if axis is None else Array(r) - - -@wraps(jnp.average) -def average(a, axis=None, weights=None, returned=False): - a = _as_jax_array_(a) - weights = _as_jax_array_(weights) - r = jnp.average(a, axis=axis, weights=weights, returned=returned) - if axis is None: - return r - elif isinstance(r, tuple): - return tuple(Array(_r) for _r in r) - else: - return Array(r) - - -@wraps(jnp.mean) -def mean(a, axis=None, dtype=None, keepdims=None, where=None): - a = _as_jax_array_(a) - r = jnp.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, where=where) - return r if axis is None else Array(r) - - -@wraps(jnp.std) -def std(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _as_jax_array_(a) - r = jnp.std(a=a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) - return r if axis is None else Array(r) - - -@wraps(jnp.var) -def var(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _as_jax_array_(a) - r = jnp.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) - return r if axis is None else Array(r) - - -@wraps(jnp.nanmedian) -def nanmedian(a, axis=None, keepdims=False): - return nanquantile(a, 0.5, axis=axis, keepdims=keepdims, interpolation='midpoint') - - -@wraps(jnp.nanmean) -def nanmean(a, axis=None, dtype=None, keepdims=None, **kwargs): - a = _as_jax_array_(a) - r = jnp.nanmean(a, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) - return r if axis is None else Array(r) - - -@wraps(jnp.nanstd) -def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _as_jax_array_(a) - r = jnp.nanstd(a=a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) - return r if axis is None else Array(r) - - -@wraps(jnp.nanvar) -def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=None, where=None): - a = _as_jax_array_(a) - r = jnp.nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where) - return r if axis is None else Array(r) - - -@wraps(jnp.corrcoef) -def corrcoef(x, y=None, rowvar=True): - x = _as_jax_array_(x) - y = _as_jax_array_(y) - return Array(jnp.corrcoef(x, y, rowvar)) - - -@wraps(jnp.correlate) -def correlate(a, v, mode='valid', **kwargs): - a = _as_jax_array_(a) - v = _as_jax_array_(v) - return Array(jnp.correlate(a, v, mode, **kwargs)) - - -@wraps(jnp.cov) -def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None): - m = _as_jax_array_(m) - y = _as_jax_array_(y) - fweights = _as_jax_array_(fweights) - aweights = _as_jax_array_(aweights) - return Array(jnp.cov(m, y=y, rowvar=rowvar, bias=bias, ddof=ddof, - fweights=fweights, aweights=aweights)) - - -@wraps(jnp.histogram) -def histogram(a, bins=10, range=None, weights=None, density=None): - a = _as_jax_array_(a) - weights = _as_jax_array_(weights) - hist, bin_edges = jnp.histogram(a=a, bins=bins, range=range, weights=weights, density=density) - return Array(hist), Array(bin_edges) - - -@wraps(jnp.bincount) -def bincount(x, weights=None, minlength=0, length=None, **kwargs): - x = _as_jax_array_(x) - weights = _as_jax_array_(weights) - res = jnp.bincount(x, weights=weights, minlength=minlength, length=length, **kwargs) - return Array(res) - - -@wraps(jnp.digitize) -def digitize(x, bins, right=False): - x = _as_jax_array_(x) - bins = _as_jax_array_(bins) - return Array(jnp.digitize(x, bins=bins, right=right)) - - -@wraps(jnp.bartlett) -def bartlett(M): - return Array(jnp.bartlett(M)) - - -@wraps(jnp.blackman) -def blackman(M): - return Array(jnp.blackman(M)) - - -@wraps(jnp.hamming) -def hamming(M): - return Array(jnp.hamming(M)) - - -@wraps(jnp.hanning) -def hanning(M): - return Array(jnp.hanning(M)) - - -@wraps(jnp.kaiser) -def kaiser(M, beta): - return Array(jnp.kaiser(M, beta)) - +tril_indices_from = _compatible_with_brainpy_array(jnp.tril_indices_from) +triu_indices_from = _compatible_with_brainpy_array(jnp.triu_indices_from) +take = _compatible_with_brainpy_array(jnp.take) +select = _compatible_with_brainpy_array(jnp.select) +nanmin = _compatible_with_brainpy_array(jnp.nanmin) +nanmax = _compatible_with_brainpy_array(jnp.nanmax) +ptp = _compatible_with_brainpy_array(jnp.ptp) +percentile = _compatible_with_brainpy_array(jnp.percentile) +nanpercentile = _compatible_with_brainpy_array(jnp.nanpercentile) +quantile = _compatible_with_brainpy_array(jnp.quantile) +nanquantile = _compatible_with_brainpy_array(jnp.nanquantile) +average = _compatible_with_brainpy_array(jnp.average) +mean = _compatible_with_brainpy_array(jnp.mean) +std = _compatible_with_brainpy_array(jnp.std) +var = _compatible_with_brainpy_array(jnp.var) +nanmedian = _compatible_with_brainpy_array(jnp.nanmedian) +nanmean = _compatible_with_brainpy_array(jnp.nanmean) +nanstd = _compatible_with_brainpy_array(jnp.nanstd) +nanvar = _compatible_with_brainpy_array(jnp.nanvar) +corrcoef = _compatible_with_brainpy_array(jnp.corrcoef) +correlate = _compatible_with_brainpy_array(jnp.correlate) +cov = _compatible_with_brainpy_array(jnp.cov) +histogram = _compatible_with_brainpy_array(jnp.histogram) +bincount = _compatible_with_brainpy_array(jnp.bincount) +digitize = _compatible_with_brainpy_array(jnp.digitize) +bartlett = _compatible_with_brainpy_array(jnp.bartlett) +blackman = _compatible_with_brainpy_array(jnp.blackman) +hamming = _compatible_with_brainpy_array(jnp.hamming) +hanning = _compatible_with_brainpy_array(jnp.hanning) +kaiser = _compatible_with_brainpy_array(jnp.kaiser) # constants # --------- @@ -1951,58 +409,16 @@ def kaiser(M, beta): pi = jnp.pi inf = jnp.inf - # linear algebra # -------------- - -@wraps(jnp.dot) -def dot(x1, x2, **kwargs): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.dot(x1, x2, **kwargs)) - - -@wraps(jnp.vdot) -def vdot(x1, x2, **kwargs): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.vdot(x1, x2, **kwargs)) - - -@wraps(jnp.inner) -def inner(x1, x2, **kwargs): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.inner(x1, x2, **kwargs)) - - -@wraps(jnp.outer) -def outer(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.outer(x1, x2)) - - -@wraps(jnp.kron) -def kron(x1, x2): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.kron(x1, x2)) - - -@wraps(jnp.matmul) -def matmul(x1, x2, **kwargs): - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - return Array(jnp.matmul(x1, x2, **kwargs)) - - -@wraps(jnp.trace) -def trace(x, offset=0, axis1=0, axis2=1, dtype=None): - x = _as_jax_array_(x) - return Array(jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)) - +dot = _compatible_with_brainpy_array(jnp.dot) +vdot = _compatible_with_brainpy_array(jnp.vdot) +inner = _compatible_with_brainpy_array(jnp.inner) +outer = _compatible_with_brainpy_array(jnp.outer) +kron = _compatible_with_brainpy_array(jnp.kron) +matmul = _compatible_with_brainpy_array(jnp.matmul) +trace = _compatible_with_brainpy_array(jnp.trace) # data types # ---------- @@ -2025,150 +441,31 @@ def trace(x, offset=0, axis1=0, axis2=1, dtype=None): complex64 = jnp.complex64 complex128 = jnp.complex128 - -# -@wraps(jnp.can_cast) -def can_cast(from_, to, casting=None): - """ can_cast(from_, to, casting='safe') - - Returns True if cast between data types can occur according to the - casting rule. If from is a scalar or array scalar, also returns - True if the scalar value can be cast without overflow or truncation - to an integer. - - Parameters - ---------- - from_ : dtype, dtype specifier, scalar, or array - Data type, scalar, or array to cast from. - to : dtype or dtype specifier - Data type to cast to. - casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional - Controls what kind of data casting may occur. - - * 'no' means the data types should not be cast at all. - * 'equiv' means only byte-order changes are allowed. - * 'safe' means only casts which can preserve values are allowed. - * 'same_kind' means only safe casts or casts within a kind, - like float64 to float32, are allowed. - * 'unsafe' means any data conversions may be done. - - Returns - ------- - out : bool - True if cast can occur according to the casting rule. - - """ - from_ = _as_jax_array_(from_) - to = _as_jax_array_(to) - return jnp.can_cast(from_, to, casting=casting) - - -@wraps(jnp.choose) -def choose(a, choices, mode='raise'): - a = _as_jax_array_(a) - choices = [_as_jax_array_(c) for c in choices] - return jnp.choose(a, choices, mode=mode) - - -def copy(a, order=None): - return array(a, copy=True, order=order) - - -def frombuffer(buffer, dtype=float, count=-1, offset=0): - return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) - - -def fromfile(file, dtype=None, count=-1, sep='', offset=0, *args, **kwargs): - return asarray(np.fromfile(file, dtype=dtype, count=count, sep=sep, offset=offset, *args, **kwargs)) - - -@wraps(jnp.fromfunction) -def fromfunction(function, shape, dtype=float, **kwargs): - return jnp.fromfunction(function, shape, dtype=dtype, **kwargs) - - -def fromiter(iterable, dtype, count=-1, *args, **kwargs): - iterable = _as_jax_array_(iterable) - return asarray(np.fromiter(iterable, dtype=dtype, count=count, *args, **kwargs)) - - -def fromstring(string, dtype=float, count=-1, *, sep): - return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) - - +can_cast = _compatible_with_brainpy_array(jnp.can_cast) +choose = _compatible_with_brainpy_array(jnp.choose) +copy = _compatible_with_brainpy_array(jnp.copy) +frombuffer = _compatible_with_brainpy_array(jnp.frombuffer) +fromfile = _compatible_with_brainpy_array(jnp.fromfile) +fromfunction = _compatible_with_brainpy_array(jnp.fromfunction) +fromiter = _compatible_with_brainpy_array(jnp.fromiter) +fromstring = _compatible_with_brainpy_array(jnp.fromstring) get_printoptions = np.get_printoptions - - -def iscomplexobj(x): - return np.iscomplexobj(_as_jax_array_(x)) - - -@wraps(jnp.isneginf) -def isneginf(x): - return Array(jnp.isneginf(_as_jax_array_(x))) - - -@wraps(jnp.isposinf) -def isposinf(x): - return Array(jnp.isposinf(_as_jax_array_(x))) - - -def isrealobj(x): - return not iscomplexobj(x) - - +iscomplexobj = _compatible_with_brainpy_array(jnp.iscomplexobj) +isneginf = _compatible_with_brainpy_array(jnp.isneginf) +isposinf = _compatible_with_brainpy_array(jnp.isposinf) +isrealobj = _compatible_with_brainpy_array(jnp.isrealobj) issubdtype = jnp.issubdtype issubsctype = jnp.issubsctype - - -def iterable(x): - return np.iterable(_as_jax_array_(x)) - - -@wraps(jnp.packbits) -def packbits(a, axis: Optional[int] = None, bitorder='big'): - return Array(jnp.packbits(_as_jax_array_(a), axis=axis, bitorder=bitorder)) - - -@wraps(jnp.piecewise) -def piecewise(x, condlist, funclist, *args, **kw): - condlist = asarray(condlist, dtype=bool) - return Array(jnp.piecewise(_as_jax_array_(x), condlist.value, funclist, *args, **kw)) - - +iterable = _compatible_with_brainpy_array(jnp.iterable) +packbits = _compatible_with_brainpy_array(jnp.packbits) +piecewise = _compatible_with_brainpy_array(jnp.piecewise) printoptions = np.printoptions set_printoptions = np.set_printoptions - - -@wraps(jnp.promote_types) -def promote_types(a, b): - a = _as_jax_array_(a) - b = _as_jax_array_(b) - return jnp.promote_types(a, b) - - -@wraps(jnp.ravel_multi_index) -def ravel_multi_index(multi_index, dims, mode='raise', order='C'): - multi_index = [_as_jax_array_(i) for i in multi_index] - return Array(jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)) - - -@wraps(jnp.result_type) -def result_type(*args): - args = [_as_jax_array_(a) for a in args] - return jnp.result_type(*args) - - -@wraps(jnp.sort_complex) -def sort_complex(a): - return Array(jnp.sort_complex(_as_jax_array_(a))) - - -@wraps(jnp.unpackbits) -def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'): - a = _as_jax_array_(a) - return Array(jnp.unpackbits(a, axis, count=count, bitorder=bitorder)) - +promote_types = _compatible_with_brainpy_array(jnp.promote_types) +ravel_multi_index = _compatible_with_brainpy_array(jnp.ravel_multi_index) +result_type = _compatible_with_brainpy_array(jnp.result_type) +sort_complex = _compatible_with_brainpy_array(jnp.sort_complex) +unpackbits = _compatible_with_brainpy_array(jnp.unpackbits) # Unique APIs # ----------- @@ -2178,7 +475,6 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'): add_newdoc_ufunc = np.add_newdoc_ufunc -@wraps(np.array2string) def array2string(a, max_line_width=None, precision=None, suppress_small=None, separator=' ', prefix="", style=np._NoValue, formatter=None, threshold=None, @@ -2192,23 +488,6 @@ def array2string(a, max_line_width=None, precision=None, legacy=legacy) -@wraps(np.asanyarray) -def asanyarray(a, dtype=None, order=None): - return asarray(a, dtype=dtype, order=order) - - -@wraps(np.ascontiguousarray) -def ascontiguousarray(a, dtype=None, order=None): - return asarray(a, dtype=dtype, order=order) - - -@wraps(np.asfarray) -def asfarray(a, dtype=np.float_): - if not np.issubdtype(dtype, np.inexact): - dtype = np.float_ - return asarray(a, dtype=dtype) - - def asscalar(a): return a.item() @@ -2224,7 +503,6 @@ def asscalar(a): np.clongdouble: 3} -@wraps(np.common_type) def common_type(*arrays): is_complex = False precision = 0 @@ -2254,116 +532,33 @@ def common_type(*arrays): issubclass_ = np.issubclass_ -@wraps(np.place) def place(arr, mask, vals): if not isinstance(arr, Array): - raise ValueError(f'Must be an instance of {Array.__name__}, but we got {type(arr)}') + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(arr)}') arr[mask] = vals -@wraps(jnp.polydiv) -def polydiv(u, v, **kwargs): - u = _as_jax_array_(u) - v = _as_jax_array_(v) - res = jnp.polydiv(u, v, **kwargs) - if isinstance(res, tuple): - return tuple(Array(r) for r in res) - else: - return Array(res) - - -# @wraps(np.polydiv) -# def polydiv(u, v, **kwargs): -# """ -# Returns the quotient and remainder of polynomial division. -# -# .. note:: -# This forms part of the old polynomial API. Since version 1.4, the -# new polynomial API defined in `numpy.polynomial` is preferred. -# A summary of the differences can be found in the -# :doc:`transition guide `. -# -# The input arrays are the coefficients (including any coefficients -# equal to zero) of the "numerator" (dividend) and "denominator" -# (divisor) polynomials, respectively. -# -# Parameters -# ---------- -# u : array_like -# Dividend polynomial's coefficients. -# -# v : array_like -# Divisor polynomial's coefficients. -# -# Returns -# ------- -# q : ArrayType -# Coefficients, including those equal to zero, of the quotient. -# r : ArrayType -# Coefficients, including those equal to zero, of the remainder. -# -# See Also -# -------- -# poly, polyadd, polyder, polydiv, polyfit, polyint, polymul, polysub -# polyval -# -# Notes -# ----- -# Both `u` and `v` must be 0-d or 1-d (ndim = 0 or 1), but `u.ndim` need -# not equal `v.ndim`. In other words, all four possible combinations - -# ``u.ndim = v.ndim = 0``, ``u.ndim = v.ndim = 1``, -# ``u.ndim = 1, v.ndim = 0``, and ``u.ndim = 0, v.ndim = 1`` - work. -# -# Examples -# -------- -# .. math:: \\frac{3x^2 + 5x + 2}{2x + 1} = 1.5x + 1.75, remainder 0.25 -# -# >>> x = bm.array([3.0, 5.0, 2.0]) -# >>> y = bm.array([2.0, 1.0]) -# >>> bm.polydiv(x, y) -# (ArrayType([1.5 , 1.75]), ArrayType([0.25])) -# -# """ -# u = atleast_1d(u) + 0.0 -# v = atleast_1d(v) + 0.0 -# # w has the common type -# w = u[0] + v[0] -# m = len(u) - 1 -# n = len(v) - 1 -# scale = 1. / v[0] -# q = zeros((max(m - n + 1, 1),), w.dtype) -# r = u.astype(w.dtype) -# for k in range(0, m - n + 1): -# d = scale * r[k] -# q[k] = d -# r[k:k + n + 1] -= d * v -# while allclose(r[0], 0, rtol=1e-14) and (r.shape[-1] > 1): -# r = r[1:] -# return ArrayType(q), ArrayType(r) - - -@wraps(np.put) +polydiv = _compatible_with_brainpy_array(jnp.polydiv) + + def put(a, ind, v): if not isinstance(a, Array): - raise ValueError(f'Must be an instance of {Array.__name__}, but we got {type(a)}') + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') a[ind] = v -@wraps(np.putmask) def putmask(a, mask, values): if not isinstance(a, Array): - raise ValueError(f'Must be an instance of {Array.__name__}, but we got {type(a)}') + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') if a.shape != values.shape: raise ValueError('Only support the shapes of "a" and "values" are consistent.') a[mask] = values -@wraps(np.safe_eval) def safe_eval(source): return tree_map(Array, np.safe_eval(source)) -@wraps(np.savetxt) def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', footer='', comments='# ', encoding=None): X = as_numpy(X) @@ -2371,10 +566,10 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', footer=footer, comments=comments, encoding=encoding) -@wraps(np.savez_compressed) def savez_compressed(file, *args, **kwds): - args = tuple([as_numpy(a) for a in args]) - kwds = {k: as_numpy(v) for k, v in kwds.items()} + args = tuple([(as_numpy(a) if isinstance(a, (jnp.ndarray, Array)) else a) for a in args]) + kwds = {k: (as_numpy(v) if isinstance(v, (jnp.ndarray, Array)) else v) + for k, v in kwds.items()} np.savez_compressed(file, *args, **kwds) @@ -2382,14 +577,12 @@ def savez_compressed(file, *args, **kwds): typename = np.typename -@wraps(np.copyto) def copyto(dst, src): if not isinstance(dst, Array): raise ValueError('dst must be an instance of ArrayType.') dst[:] = src -@wraps(np.matrix) def matrix(data, dtype=None): data = array(data, copy=True, dtype=dtype) if data.ndim > 2: @@ -2400,7 +593,6 @@ def matrix(data, dtype=None): return data -@wraps(np.asmatrix) def asmatrix(data, dtype=None): data = array(data, dtype=dtype) if data.ndim > 2: @@ -2411,6 +603,5 @@ def asmatrix(data, dtype=None): return data -@wraps(np.mat) def mat(data, dtype=None): return asmatrix(data, dtype=dtype) diff --git a/brainpy/_src/math/arrayinterporate.py b/brainpy/_src/math/arrayinterporate.py index 74146309d..257ed9b8d 100644 --- a/brainpy/_src/math/arrayinterporate.py +++ b/brainpy/_src/math/arrayinterporate.py @@ -14,7 +14,6 @@ def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj - def as_device_array(tensor, dtype=None): """Convert the input to a ``jax.numpy.DeviceArray``. diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 6486cf27a..2c1ca960a 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -2,6 +2,7 @@ import warnings +import operator from typing import Optional, Tuple as TupleType import numpy as np @@ -937,6 +938,33 @@ def as_variable(self): """As an instance of Variable.""" return Variable(self) + def __format__(self, specification): + return self.value.__format__(specification) + + def __float__(self): + return self.value.__float__() + + def __int__(self): + return self.value.__int__() + + def __complex__(self): + return self.value.__complex__() + + def __hex__(self): + assert self.ndim == 0, 'hex only works on scalar values' + return hex(self._value) # type: ignore + + def __oct__(self): + assert self.ndim == 0, 'oct only works on scalar values' + return oct(self._value) # type: ignore + + def __index__(self): + return operator.index(self._value) + + def __dlpack__(self): + from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top + return to_dlpack(self.value) + JaxArray = Array ndarray = Array diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 8a976d643..9608ec951 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -14,7 +14,7 @@ from jax.tree_util import register_pytree_node from brainpy.check import jit_error_checking -from ._utils import wraps +from ._utils import _return from .ndarray import Array, Variable __all__ = [ @@ -517,7 +517,8 @@ def split_keys(self, n): def rand(self, *dn, key=None): key = self.split_key() if key is None else _formalize_key(key) - return jr.uniform(key, shape=dn, minval=0., maxval=1.) + r = jr.uniform(key, shape=dn, minval=0., maxval=1.) + return _return(r) def randint(self, low, high=None, size=None, dtype=jnp.int_, key=None): low = _as_jax_array(low) @@ -531,9 +532,10 @@ def randint(self, low, high=None, size=None, dtype=jnp.int_, key=None): size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) key = self.split_key() if key is None else _formalize_key(key) - return jr.randint(key, - shape=_size2shape(size), - minval=low, maxval=high, dtype=dtype) + r = jr.randint(key, + shape=_size2shape(size), + minval=low, maxval=high, dtype=dtype) + return _return(r) def random_integers(self, low, high=None, size=None, key=None): low = _as_jax_array(low) @@ -547,27 +549,33 @@ def random_integers(self, low, high=None, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) key = self.split_key() if key is None else _formalize_key(key) - return jr.randint(key, - shape=_size2shape(size), - minval=low, - maxval=high) + r = jr.randint(key, + shape=_size2shape(size), + minval=low, + maxval=high) + return _return(r) def randn(self, *dn, key=None): key = self.split_key() if key is None else _formalize_key(key) - return jr.normal(key, shape=dn) + r = jr.normal(key, shape=dn) + return _return(r) def random(self, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) - return jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.) + r = jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.) + return _return(r) def random_sample(self, size=None, key=None): - return self.random(size=size, key=key) + r = self.random(size=size, key=key) + return _return(r) def ranf(self, size=None, key=None): - return self.random(size=size, key=key) + r = self.random(size=size, key=key) + return _return(r) def sample(self, size=None, key=None): - return self.random(size=size, key=key) + r = self.random(size=size, key=key) + return _return(r) def choice(self, a, size=None, replace=True, p=None, key=None): a = _as_jax_array(a) @@ -575,14 +583,15 @@ def choice(self, a, size=None, replace=True, p=None, key=None): a = _check_py_seq(a) p = _check_py_seq(p) key = self.split_key() if key is None else _formalize_key(key) - return jr.choice(key, a=a, shape=_size2shape(size), - replace=replace, p=p) + r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p) + return _return(r) def permutation(self, x, axis: int = 0, independent: bool = False, key=None): x = x.value if isinstance(x, Array) else x x = _check_py_seq(x) key = self.split_key() if key is None else _formalize_key(key) - return jr.permutation(key, x, axis=axis, independent=independent) + r = jr.permutation(key, x, axis=axis, independent=independent) + return _return(r) def shuffle(self, x, axis=0, key=None): if not isinstance(x, Array): @@ -599,7 +608,8 @@ def beta(self, a, b, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b)) key = self.split_key() if key is None else _formalize_key(key) - return jr.beta(key, a=a, b=b, shape=_size2shape(size)) + r = jr.beta(key, a=a, b=b, shape=_size2shape(size)) + return _return(r) def exponential(self, scale=None, size=None, key=None): scale = _as_jax_array(scale) @@ -608,10 +618,9 @@ def exponential(self, scale=None, size=None, key=None): size = jnp.shape(scale) key = self.split_key() if key is None else _formalize_key(key) r = jr.exponential(key, shape=_size2shape(size)) - if scale is None: - return r - else: - return r / scale + if scale is not None: + r = r / scale + return _return(r) def gamma(self, shape, scale=None, size=None, key=None): shape = _as_jax_array(shape) @@ -622,10 +631,9 @@ def gamma(self, shape, scale=None, size=None, key=None): size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale)) key = self.split_key() if key is None else _formalize_key(key) r = jr.gamma(key, a=shape, shape=_size2shape(size)) - if scale is None: - return r - else: - return r * scale + if scale is not None: + r = r * scale + return _return(r) def gumbel(self, loc=None, scale=None, size=None, key=None): loc = _as_jax_array(loc) @@ -635,7 +643,8 @@ def gumbel(self, loc=None, scale=None, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) key = self.split_key() if key is None else _formalize_key(key) - return _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size))) + r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size))) + return _return(r) def laplace(self, loc=None, scale=None, size=None, key=None): loc = _as_jax_array(loc) @@ -645,7 +654,8 @@ def laplace(self, loc=None, scale=None, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) key = self.split_key() if key is None else _formalize_key(key) - return _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size))) + r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size))) + return _return(r) def logistic(self, loc=None, scale=None, size=None, key=None): loc = _as_jax_array(loc) @@ -655,7 +665,8 @@ def logistic(self, loc=None, scale=None, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) key = self.split_key() if key is None else _formalize_key(key) - return _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size))) + r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size))) + return _return(r) def normal(self, loc=None, scale=None, size=None, key=None): loc = _as_jax_array(loc) @@ -665,7 +676,8 @@ def normal(self, loc=None, scale=None, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc)) key = self.split_key() if key is None else _formalize_key(key) - return _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size))) + r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size))) + return _return(r) def pareto(self, a, size=None, key=None): a = _as_jax_array(a) @@ -673,22 +685,26 @@ def pareto(self, a, size=None, key=None): if size is None: size = jnp.shape(a) key = self.split_key() if key is None else _formalize_key(key) - return jr.pareto(key, b=a, shape=_size2shape(size)) + r = jr.pareto(key, b=a, shape=_size2shape(size)) + return _return(r) def poisson(self, lam=1.0, size=None, key=None): lam = _check_py_seq(_as_jax_array(lam)) if size is None: size = jnp.shape(lam) key = self.split_key() if key is None else _formalize_key(key) - return jr.poisson(key, lam=lam, shape=_size2shape(size)) + r = jr.poisson(key, lam=lam, shape=_size2shape(size)) + return _return(r) def standard_cauchy(self, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) - return jr.cauchy(key, shape=_size2shape(size)) + r = jr.cauchy(key, shape=_size2shape(size)) + return _return(r) def standard_exponential(self, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) - return jr.exponential(key, shape=_size2shape(size)) + r = jr.exponential(key, shape=_size2shape(size)) + return _return(r) def standard_gamma(self, shape, size=None, key=None): shape = _as_jax_array(shape) @@ -696,11 +712,13 @@ def standard_gamma(self, shape, size=None, key=None): if size is None: size = jnp.shape(shape) key = self.split_key() if key is None else _formalize_key(key) - return jr.gamma(key, a=shape, shape=_size2shape(size)) + r = jr.gamma(key, a=shape, shape=_size2shape(size)) + return _return(r) def standard_normal(self, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) - return jr.normal(key, shape=_size2shape(size)) + r = jr.normal(key, shape=_size2shape(size)) + return _return(r) def standard_t(self, df, size=None, key=None): df = _as_jax_array(df) @@ -708,7 +726,8 @@ def standard_t(self, df, size=None, key=None): if size is None: size = jnp.shape(size) key = self.split_key() if key is None else _formalize_key(key) - return jr.t(key, df=df, shape=_size2shape(size)) + r = jr.t(key, df=df, shape=_size2shape(size)) + return _return(r) def uniform(self, low=0.0, high=1.0, size=None, key=None): low = _as_jax_array(low) @@ -718,7 +737,8 @@ def uniform(self, low=0.0, high=1.0, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) key = self.split_key() if key is None else _formalize_key(key) - return jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high) + r = jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high) + return _return(r) def truncated_normal(self, lower, upper, size, scale=None, key=None): lower = _as_jax_array(lower) @@ -736,10 +756,9 @@ def truncated_normal(self, lower, upper, size, scale=None, key=None): lower=lower, upper=upper, shape=_size2shape(size)) - if scale is None: - return rands - else: - return rands * scale + if scale is not None: + rands = rands * scale + return _return(rands) def _check_p(self, p): raise ValueError(f'Parameter p should be within [0, 1], but we got {p}') @@ -750,7 +769,8 @@ def bernoulli(self, p, size=None, key=None): if size is None: size = jnp.shape(p) key = self.split_key() if key is None else _formalize_key(key) - return jr.bernoulli(key, p=p, shape=_size2shape(size)) + r = jr.bernoulli(key, p=p, shape=_size2shape(size)) + return _return(r) def lognormal(self, mean=None, sigma=None, size=None, key=None): mean = _check_py_seq(_as_jax_array(mean)) @@ -762,7 +782,7 @@ def lognormal(self, mean=None, sigma=None, size=None, key=None): samples = jr.normal(key, shape=_size2shape(size)) samples = _loc_scale(mean, sigma, samples) samples = jnp.exp(samples) - return samples + return _return(samples) def binomial(self, n, p, size=None, key=None): n = _check_py_seq(n.value if isinstance(n, Array) else n) @@ -771,7 +791,8 @@ def binomial(self, n, p, size=None, key=None): if size is None: size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)) key = self.split_key() if key is None else _formalize_key(key) - return _binomial(key, p, n, shape=_size2shape(size)) + r = _binomial(key, p, n, shape=_size2shape(size)) + return _return(r) def chisquare(self, df, size=None, key=None): df = _check_py_seq(_as_jax_array(df)) @@ -785,12 +806,13 @@ def chisquare(self, df, size=None, key=None): else: dist = jr.normal(key, (df,) + _size2shape(size)) ** 2 dist = dist.sum(axis=0) - return dist + return _return(dist) def dirichlet(self, alpha, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) alpha = _check_py_seq(_as_jax_array(alpha)) - return jr.dirichlet(key, alpha=alpha, shape=_size2shape(size)) + r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size)) + return _return(r) def geometric(self, p, size=None, key=None): p = _as_jax_array(p) @@ -800,7 +822,7 @@ def geometric(self, p, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) u = jr.uniform(key, size) r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p)) - return r + return _return(r) def _check_p2(self, p): raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}') @@ -815,7 +837,8 @@ def multinomial(self, n, pvals, size=None, key=None): size = _size2shape(size) n_max = int(np.max(jax.device_get(n))) batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n)) - return _multinomial(key, pvals, n, n_max, batch_shape + size) + r = _multinomial(key, pvals, n, n_max, batch_shape + size) + return _return(r) def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky', key=None): if method not in {'svd', 'eigh', 'cholesky'}: @@ -848,7 +871,7 @@ def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky', ke factor = jnp.linalg.cholesky(cov) normal_samples = jr.normal(key, size + mean.shape[-1:]) r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) - return r + return _return(r) def rayleigh(self, scale=1.0, size=None, key=None): scale = _check_py_seq(_as_jax_array(scale)) @@ -856,12 +879,14 @@ def rayleigh(self, scale=1.0, size=None, key=None): size = jnp.shape(scale) key = self.split_key() if key is None else _formalize_key(key) x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), minval=0, maxval=1))) - return x * scale + r = x * scale + return _return(r) def triangular(self, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size)) - return 2 * bernoulli_samples - 1 + r = 2 * bernoulli_samples - 1 + return _return(r) def vonmises(self, mu, kappa, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) @@ -873,7 +898,7 @@ def vonmises(self, mu, kappa, size=None, key=None): samples = _von_mises_centered(key, kappa, size) samples = samples + mu samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi - return samples + return _return(samples) def weibull(self, a, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) @@ -886,7 +911,7 @@ def weibull(self, a, size=None, key=None): size = _size2shape(size) random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1) r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) - return r + return _return(r) def weibull_min(self, a, scale=None, size=None, key=None): """Sample from a Weibull minimum distribution. @@ -918,13 +943,14 @@ def weibull_min(self, a, scale=None, size=None, key=None): r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) if scale is not None: r /= scale - return r + return _return(r) def maxwell(self, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) shape = core.canonicalize_shape(_size2shape(size)) + (3,) norm_rvs = jr.normal(key=key, shape=shape) - return jnp.linalg.norm(norm_rvs, axis=-1) + r = jnp.linalg.norm(norm_rvs, axis=-1) + return _return(r) def negative_binomial(self, n, p, size=None, key=None): n = _check_py_seq(_as_jax_array(n)) @@ -938,7 +964,8 @@ def negative_binomial(self, n, p, size=None, key=None): else: keys = jr.split(_formalize_key(key), 2) rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0]) - return self.poisson(lam=rate, key=keys[1]) + r = self.poisson(lam=rate, key=keys[1]) + return _return(r) def wald(self, mean, scale, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) @@ -978,7 +1005,7 @@ def wald(self, mean, scale, size=None, key=None): res = jnp.where(sampled_uniform <= mean / (mean + sampled), sampled, jnp.square(mean) / sampled) - return res + return _return(res) def t(self, df, size=None, key=None): df = _check_py_seq(_as_jax_array(df)) @@ -995,7 +1022,8 @@ def t(self, df, size=None, key=None): two = _const(n, 2) half_df = lax.div(df, two) g = jr.gamma(keys[1], half_df, size) - return n * jnp.sqrt(half_df / g) + r = n * jnp.sqrt(half_df / g) + return _return(r) def orthogonal(self, n: int, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) @@ -1005,7 +1033,8 @@ def orthogonal(self, n: int, size=None, key=None): z = jr.normal(key, size + (n, n)) q, r = jnp.linalg.qr(z) d = jnp.diagonal(r, 0, -2, -1) - return q * jnp.expand_dims(d / abs(d), -2) + r = q * jnp.expand_dims(d / abs(d), -2) + return _return(r) def noncentral_chisquare(self, df, nonc, size=None, key=None): df = _check_py_seq(_as_jax_array(df)) @@ -1022,14 +1051,16 @@ def noncentral_chisquare(self, df, nonc, size=None, key=None): cond = jnp.greater(df, 1.0) df2 = jnp.where(cond, df - 1.0, df + 2.0 * i) chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size) - return jnp.where(cond, chi2 + n * n, chi2) + r = jnp.where(cond, chi2 + n * n, chi2) + return _return(r) def loggamma(self, a, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) a = _check_py_seq(_as_jax_array(a)) if size is None: size = jnp.shape(a) - return jr.loggamma(key, a, shape=_size2shape(size)) + r = jr.loggamma(key, a, shape=_size2shape(size)) + return _return(r) def categorical(self, logits, axis: int = -1, size=None, key=None): key = self.split_key() if key is None else _formalize_key(key) @@ -1037,23 +1068,26 @@ def categorical(self, logits, axis: int = -1, size=None, key=None): if size is None: size = list(jnp.shape(logits)) size.pop(axis) - return jr.categorical(key, logits, axis=axis, shape=_size2shape(size)) + r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size)) + return _return(r) def zipf(self, a, size=None, key=None): a = _check_py_seq(_as_jax_array(a)) if size is None: size = jnp.shape(a) - return call(lambda x: np.random.zipf(x, size), - a, - result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + r = call(lambda x: np.random.zipf(x, size), + a, + result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + return _return(r) def power(self, a, size=None, key=None): a = _check_py_seq(_as_jax_array(a)) if size is None: size = jnp.shape(a) size = _size2shape(size) - return call(lambda a: np.random.power(a=a, size=size), - a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + r = call(lambda a: np.random.power(a=a, size=size), + a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + return _return(r) def f(self, dfnum, dfden, size=None, key=None): dfnum = _as_jax_array(dfnum) @@ -1064,11 +1098,12 @@ def f(self, dfnum, dfden, size=None, key=None): size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden} - return call(lambda x: np.random.f(dfnum=x['dfnum'], - dfden=x['dfden'], - size=size), - d, - result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + r = call(lambda x: np.random.f(dfnum=x['dfnum'], + dfden=x['dfden'], + size=size), + d, + result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + return _return(r) def hypergeometric(self, ngood, nbad, nsample, size=None, key=None): ngood = _check_py_seq(_as_jax_array(ngood)) @@ -1081,19 +1116,21 @@ def hypergeometric(self, ngood, nbad, nsample, size=None, key=None): jnp.shape(nsample)) size = _size2shape(size) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - return call(lambda x: np.random.hypergeometric(ngood=x['ngood'], - nbad=x['nbad'], - nsample=x['nsample'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + r = call(lambda x: np.random.hypergeometric(ngood=x['ngood'], + nbad=x['nbad'], + nsample=x['nsample'], + size=size), + d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + return _return(r) def logseries(self, p, size=None, key=None): p = _check_py_seq(_as_jax_array(p)) if size is None: size = jnp.shape(p) size = _size2shape(size) - return call(lambda p: np.random.logseries(p=p, size=size), - p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + r = call(lambda p: np.random.logseries(p=p, size=size), + p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + return _return(r) def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None): dfnum = _check_py_seq(_as_jax_array(dfnum)) @@ -1105,11 +1142,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None): jnp.shape(nonc)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} - return call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], - dfden=x['dfden'], - nonc=x['nonc'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], + dfden=x['dfden'], + nonc=x['nonc'], + size=size), + d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + return _return(r) # alias diff --git a/brainpy/math/arraycompatible.py b/brainpy/math/arraycompatible.py new file mode 100644 index 000000000..4326672a8 --- /dev/null +++ b/brainpy/math/arraycompatible.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- + + +from brainpy._src.math.arraycompatible import ( + full as full, + full_like as full_like, + eye as eye, + identity as identity, + diag as diag, + tri as tri, + tril as tril, + triu as triu, + real as real, + imag as imag, + conj as conj, + conjugate as conjugate, + ndim as ndim, + isreal as isreal, + isscalar as isscalar, + add as add, + reciprocal as reciprocal, + negative as negative, + positive as positive, + multiply as multiply, + divide as divide, + power as power, + subtract as subtract, + true_divide as true_divide, + floor_divide as floor_divide, + float_power as float_power, + fmod as fmod, + mod as mod, + modf as modf, + divmod as divmod, + remainder as remainder, + abs as abs, + exp as exp, + exp2 as exp2, + expm1 as expm1, + log as log, + log10 as log10, + log1p as log1p, + log2 as log2, + logaddexp as logaddexp, + logaddexp2 as logaddexp2, + lcm as lcm, + gcd as gcd, + arccos as arccos, + arccosh as arccosh, + arcsin as arcsin, + arcsinh as arcsinh, + arctan as arctan, + arctan2 as arctan2, + arctanh as arctanh, + cos as cos, + cosh as cosh, + sin as sin, + sinc as sinc, + sinh as sinh, + tan as tan, + tanh as tanh, + deg2rad as deg2rad, + hypot as hypot, + rad2deg as rad2deg, + degrees as degrees, + radians as radians, + round as round, + around as around, + round_ as round_, + rint as rint, + floor as floor, + ceil as ceil, + trunc as trunc, + fix as fix, + prod as prod, + sum as sum, + diff as diff, + median as median, + nancumprod as nancumprod, + nancumsum as nancumsum, + nanprod as nanprod, + nansum as nansum, + cumprod as cumprod, + cumsum as cumsum, + ediff1d as ediff1d, + cross as cross, + trapz as trapz, + isfinite as isfinite, + isinf as isinf, + isnan as isnan, + signbit as signbit, + copysign as copysign, + nextafter as nextafter, + ldexp as ldexp, + frexp as frexp, + convolve as convolve, + sqrt as sqrt, + cbrt as cbrt, + square as square, + absolute as absolute, + fabs as fabs, + sign as sign, + heaviside as heaviside, + maximum as maximum, + minimum as minimum, + fmax as fmax, + fmin as fmin, + interp as interp, + clip as clip, + angle as angle, + bitwise_and as bitwise_and, + bitwise_not as bitwise_not, + bitwise_or as bitwise_or, + bitwise_xor as bitwise_xor, + invert as invert, + left_shift as left_shift, + right_shift as right_shift, + equal as equal, + not_equal as not_equal, + greater as greater, + greater_equal as greater_equal, + less as less, + less_equal as less_equal, + array_equal as array_equal, + isclose as isclose, + allclose as allclose, + logical_and as logical_and, + logical_not as logical_not, + logical_or as logical_or, + logical_xor as logical_xor, + all as all, + any as any, + alltrue as alltrue, + sometrue as sometrue, + shape as shape, + size as size, + reshape as reshape, + ravel as ravel, + moveaxis as moveaxis, + transpose as transpose, + swapaxes as swapaxes, + concatenate as concatenate, + stack as stack, + vstack as vstack, + hstack as hstack, + dstack as dstack, + column_stack as column_stack, + split as split, + dsplit as dsplit, + hsplit as hsplit, + vsplit as vsplit, + tile as tile, + repeat as repeat, + unique as unique, + append as append, + flip as flip, + fliplr as fliplr, + flipud as flipud, + roll as roll, + atleast_1d as atleast_1d, + atleast_2d as atleast_2d, + atleast_3d as atleast_3d, + expand_dims as expand_dims, + squeeze as squeeze, + sort as sort, + argsort as argsort, + argmax as argmax, + argmin as argmin, + argwhere as argwhere, + nonzero as nonzero, + flatnonzero as flatnonzero, + where as where, + searchsorted as searchsorted, + extract as extract, + count_nonzero as count_nonzero, + max as max, + min as min, + amax as amax, + amin as amin, + array_split as array_split, + meshgrid as meshgrid, + vander as vander, + nonzero as nonzero, + where as where, + tril_indices as tril_indices, + tril_indices_from as tril_indices_from, + triu_indices as triu_indices, + triu_indices_from as triu_indices_from, + take as take, + select as select, + nanmin as nanmin, + nanmax as nanmax, + ptp as ptp, + percentile as percentile, + nanpercentile as nanpercentile, + quantile as quantile, + nanquantile as nanquantile, + median as median, + average as average, + mean as mean, + std as std, + var as var, + nanmedian as nanmedian, + nanmean as nanmean, + nanstd as nanstd, + nanvar as nanvar, + corrcoef as corrcoef, + correlate as correlate, + cov as cov, + histogram as histogram, + bincount as bincount, + digitize as digitize, + bartlett as bartlett, + blackman as blackman, + hamming as hamming, + hanning as hanning, + kaiser as kaiser, + e as e, + pi as pi, + inf as inf, + dot as dot, + vdot as vdot, + inner as inner, + outer as outer, + kron as kron, + matmul as matmul, + trace as trace, + dtype as dtype, + finfo as finfo, + iinfo as iinfo, + uint8 as uint8, + uint16 as uint16, + uint32 as uint32, + uint64 as uint64, + int8 as int8, + int16 as int16, + int32 as int32, + int64 as int64, + float16 as float16, + float32 as float32, + float64 as float64, + complex64 as complex64, + complex128 as complex128, + product as product, + row_stack as row_stack, + apply_over_axes as apply_over_axes, + apply_along_axis as apply_along_axis, + array_equiv as array_equiv, + array_repr as array_repr, + array_str as array_str, + block as block, + broadcast_arrays as broadcast_arrays, + broadcast_shapes as broadcast_shapes, + broadcast_to as broadcast_to, + compress as compress, + cumproduct as cumproduct, + diag_indices as diag_indices, + diag_indices_from as diag_indices_from, + diagflat as diagflat, + diagonal as diagonal, + einsum as einsum, + einsum_path as einsum_path, + geomspace as geomspace, + gradient as gradient, + histogram2d as histogram2d, + histogram_bin_edges as histogram_bin_edges, + histogramdd as histogramdd, + i0 as i0, + in1d as in1d, + indices as indices, + insert as insert, + intersect1d as intersect1d, + iscomplex as iscomplex, + isin as isin, + ix_ as ix_, + lexsort as lexsort, + load as load, + save as save, + savez as savez, + mask_indices as mask_indices, + msort as msort, + nan_to_num as nan_to_num, + nanargmax as nanargmax, + setdiff1d as setdiff1d, + nanargmin as nanargmin, + pad as pad, + poly as poly, + polyadd as polyadd, + polyder as polyder, + polyfit as polyfit, + polyint as polyint, + polymul as polymul, + polysub as polysub, + polyval as polyval, + resize as resize, + rollaxis as rollaxis, + roots as roots, + rot90 as rot90, + setxor1d as setxor1d, + tensordot as tensordot, + trim_zeros as trim_zeros, + union1d as union1d, + unravel_index as unravel_index, + unwrap as unwrap, + take_along_axis as take_along_axis, + can_cast as can_cast, + choose as choose, + copy as copy, + frombuffer as frombuffer, + fromfile as fromfile, + fromfunction as fromfunction, + fromiter as fromiter, + fromstring as fromstring, + get_printoptions as get_printoptions, + iscomplexobj as iscomplexobj, + isneginf as isneginf, + isposinf as isposinf, + isrealobj as isrealobj, + issubdtype as issubdtype, + issubsctype as issubsctype, + iterable as iterable, + packbits as packbits, + piecewise as piecewise, + printoptions as printoptions, + set_printoptions as set_printoptions, + promote_types as promote_types, + ravel_multi_index as ravel_multi_index, + result_type as result_type, + sort_complex as sort_complex, + unpackbits as unpackbits, + delete as delete, + add_docstring as add_docstring, + add_newdoc as add_newdoc, + add_newdoc_ufunc as add_newdoc_ufunc, + array2string as array2string, + asanyarray as asanyarray, + ascontiguousarray as ascontiguousarray, + asfarray as asfarray, + asscalar as asscalar, + common_type as common_type, + disp as disp, + genfromtxt as genfromtxt, + loadtxt as loadtxt, + info as info, + issubclass_ as issubclass_, + place as place, + polydiv as polydiv, + put as put, + putmask as putmask, + safe_eval as safe_eval, + savetxt as savetxt, + savez_compressed as savez_compressed, + show_config as show_config, + typename as typename, + copyto as copyto, + matrix as matrix, + asmatrix as asmatrix, + mat as mat, +) diff --git a/brainpy/math/fft.py b/brainpy/math/fft.py index 633f86615..18f79fc19 100644 --- a/brainpy/math/fft.py +++ b/brainpy/math/fft.py @@ -1,2 +1,23 @@ # -*- coding: utf-8 -*- +from brainpy._src.math.fft import ( + fft as fft, + fft2 as fft2, + fftfreq as fftfreq, + fftn as fftn, + fftshift as fftshift, + hfft as hfft, + ifft as ifft, + ifft2 as ifft2, + ifftn as ifftn, + ifftshift as ifftshift, + ihfft as ihfft, + irfft as irfft, + irfft2 as irfft2, + irfftn as irfftn, + rfft as rfft, + rfft2 as rfft2, + rfftfreq as rfftfreq, + rfftn as rfftn, +) + diff --git a/brainpy/math/linalg.py b/brainpy/math/linalg.py index 633f86615..7cdf5dcc6 100644 --- a/brainpy/math/linalg.py +++ b/brainpy/math/linalg.py @@ -1,2 +1,24 @@ # -*- coding: utf-8 -*- +from brainpy._src.math.linalg import ( + cholesky as cholesky, + cond as cond, + det as det, + eig as eig, + eigh as eigh, + eigvals as eigvals, + eigvalsh as eigvalsh, + inv as inv, + svd as svd, + lstsq as lstsq, + matrix_power as matrix_power, + matrix_rank as matrix_rank, + norm as norm, + pinv as pinv, + qr as qr, + solve as solve, + slogdet as slogdet, + tensorinv as tensorinv, + tensorsolve as tensorsolve, + multi_dot as multi_dot, +) diff --git a/brainpy/math/others.py b/brainpy/math/others.py index 7d0f262ee..93e2f1a46 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -3,3 +3,6 @@ from brainpy._src.math.others import ( shared_args_over_time as shared_args_over_time, ) +from brainpy._src.math._utils import ( + npfun_returns_bparray as npfun_returns_bparray +) From c6df4c32f9c2fee435039010f15b0d8fd2b6d6d8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:10:01 +0800 Subject: [PATCH 07/13] numpy functions in `brainpy.math` return JAX Array or brainpy Array by `npfun_returns_bparray()` --- brainpy/_src/math/_utils.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/brainpy/_src/math/_utils.py b/brainpy/_src/math/_utils.py index 0d1cbbce8..d6d8e22d1 100644 --- a/brainpy/_src/math/_utils.py +++ b/brainpy/_src/math/_utils.py @@ -9,6 +9,27 @@ from .ndarray import Array +__all__ = [ + 'npfun_returns_bparray' +] + + +def _as_jax_array_(obj): + return obj.value if isinstance(obj, Array) else obj + + +def _return(x): + return Array(x) if _return_bp_array else x + + +_return_bp_array = True + + +def npfun_returns_bparray(mode: bool): + global _return_bp_array + assert isinstance(mode, bool) + _return_bp_array = mode + def wraps(fun: Callable): """Specialized version of functools.wraps for wrapping numpy functions. @@ -29,10 +50,6 @@ def wrap(op): return wrap -def _as_jax_array(a): - return a.value if isinstance(a, Array) else a - - def _as_brainpy_array(a): return Array(a) if isinstance(a, (np.ndarray, jax.Array)) else a @@ -41,14 +58,14 @@ def _is_leaf(a): return isinstance(a, Array) -def _compatible_with_brainpy_array(fun: Callable, return_brainpy_array: bool = False): +def _compatible_with_brainpy_array(fun: Callable): @functools.wraps(fun) def new_fun(*args, **kwargs): - args = tree_map(_as_jax_array, args, is_leaf=_is_leaf) + args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) if len(kwargs): - kwargs = tree_map(_as_jax_array, kwargs, is_leaf=_is_leaf) + kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) r = fun(*args, **kwargs) - return tree_map(_as_brainpy_array, r) if return_brainpy_array else r + return tree_map(_as_brainpy_array, r) if _return_bp_array else r new_fun.__doc__ = getattr(fun, "__doc__", None) From 061de39067044f7bc17c65a92ad57a334898edf6 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:10:17 +0800 Subject: [PATCH 08/13] fix `DiffusiveCoupling` bug --- brainpy/_src/dyn/synapses/delay_couplings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/dyn/synapses/delay_couplings.py b/brainpy/_src/dyn/synapses/delay_couplings.py index ef3d5f5ce..2dad86758 100644 --- a/brainpy/_src/dyn/synapses/delay_couplings.py +++ b/brainpy/_src/dyn/synapses/delay_couplings.py @@ -206,7 +206,7 @@ def update(self, tdi): indices = (jnp.arange(self.coupling_var1.size),) f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (..., pre.num) delays = f(self.delay_steps) # (..., post.num, pre.num) - diffusive = (jnp.moveaxis(delays.value, axis - 1, axis) - + diffusive = (jnp.moveaxis(bm.as_jax(delays), axis - 1, axis) - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) elif self.delay_type == 'int': From adfb0b1ff8e0cf0763ceb040d8b30bc20d051959 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:11:00 +0800 Subject: [PATCH 09/13] support `static_argnums` in `brainpy.math.jit` --- brainpy/_src/math/object_transform/jit.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index e4b64741f..4b9f8b8d3 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -7,20 +7,15 @@ """ -from typing import Callable, Union, Optional, Sequence, Dict, Any +from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax - -try: - from jax.errors import UnexpectedTracerError, ConcretizationTypeError -except ImportError: - from jax.core import UnexpectedTracerError, ConcretizationTypeError +from jax.errors import UnexpectedTracerError, ConcretizationTypeError from brainpy import errors, tools, check from brainpy._src.math.ndarray import Variable, add_context, del_context from .abstract import ObjectTransform from .base import BrainPyObject -from ._utils import infer_dyn_vars __all__ = [ 'jit', @@ -35,7 +30,8 @@ def __init__( target: callable, dyn_vars: Dict[str, Variable], child_objs: Dict[str, BrainPyObject], - static_argnames: Optional[Any] = None, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, device: Optional[Any] = None, name: Optional[str] = None, inline: bool = False, @@ -46,12 +42,14 @@ def __init__( self.register_implicit_vars(dyn_vars) self.register_implicit_nodes(child_objs) - + if hasattr(target, '__self__') and isinstance(getattr(target, '__self__'), BrainPyObject): + self.register_implicit_nodes(getattr(target, '__self__')) self.target = target self._all_vars = self.vars().unique() # transformation self._f = jax.jit(self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), static_argnames=static_argnames, device=device, inline=inline, @@ -100,7 +98,8 @@ def jit( func: Callable, dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, - static_argnames: Optional[Union[str, Any]] = None, + static_argnums: Union[int, Iterable[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, device: Optional[Any] = None, inline: bool = False, keep_unused: bool = False, @@ -230,6 +229,7 @@ def jit( return JITTransform(target=func, dyn_vars=dyn_vars, child_objs=child_objs, + static_argnums=static_argnums, static_argnames=static_argnames, device=device, inline=inline, From d36273d6f244f54bd69435bf9fbe7eb92934f81e Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:11:37 +0800 Subject: [PATCH 10/13] update channel models --- brainpy/_src/dyn/channels/Ca.py | 126 +++++++++++++++---------------- brainpy/_src/dyn/channels/K.py | 49 ++++++------ brainpy/_src/dyn/channels/KCa.py | 4 +- brainpy/_src/dyn/channels/Na.py | 17 ++--- brainpy/_src/dyn/runners.py | 4 +- 5 files changed, 97 insertions(+), 103 deletions(-) diff --git a/brainpy/_src/dyn/channels/Ca.py b/brainpy/_src/dyn/channels/Ca.py index 3583ae134..f0d14145c 100644 --- a/brainpy/_src/dyn/channels/Ca.py +++ b/brainpy/_src/dyn/channels/Ca.py @@ -7,8 +7,6 @@ from typing import Union, Callable -import jax.numpy as jnp - import brainpy.math as bm from brainpy._src.dyn.base import Channel from brainpy._src.initialize import OneInit, Initializer, parameter, variable @@ -23,7 +21,7 @@ 'CalciumDetailed', 'CalciumFirstOrder', - 'ICa_p2q_ss', 'ICa_p2q_markov', + '_ICa_p2q_ss', '_ICa_p2q_markov', 'ICaN_IS2008', @@ -141,12 +139,12 @@ def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): def update(self, tdi, V): for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values(): - node.update(tdi, V, self.C, self.E) + node.update(tdi, V, self.C.value, self.E.value) self.C.value = self.integral(self.C.value, tdi['t'], V, tdi['dt']) - self.E.value = self._reversal_potential(self.C) + self.E.value = self._reversal_potential(self.C.value) def _reversal_potential(self, C): - return self._constant * jnp.log(self.C0 / C) + return self._constant * bm.log(self.C0 / C) class CalciumDetailed(CalciumDyna): @@ -292,7 +290,7 @@ def __init__( def derivative(self, C, t, V): ICa = self.current(V, C, self.E) - drive = jnp.maximum(- ICa / (2 * self.F * self.d), 0.) + drive = bm.maximum(- ICa / (2 * self.F * self.d), 0.) return drive + (self.C_rest - C) / self.tau @@ -335,14 +333,14 @@ def __init__( def derivative(self, C, t, V): ICa = self.current(V, C, self.E) - drive = jnp.maximum(- self.alpha * ICa, 0.) + drive = bm.maximum(- self.alpha * ICa, 0.) return drive - self.beta * C # ------------------------- -class ICa_p2q_ss(CalciumChannel): +class _ICa_p2q_ss(CalciumChannel): r"""The calcium current model of :math:`p^2q` current which described with steady-state format. The dynamics of this generalized calcium current model is given by: @@ -386,10 +384,10 @@ def __init__( mode: bm.Mode = None, name: str = None ): - super(ICa_p2q_ss, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode, ) + super(_ICa_p2q_ss, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode, ) # parameters self.phi_p = parameter(phi_p, self.varshape, allow_none=False) @@ -397,8 +395,8 @@ def __init__( self.g_max = parameter(g_max, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) - self.q = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) # functions self.integral = odeint(JointEq([self.dp, self.dq]), method=method) @@ -435,7 +433,7 @@ def f_q_tau(self, V): raise NotImplementedError -class ICa_p2q_markov(CalciumChannel): +class _ICa_p2q_markov(CalciumChannel): r"""The calcium current model of :math:`p^2q` current which described with first-order Markov chain. The dynamics of this generalized calcium current model is given by: @@ -479,10 +477,10 @@ def __init__( name: str = None, mode: bm.Mode = None, ): - super(ICa_p2q_markov, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode) + super(_ICa_p2q_markov, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode) # parameters self.phi_p = parameter(phi_p, self.varshape, allow_none=False) @@ -490,8 +488,8 @@ def __init__( self.g_max = parameter(g_max, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) - self.q = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) # functions self.integral = odeint(JointEq([self.dp, self.dq]), method=method) @@ -592,18 +590,18 @@ def __init__( self.phi = parameter(phi, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(self.derivative, method=method) def derivative(self, p, t, V): - phi_p = 1.0 / (1 + jnp.exp(-(V + 43.) / 5.2)) - p_inf = 2.7 / (jnp.exp(-(V + 55.) / 15.) + jnp.exp((V + 55.) / 15.)) + 1.6 + phi_p = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) + p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6 return self.phi * (phi_p - p) / p_inf def update(self, tdi, V, C_Ca, E_Ca): - self.p.value = self.integral(self.p, tdi['t'], V, tdi['dt']) + self.p.value = self.integral(self.p.value, tdi['t'], V, tdi['dt']) def current(self, V, C_Ca, E_Ca): M = C_Ca / (C_Ca + 0.2) @@ -611,12 +609,12 @@ def current(self, V, C_Ca, E_Ca): return g * (self.E - V) def reset_state(self, V, C_Ca, E_Ca, batch_size=None): - self.p.value = 1.0 / (1 + jnp.exp(-(V + 43.) / 5.2)) + self.p.value = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) if batch_size is not None: assert self.p.shape[0] == batch_size -class ICaT_HM1992(ICa_p2q_ss): +class ICaT_HM1992(_ICa_p2q_ss): r"""The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. The dynamics of the low-threshold T-type calcium current model [1]_ is given by: @@ -697,22 +695,22 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): - return 1. / (1 + jnp.exp(-(V + 59. - self.V_sh) / 6.2)) + return 1. / (1 + bm.exp(-(V + 59. - self.V_sh) / 6.2)) def f_p_tau(self, V): - return 1. / (jnp.exp(-(V + 132. - self.V_sh) / 16.7) + - jnp.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 + return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + + bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 def f_q_inf(self, V): - return 1. / (1. + jnp.exp((V + 83. - self.V_sh) / 4.0)) + return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.0)) def f_q_tau(self, V): - return jnp.where(V >= (-80. + self.V_sh), - jnp.exp(-(V + 22. - self.V_sh) / 10.5) + 28., - jnp.exp((V + 467. - self.V_sh) / 66.6)) + return bm.where(V >= (-80. + self.V_sh), + bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28., + bm.exp((V + 467. - self.V_sh) / 66.6)) -class ICaT_HP1992(ICa_p2q_ss): +class ICaT_HP1992(_ICa_p2q_ss): r"""The low-threshold T-type calcium current model for thalamic reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_. @@ -795,21 +793,21 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): - return 1. / (1. + jnp.exp(-(V + 52. - self.V_sh) / 7.4)) + return 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4)) def f_p_tau(self, V): - return 3. + 1. / (jnp.exp((V + 27. - self.V_sh) / 10.) + - jnp.exp(-(V + 102. - self.V_sh) / 15.)) + return 3. + 1. / (bm.exp((V + 27. - self.V_sh) / 10.) + + bm.exp(-(V + 102. - self.V_sh) / 15.)) def f_q_inf(self, V): - return 1. / (1. + jnp.exp((V + 80. - self.V_sh) / 5.)) + return 1. / (1. + bm.exp((V + 80. - self.V_sh) / 5.)) def f_q_tau(self, V): - return 85. + 1. / (jnp.exp((V + 48. - self.V_sh) / 4.) + - jnp.exp(-(V + 407. - self.V_sh) / 50.)) + return 85. + 1. / (bm.exp((V + 48. - self.V_sh) / 4.) + + bm.exp(-(V + 407. - self.V_sh) / 50.)) -class ICaHT_HM1992(ICa_p2q_ss): +class ICaHT_HM1992(_ICa_p2q_ss): r"""The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. The high-threshold T-type calcium current model is adopted from [1]_. @@ -886,29 +884,29 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) - self.q = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(JointEq([self.dp, self.dq]), method=method) def f_p_inf(self, V): - return 1. / (1. + jnp.exp(-(V + 59. - self.V_sh) / 6.2)) + return 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2)) def f_p_tau(self, V): - return 1. / (jnp.exp(-(V + 132. - self.V_sh) / 16.7) + - jnp.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 + return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + + bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 def f_q_inf(self, V): - return 1. / (1. + jnp.exp((V + 83. - self.V_sh) / 4.)) + return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.)) def f_q_tau(self, V): - return jnp.where(V >= (-80. + self.V_sh), - jnp.exp(-(V + 22. - self.V_sh) / 10.5) + 28., - jnp.exp((V + 467. - self.V_sh) / 66.6)) + return bm.where(V >= (-80. + self.V_sh), + bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28., + bm.exp((V + 467. - self.V_sh) / 66.6)) -class ICaHT_Re1993(ICa_p2q_markov): +class ICaHT_Re1993(_ICa_p2q_markov): r"""The high-threshold T-type calcium current model proposed by (Reuveni, et al., 1993) [1]_. HVA Calcium current was described for neocortical neurons by Sayer et al. (1990). @@ -994,19 +992,19 @@ def __init__( def f_p_alpha(self, V): temp = -27 - V + self.V_sh - return 0.055 * temp / (jnp.exp(temp / 3.8) - 1) + return 0.055 * temp / (bm.exp(temp / 3.8) - 1) def f_p_beta(self, V): - return 0.94 * jnp.exp((-75. - V + self.V_sh) / 17.) + return 0.94 * bm.exp((-75. - V + self.V_sh) / 17.) def f_q_alpha(self, V): - return 0.000457 * jnp.exp((-13. - V + self.V_sh) / 50.) + return 0.000457 * bm.exp((-13. - V + self.V_sh) / 50.) def f_q_beta(self, V): - return 0.0065 / (jnp.exp((-15. - V + self.V_sh) / 28.) + 1.) + return 0.0065 / (bm.exp((-15. - V + self.V_sh) / 28.) + 1.) -class ICaL_IS2008(ICa_p2q_ss): +class ICaL_IS2008(_ICa_p2q_ss): r"""The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_. The L-type calcium channel model is adopted from (Inoue, et, al., 2008) [1]_. @@ -1080,15 +1078,15 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): - return 1. / (1 + jnp.exp(-(V + 10. - self.V_sh) / 4.)) + return 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.)) def f_p_tau(self, V): - return 0.4 + .7 / (jnp.exp(-(V + 5. - self.V_sh) / 15.) + - jnp.exp((V + 5. - self.V_sh) / 15.)) + return 0.4 + .7 / (bm.exp(-(V + 5. - self.V_sh) / 15.) + + bm.exp((V + 5. - self.V_sh) / 15.)) def f_q_inf(self, V): - return 1. / (1. + jnp.exp((V + 25. - self.V_sh) / 2.)) + return 1. / (1. + bm.exp((V + 25. - self.V_sh) / 2.)) def f_q_tau(self, V): - return 300. + 100. / (jnp.exp((V + 40 - self.V_sh) / 9.5) + - jnp.exp(-(V + 40 - self.V_sh) / 9.5)) + return 300. + 100. / (bm.exp((V + 40 - self.V_sh) / 9.5) + + bm.exp(-(V + 40 - self.V_sh) / 9.5)) diff --git a/brainpy/_src/dyn/channels/K.py b/brainpy/_src/dyn/channels/K.py index 25161912f..49e43695f 100644 --- a/brainpy/_src/dyn/channels/K.py +++ b/brainpy/_src/dyn/channels/K.py @@ -16,16 +16,13 @@ from .base import PotassiumChannel __all__ = [ - 'IK_p4_markov', 'IKDR_Ba2002', 'IK_TM1991', 'IK_HH1952', - 'IKA_p4q_ss', 'IKA1_HM1992', 'IKA2_HM1992', - 'IKK2_pq_ss', 'IKK2A_HM1992', 'IKK2B_HM1992', @@ -33,7 +30,7 @@ ] -class IK_p4_markov(PotassiumChannel): +class _IK_p4_markov(PotassiumChannel): r"""The delayed rectifier potassium channel of :math:`p^4` current which described with first-order Markov chain. @@ -79,10 +76,10 @@ def __init__( name: str = None, mode: bm.Mode = None, ): - super(IK_p4_markov, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode) + super(_IK_p4_markov, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode) self.E = parameter(E, self.varshape, allow_none=False) self.g_max = parameter(g_max, self.varshape, allow_none=False) @@ -98,7 +95,7 @@ def derivative(self, p, t, V): return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) def update(self, tdi, V): - self.p.value = self.integral(self.p, tdi['t'], V, tdi['dt']) + self.p.value = self.integral(self.p.value, tdi['t'], V, tdi['dt']) def current(self, V): return self.g_max * self.p ** 4 * (self.E - V) @@ -117,7 +114,7 @@ def f_p_beta(self, V): raise NotImplementedError -class IKDR_Ba2002(IK_p4_markov): +class IKDR_Ba2002(_IK_p4_markov): r"""The delayed rectifier potassium channel current. The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. @@ -201,7 +198,7 @@ def f_p_beta(self, V): return 0.5 * jnp.exp(-(V - self.V_sh - 10.) / 40.) -class IK_TM1991(IK_p4_markov): +class IK_TM1991(_IK_p4_markov): r"""The potassium channel described by (Traub and Miles, 1991) [1]_. The dynamics of this channel is given by: @@ -271,7 +268,7 @@ def f_p_beta(self, V): return 0.5 * jnp.exp((10 - V + self.V_sh) / 40) -class IK_HH1952(IK_p4_markov): +class IK_HH1952(_IK_p4_markov): r"""The potassium channel described by Hodgkin–Huxley model [1]_. The dynamics of this channel is given by: @@ -342,7 +339,7 @@ def f_p_beta(self, V): return 0.125 * jnp.exp(-(V - self.V_sh + 20) / 80) -class IKA_p4q_ss(PotassiumChannel): +class _IKA_p4q_ss(PotassiumChannel): r"""The rapidly inactivating Potassium channel of :math:`p^4q` current which described with steady-state format. @@ -396,10 +393,10 @@ def __init__( name: str = None, mode: bm.Mode = None, ): - super(IKA_p4q_ss, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode) + super(_IKA_p4q_ss, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode) # parameters self.E = parameter(E, self.varshape, allow_none=False) @@ -447,7 +444,7 @@ def f_q_tau(self, V): raise NotImplementedError -class IKA1_HM1992(IKA_p4q_ss): +class IKA1_HM1992(_IKA_p4q_ss): r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. This model is developed according to the average behavior of @@ -542,7 +539,7 @@ def f_q_tau(self, V): 19.) -class IKA2_HM1992(IKA_p4q_ss): +class IKA2_HM1992(_IKA_p4q_ss): r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. This model is developed according to the average behavior of @@ -637,7 +634,7 @@ def f_q_tau(self, V): 19.) -class IKK2_pq_ss(PotassiumChannel): +class _IKK2_pq_ss(PotassiumChannel): r"""The slowly inactivating Potassium channel of :math:`pq` current which described with steady-state format. @@ -691,10 +688,10 @@ def __init__( name: str = None, mode: bm.Mode = None, ): - super(IKK2_pq_ss, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode) + super(_IKK2_pq_ss, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode) # parameters self.E = parameter(E, self.varshape, allow_none=False) @@ -742,7 +739,7 @@ def f_q_tau(self, V): raise NotImplementedError -class IKK2A_HM1992(IKK2_pq_ss): +class IKK2A_HM1992(_IKK2_pq_ss): r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. The dynamics of the model is given as [2]_ [3]_. @@ -831,7 +828,7 @@ def f_q_tau(self, V): jnp.exp(-(V - self.V_sh + 130.) / 7.1)) -class IKK2B_HM1992(IKK2_pq_ss): +class IKK2B_HM1992(_IKK2_pq_ss): r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. The dynamics of the model is given as [2]_ [3]_. diff --git a/brainpy/_src/dyn/channels/KCa.py b/brainpy/_src/dyn/channels/KCa.py index 114b7f067..103561d97 100644 --- a/brainpy/_src/dyn/channels/KCa.py +++ b/brainpy/_src/dyn/channels/KCa.py @@ -107,13 +107,13 @@ def __init__( self.integral = odeint(self.dp, method=method) def dp(self, p, t, C_Ca): - C2 = self.alpha * jnp.power(C_Ca, self.n) + C2 = self.alpha * jnp.power(bm.as_jax(C_Ca), self.n) C3 = C2 + self.beta return self.phi * (C2 / C3 - p) * C3 def update(self, tdi, V, C_Ca, E_Ca): t, dt = tdi['t'], tdi['dt'] - self.p.value = self.integral(self.p, t, C_Ca=C_Ca, dt=dt) + self.p.value = self.integral(self.p.value, t, C_Ca=C_Ca, dt=dt) def current(self, V, C_Ca, E_Ca): return self.g_max * self.p * self.p * (self.E - V) diff --git a/brainpy/_src/dyn/channels/Na.py b/brainpy/_src/dyn/channels/Na.py index 5045591a1..d867d5334 100644 --- a/brainpy/_src/dyn/channels/Na.py +++ b/brainpy/_src/dyn/channels/Na.py @@ -16,14 +16,13 @@ from .base import SodiumChannel __all__ = [ - 'INa_p3q_markov', 'INa_Ba2002', 'INa_TM1991', 'INa_HH1952', ] -class INa_p3q_markov(SodiumChannel): +class _INa_p3q_markov(SodiumChannel): r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain. The general model can be used to model the dynamics with: @@ -64,10 +63,10 @@ def __init__( name: str = None, mode: bm.Mode = None, ): - super(INa_p3q_markov, self).__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) + super(_INa_p3q_markov, self).__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) # parameters self.E = parameter(E, self.varshape, allow_none=False) @@ -119,7 +118,7 @@ def f_q_beta(self, V): raise NotImplementedError -class INa_Ba2002(INa_p3q_markov): +class INa_Ba2002(_INa_p3q_markov): r"""The sodium current model. The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_. @@ -200,7 +199,7 @@ def f_q_beta(self, V): return 4. / (1. + jnp.exp(-(V - self.V_sh - 40.) / 5.)) -class INa_TM1991(INa_p3q_markov): +class INa_TM1991(_INa_p3q_markov): r"""The sodium current model described by (Traub and Miles, 1991) [1]_. The dynamics of this sodium current model is given by: @@ -286,7 +285,7 @@ def f_q_beta(self, V): return 4. / (1 + jnp.exp(-(V - self.V_sh - 40) / 5)) -class INa_HH1952(INa_p3q_markov): +class INa_HH1952(_INa_p3q_markov): r"""The sodium current model described by Hodgkin–Huxley model [1]_. The dynamics of this sodium current model is given by: diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index aa5eabbc7..4a567de22 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -178,7 +178,7 @@ def check_and_format_inputs(host, inputs): # input data if type_ == 'iter': if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)): - array_inputs[op].append([variable, jnp.asarray(value)]) + array_inputs[op].append([variable, bm.as_jax(value)]) else: next_inputs[op].append([variable, iter(value)]) elif type_ == 'func': @@ -547,7 +547,7 @@ def _step_func_monitor(self, shared): if idx is None: res[key] = variable.value else: - res[key] = variable[jnp.asarray(idx)] + res[key] = variable[bm.as_jax(idx)] return res def _step_func_input(self, shared): From 75f6fceaa279c4d01b0d70f434c45dd0cb988f10 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 14 Jan 2023 21:11:45 +0800 Subject: [PATCH 11/13] updates --- brainpy/__init__.py | 712 +++++++++--------- brainpy/_src/connect/random_conn.py | 2 +- brainpy/_src/dyn/layers/linear.py | 17 +- brainpy/_src/dyn/layers/reservoir.py | 49 +- brainpy/_src/initialize/random_inits.py | 10 +- .../_src/optimizers/tests/test_scheduler.py | 130 ++-- brainpy/math/__init__.py | 45 +- .../dynamics_analysis/2d_mean_field_QIF.py | 4 +- .../2d_wilson_cowan_model.py | 4 +- .../dynamics_analysis/highdim_RNN_Analysis.py | 2 +- ...2017_unified_thalamus_oscillation_model.py | 17 +- .../Sanda_2021_hippo-tha-cortex-model.py | 2 +- examples/dynamics_simulation/hh_model.py | 10 +- .../dynamics_simulation/multi_scale_COBAHH.py | 6 +- ...Bellec_2020_eprop_evidence_accumulation.py | 6 +- .../dynamics_training/Song_2016_EI_RNN.py | 2 +- .../Sussillo_Abbott_2009_FORCE_Learning.py | 8 +- .../dynamics_training/echo_state_network.py | 15 +- examples/dynamics_training/reservoir-mnist.py | 129 ++++ .../SurrogateGrad_lif-ANN-style.py | 8 +- .../training_snn_models/SurrogateGrad_lif.py | 15 +- .../spikebased_bp_for_cifar10.py | 3 +- 22 files changed, 647 insertions(+), 549 deletions(-) create mode 100644 examples/dynamics_training/reservoir-mnist.py diff --git a/brainpy/__init__.py b/brainpy/__init__.py index f25604e70..a058532e5 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -237,359 +237,359 @@ del compat -import brainpy._src.math.arraycompatible as bm -math.__dict__['full'] = bm.full -math.__dict__['full_like'] = bm.full_like -math.__dict__['eye'] = bm.eye -math.__dict__['identity'] = bm.identity -math.__dict__['diag'] = bm.diag -math.__dict__['tri'] = bm.tri -math.__dict__['tril'] = bm.tril -math.__dict__['triu'] = bm.triu -math.__dict__['real'] = bm.real -math.__dict__['imag'] = bm.imag -math.__dict__['conj'] = bm.conj -math.__dict__['conjugate'] = bm.conjugate -math.__dict__['ndim'] = bm.ndim -math.__dict__['isreal'] = bm.isreal -math.__dict__['isscalar'] = bm.isscalar -math.__dict__['add'] = bm.add -math.__dict__['reciprocal'] = bm.reciprocal -math.__dict__['negative'] = bm.negative -math.__dict__['positive'] = bm.positive -math.__dict__['multiply'] = bm.multiply -math.__dict__['divide'] = bm.divide -math.__dict__['power'] = bm.power -math.__dict__['subtract'] = bm.subtract -math.__dict__['true_divide'] = bm.true_divide -math.__dict__['floor_divide'] = bm.floor_divide -math.__dict__['float_power'] = bm.float_power -math.__dict__['fmod'] = bm.fmod -math.__dict__['mod'] = bm.mod -math.__dict__['modf'] = bm.modf -math.__dict__['divmod'] = bm.divmod -math.__dict__['remainder'] = bm.remainder -math.__dict__['abs'] = bm.abs -math.__dict__['exp'] = bm.exp -math.__dict__['exp2'] = bm.exp2 -math.__dict__['expm1'] = bm.expm1 -math.__dict__['log'] = bm.log -math.__dict__['log10'] = bm.log10 -math.__dict__['log1p'] = bm.log1p -math.__dict__['log2'] = bm.log2 -math.__dict__['logaddexp'] = bm.logaddexp -math.__dict__['logaddexp2'] = bm.logaddexp2 -math.__dict__['lcm'] = bm.lcm -math.__dict__['gcd'] = bm.gcd -math.__dict__['arccos'] = bm.arccos -math.__dict__['arccosh'] = bm.arccosh -math.__dict__['arcsin'] = bm.arcsin -math.__dict__['arcsinh'] = bm.arcsinh -math.__dict__['arctan'] = bm.arctan -math.__dict__['arctan2'] = bm.arctan2 -math.__dict__['arctanh'] = bm.arctanh -math.__dict__['cos'] = bm.cos -math.__dict__['cosh'] = bm.cosh -math.__dict__['sin'] = bm.sin -math.__dict__['sinc'] = bm.sinc -math.__dict__['sinh'] = bm.sinh -math.__dict__['tan'] = bm.tan -math.__dict__['tanh'] = bm.tanh -math.__dict__['deg2rad'] = bm.deg2rad -math.__dict__['hypot'] = bm.hypot -math.__dict__['rad2deg'] = bm.rad2deg -math.__dict__['degrees'] = bm.degrees -math.__dict__['radians'] = bm.radians -math.__dict__['round'] = bm.round -math.__dict__['around'] = bm.around -math.__dict__['round_'] = bm.round_ -math.__dict__['rint'] = bm.rint -math.__dict__['floor'] = bm.floor -math.__dict__['ceil'] = bm.ceil -math.__dict__['trunc'] = bm.trunc -math.__dict__['fix'] = bm.fix -math.__dict__['prod'] = bm.prod -math.__dict__['sum'] = bm.sum -math.__dict__['diff'] = bm.diff -math.__dict__['median'] = bm.median -math.__dict__['nancumprod'] = bm.nancumprod -math.__dict__['nancumsum'] = bm.nancumsum -math.__dict__['nanprod'] = bm.nanprod -math.__dict__['nansum'] = bm.nansum -math.__dict__['cumprod'] = bm.cumprod -math.__dict__['cumsum'] = bm.cumsum -math.__dict__['ediff1d'] = bm.ediff1d -math.__dict__['cross'] = bm.cross -math.__dict__['trapz'] = bm.trapz -math.__dict__['isfinite'] = bm.isfinite -math.__dict__['isinf'] = bm.isinf -math.__dict__['isnan'] = bm.isnan -math.__dict__['signbit'] = bm.signbit -math.__dict__['copysign'] = bm.copysign -math.__dict__['nextafter'] = bm.nextafter -math.__dict__['ldexp'] = bm.ldexp -math.__dict__['frexp'] = bm.frexp -math.__dict__['convolve'] = bm.convolve -math.__dict__['sqrt'] = bm.sqrt -math.__dict__['cbrt'] = bm.cbrt -math.__dict__['square'] = bm.square -math.__dict__['absolute'] = bm.absolute -math.__dict__['fabs'] = bm.fabs -math.__dict__['sign'] = bm.sign -math.__dict__['heaviside'] = bm.heaviside -math.__dict__['maximum'] = bm.maximum -math.__dict__['minimum'] = bm.minimum -math.__dict__['fmax'] = bm.fmax -math.__dict__['fmin'] = bm.fmin -math.__dict__['interp'] = bm.interp -math.__dict__['clip'] = bm.clip -math.__dict__['angle'] = bm.angle -math.__dict__['bitwise_and'] = bm.bitwise_and -math.__dict__['bitwise_not'] = bm.bitwise_not -math.__dict__['bitwise_or'] = bm.bitwise_or -math.__dict__['bitwise_xor'] = bm.bitwise_xor -math.__dict__['invert'] = bm.invert -math.__dict__['left_shift'] = bm.left_shift -math.__dict__['right_shift'] = bm.right_shift -math.__dict__['equal'] = bm.equal -math.__dict__['not_equal'] = bm.not_equal -math.__dict__['greater'] = bm.greater -math.__dict__['greater_equal'] = bm.greater_equal -math.__dict__['less'] = bm.less -math.__dict__['less_equal'] = bm.less_equal -math.__dict__['array_equal'] = bm.array_equal -math.__dict__['isclose'] = bm.isclose -math.__dict__['allclose'] = bm.allclose -math.__dict__['logical_and'] = bm.logical_and -math.__dict__['logical_not'] = bm.logical_not -math.__dict__['logical_or'] = bm.logical_or -math.__dict__['logical_xor'] = bm.logical_xor -math.__dict__['all'] = bm.all -math.__dict__['any'] = bm.any -math.__dict__['alltrue'] = bm.alltrue -math.__dict__['sometrue'] = bm.sometrue -math.__dict__['shape'] = bm.shape -math.__dict__['size'] = bm.size -math.__dict__['reshape'] = bm.reshape -math.__dict__['ravel'] = bm.ravel -math.__dict__['moveaxis'] = bm.moveaxis -math.__dict__['transpose'] = bm.transpose -math.__dict__['swapaxes'] = bm.swapaxes -math.__dict__['concatenate'] = bm.concatenate -math.__dict__['stack'] = bm.stack -math.__dict__['vstack'] = bm.vstack -math.__dict__['hstack'] = bm.hstack -math.__dict__['dstack'] = bm.dstack -math.__dict__['column_stack'] = bm.column_stack -math.__dict__['split'] = bm.split -math.__dict__['dsplit'] = bm.dsplit -math.__dict__['hsplit'] = bm.hsplit -math.__dict__['vsplit'] = bm.vsplit -math.__dict__['tile'] = bm.tile -math.__dict__['repeat'] = bm.repeat -math.__dict__['unique'] = bm.unique -math.__dict__['append'] = bm.append -math.__dict__['flip'] = bm.flip -math.__dict__['fliplr'] = bm.fliplr -math.__dict__['flipud'] = bm.flipud -math.__dict__['roll'] = bm.roll -math.__dict__['atleast_1d'] = bm.atleast_1d -math.__dict__['atleast_2d'] = bm.atleast_2d -math.__dict__['atleast_3d'] = bm.atleast_3d -math.__dict__['expand_dims'] = bm.expand_dims -math.__dict__['squeeze'] = bm.squeeze -math.__dict__['sort'] = bm.sort -math.__dict__['argsort'] = bm.argsort -math.__dict__['argmax'] = bm.argmax -math.__dict__['argmin'] = bm.argmin -math.__dict__['argwhere'] = bm.argwhere -math.__dict__['nonzero'] = bm.nonzero -math.__dict__['flatnonzero'] = bm.flatnonzero -math.__dict__['where'] = bm.where -math.__dict__['searchsorted'] = bm.searchsorted -math.__dict__['extract'] = bm.extract -math.__dict__['count_nonzero'] = bm.count_nonzero -math.__dict__['max'] = bm.max -math.__dict__['min'] = bm.min -math.__dict__['amax'] = bm.amax -math.__dict__['amin'] = bm.amin -math.__dict__['array_split'] = bm.array_split -math.__dict__['meshgrid'] = bm.meshgrid -math.__dict__['vander'] = bm.vander -math.__dict__['nonzero'] = bm.nonzero -math.__dict__['where'] = bm.where -math.__dict__['tril_indices'] = bm.tril_indices -math.__dict__['tril_indices_from'] = bm.tril_indices_from -math.__dict__['triu_indices'] = bm.triu_indices -math.__dict__['triu_indices_from'] = bm.triu_indices_from -math.__dict__['take'] = bm.take -math.__dict__['select'] = bm.select -math.__dict__['nanmin'] = bm.nanmin -math.__dict__['nanmax'] = bm.nanmax -math.__dict__['ptp'] = bm.ptp -math.__dict__['percentile'] = bm.percentile -math.__dict__['nanpercentile'] = bm.nanpercentile -math.__dict__['quantile'] = bm.quantile -math.__dict__['nanquantile'] = bm.nanquantile -math.__dict__['median'] = bm.median -math.__dict__['average'] = bm.average -math.__dict__['mean'] = bm.mean -math.__dict__['std'] = bm.std -math.__dict__['var'] = bm.var -math.__dict__['nanmedian'] = bm.nanmedian -math.__dict__['nanmean'] = bm.nanmean -math.__dict__['nanstd'] = bm.nanstd -math.__dict__['nanvar'] = bm.nanvar -math.__dict__['corrcoef'] = bm.corrcoef -math.__dict__['correlate'] = bm.correlate -math.__dict__['cov'] = bm.cov -math.__dict__['histogram'] = bm.histogram -math.__dict__['bincount'] = bm.bincount -math.__dict__['digitize'] = bm.digitize -math.__dict__['bartlett'] = bm.bartlett -math.__dict__['blackman'] = bm.blackman -math.__dict__['hamming'] = bm.hamming -math.__dict__['hanning'] = bm.hanning -math.__dict__['kaiser'] = bm.kaiser -math.__dict__['e'] = bm.e -math.__dict__['pi'] = bm.pi -math.__dict__['inf'] = bm.inf -math.__dict__['dot'] = bm.dot -math.__dict__['vdot'] = bm.vdot -math.__dict__['inner'] = bm.inner -math.__dict__['outer'] = bm.outer -math.__dict__['kron'] = bm.kron -math.__dict__['matmul'] = bm.matmul -math.__dict__['trace'] = bm.trace -math.__dict__['dtype'] = bm.dtype -math.__dict__['finfo'] = bm.finfo -math.__dict__['iinfo'] = bm.iinfo -math.__dict__['uint8'] = bm.uint8 -math.__dict__['uint16'] = bm.uint16 -math.__dict__['uint32'] = bm.uint32 -math.__dict__['uint64'] = bm.uint64 -math.__dict__['int8'] = bm.int8 -math.__dict__['int16'] = bm.int16 -math.__dict__['int32'] = bm.int32 -math.__dict__['int64'] = bm.int64 -math.__dict__['float16'] = bm.float16 -math.__dict__['float32'] = bm.float32 -math.__dict__['float64'] = bm.float64 -math.__dict__['complex64'] = bm.complex64 -math.__dict__['complex128'] = bm.complex128 -math.__dict__['product'] = bm.product -math.__dict__['row_stack'] = bm.row_stack -math.__dict__['apply_over_axes'] = bm.apply_over_axes -math.__dict__['apply_along_axis'] = bm.apply_along_axis -math.__dict__['array_equiv'] = bm.array_equiv -math.__dict__['array_repr'] = bm.array_repr -math.__dict__['array_str'] = bm.array_str -math.__dict__['block'] = bm.block -math.__dict__['broadcast_arrays'] = bm.broadcast_arrays -math.__dict__['broadcast_shapes'] = bm.broadcast_shapes -math.__dict__['broadcast_to'] = bm.broadcast_to -math.__dict__['compress'] = bm.compress -math.__dict__['cumproduct'] = bm.cumproduct -math.__dict__['diag_indices'] = bm.diag_indices -math.__dict__['diag_indices_from'] = bm.diag_indices_from -math.__dict__['diagflat'] = bm.diagflat -math.__dict__['diagonal'] = bm.diagonal -math.__dict__['einsum'] = bm.einsum -math.__dict__['einsum_path'] = bm.einsum_path -math.__dict__['geomspace'] = bm.geomspace -math.__dict__['gradient'] = bm.gradient -math.__dict__['histogram2d'] = bm.histogram2d -math.__dict__['histogram_bin_edges'] = bm.histogram_bin_edges -math.__dict__['histogramdd'] = bm.histogramdd -math.__dict__['i0'] = bm.i0 -math.__dict__['in1d'] = bm.in1d -math.__dict__['indices'] = bm.indices -math.__dict__['insert'] = bm.insert -math.__dict__['intersect1d'] = bm.intersect1d -math.__dict__['iscomplex'] = bm.iscomplex -math.__dict__['isin'] = bm.isin -math.__dict__['ix_'] = bm.ix_ -math.__dict__['lexsort'] = bm.lexsort -math.__dict__['load'] = bm.load -math.__dict__['save'] = bm.save -math.__dict__['savez'] = bm.savez -math.__dict__['mask_indices'] = bm.mask_indices -math.__dict__['msort'] = bm.msort -math.__dict__['nan_to_num'] = bm.nan_to_num -math.__dict__['nanargmax'] = bm.nanargmax -math.__dict__['setdiff1d'] = bm.setdiff1d -math.__dict__['nanargmin'] = bm.nanargmin -math.__dict__['pad'] = bm.pad -math.__dict__['poly'] = bm.poly -math.__dict__['polyadd'] = bm.polyadd -math.__dict__['polyder'] = bm.polyder -math.__dict__['polyfit'] = bm.polyfit -math.__dict__['polyint'] = bm.polyint -math.__dict__['polymul'] = bm.polymul -math.__dict__['polysub'] = bm.polysub -math.__dict__['polyval'] = bm.polyval -math.__dict__['resize'] = bm.resize -math.__dict__['rollaxis'] = bm.rollaxis -math.__dict__['roots'] = bm.roots -math.__dict__['rot90'] = bm.rot90 -math.__dict__['setxor1d'] = bm.setxor1d -math.__dict__['tensordot'] = bm.tensordot -math.__dict__['trim_zeros'] = bm.trim_zeros -math.__dict__['union1d'] = bm.union1d -math.__dict__['unravel_index'] = bm.unravel_index -math.__dict__['unwrap'] = bm.unwrap -math.__dict__['take_along_axis'] = bm.take_along_axis -math.__dict__['can_cast'] = bm.can_cast -math.__dict__['choose'] = bm.choose -math.__dict__['copy'] = bm.copy -math.__dict__['frombuffer'] = bm.frombuffer -math.__dict__['fromfile'] = bm.fromfile -math.__dict__['fromfunction'] = bm.fromfunction -math.__dict__['fromiter'] = bm.fromiter -math.__dict__['fromstring'] = bm.fromstring -math.__dict__['get_printoptions'] = bm.get_printoptions -math.__dict__['iscomplexobj'] = bm.iscomplexobj -math.__dict__['isneginf'] = bm.isneginf -math.__dict__['isposinf'] = bm.isposinf -math.__dict__['isrealobj'] = bm.isrealobj -math.__dict__['issubdtype'] = bm.issubdtype -math.__dict__['issubsctype'] = bm.issubsctype -math.__dict__['iterable'] = bm.iterable -math.__dict__['packbits'] = bm.packbits -math.__dict__['piecewise'] = bm.piecewise -math.__dict__['printoptions'] = bm.printoptions -math.__dict__['set_printoptions'] = bm.set_printoptions -math.__dict__['promote_types'] = bm.promote_types -math.__dict__['ravel_multi_index'] = bm.ravel_multi_index -math.__dict__['result_type'] = bm.result_type -math.__dict__['sort_complex'] = bm.sort_complex -math.__dict__['unpackbits'] = bm.unpackbits -math.__dict__['delete'] = bm.delete -math.__dict__['add_docstring'] = bm.add_docstring -math.__dict__['add_newdoc'] = bm.add_newdoc -math.__dict__['add_newdoc_ufunc'] = bm.add_newdoc_ufunc -math.__dict__['array2string'] = bm.array2string -math.__dict__['asanyarray'] = bm.asanyarray -math.__dict__['ascontiguousarray'] = bm.ascontiguousarray -math.__dict__['asfarray'] = bm.asfarray -math.__dict__['asscalar'] = bm.asscalar -math.__dict__['common_type'] = bm.common_type -math.__dict__['disp'] = bm.disp -math.__dict__['genfromtxt'] = bm.genfromtxt -math.__dict__['loadtxt'] = bm.loadtxt -math.__dict__['info'] = bm.info -math.__dict__['issubclass_'] = bm.issubclass_ -math.__dict__['place'] = bm.place -math.__dict__['polydiv'] = bm.polydiv -math.__dict__['put'] = bm.put -math.__dict__['putmask'] = bm.putmask -math.__dict__['safe_eval'] = bm.safe_eval -math.__dict__['savetxt'] = bm.savetxt -math.__dict__['savez_compressed'] = bm.savez_compressed -math.__dict__['show_config'] = bm.show_config -math.__dict__['typename'] = bm.typename -math.__dict__['copyto'] = bm.copyto -math.__dict__['matrix'] = bm.matrix -math.__dict__['asmatrix'] = bm.asmatrix -math.__dict__['mat'] = bm.mat -del bm +# import brainpy._src.math.arraycompatible as bm +# math.__dict__['full'] = bm.full +# math.__dict__['full_like'] = bm.full_like +# math.__dict__['eye'] = bm.eye +# math.__dict__['identity'] = bm.identity +# math.__dict__['diag'] = bm.diag +# math.__dict__['tri'] = bm.tri +# math.__dict__['tril'] = bm.tril +# math.__dict__['triu'] = bm.triu +# math.__dict__['real'] = bm.real +# math.__dict__['imag'] = bm.imag +# math.__dict__['conj'] = bm.conj +# math.__dict__['conjugate'] = bm.conjugate +# math.__dict__['ndim'] = bm.ndim +# math.__dict__['isreal'] = bm.isreal +# math.__dict__['isscalar'] = bm.isscalar +# math.__dict__['add'] = bm.add +# math.__dict__['reciprocal'] = bm.reciprocal +# math.__dict__['negative'] = bm.negative +# math.__dict__['positive'] = bm.positive +# math.__dict__['multiply'] = bm.multiply +# math.__dict__['divide'] = bm.divide +# math.__dict__['power'] = bm.power +# math.__dict__['subtract'] = bm.subtract +# math.__dict__['true_divide'] = bm.true_divide +# math.__dict__['floor_divide'] = bm.floor_divide +# math.__dict__['float_power'] = bm.float_power +# math.__dict__['fmod'] = bm.fmod +# math.__dict__['mod'] = bm.mod +# math.__dict__['modf'] = bm.modf +# math.__dict__['divmod'] = bm.divmod +# math.__dict__['remainder'] = bm.remainder +# math.__dict__['abs'] = bm.abs +# math.__dict__['exp'] = bm.exp +# math.__dict__['exp2'] = bm.exp2 +# math.__dict__['expm1'] = bm.expm1 +# math.__dict__['log'] = bm.log +# math.__dict__['log10'] = bm.log10 +# math.__dict__['log1p'] = bm.log1p +# math.__dict__['log2'] = bm.log2 +# math.__dict__['logaddexp'] = bm.logaddexp +# math.__dict__['logaddexp2'] = bm.logaddexp2 +# math.__dict__['lcm'] = bm.lcm +# math.__dict__['gcd'] = bm.gcd +# math.__dict__['arccos'] = bm.arccos +# math.__dict__['arccosh'] = bm.arccosh +# math.__dict__['arcsin'] = bm.arcsin +# math.__dict__['arcsinh'] = bm.arcsinh +# math.__dict__['arctan'] = bm.arctan +# math.__dict__['arctan2'] = bm.arctan2 +# math.__dict__['arctanh'] = bm.arctanh +# math.__dict__['cos'] = bm.cos +# math.__dict__['cosh'] = bm.cosh +# math.__dict__['sin'] = bm.sin +# math.__dict__['sinc'] = bm.sinc +# math.__dict__['sinh'] = bm.sinh +# math.__dict__['tan'] = bm.tan +# math.__dict__['tanh'] = bm.tanh +# math.__dict__['deg2rad'] = bm.deg2rad +# math.__dict__['hypot'] = bm.hypot +# math.__dict__['rad2deg'] = bm.rad2deg +# math.__dict__['degrees'] = bm.degrees +# math.__dict__['radians'] = bm.radians +# math.__dict__['round'] = bm.round +# math.__dict__['around'] = bm.around +# math.__dict__['round_'] = bm.round_ +# math.__dict__['rint'] = bm.rint +# math.__dict__['floor'] = bm.floor +# math.__dict__['ceil'] = bm.ceil +# math.__dict__['trunc'] = bm.trunc +# math.__dict__['fix'] = bm.fix +# math.__dict__['prod'] = bm.prod +# math.__dict__['sum'] = bm.sum +# math.__dict__['diff'] = bm.diff +# math.__dict__['median'] = bm.median +# math.__dict__['nancumprod'] = bm.nancumprod +# math.__dict__['nancumsum'] = bm.nancumsum +# math.__dict__['nanprod'] = bm.nanprod +# math.__dict__['nansum'] = bm.nansum +# math.__dict__['cumprod'] = bm.cumprod +# math.__dict__['cumsum'] = bm.cumsum +# math.__dict__['ediff1d'] = bm.ediff1d +# math.__dict__['cross'] = bm.cross +# math.__dict__['trapz'] = bm.trapz +# math.__dict__['isfinite'] = bm.isfinite +# math.__dict__['isinf'] = bm.isinf +# math.__dict__['isnan'] = bm.isnan +# math.__dict__['signbit'] = bm.signbit +# math.__dict__['copysign'] = bm.copysign +# math.__dict__['nextafter'] = bm.nextafter +# math.__dict__['ldexp'] = bm.ldexp +# math.__dict__['frexp'] = bm.frexp +# math.__dict__['convolve'] = bm.convolve +# math.__dict__['sqrt'] = bm.sqrt +# math.__dict__['cbrt'] = bm.cbrt +# math.__dict__['square'] = bm.square +# math.__dict__['absolute'] = bm.absolute +# math.__dict__['fabs'] = bm.fabs +# math.__dict__['sign'] = bm.sign +# math.__dict__['heaviside'] = bm.heaviside +# math.__dict__['maximum'] = bm.maximum +# math.__dict__['minimum'] = bm.minimum +# math.__dict__['fmax'] = bm.fmax +# math.__dict__['fmin'] = bm.fmin +# math.__dict__['interp'] = bm.interp +# math.__dict__['clip'] = bm.clip +# math.__dict__['angle'] = bm.angle +# math.__dict__['bitwise_and'] = bm.bitwise_and +# math.__dict__['bitwise_not'] = bm.bitwise_not +# math.__dict__['bitwise_or'] = bm.bitwise_or +# math.__dict__['bitwise_xor'] = bm.bitwise_xor +# math.__dict__['invert'] = bm.invert +# math.__dict__['left_shift'] = bm.left_shift +# math.__dict__['right_shift'] = bm.right_shift +# math.__dict__['equal'] = bm.equal +# math.__dict__['not_equal'] = bm.not_equal +# math.__dict__['greater'] = bm.greater +# math.__dict__['greater_equal'] = bm.greater_equal +# math.__dict__['less'] = bm.less +# math.__dict__['less_equal'] = bm.less_equal +# math.__dict__['array_equal'] = bm.array_equal +# math.__dict__['isclose'] = bm.isclose +# math.__dict__['allclose'] = bm.allclose +# math.__dict__['logical_and'] = bm.logical_and +# math.__dict__['logical_not'] = bm.logical_not +# math.__dict__['logical_or'] = bm.logical_or +# math.__dict__['logical_xor'] = bm.logical_xor +# math.__dict__['all'] = bm.all +# math.__dict__['any'] = bm.any +# math.__dict__['alltrue'] = bm.alltrue +# math.__dict__['sometrue'] = bm.sometrue +# math.__dict__['shape'] = bm.shape +# math.__dict__['size'] = bm.size +# math.__dict__['reshape'] = bm.reshape +# math.__dict__['ravel'] = bm.ravel +# math.__dict__['moveaxis'] = bm.moveaxis +# math.__dict__['transpose'] = bm.transpose +# math.__dict__['swapaxes'] = bm.swapaxes +# math.__dict__['concatenate'] = bm.concatenate +# math.__dict__['stack'] = bm.stack +# math.__dict__['vstack'] = bm.vstack +# math.__dict__['hstack'] = bm.hstack +# math.__dict__['dstack'] = bm.dstack +# math.__dict__['column_stack'] = bm.column_stack +# math.__dict__['split'] = bm.split +# math.__dict__['dsplit'] = bm.dsplit +# math.__dict__['hsplit'] = bm.hsplit +# math.__dict__['vsplit'] = bm.vsplit +# math.__dict__['tile'] = bm.tile +# math.__dict__['repeat'] = bm.repeat +# math.__dict__['unique'] = bm.unique +# math.__dict__['append'] = bm.append +# math.__dict__['flip'] = bm.flip +# math.__dict__['fliplr'] = bm.fliplr +# math.__dict__['flipud'] = bm.flipud +# math.__dict__['roll'] = bm.roll +# math.__dict__['atleast_1d'] = bm.atleast_1d +# math.__dict__['atleast_2d'] = bm.atleast_2d +# math.__dict__['atleast_3d'] = bm.atleast_3d +# math.__dict__['expand_dims'] = bm.expand_dims +# math.__dict__['squeeze'] = bm.squeeze +# math.__dict__['sort'] = bm.sort +# math.__dict__['argsort'] = bm.argsort +# math.__dict__['argmax'] = bm.argmax +# math.__dict__['argmin'] = bm.argmin +# math.__dict__['argwhere'] = bm.argwhere +# math.__dict__['nonzero'] = bm.nonzero +# math.__dict__['flatnonzero'] = bm.flatnonzero +# math.__dict__['where'] = bm.where +# math.__dict__['searchsorted'] = bm.searchsorted +# math.__dict__['extract'] = bm.extract +# math.__dict__['count_nonzero'] = bm.count_nonzero +# math.__dict__['max'] = bm.max +# math.__dict__['min'] = bm.min +# math.__dict__['amax'] = bm.amax +# math.__dict__['amin'] = bm.amin +# math.__dict__['array_split'] = bm.array_split +# math.__dict__['meshgrid'] = bm.meshgrid +# math.__dict__['vander'] = bm.vander +# math.__dict__['nonzero'] = bm.nonzero +# math.__dict__['where'] = bm.where +# math.__dict__['tril_indices'] = bm.tril_indices +# math.__dict__['tril_indices_from'] = bm.tril_indices_from +# math.__dict__['triu_indices'] = bm.triu_indices +# math.__dict__['triu_indices_from'] = bm.triu_indices_from +# math.__dict__['take'] = bm.take +# math.__dict__['select'] = bm.select +# math.__dict__['nanmin'] = bm.nanmin +# math.__dict__['nanmax'] = bm.nanmax +# math.__dict__['ptp'] = bm.ptp +# math.__dict__['percentile'] = bm.percentile +# math.__dict__['nanpercentile'] = bm.nanpercentile +# math.__dict__['quantile'] = bm.quantile +# math.__dict__['nanquantile'] = bm.nanquantile +# math.__dict__['median'] = bm.median +# math.__dict__['average'] = bm.average +# math.__dict__['mean'] = bm.mean +# math.__dict__['std'] = bm.std +# math.__dict__['var'] = bm.var +# math.__dict__['nanmedian'] = bm.nanmedian +# math.__dict__['nanmean'] = bm.nanmean +# math.__dict__['nanstd'] = bm.nanstd +# math.__dict__['nanvar'] = bm.nanvar +# math.__dict__['corrcoef'] = bm.corrcoef +# math.__dict__['correlate'] = bm.correlate +# math.__dict__['cov'] = bm.cov +# math.__dict__['histogram'] = bm.histogram +# math.__dict__['bincount'] = bm.bincount +# math.__dict__['digitize'] = bm.digitize +# math.__dict__['bartlett'] = bm.bartlett +# math.__dict__['blackman'] = bm.blackman +# math.__dict__['hamming'] = bm.hamming +# math.__dict__['hanning'] = bm.hanning +# math.__dict__['kaiser'] = bm.kaiser +# math.__dict__['e'] = bm.e +# math.__dict__['pi'] = bm.pi +# math.__dict__['inf'] = bm.inf +# math.__dict__['dot'] = bm.dot +# math.__dict__['vdot'] = bm.vdot +# math.__dict__['inner'] = bm.inner +# math.__dict__['outer'] = bm.outer +# math.__dict__['kron'] = bm.kron +# math.__dict__['matmul'] = bm.matmul +# math.__dict__['trace'] = bm.trace +# math.__dict__['dtype'] = bm.dtype +# math.__dict__['finfo'] = bm.finfo +# math.__dict__['iinfo'] = bm.iinfo +# math.__dict__['uint8'] = bm.uint8 +# math.__dict__['uint16'] = bm.uint16 +# math.__dict__['uint32'] = bm.uint32 +# math.__dict__['uint64'] = bm.uint64 +# math.__dict__['int8'] = bm.int8 +# math.__dict__['int16'] = bm.int16 +# math.__dict__['int32'] = bm.int32 +# math.__dict__['int64'] = bm.int64 +# math.__dict__['float16'] = bm.float16 +# math.__dict__['float32'] = bm.float32 +# math.__dict__['float64'] = bm.float64 +# math.__dict__['complex64'] = bm.complex64 +# math.__dict__['complex128'] = bm.complex128 +# math.__dict__['product'] = bm.product +# math.__dict__['row_stack'] = bm.row_stack +# math.__dict__['apply_over_axes'] = bm.apply_over_axes +# math.__dict__['apply_along_axis'] = bm.apply_along_axis +# math.__dict__['array_equiv'] = bm.array_equiv +# math.__dict__['array_repr'] = bm.array_repr +# math.__dict__['array_str'] = bm.array_str +# math.__dict__['block'] = bm.block +# math.__dict__['broadcast_arrays'] = bm.broadcast_arrays +# math.__dict__['broadcast_shapes'] = bm.broadcast_shapes +# math.__dict__['broadcast_to'] = bm.broadcast_to +# math.__dict__['compress'] = bm.compress +# math.__dict__['cumproduct'] = bm.cumproduct +# math.__dict__['diag_indices'] = bm.diag_indices +# math.__dict__['diag_indices_from'] = bm.diag_indices_from +# math.__dict__['diagflat'] = bm.diagflat +# math.__dict__['diagonal'] = bm.diagonal +# math.__dict__['einsum'] = bm.einsum +# math.__dict__['einsum_path'] = bm.einsum_path +# math.__dict__['geomspace'] = bm.geomspace +# math.__dict__['gradient'] = bm.gradient +# math.__dict__['histogram2d'] = bm.histogram2d +# math.__dict__['histogram_bin_edges'] = bm.histogram_bin_edges +# math.__dict__['histogramdd'] = bm.histogramdd +# math.__dict__['i0'] = bm.i0 +# math.__dict__['in1d'] = bm.in1d +# math.__dict__['indices'] = bm.indices +# math.__dict__['insert'] = bm.insert +# math.__dict__['intersect1d'] = bm.intersect1d +# math.__dict__['iscomplex'] = bm.iscomplex +# math.__dict__['isin'] = bm.isin +# math.__dict__['ix_'] = bm.ix_ +# math.__dict__['lexsort'] = bm.lexsort +# math.__dict__['load'] = bm.load +# math.__dict__['save'] = bm.save +# math.__dict__['savez'] = bm.savez +# math.__dict__['mask_indices'] = bm.mask_indices +# math.__dict__['msort'] = bm.msort +# math.__dict__['nan_to_num'] = bm.nan_to_num +# math.__dict__['nanargmax'] = bm.nanargmax +# math.__dict__['setdiff1d'] = bm.setdiff1d +# math.__dict__['nanargmin'] = bm.nanargmin +# math.__dict__['pad'] = bm.pad +# math.__dict__['poly'] = bm.poly +# math.__dict__['polyadd'] = bm.polyadd +# math.__dict__['polyder'] = bm.polyder +# math.__dict__['polyfit'] = bm.polyfit +# math.__dict__['polyint'] = bm.polyint +# math.__dict__['polymul'] = bm.polymul +# math.__dict__['polysub'] = bm.polysub +# math.__dict__['polyval'] = bm.polyval +# math.__dict__['resize'] = bm.resize +# math.__dict__['rollaxis'] = bm.rollaxis +# math.__dict__['roots'] = bm.roots +# math.__dict__['rot90'] = bm.rot90 +# math.__dict__['setxor1d'] = bm.setxor1d +# math.__dict__['tensordot'] = bm.tensordot +# math.__dict__['trim_zeros'] = bm.trim_zeros +# math.__dict__['union1d'] = bm.union1d +# math.__dict__['unravel_index'] = bm.unravel_index +# math.__dict__['unwrap'] = bm.unwrap +# math.__dict__['take_along_axis'] = bm.take_along_axis +# math.__dict__['can_cast'] = bm.can_cast +# math.__dict__['choose'] = bm.choose +# math.__dict__['copy'] = bm.copy +# math.__dict__['frombuffer'] = bm.frombuffer +# math.__dict__['fromfile'] = bm.fromfile +# math.__dict__['fromfunction'] = bm.fromfunction +# math.__dict__['fromiter'] = bm.fromiter +# math.__dict__['fromstring'] = bm.fromstring +# math.__dict__['get_printoptions'] = bm.get_printoptions +# math.__dict__['iscomplexobj'] = bm.iscomplexobj +# math.__dict__['isneginf'] = bm.isneginf +# math.__dict__['isposinf'] = bm.isposinf +# math.__dict__['isrealobj'] = bm.isrealobj +# math.__dict__['issubdtype'] = bm.issubdtype +# math.__dict__['issubsctype'] = bm.issubsctype +# math.__dict__['iterable'] = bm.iterable +# math.__dict__['packbits'] = bm.packbits +# math.__dict__['piecewise'] = bm.piecewise +# math.__dict__['printoptions'] = bm.printoptions +# math.__dict__['set_printoptions'] = bm.set_printoptions +# math.__dict__['promote_types'] = bm.promote_types +# math.__dict__['ravel_multi_index'] = bm.ravel_multi_index +# math.__dict__['result_type'] = bm.result_type +# math.__dict__['sort_complex'] = bm.sort_complex +# math.__dict__['unpackbits'] = bm.unpackbits +# math.__dict__['delete'] = bm.delete +# math.__dict__['add_docstring'] = bm.add_docstring +# math.__dict__['add_newdoc'] = bm.add_newdoc +# math.__dict__['add_newdoc_ufunc'] = bm.add_newdoc_ufunc +# math.__dict__['array2string'] = bm.array2string +# math.__dict__['asanyarray'] = bm.asanyarray +# math.__dict__['ascontiguousarray'] = bm.ascontiguousarray +# math.__dict__['asfarray'] = bm.asfarray +# math.__dict__['asscalar'] = bm.asscalar +# math.__dict__['common_type'] = bm.common_type +# math.__dict__['disp'] = bm.disp +# math.__dict__['genfromtxt'] = bm.genfromtxt +# math.__dict__['loadtxt'] = bm.loadtxt +# math.__dict__['info'] = bm.info +# math.__dict__['issubclass_'] = bm.issubclass_ +# math.__dict__['place'] = bm.place +# math.__dict__['polydiv'] = bm.polydiv +# math.__dict__['put'] = bm.put +# math.__dict__['putmask'] = bm.putmask +# math.__dict__['safe_eval'] = bm.safe_eval +# math.__dict__['savetxt'] = bm.savetxt +# math.__dict__['savez_compressed'] = bm.savez_compressed +# math.__dict__['show_config'] = bm.show_config +# math.__dict__['typename'] = bm.typename +# math.__dict__['copyto'] = bm.copyto +# math.__dict__['matrix'] = bm.matrix +# math.__dict__['asmatrix'] = bm.asmatrix +# math.__dict__['mat'] = bm.mat +# del bm diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index c8a37cd78..6cd43b177 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -129,7 +129,7 @@ def build_csr(self): def build_mat(self): pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state - mat = jnp.asarray(mat) + mat = bm.asarray(mat) if not self.include_self: mat = bm.fill_diagonal(mat, False) return mat.astype(MAT_DTYPE) diff --git a/brainpy/_src/dyn/layers/linear.py b/brainpy/_src/dyn/layers/linear.py index d28f4fb7b..9e29836b7 100644 --- a/brainpy/_src/dyn/layers/linear.py +++ b/brainpy/_src/dyn/layers/linear.py @@ -81,7 +81,11 @@ def __repr__(self): f'num_out={self.num_out}, ' f'mode={self.mode})') - def update(self, sha, x): + def update(self, *args): + if len(args) == 1: + sha, x = dict(), bm.as_jax(args[0]) + else: + sha, x = args[0], bm.as_jax(args[1]) res = x @ self.W if self.b is not None: res += self.b @@ -102,7 +106,7 @@ def online_init(self): num_input = self.num_in else: num_input = self.num_in + 1 - self.online_fit_by.initialize(feature_in=num_input, feature_out=self.num_out, identifier=self.name) + self.online_fit_by.register_target(feature_in=num_input, identifier=self.name) def online_fit(self, target: ArrayType, @@ -139,13 +143,6 @@ def online_fit(self, self.b += db[0] self.W += dW - def offline_init(self): - if self.b is None: - num_input = self.num_in + 1 - else: - num_input = self.num_in - self.offline_fit_by.initialize(feature_in=num_input, feature_out=self.num_out, identifier=self.name) - def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): @@ -176,7 +173,7 @@ def offline_fit(self, xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input) # solve weights by offline training methods - weights = self.offline_fit_by(self.name, target, xs, ys) + weights = self.offline_fit_by(target, xs, ys) # assign trained weights if self.b is None: diff --git a/brainpy/_src/dyn/layers/reservoir.py b/brainpy/_src/dyn/layers/reservoir.py index 7cd61f05a..18675672c 100644 --- a/brainpy/_src/dyn/layers/reservoir.py +++ b/brainpy/_src/dyn/layers/reservoir.py @@ -6,7 +6,7 @@ import brainpy.math as bm from brainpy._src.initialize import Normal, ZeroInit, Initializer, parameter, variable -from brainpy.check import is_float, is_initializer, is_string +from brainpy import check from brainpy.tools import to_size from brainpy.types import ArrayType from .base import Layer @@ -36,8 +36,9 @@ class Reservoir(Layer): A float between 0 and 1. activation : str, callable, optional Reservoir activation function. + - If a str, should be a :py:mod:`brainpy.math.activations` function name. - - If a callable, should be an element-wise operator on tensor. + - If a callable, should be an element-wise operator. activation_type : str - If "internal" (default), then leaky integration happens on states transformed by the activation function: @@ -66,9 +67,12 @@ class Reservoir(Layer): neurons connected to other reservoir neurons, including themselves. Must be in [0, 1], by default 0.1 comp_type: str - The connectivity type, can be "dense" or "sparse". + The connectivity type, can be "dense" or "sparse", "jit". + + - ``"dense"`` means the connectivity matrix is a dense matrix. + - ``"sparse"`` means the connectivity matrix is a CSR sparse matrix. spectral_radius : float, optional - Spectral radius of recurrent weight matrix, by default None + Spectral radius of recurrent weight matrix, by default None. noise_rec : float, optional Gain of noise applied to reservoir internal states, by default 0.0 noise_in : float, optional @@ -118,37 +122,38 @@ def __init__( self.num_unit = num_out assert num_out > 0, f'Must be a positive integer, but we got {num_out}' self.leaky_rate = leaky_rate - is_float(leaky_rate, 'leaky_rate', 0., 1.) - self.activation = getattr(bm.activations, activation) + check.is_float(leaky_rate, 'leaky_rate', 0., 1.) + self.activation = getattr(bm.activations, activation) if isinstance(activation, str) else activation + check.is_callable(self.activation, allow_none=False) self.activation_type = activation_type - is_string(activation_type, 'activation_type', ['internal', 'external']) + check.is_string(activation_type, 'activation_type', ['internal', 'external']) self.rng = bm.random.default_rng(seed) - is_float(spectral_radius, 'spectral_radius', allow_none=True) + check.is_float(spectral_radius, 'spectral_radius', allow_none=True) self.spectral_radius = spectral_radius # initializations - is_initializer(Win_initializer, 'ff_initializer', allow_none=False) - is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False) - is_initializer(b_initializer, 'bias_initializer', allow_none=True) + check.is_initializer(Win_initializer, 'ff_initializer', allow_none=False) + check.is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False) + check.is_initializer(b_initializer, 'bias_initializer', allow_none=True) self._Win_initializer = Win_initializer self._Wrec_initializer = Wrec_initializer self._b_initializer = b_initializer # connectivity - is_float(in_connectivity, 'ff_connectivity', 0., 1.) - is_float(rec_connectivity, 'rec_connectivity', 0., 1.) + check.is_float(in_connectivity, 'ff_connectivity', 0., 1.) + check.is_float(rec_connectivity, 'rec_connectivity', 0., 1.) self.ff_connectivity = in_connectivity self.rec_connectivity = rec_connectivity - is_string(comp_type, 'conn_type', ['dense', 'sparse']) + check.is_string(comp_type, 'conn_type', ['dense', 'sparse', 'jit']) self.comp_type = comp_type # noises - is_float(noise_in, 'noise_ff') - is_float(noise_rec, 'noise_rec') + check.is_float(noise_in, 'noise_ff') + check.is_float(noise_rec, 'noise_rec') self.noise_ff = noise_in self.noise_rec = noise_rec self.noise_type = noise_type - is_string(noise_type, 'noise_type', ['normal', 'uniform']) + check.is_string(noise_type, 'noise_type', ['normal', 'uniform']) # initialize feedforward weights weight_shape = (input_shape[-1], self.num_unit) @@ -170,7 +175,7 @@ def __init__( conn_mat = self.rng.random(recurrent_shape) > self.rec_connectivity self.Wrec[conn_mat] = 0. if self.spectral_radius is not None: - current_sr = max(abs(jnp.linalg.eig(self.Wrec)[0])) + current_sr = max(abs(jnp.linalg.eig(bm.as_jax(self.Wrec))[0])) self.Wrec *= self.spectral_radius / current_sr if self.comp_type == 'sparse' and self.rec_connectivity < 1.: self.rec_pres, self.rec_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) @@ -186,11 +191,13 @@ def __init__( def reset_state(self, batch_size=None): self.state.value = variable(jnp.zeros, batch_size, self.output_shape) - def update(self, sha, x): + def update(self, *args): """Feedforward output.""" # inputs - x = jnp.concatenate(x, axis=-1) - if self.noise_ff > 0: x += self.noise_ff * self.rng.uniform(-1, 1, x.shape) + x = args[0] if len(args) == 1 else args[1] + x = bm.as_jax(x) + if self.noise_ff > 0: + x += self.noise_ff * self.rng.uniform(-1, 1, x.shape) if self.comp_type == 'sparse' and self.ff_connectivity < 1.: sparse = {'data': self.Win, 'index': (self.ff_pres, self.ff_posts), diff --git a/brainpy/_src/initialize/random_inits.py b/brainpy/_src/initialize/random_inits.py index fdba28026..a51a0fbed 100644 --- a/brainpy/_src/initialize/random_inits.py +++ b/brainpy/_src/initialize/random_inits.py @@ -114,7 +114,7 @@ def __init__(self, mean=0., scale=1., seed=None): def __call__(self, *shape, dtype=None): shape = _format_shape(shape) weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale) - return jnp.asarray(weights, dtype=dtype) + return bm.as_jax(weights, dtype=dtype) def __repr__(self): return f'{self.__class__.__name__}(scale={self.scale}, rng={self.rng})' @@ -140,7 +140,7 @@ def __init__(self, min_val: float = 0., max_val: float = 1., seed=None): def __call__(self, shape, dtype=None): shape = _format_shape(shape) r = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape) - return jnp.asarray(r, dtype=dtype) + return bm.as_jax(r, dtype=dtype) def __repr__(self): return (f'{self.__class__.__name__}(min_val={self.min_val}, ' @@ -180,14 +180,14 @@ def __call__(self, shape, dtype=None): variance = (self.scale / denominator).astype(dtype) if self.distribution == "truncated_normal": stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype) - return self.rng.truncated_normal(-2, 2, shape, dtype) * stddev + res = self.rng.truncated_normal(-2, 2, shape, dtype) * stddev elif self.distribution == "normal": res = self.rng.randn(*shape) * jnp.sqrt(variance).astype(dtype) elif self.distribution == "uniform": res = self.rng.uniform(low=-1, high=1, size=shape) * jnp.sqrt(3 * variance).astype(dtype) else: raise ValueError("invalid distribution for variance scaling initializer") - return jnp.asarray(res, dtype=dtype) + return bm.as_jax(res, dtype=dtype) def __repr__(self): name = self.__class__.__name__ @@ -336,7 +336,7 @@ def __call__(self, shape, dtype=None): q_mat = q_mat.T q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) q_mat = jnp.moveaxis(q_mat, 0, self.axis) - return self.scale * jnp.asarray(q_mat, dtype=dtype) + return self.scale * bm.as_jax(q_mat, dtype=dtype) def __repr__(self): return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, rng={self.rng})' diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py index 9afb0c798..f3ef9a16c 100644 --- a/brainpy/_src/optimizers/tests/test_scheduler.py +++ b/brainpy/_src/optimizers/tests/test_scheduler.py @@ -30,70 +30,70 @@ def test2(self, last_epoch): self.assertTrue(lr1 == lr2) -class TestStepLR(parameterized.TestCase): - - @parameterized.named_parameters( - {'testcase_name': f'last_epoch={last_epoch}', - 'last_epoch': last_epoch} - for last_epoch in [-1, 0, 5, 10] - ) - def test1(self, last_epoch): - scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) - scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) - - for i in range(1, 25): - lr1 = scheduler1(i + last_epoch) - lr2 = scheduler2() - scheduler2.step_epoch() - print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}') - self.assertTrue(lr1 == lr2) - - -class TestCosineAnnealingLR(unittest.TestCase): - def test1(self): - max_epoch = 50 - iters = 200 - sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) - all_lr1 = [[], []] - all_lr2 = [[], []] - for epoch in range(max_epoch): - for batch in range(iters): - all_lr1[0].append(epoch + batch / iters) - all_lr1[1].append(sch()) - sch.step_epoch() - all_lr2[0].append(epoch) - all_lr2[1].append(sch()) - sch.step_epoch() - plt.subplot(211) - plt.plot(jax.numpy.asarray(all_lr1[0]), jax.numpy.asarray(all_lr1[1])) - plt.subplot(212) - plt.plot(jax.numpy.asarray(all_lr2[0]), jax.numpy.asarray(all_lr2[1])) - plt.show() - plt.close() - - -class TestCosineAnnealingWarmRestarts(unittest.TestCase): - def test1(self): - max_epoch = 50 - iters = 200 - sch = scheduler.CosineAnnealingWarmRestarts(0.1, - iters, - T_0=5, - T_mult=1, - last_call=-1) - all_lr1 = [] - all_lr2 = [] - for epoch in range(max_epoch): - for batch in range(iters): - all_lr1.append(sch()) - sch.step_call() - all_lr2.append(sch()) - sch.step_epoch() - plt.subplot(211) - plt.plot(jax.numpy.asarray(all_lr1)) - plt.subplot(212) - plt.plot(jax.numpy.asarray(all_lr2)) - plt.show() - plt.close() +# class TestStepLR(parameterized.TestCase): +# +# @parameterized.named_parameters( +# {'testcase_name': f'last_epoch={last_epoch}', +# 'last_epoch': last_epoch} +# for last_epoch in [-1, 0, 5, 10] +# ) +# def test1(self, last_epoch): +# scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) +# scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) +# +# for i in range(1, 25): +# lr1 = scheduler1(i + last_epoch) +# lr2 = scheduler2() +# scheduler2.step_epoch() +# print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}') +# self.assertTrue(lr1 == lr2) +# +# +# class TestCosineAnnealingLR(unittest.TestCase): +# def test1(self): +# max_epoch = 50 +# iters = 200 +# sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) +# all_lr1 = [[], []] +# all_lr2 = [[], []] +# for epoch in range(max_epoch): +# for batch in range(iters): +# all_lr1[0].append(epoch + batch / iters) +# all_lr1[1].append(sch()) +# sch.step_epoch() +# all_lr2[0].append(epoch) +# all_lr2[1].append(sch()) +# sch.step_epoch() +# plt.subplot(211) +# plt.plot(jax.numpy.asarray(all_lr1[0]), jax.numpy.asarray(all_lr1[1])) +# plt.subplot(212) +# plt.plot(jax.numpy.asarray(all_lr2[0]), jax.numpy.asarray(all_lr2[1])) +# plt.show() +# plt.close() +# +# +# class TestCosineAnnealingWarmRestarts(unittest.TestCase): +# def test1(self): +# max_epoch = 50 +# iters = 200 +# sch = scheduler.CosineAnnealingWarmRestarts(0.1, +# iters, +# T_0=5, +# T_mult=1, +# last_call=-1) +# all_lr1 = [] +# all_lr2 = [] +# for epoch in range(max_epoch): +# for batch in range(iters): +# all_lr1.append(sch()) +# sch.step_call() +# all_lr2.append(sch()) +# sch.step_epoch() +# plt.subplot(211) +# plt.plot(jax.numpy.asarray(all_lr1)) +# plt.subplot(212) +# plt.plot(jax.numpy.asarray(all_lr2)) +# plt.show() +# plt.close() diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index a1e8240ff..83ccdcb56 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -5,6 +5,7 @@ from .ndarray import * from .delayvars import * from .arrayoperation import * +from .arraycompatible import * # functions from .activations import * @@ -59,47 +60,3 @@ spike2_with_sigmoid_grad as spike2_with_sigmoid_grad, spike2_with_linear_grad as spike2_with_linear_grad, ) - -import brainpy._src.math.fft as bm_fft -fft.__dict__['fft'] = bm_fft.fft -fft.__dict__['fft2'] = bm_fft.fft2 -fft.__dict__['fftfreq'] = bm_fft.fftfreq -fft.__dict__['fftn'] = bm_fft.fftn -fft.__dict__['fftshift'] = bm_fft.fftshift -fft.__dict__['hfft'] = bm_fft.hfft -fft.__dict__['ifft'] = bm_fft.ifft -fft.__dict__['ifft2'] = bm_fft.ifft2 -fft.__dict__['ifftn'] = bm_fft.ifftn -fft.__dict__['ifftshift'] = bm_fft.ifftshift -fft.__dict__['ihfft'] = bm_fft.ihfft -fft.__dict__['irfft'] = bm_fft.irfft -fft.__dict__['irfft2'] = bm_fft.irfft2 -fft.__dict__['irfftn'] = bm_fft.irfftn -fft.__dict__['rfft'] = bm_fft.rfft -fft.__dict__['rfft2'] = bm_fft.rfft2 -fft.__dict__['rfftfreq'] = bm_fft.rfftfreq -fft.__dict__['rfftn'] = bm_fft.rfftn -del bm_fft - -import brainpy._src.math.linalg as bm_linalg -linalg.__dict__['cholesky'] = bm_linalg.cholesky -linalg.__dict__['cond'] = bm_linalg.cond -linalg.__dict__['det'] = bm_linalg.det -linalg.__dict__['eig'] = bm_linalg.eig -linalg.__dict__['eigh'] = bm_linalg.eigh -linalg.__dict__['eigvals'] = bm_linalg.eigvals -linalg.__dict__['eigvalsh'] = bm_linalg.eigvalsh -linalg.__dict__['inv'] = bm_linalg.inv -linalg.__dict__['svd'] = bm_linalg.svd -linalg.__dict__['lstsq'] = bm_linalg.lstsq -linalg.__dict__['matrix_power'] = bm_linalg.matrix_power -linalg.__dict__['matrix_rank'] = bm_linalg.matrix_rank -linalg.__dict__['norm'] = bm_linalg.norm -linalg.__dict__['pinv'] = bm_linalg.pinv -linalg.__dict__['qr'] = bm_linalg.qr -linalg.__dict__['solve'] = bm_linalg.solve -linalg.__dict__['slogdet'] = bm_linalg.slogdet -linalg.__dict__['tensorinv'] = bm_linalg.tensorinv -linalg.__dict__['tensorsolve'] = bm_linalg.tensorsolve -linalg.__dict__['multi_dot'] = bm_linalg.multi_dot -del bm_linalg diff --git a/examples/dynamics_analysis/2d_mean_field_QIF.py b/examples/dynamics_analysis/2d_mean_field_QIF.py index 0d3c17798..467bc6118 100644 --- a/examples/dynamics_analysis/2d_mean_field_QIF.py +++ b/examples/dynamics_analysis/2d_mean_field_QIF.py @@ -3,7 +3,7 @@ bp.math.enable_x64() -class MeanFieldQIF(bp.dyn.DynamicalSystem): +class MeanFieldQIF(bp.DynamicalSystem): """A mean-field model of a quadratic integrate-and-fire neuron population. References @@ -49,7 +49,7 @@ def update(self, tdi): # simulation -runner = bp.dyn.DSRunner(qif, inputs=['Iext', 1.], monitors=['r', 'v']) +runner = bp.DSRunner(qif, inputs=['Iext', 1.], monitors=['r', 'v']) runner.run(100.) bp.visualize.line_plot(runner.mon.ts, runner.mon.r, legend='r') bp.visualize.line_plot(runner.mon.ts, runner.mon.v, legend='v', show=True) diff --git a/examples/dynamics_analysis/2d_wilson_cowan_model.py b/examples/dynamics_analysis/2d_wilson_cowan_model.py index 6248f5940..4298feabb 100644 --- a/examples/dynamics_analysis/2d_wilson_cowan_model.py +++ b/examples/dynamics_analysis/2d_wilson_cowan_model.py @@ -4,7 +4,7 @@ bp.math.enable_x64() -class WilsonCowanModel(bp.dyn.DynamicalSystem): +class WilsonCowanModel(bp.DynamicalSystem): def __init__(self, num, method='exp_auto'): super(WilsonCowanModel, self).__init__() @@ -59,7 +59,7 @@ def update(self, tdi): model.i[:] = [0.0, 1.] # simulation -runner = bp.dyn.DSRunner(model, monitors=['e', 'i']) +runner = bp.DSRunner(model, monitors=['e', 'i']) runner.run(100) fig, gs = bp.visualize.get_figure(2, 1, 3, 8) diff --git a/examples/dynamics_analysis/highdim_RNN_Analysis.py b/examples/dynamics_analysis/highdim_RNN_Analysis.py index 614c522bc..3bf851b03 100644 --- a/examples/dynamics_analysis/highdim_RNN_Analysis.py +++ b/examples/dynamics_analysis/highdim_RNN_Analysis.py @@ -134,7 +134,7 @@ def data_generation(): # Visualize neural activity for in sample trials # --- # We will run the network for 100 sample trials, then visual the neural activity trajectories in a PCA space. -runner = bp.train.DSTrainer(net, monitors={'r': net.h}, progress_bar=False) +runner = bp.DSTrainer(net, monitors={'r': net.h}, progress_bar=False) env.reset(no_step=True) num_trial = 100 diff --git a/examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py b/examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py index 88625740d..c1a5f99f9 100644 --- a/examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py +++ b/examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py @@ -90,12 +90,15 @@ def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ): IAHP = bp.channels.IAHP_De1994(size, g_max=0.2, E=-90.) ICaN = bp.channels.ICaN_IS2008(size, g_max=0.2) ICaT = bp.channels.ICaT_HP1992(size, g_max=1.3) - Ca = bp.channels.CalciumDetailed(size, C_rest=5e-5, tau=100., d=0.5, + Ca = bp.channels.CalciumDetailed(size, + C_rest=5e-5, tau=100., d=0.5, IAHP=IAHP, ICaN=ICaN, ICaT=ICaT) - super(TRN, self).__init__(size, A=1.43e-4, - V_initializer=V_initializer, V_th=20., - IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca) + super(TRN, self).__init__(size, + A=1.43e-4, V_th=20., + V_initializer=V_initializer, + IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca + ) class MgBlock(bp.SynOut): @@ -287,12 +290,14 @@ def try_trn_neuron(): @bm.to_dynsys(child_objs=trn) def update(s, inp): trn.input += inp - trn.update(s, ) + trn.update(s) + return trn.input.value runner = bp.DSRunner(update, monitors={'V': trn.V}) - runner.run(inputs=inputs) + I = runner.run(inputs=inputs) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) + bp.visualize.line_plot(runner.mon.ts, I, show=True) def try_network(): diff --git a/examples/dynamics_simulation/Sanda_2021_hippo-tha-cortex-model.py b/examples/dynamics_simulation/Sanda_2021_hippo-tha-cortex-model.py index 05791151c..05bf911e6 100644 --- a/examples/dynamics_simulation/Sanda_2021_hippo-tha-cortex-model.py +++ b/examples/dynamics_simulation/Sanda_2021_hippo-tha-cortex-model.py @@ -3,7 +3,7 @@ import brainpy as bp -class HippoThaCortexModel(bp.dyn.Network): +class HippoThaCortexModel(bp.Network): def __init__(self, ): super(HippoThaCortexModel, self).__init__() diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py index 5040e2370..dc54c1c9a 100644 --- a/examples/dynamics_simulation/hh_model.py +++ b/examples/dynamics_simulation/hh_model.py @@ -1,17 +1,15 @@ # -*- coding: utf-8 -*- import brainpy as bp -from brainpy import dyn -from brainpy.dyn import channels -class HH(dyn.CondNeuGroup): +class HH(bp.CondNeuGroup): def __init__(self, size): super(HH, self).__init__(size) - self.INa = channels.INa_HH1952(size, ) - self.IK = channels.IK_HH1952(size, ) - self.IL = channels.IL(size, E=-54.387, g_max=0.03) + self.INa = bp.channels.INa_HH1952(size, ) + self.IK = bp.channels.IK_HH1952(size, ) + self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03) hh = HH(1) diff --git a/examples/dynamics_simulation/multi_scale_COBAHH.py b/examples/dynamics_simulation/multi_scale_COBAHH.py index b23340d9a..ddb070139 100644 --- a/examples/dynamics_simulation/multi_scale_COBAHH.py +++ b/examples/dynamics_simulation/multi_scale_COBAHH.py @@ -7,9 +7,9 @@ import brainpy as bp import brainpy.math as bm -from brainpy.dyn.channels import INa_TM1991, IL -from brainpy.dyn.synapses import Exponential -from brainpy.dyn.synouts import COBA +from brainpy.channels import INa_TM1991, IL +from brainpy.synapses import Exponential +from brainpy.synouts import COBA from brainpy.connect import FixedProb from jax import vmap import seaborn as sns diff --git a/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py b/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py index 8bcf31241..4f1e38227 100644 --- a/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py +++ b/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py @@ -36,7 +36,7 @@ regularization_f0 = reg_rate / 1000. # mean target network firing frequency -class EligSNN(bp.dyn.Network): +class EligSNN(bp.Network): def __init__(self, num_in, num_rec, num_out, eprop=True, tau_a=2e3, tau_v=2e1): super(EligSNN, self).__init__() @@ -170,7 +170,7 @@ def loss_fun(predicts, targets): # Training -trainer = bp.train.BPTT( +trainer = bp.BPTT( net, loss_fun, loss_has_aux=True, optimizer=bp.optimizers.Adam(lr=0.01), @@ -182,7 +182,7 @@ def loss_fun(predicts, targets): # visualization dataset, _ = next(get_data(20, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0)()) -runner = bp.train.DSTrainer(net, monitors={'spike': net.r.spike}) +runner = bp.DSTrainer(net, monitors={'spike': net.r.spike}) outs = runner.predict(dataset, reset_state=True) for i in range(10): diff --git a/examples/dynamics_training/Song_2016_EI_RNN.py b/examples/dynamics_training/Song_2016_EI_RNN.py index 97c3381f8..ff693b84c 100644 --- a/examples/dynamics_training/Song_2016_EI_RNN.py +++ b/examples/dynamics_training/Song_2016_EI_RNN.py @@ -71,7 +71,7 @@ # Here we define a E-I recurrent network, in particular, no self-connections are allowed. # %% -class RNN(bp.dyn.DynamicalSystem): +class RNN(bp.DynamicalSystem): r"""E-I RNN. The RNNs are described by the equations diff --git a/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py b/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py index 6862bb1e7..310c144d4 100644 --- a/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py +++ b/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py @@ -132,18 +132,18 @@ def plot_params(net): plt.figure(figsize=(16, 10)) plt.subplot(221) - plt.imshow((net.w_rr + net.w_ro @ net.w_or).numpy(), interpolation=None) + plt.imshow(bm.as_numpy(net.w_rr + net.w_ro @ net.w_or), interpolation=None) plt.colorbar() plt.title('Effective matrix - W_rr + W_ro * W_or') plt.subplot(222) - plt.imshow(net.w_ro.numpy(), interpolation=None) + plt.imshow(bm.as_numpy(net.w_ro), interpolation=None) plt.colorbar() plt.title('Readout weights - W_ro') x_circ = np.linspace(-1, 1, 1000) y_circ = np.sqrt(1 - x_circ ** 2) - evals, _ = np.linalg.eig(net.w_rr.numpy()) + evals, _ = np.linalg.eig(bm.as_numpy(net.w_rr)) plt.subplot(223) plt.plot(np.real(evals), np.imag(evals), 'o') plt.plot(x_circ, y_circ, 'k') @@ -151,7 +151,7 @@ def plot_params(net): plt.axis('equal') plt.title('Eigenvalues of W_rr') - evals, _ = np.linalg.eig((net.w_rr + net.w_ro @ net.w_or).numpy()) + evals, _ = np.linalg.eig(bm.as_numpy((net.w_rr + net.w_ro @ net.w_or))) plt.subplot(224) plt.plot(np.real(evals), np.imag(evals), 'o', color='orange') plt.plot(x_circ, y_circ, 'k') diff --git a/examples/dynamics_training/echo_state_network.py b/examples/dynamics_training/echo_state_network.py index 387daa127..bdb8c6ca7 100644 --- a/examples/dynamics_training/echo_state_network.py +++ b/examples/dynamics_training/echo_state_network.py @@ -9,17 +9,20 @@ class ESN(bp.DynamicalSystem): def __init__(self, num_in, num_hidden, num_out): super(ESN, self).__init__() - self.r = bp.layers.Reservoir(num_in, num_hidden, + self.r = bp.layers.Reservoir(num_in, + num_hidden, Win_initializer=bp.init.Uniform(-0.1, 0.1), Wrec_initializer=bp.init.Normal(scale=0.1), in_connectivity=0.02, rec_connectivity=0.02, comp_type='dense') - self.o = bp.layers.Dense(num_hidden, num_out, W_initializer=bp.init.Normal(), + self.o = bp.layers.Dense(num_hidden, + num_out, + W_initializer=bp.init.Normal(), mode=bm.training_mode) - def update(self, sha, x): - return self.o(sha, self.r(sha, x)) + def update(self, s, x): + return self.o(s, self.r(s, x)) class NGRC(bp.DynamicalSystem): @@ -31,8 +34,8 @@ def __init__(self, num_in, num_out): W_initializer=bp.init.Normal(0.1), mode=bm.training_mode) - def update(self, shared_args, x): - return self.o(shared_args, self.r(shared_args, x)) + def update(self, s, x): + return self.o(s, self.r(s, x)) def train_esn_with_ridge(num_in=100, num_out=30): diff --git a/examples/dynamics_training/reservoir-mnist.py b/examples/dynamics_training/reservoir-mnist.py new file mode 100644 index 000000000..63cc289f4 --- /dev/null +++ b/examples/dynamics_training/reservoir-mnist.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + + +import brainpy_datasets as bd +import jax.numpy as jnp +from tqdm import tqdm + +import brainpy as bp +import brainpy.math as bm + +traindata = bd.vision.MNIST(root='D:/data', split='train') +testdata = bd.vision.MNIST(root='D:/data', split='test') + + +def offline_train(num_hidden=2000, num_in=28, num_out=10): + # training + x_train = jnp.asarray(traindata.data / 255, dtype=bm.float_) + x_train = x_train.reshape(-1, x_train.shape[-1]) + y_train = bm.one_hot(jnp.repeat(traindata.targets, x_train.shape[1]), 10, dtype=bm.float_) + + reservoir = bp.layers.Reservoir( + num_in, + num_hidden, + Win_initializer=bp.init.Uniform(-0.6, 0.6), + Wrec_initializer=bp.init.Normal(scale=0.1), + in_connectivity=0.1, + rec_connectivity=0.9, + spectral_radius=1.3, + leaky_rate=0.2, + comp_type='dense', + mode=bm.batching_mode + ) + reservoir.reset_state(1) + outs = bm.for_loop(bm.Partial(reservoir, {}), x_train) + weight = bp.algorithms.RidgeRegression(alpha=1e-8)(y_train, outs) + + # predicting + reservoir.reset_state(1) + esn = bp.Sequential( + reservoir, + bp.layers.Dense(num_hidden, + num_out, + W_initializer=weight, + b_initializer=None, + mode=bm.training_mode) + ) + + preds = bm.for_loop(lambda x: jnp.argmax(esn({}, x), axis=-1), + x_train, + child_objs=esn) + accuracy = jnp.mean(preds == jnp.repeat(traindata.targets, x_train.shape[1])) + print(accuracy) + + +def force_online_train(num_hidden=2000, num_in=28, num_out=10, train_stage='final_step'): + assert train_stage in ['final_step', 'all_steps'] + + x_train = jnp.asarray(traindata.data / 255, dtype=bm.float_) + x_test = jnp.asarray(testdata.data / 255, dtype=bm.float_) + y_train = bm.one_hot(traindata.targets, 10, dtype=bm.float_) + + reservoir = bp.layers.Reservoir( + num_in, + num_hidden, + Win_initializer=bp.init.Uniform(-0.6, 0.6), + Wrec_initializer=bp.init.Normal(scale=1.3 / jnp.sqrt(num_hidden * 0.9)), + in_connectivity=0.1, + rec_connectivity=0.9, + comp_type='dense', + mode=bm.batching_mode + ) + readout = bp.layers.Dense(num_hidden, num_out, b_initializer=None, mode=bm.training_mode) + rls = bp.algorithms.RLS() + rls.register_target(num_hidden) + + @bm.jit + @bm.to_object(child_objs=(reservoir, readout, rls)) + def train_step(xs, y): + reservoir.reset_state(xs.shape[0]) + if train_stage == 'final_step': + for x in xs.transpose(1, 0, 2): + o = reservoir(x) + pred = readout(o) + dw = rls(y, o, pred) + readout.W += dw + elif train_stage == 'all_steps': + for x in xs.transpose(1, 0, 2): + o = reservoir(x) + pred = readout(o) + dw = rls(y, o, pred) + readout.W += dw + else: + raise ValueError + + @bm.jit + @bm.to_object(child_objs=(reservoir, readout)) + def predict(xs): + reservoir.reset_state(xs.shape[0]) + for x in xs.transpose(1, 0, 2): + o = reservoir(x) + y = readout(o) + return jnp.argmax(y, axis=1) + + # training + batch_size = 1 + for i in tqdm(range(0, x_train.shape[0], batch_size), desc='Training'): + train_step(x_train[i: i + batch_size], y_train[i: i + batch_size]) + + # verifying + preds = [] + batch_size = 500 + for i in tqdm(range(0, x_train.shape[0], batch_size), desc='Verifying'): + preds.append(predict(x_train[i: i + batch_size])) + preds = jnp.concatenate(preds) + acc = jnp.mean(preds == jnp.asarray(traindata.targets, dtype=bm.int_)) + print('Train accuracy', acc) + + # prediction + preds = [] + for i in tqdm(range(0, x_test.shape[0], batch_size), desc='Predicting'): + preds.append(predict(x_test[i: i + batch_size])) + preds = jnp.concatenate(preds) + acc = jnp.mean(preds == jnp.asarray(testdata.targets, dtype=bm.int_)) + print('Test accuracy', acc) + + +if __name__ == '__main__': + # offline_train() + force_online_train(num_hidden=2000) diff --git a/examples/training_snn_models/SurrogateGrad_lif-ANN-style.py b/examples/training_snn_models/SurrogateGrad_lif-ANN-style.py index d3c2d47b1..fbffdb529 100644 --- a/examples/training_snn_models/SurrogateGrad_lif-ANN-style.py +++ b/examples/training_snn_models/SurrogateGrad_lif-ANN-style.py @@ -68,9 +68,9 @@ def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): def print_classification_accuracy(output, target): """ Dirty little helper function to compute classification accuracy. """ - m = jnp.max(output, axis=1) # max over time - am = jnp.argmax(m, axis=1) # argmax over output units - acc = jnp.mean(target == am) # compare to labels + m = bm.max(output, axis=1) # max over time + am = bm.argmax(m, axis=1) # argmax over output units + acc = bm.mean(target == am) # compare to labels print("Accuracy %.3f" % acc) @@ -83,7 +83,7 @@ def print_classification_accuracy(output, target): mask = bm.random.rand(num_sample, num_step, net.num_in) x_data = bm.zeros((num_sample, num_step, net.num_in)) x_data[mask < freq * bm.get_dt() / 1000.] = 1.0 -y_data = jnp.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_) +y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_) rng = bm.random.RandomState() # Before training diff --git a/examples/training_snn_models/SurrogateGrad_lif.py b/examples/training_snn_models/SurrogateGrad_lif.py index 1f65f260d..c4af45567 100644 --- a/examples/training_snn_models/SurrogateGrad_lif.py +++ b/examples/training_snn_models/SurrogateGrad_lif.py @@ -70,9 +70,9 @@ def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): def print_classification_accuracy(output, target): """ Dirty little helper function to compute classification accuracy. """ - m = jnp.max(output, axis=1) # max over time - am = jnp.argmax(m, axis=1) # argmax over output units - acc = jnp.mean(target == am) # compare to labels + m = bm.max(output, axis=1) # max over time + am = bm.argmax(m, axis=1) # argmax over output units + acc = bm.mean(target == am) # compare to labels print("Accuracy %.3f" % acc) @@ -85,7 +85,7 @@ def print_classification_accuracy(output, target): mask = bm.random.rand(num_sample, num_step, net.num_in) x_data = bm.zeros((num_sample, num_step, net.num_in)) x_data[mask < freq * bm.get_dt() / 1000.] = 1.0 -y_data = jnp.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_) +y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_) rng = bm.random.RandomState(123) @@ -122,10 +122,11 @@ def train(_): # train the network net.reset_state(num_sample) train_losses = [] -for i in range(0, 3000, 400): +b = 100 +for i in range(0, 3000, b): t0 = time.time() - ls = bm.for_loop(train, operands=bm.arange(i, i + 400, 1)) - print(f'Train {i + 400} epoch, loss = {jnp.mean(ls):.4f}, used time {time.time() - t0:.4f} s') + ls = bm.for_loop(train, operands=bm.arange(i, i + b, 1)) + print(f'Train {i + b} epoch, loss = {jnp.mean(ls):.4f}, used time {time.time() - t0:.4f} s') train_losses.append(ls) # visualize the training losses diff --git a/examples/training_snn_models/spikebased_bp_for_cifar10.py b/examples/training_snn_models/spikebased_bp_for_cifar10.py index 2f99c37a3..0c279641f 100644 --- a/examples/training_snn_models/spikebased_bp_for_cifar10.py +++ b/examples/training_snn_models/spikebased_bp_for_cifar10.py @@ -17,6 +17,7 @@ sys.path.append('../../') os.environ['CUDA_VISIBLE_DEVICES'] = '1' + import tqdm import argparse import time @@ -35,7 +36,7 @@ # bm.set_platform('gpu') parser = argparse.ArgumentParser(description='CIFAR10 Training') -parser.add_argument('-data', default='./data', type=str, help='path to dataset') +parser.add_argument('-data', default='D:/data', type=str, help='path to dataset') parser.add_argument('-b', default=64, type=int, metavar='N') parser.add_argument('-T', default=100, type=int, help='Simulation timesteps') parser.add_argument('-lr', default=0.0025, type=float, help='initial learning rate') From bf06081ad935e0ed6aed04d804eb86e400e65577 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 15 Jan 2023 10:56:32 +0800 Subject: [PATCH 12/13] fix bugs --- brainpy/_src/dyn/layers/pooling.py | 6 + brainpy/_src/dyn/neurons/reduced_models.py | 2 +- brainpy/_src/initialize/random_inits.py | 2 +- brainpy/_src/losses/comparison.py | 50 +- brainpy/_src/math/__init__.py | 4 +- brainpy/_src/math/_utils.py | 11 +- .../_src/math/object_transform/autograd.py | 20 +- brainpy/_src/train/offline.py | 4 - examples/training_snn_models/OTTT-SNN.py | 678 ------------------ .../SurrogateGrad_lif-ANN-style.py | 139 ---- .../training_snn_models/SurrogateGrad_lif.py | 143 ---- .../SurrogateGrad_lif_fashion_mnist.py | 244 ------- .../fashion_mnist_conv_lif.py | 261 ------- .../training_snn_models/mnist_lif_readout.py | 155 ---- examples/training_snn_models/readme.md | 3 + 15 files changed, 59 insertions(+), 1663 deletions(-) delete mode 100644 examples/training_snn_models/OTTT-SNN.py delete mode 100644 examples/training_snn_models/SurrogateGrad_lif-ANN-style.py delete mode 100644 examples/training_snn_models/SurrogateGrad_lif.py delete mode 100644 examples/training_snn_models/SurrogateGrad_lif_fashion_mnist.py delete mode 100644 examples/training_snn_models/fashion_mnist_conv_lif.py delete mode 100644 examples/training_snn_models/mnist_lif_readout.py create mode 100644 examples/training_snn_models/readme.md diff --git a/brainpy/_src/dyn/layers/pooling.py b/brainpy/_src/dyn/layers/pooling.py index 08490a40a..0a1754544 100644 --- a/brainpy/_src/dyn/layers/pooling.py +++ b/brainpy/_src/dyn/layers/pooling.py @@ -82,6 +82,7 @@ def __init__( def update(self, *args): x = args[0] if len(args) == 1 else args[1] + x = bm.as_jax(x) window_shape = self._infer_shape(x.ndim, self.kernel_size) stride = self._infer_shape(x.ndim, self.stride) padding = (self.padding @@ -258,6 +259,7 @@ def __init__( def update(self, *args): x = args[0] if len(args) == 1 else args[1] + x = bm.as_jax(x) window_shape = self._infer_shape(x.ndim, self.kernel_size) strides = self._infer_shape(x.ndim, self.stride) padding = (self.padding if isinstance(self.padding, str) else @@ -356,6 +358,7 @@ def __init__( def update(self, *args): x = args[0] if len(args) == 1 else args[1] + x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) if x.ndim < x_dim: raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') @@ -521,6 +524,7 @@ def __init__( class _AvgPoolNd(_MaxPoolNd): def update(self, *args): x = args[0] if len(args) == 1 else args[1] + x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) if x.ndim < x_dim: raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') @@ -694,6 +698,7 @@ def _adaptive_pool1d(x, target_size: int, operation: Callable): Returns: A JAX array of shape `(target_size, )`. """ + x = bm.as_jax(x) size = jnp.size(x) num_head_arrays = size % target_size num_block = size // target_size @@ -767,6 +772,7 @@ def update(self, *args): or `(..., dim_1, dim_2)`. """ x = args[0] if len(args) == 1 else args[1] + x = bm.as_jax(x) # channel axis channel_axis = self.channel_axis diff --git a/brainpy/_src/dyn/neurons/reduced_models.py b/brainpy/_src/dyn/neurons/reduced_models.py index 9ed8ecb0f..35cec630c 100644 --- a/brainpy/_src/dyn/neurons/reduced_models.py +++ b/brainpy/_src/dyn/neurons/reduced_models.py @@ -580,7 +580,7 @@ def __init__( b: Union[float, ArrayType, Initializer, Callable] = 1., tau: Union[float, ArrayType, Initializer, Callable] = 10., tau_w: Union[float, ArrayType, Initializer, Callable] = 30., - tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = 30., + tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = None, R: Union[float, ArrayType, Initializer, Callable] = 1., V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), diff --git a/brainpy/_src/initialize/random_inits.py b/brainpy/_src/initialize/random_inits.py index a51a0fbed..99419eaa6 100644 --- a/brainpy/_src/initialize/random_inits.py +++ b/brainpy/_src/initialize/random_inits.py @@ -329,7 +329,7 @@ def __call__(self, shape, dtype=None): n_cols = np.prod(shape) // n_rows matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows) norm_dst = self.rng.normal(size=matrix_shape) - q_mat, r_mat = jnp.linalg.qr(norm_dst) + q_mat, r_mat = jnp.linalg.qr(bm.as_jax(norm_dst)) # Enforce Q is uniformly distributed q_mat *= jnp.sign(jnp.diag(r_mat)) if n_rows < n_cols: diff --git a/brainpy/_src/losses/comparison.py b/brainpy/_src/losses/comparison.py index 40bc6bcb7..8a2a66c91 100644 --- a/brainpy/_src/losses/comparison.py +++ b/brainpy/_src/losses/comparison.py @@ -11,9 +11,9 @@ from typing import Tuple import jax.numpy as jnp +from jax.lax import scan from jax.scipy.special import logsumexp from jax.tree_util import tree_map -from jax.lax import scan import brainpy.math as bm from brainpy.types import ArrayType @@ -106,7 +106,7 @@ def _cel(_pred, _tar): loss = logsumexp(bm.as_jax(_pred), axis=-1) - (_pred * _tar).sum(axis=-1) return _reduce(outputs=loss, reduction=reduction) - r = tree_map(_cel, predicts, targets, is_leaf=lambda x: isinstance(x, bm.Array)) + r = tree_map(_cel, predicts, targets, is_leaf=_is_leaf) return _multi_return(r) @@ -128,7 +128,7 @@ def crs(_prd, _tar): logits = jnp.take_along_axis(_prd, _tar, -1).squeeze(-1) return logsumexp(bm.as_jax(_prd), axis=-1) - logits - r = tree_map(crs, predicts, targets, is_leaf=lambda x: isinstance(x, bm.Array)) + r = tree_map(crs, predicts, targets, is_leaf=_is_leaf) return _multi_return(r) @@ -142,9 +142,14 @@ def cross_entropy_sigmoid(predicts, targets): Returns: (batch, ...) tensor of the cross-entropies for each entry. """ - r = tree_map(lambda pred, tar: jnp.maximum(pred, 0) - pred * tar + jnp.log(1 + jnp.exp(-jnp.abs(pred))), - predicts, - targets) + r = tree_map( + lambda pred, tar: bm.as_jax( + bm.maximum(pred, 0) - pred * tar + bm.log(1 + bm.exp(-bm.abs(pred))) + ), + predicts, + targets, + is_leaf=_is_leaf + ) return _multi_return(r) @@ -201,7 +206,7 @@ def loss(pred, tar): norm = jnp.linalg.norm(bm.as_jax(diff), ord=1, axis=1, keepdims=False) return _reduce(outputs=norm, reduction=reduction) - r = tree_map(loss, logits, targets, is_leaf=lambda x: isinstance(x, bm.Array)) + r = tree_map(loss, logits, targets, is_leaf=_is_leaf) return _multi_return(r) @@ -228,7 +233,9 @@ def l2_loss(predicts, targets): ---------- .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. """ - r = tree_map(lambda pred, tar: 0.5 * (pred - tar) ** 2, predicts, targets) + r = tree_map(lambda pred, tar: 0.5 * (pred - tar) ** 2, + predicts, + targets) return _multi_return(r) @@ -243,7 +250,10 @@ def mean_absolute_error(x, y, axis=None, reduction: str = 'mean'): Returns: tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. """ - r = tree_map(lambda a, b: _reduce(jnp.abs(a - b), reduction=reduction, axis=axis), x, y) + r = tree_map(lambda a, b: _reduce(bm.abs(a - b), reduction=reduction, axis=axis), + x, + y, + is_leaf=_is_leaf) return _multi_return(r) @@ -260,7 +270,8 @@ def mean_squared_error(predicts, targets, axis=None, reduction: str = 'mean'): """ r = tree_map(lambda a, b: _reduce((a - b) ** 2, reduction, axis=axis), predicts, - targets) + targets, + is_leaf=_is_leaf) return _multi_return(r) @@ -276,7 +287,9 @@ def mean_squared_log_error(predicts, targets, axis=None, reduction: str = 'mean' tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. """ r = tree_map(lambda a, b: _reduce((jnp.log1p(a) - jnp.log1p(b)) ** 2, reduction, axis=axis), - predicts, targets, is_leaf=_is_leaf) + predicts, + targets, + is_leaf=_is_leaf) return _multi_return(r) @@ -309,12 +322,13 @@ def huber_loss(predicts, targets, delta: float = 1.0): def _loss(y_predict, y_target): # 0.5 * err^2 if |err| <= d # 0.5 * d^2 + d * (|err| - d) if |err| > d - diff = jnp.abs(y_predict - y_target) - return jnp.where(diff > delta, - delta * (diff - .5 * delta), - 0.5 * diff ** 2) + diff = bm.abs(y_predict - y_target) + r = bm.where(diff > delta, + delta * (diff - .5 * delta), + 0.5 * diff ** 2) + return bm.as_jax(r) - r = tree_map(_loss, targets, predicts) + r = tree_map(_loss, targets, predicts, is_leaf=_is_leaf) return _multi_return(r) @@ -382,7 +396,7 @@ def loss(pred, tar): log_not_p = bm.log_sigmoid(-pred) return -tar * log_p - (1. - tar) * log_not_p - r = tree_map(loss, logits, labels, is_leaf=lambda x: isinstance(x, bm.Array)) + r = tree_map(loss, logits, labels, is_leaf=_is_leaf) return _multi_return(r) @@ -433,7 +447,7 @@ def loss(pred, tar): errors = bm.as_jax(pred - tar) return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) - r = tree_map(loss, predicts, targets, is_leaf=lambda x: isinstance(x, bm.Array)) + r = tree_map(loss, predicts, targets, is_leaf=_is_leaf) return _multi_return(r) diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 6a9a8eb53..dbe499714 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -41,9 +41,9 @@ # high-level numpy operations from .arraycreation import * from .arrayinterporate import * -# from .arraycompatible import * +from .arraycompatible import * from .others import * -from . import random +from . import random, linalg, fft # operators from .operators import * diff --git a/brainpy/_src/math/_utils.py b/brainpy/_src/math/_utils.py index d6d8e22d1..7a4950a97 100644 --- a/brainpy/_src/math/_utils.py +++ b/brainpy/_src/math/_utils.py @@ -4,7 +4,6 @@ from typing import Callable import jax -import numpy as np from jax.tree_util import tree_map from .ndarray import Array @@ -18,8 +17,8 @@ def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj -def _return(x): - return Array(x) if _return_bp_array else x +def _return(a): + return Array(a) if isinstance(a, jax.Array) and a.ndim > 1 else a _return_bp_array = True @@ -50,10 +49,6 @@ def wrap(op): return wrap -def _as_brainpy_array(a): - return Array(a) if isinstance(a, (np.ndarray, jax.Array)) else a - - def _is_leaf(a): return isinstance(a, Array) @@ -65,7 +60,7 @@ def new_fun(*args, **kwargs): if len(kwargs): kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) r = fun(*args, **kwargs) - return tree_map(_as_brainpy_array, r) if _return_bp_array else r + return tree_map(_return, r) if _return_bp_array else r new_fun.__doc__ = getattr(fun, "__doc__", None) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 364e45948..911f20cdb 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -18,9 +18,9 @@ from jax.util import safe_map from brainpy import errors, tools, check -from brainpy._src.math.object_transform.base import BrainPyObject -from brainpy._src.math.object_transform.abstract import ObjectTransform from brainpy._src.math.ndarray import Array, Variable, add_context, del_context +from brainpy._src.math.object_transform.abstract import ObjectTransform +from brainpy._src.math.object_transform.base import BrainPyObject __all__ = [ 'grad', # gradient of scalar function @@ -75,7 +75,7 @@ def __init__( _argnums = tuple(a + 2 for a in _argnums) if len(self._grad_vars) > 0: _argnums = (0,) + _argnums - self.nonvar_argnums = argnums + self._nonvar_argnums = argnums self.return_value = return_value self.has_aux = has_aux @@ -134,10 +134,12 @@ def __call__(self, *args, **kwargs): # old_dyn_vs = [v.value for v in self._dyn_vars] try: add_context(self.name) - grads, (outputs, new_grad_vs, new_dyn_vs) = self._call([v.value for v in self._grad_vars], - [v.value for v in self._dyn_vars], - *args, - **kwargs) + grads, (outputs, new_grad_vs, new_dyn_vs) = self._call( + [v.value for v in self._grad_vars], + [v.value for v in self._dyn_vars], + *args, + **kwargs + ) del_context(self._name) except UnexpectedTracerError as e: del_context(self._name) @@ -155,11 +157,11 @@ def __call__(self, *args, **kwargs): # check returned grads if len(self._grad_vars) > 0: - if self.nonvar_argnums is None: + if self._nonvar_argnums is None: grads = self._grad_tree.unflatten(grads) else: var_grads = self._grad_tree.unflatten(grads[0]) - arg_grads = grads[1] if isinstance(self.nonvar_argnums, int) else grads[1:] + arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:] grads = (var_grads, arg_grads) # check returned value diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index a640407a0..994eae584 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -84,10 +84,6 @@ def __init__( for node in self.train_nodes: node.offline_fit_by = fit_method - # initialize the fitting method - for node in self.train_nodes: - node.offline_init() - def __repr__(self): name = self.__class__.__name__ prefix = ' ' * len(name) diff --git a/examples/training_snn_models/OTTT-SNN.py b/examples/training_snn_models/OTTT-SNN.py deleted file mode 100644 index c8a021574..000000000 --- a/examples/training_snn_models/OTTT-SNN.py +++ /dev/null @@ -1,678 +0,0 @@ -# -*- coding: utf-8 -*- - - -# python OTTT-SNN.py -data_dir ./data -dataset cifar10 -out_dir ./log -gpu-id 0 -online_update - -import argparse -import functools -import os -import sys -import time - -sys.path.append('../../') - -import jax -import jax.numpy as jnp -import numpy as np -import torch.utils.data as data -import torchvision.datasets as datasets -import torchvision.transforms as transforms -import tqdm -from torchtoolbox.transform import Cutout - -import brainpy as bp -import brainpy.math as bm - -bm.set_environment(bm.TrainingMode()) -conv_init = bp.init.KaimingNormal(mode='fan_out', scale=jnp.sqrt(2), in_axis=0) -dense_init = bp.init.Normal(0, 0.01) - - -@jax.custom_vjp -def replace(spike, rate): - return rate - - -def replace_fwd(spike, rate): - return replace(spike, rate), () - - -def replace_bwd(res, g): - return g, g - - -replace.defvjp(replace_fwd, replace_bwd) - - -class ScaledWSConv2d(bp.layers.Conv2d): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - groups=1, - b_initializer=bp.init.ZeroInit(), - gain=True, - eps=1e-4): - super(ScaledWSConv2d, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=groups, - w_initializer=conv_init, - b_initializer=b_initializer) - bp.check.is_subclass(self.mode, bm.TrainingMode) - if gain: - self.gain = bm.TrainVar(jnp.ones([1, 1, 1, self.out_channels])) - else: - self.gain = None - self.eps = eps - - def update(self, *args): - assert self.mask is None - x = args[0] if len(args) == 1 else args[1] - self._check_input_dim(x) - w = self.w.value - fan_in = np.prod(w.shape[:-1]) - mean = jnp.mean(w, axis=[0, 1, 2], keepdims=True) - var = jnp.var(w, axis=[0,1, 2], keepdims=True) - w = (w - mean) / ((var * fan_in + self.eps) ** 0.5) - if self.gain is not None: - w = w * self.gain - y = jax.lax.conv_general_dilated(lhs=bm.as_jax(x), - rhs=bm.as_jax(w), - window_strides=self.stride, - padding=self.padding, - lhs_dilation=self.lhs_dilation, - rhs_dilation=self.rhs_dilation, - feature_group_count=self.groups, - dimension_numbers=self.dimension_numbers) - return y if self.b is None else (y + self.b.value) - - -class ScaledWSLinear(bp.layers.Dense): - def __init__(self, - in_features, - out_features, - b_initializer=bp.init.ZeroInit(), - gain=True, - eps=1e-4): - super(ScaledWSLinear, self).__init__(num_in=in_features, - num_out=out_features, - W_initializer=dense_init, - b_initializer=b_initializer) - bp.check.is_subclass(self.mode, bm.TrainingMode) - if gain: - self.gain = bm.TrainVar(jnp.ones(1, self.num_out)) - else: - self.gain = None - self.eps = eps - - def update(self, s, x): - fan_in = self.W.shape[0] - mean = jnp.mean(self.W.value, axis=0, keepdims=True) - var = jnp.var(self.W.value, axis=0, keepdims=True) - weight = (self.W.value - mean) / ((var * fan_in + self.eps) ** 0.5) - if self.gain is not None: - weight = weight * self.gain - if self.b is not None: - return x @ weight + self.b - else: - return x @ weight - - -class Scale(bp.layers.Layer): - def __init__(self, scale: float): - super(Scale, self).__init__() - self.scale = scale - - def update(self, s, x): - return x * self.scale - - -class WrappedSNNOp(bp.layers.Layer): - def __init__(self, op): - super(WrappedSNNOp, self).__init__() - self.op = op - - def update(self, s, x): - if s['require_wrap']: - spike, rate = jnp.split(x, 2, axis=0) - out = jax.lax.stop_gradient(self.op(s, spike)) - in_for_grad = replace(spike, rate) - out_for_grad = self.op(s, in_for_grad) - output = replace(out_for_grad, out) - return output - else: - return self.op(s, x) - - -class OnlineSpikingVGG(bp.DynamicalSystem): - cfg = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512] - - def __init__( - self, - neuron_model, - weight_standardization=True, - num_classes=1000, - neuron_pars: dict = None, - light_classifier=True, - batch_norm=False, - grad_with_rate: bool = False, - fc_hw: int = 3, - c_in: int = 3 - ): - super(OnlineSpikingVGG, self).__init__() - - if neuron_pars is None: - neuron_pars = dict() - self.neuron_pars = neuron_pars - self.neuron_model = neuron_model - self.grad_with_rate = grad_with_rate - self.fc_hw = fc_hw - - neuron_sizes = [(32, 32, 64), - (32, 32, 128), - (16, 16, 256), - (16, 16, 256), - (8, 8, 512), - (8, 8, 512), - (4, 4, 512), - (4, 4, 512), ] - neuron_i = 0 - layers = [] - first_conv = True - in_channels = c_in - for v in self.cfg: - if v == 'M': - layers.append(bp.layers.AvgPool2d(kernel_size=2, stride=2)) - else: - if weight_standardization: - conv2d = ScaledWSConv2d(in_channels, v, kernel_size=3, padding=1, stride=1) - if first_conv: - first_conv = False - elif self.grad_with_rate: - conv2d = WrappedSNNOp(conv2d) - layers += [conv2d, - self.neuron_model(neuron_sizes[neuron_i], **self.neuron_pars), - Scale(2.74)] - else: - conv2d = bp.layers.Conv2d(in_channels, v, kernel_size=3, padding=1, stride=1, w_initializer=conv_init, ) - if first_conv: - first_conv = False - elif self.grad_with_rate: - conv2d = WrappedSNNOp(conv2d) - if batch_norm: - layers += [conv2d, - bp.layers.BatchNorm2d(v, momentum=0.9), - self.neuron_model(neuron_sizes[neuron_i], **self.neuron_pars)] - else: - layers += [conv2d, - self.neuron_model(neuron_sizes[neuron_i], **self.neuron_pars), - Scale(2.74)] - neuron_i += 1 - in_channels = v - self.features = bp.Sequential(*layers) - - if light_classifier: - self.avgpool = bp.layers.AdaptiveAvgPool2d((self.fc_hw, self.fc_hw)) - if self.grad_with_rate: - self.classifier = WrappedSNNOp(bp.layers.Dense(512 * self.fc_hw * self.fc_hw, - num_classes, - W_initializer=dense_init)) - else: - self.classifier = bp.layers.Dense(512 * self.fc_hw * self.fc_hw, - num_classes, - W_initializer=dense_init) - else: - self.avgpool = bp.layers.AdaptiveAvgPool2d((7, 7)) - if self.grad_with_rate: - self.classifier = bp.Sequential( - WrappedSNNOp(ScaledWSLinear(512 * 7 * 7, 4096)), - neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0), - Scale(2.74), - bp.layers.Dropout(0.5), - WrappedSNNOp(ScaledWSLinear(4096, 4096)), - neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0), - Scale(2.74), - bp.layers.Dropout(0.5), - WrappedSNNOp(bp.layers.Dense(4096, num_classes, W_initializer=dense_init)), - ) - else: - self.classifier = bp.Sequential( - ScaledWSLinear(512 * 7 * 7, 4096), - neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0), - Scale(2.74), - bp.layers.Dropout(0.5), - ScaledWSLinear(4096, 4096), - neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0), - Scale(2.74), - bp.layers.Dropout(0.5), - bp.layers.Dense(4096, num_classes, W_initializer=dense_init), - ) - - def update(self, s, x): - if self.grad_with_rate and s['fit']: - s['require_wrap'] = True - s['output_type'] = 'spike_rate' - x = self.features(s, x) - x = self.avgpool(s, x) - x = bm.flatten(x, 1) - x = self.classifier(s, x) - else: - s['require_wrap'] = False - s['output_type'] = 'spike' - x = self.features(s, x) - x = self.avgpool(s, x) - x = bm.flatten(x, 1) - x = self.classifier(s, x) - return x - - -class OnlineIFNode(bp.DynamicalSystem): - def __init__( - self, - size, - v_threshold: float = 1., - v_reset: float = None, - f_surrogate=bm.surrogate.sigmoid, - detach_reset: bool = True, - track_rate: bool = True, - neuron_dropout: float = 0.0, - name: str = None, - mode: bm.Mode = None - ): - super().__init__(name=name, mode=mode) - bp.check.is_subclass(self.mode, bm.TrainingMode) - - self.size = bp.check.is_sequence(size, elem_type=int) - self.f_surrogate = bp.check.is_callable(f_surrogate) - self.detach_reset = detach_reset - self.v_reset = v_reset - self.v_threshold = v_threshold - self.track_rate = track_rate - self.dropout = neuron_dropout - - if self.dropout > 0.0: - self.rng = bm.random.default_rng() - self.reset_state(1) - - def reset_state(self, batch_size=1): - self.v = bp.init.variable_(bm.zeros, self.size, batch_size) - self.spike = bp.init.variable_(bm.zeros, self.size, batch_size) - if self.track_rate: - self.rate_tracking = bp.init.variable_(bm.zeros, self.size, batch_size) - - def update(self, s, x): - # neuron charge - self.v.value = jax.lax.stop_gradient(self.v.value) + x - # neuron fire - spike = self.f_surrogate(self.v.value - self.v_threshold) - # spike reset - spike_d = jax.lax.stop_gradient(spike) if self.detach_reset else spike - if self.v_reset is None: - self.v -= spike_d * self.v_threshold - else: - self.v.value = (1. - spike_d) * self.v + spike_d * self.v_reset - # dropout - if self.dropout > 0.0 and s['fit']: - mask = self.rng.bernoulli(1 - self.dropout, self.v.shape) / (1 - self.dropout) - spike = mask * spike - self.spike.value = spike - # spike track - if self.track_rate: - self.rate_tracking += jax.lax.stop_gradient(spike) - # output - if s['output_type'] == 'spike_rate': - assert self.track_rate - return jnp.concatenate([spike, self.rate_tracking.value]) - else: - return spike - - -class OnlineLIFNode(bp.DynamicalSystem): - def __init__( - self, - size, - tau: float = 2., - decay_input: bool = False, - v_threshold: float = 1., - v_reset: float = None, - f_surrogate=bm.surrogate.sigmoid, - detach_reset: bool = True, - track_rate: bool = True, - neuron_dropout: float = 0.0, - name: str = None, - mode: bm.Mode = None - ): - super().__init__(name=name, mode=mode) - bp.check.is_subclass(self.mode, bm.TrainingMode) - - self.size = bp.check.is_sequence(size, elem_type=int) - self.tau = tau - self.decay_input = decay_input - self.v_threshold = v_threshold - self.v_reset = v_reset - self.f_surrogate = f_surrogate - self.detach_reset = detach_reset - self.track_rate = track_rate - self.dropout = neuron_dropout - - if self.dropout > 0.0: - self.rng = bm.random.default_rng() - self.reset_state(1) - - def reset_state(self, batch_size=1): - self.v = bp.init.variable_(bm.zeros, self.size, batch_size) - self.spike = bp.init.variable_(bm.zeros, self.size, batch_size) - if self.track_rate: - self.rate_tracking = bp.init.variable_(bm.zeros, self.size, batch_size) - - def update(self, s, x): - # neuron charge - if self.decay_input: - x = x / self.tau - if self.v_reset is None or self.v_reset == 0: - self.v = jax.lax.stop_gradient(self.v.value) * (1 - 1. / self.tau) + x - else: - self.v = jax.lax.stop_gradient(self.v.value) * (1 - 1. / self.tau) + self.v_reset / self.tau + x - # neuron fire - spike = self.f_surrogate(self.v - self.v_threshold) - # neuron reset - spike_d = jax.lax.stop_gradient(spike) if self.detach_reset else spike - if self.v_reset is None: - self.v -= spike_d * self.v_threshold - else: - self.v = (1. - spike_d) * self.v + spike_d * self.v_reset - # dropout - if self.dropout > 0.0 and s['fit']: - mask = self.rng.bernoulli(1 - self.dropout, spike.shape) / (1 - self.dropout) - spike = mask * spike - self.spike.value = spike - # spike - if self.track_rate: - self.rate_tracking.value = jax.lax.stop_gradient(self.rate_tracking * (1 - 1. / self.tau) + spike) - if s['output_type'] == 'spike_rate': - assert self.track_rate - return jnp.concatenate((spike, self.rate_tracking.value)) - else: - return spike - - -class AverageMeter(object): - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -@functools.partial(jax.jit, static_argnums=2) -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - _, pred = jax.vmap(jax.lax.top_k, in_axes=(0, None))(output, maxk) - pred = pred.T - correct = (pred == target.reshape(1, -1)).astype(bm.float_) - res = [] - for k in topk: - correct_k = correct[:k].reshape(-1).sum(0) - res.append(correct_k * 100.0 / target.size) - return res - - -def classify_cifar(): - parser = argparse.ArgumentParser(description='Classify CIFAR') - parser.add_argument('-T', default=6, type=int, help='simulating time-steps') - parser.add_argument('-tau', default=2., type=float) - parser.add_argument('-b', default=128, type=int, help='batch size') - parser.add_argument('-epochs', default=300, type=int, help='number of total epochs to run') - parser.add_argument('-j', default=4, type=int, help='number of data loading workers (default: 4)') - parser.add_argument('-data_dir', type=str, default=r'D:/data') - parser.add_argument('-dataset', default='cifar10', type=str) - parser.add_argument('-out_dir', default='./logs', type=str, help='root dir for saving logs and checkpoint') - parser.add_argument('-resume', type=str, help='resume from the checkpoint path') - parser.add_argument('-opt', type=str, help='use which optimizer. SGD or Adam', default='SGD') - parser.add_argument('-lr', default=0.1, type=float, help='learning rate') - parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD') - parser.add_argument('-lr_scheduler', default='CosALR', type=str, help='use which schedule. StepLR or CosALR') - parser.add_argument('-step_size', default=100, type=float, help='step_size for StepLR') - parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR') - parser.add_argument('-T_max', default=300, type=int, help='T_max for CosineAnnealingLR') - parser.add_argument('-drop_rate', type=float, default=0.0) - parser.add_argument('-weight_decay', type=float, default=0.0) - parser.add_argument('-loss_lambda', type=float, default=0.05) - parser.add_argument('-online_update', action='store_true') - parser.add_argument('-gpu-id', default='0', type=str, help='gpu id') - args = parser.parse_args() - os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id - - # datasets - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - Cutout(), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - if args.dataset == 'cifar10': - dataloader = datasets.CIFAR10 - num_classes = 10 - else: - dataloader = datasets.CIFAR100 - num_classes = 100 - trainset = dataloader(root=args.data_dir, train=True, download=True, transform=transform_train) - train_data_loader = data.DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) - testset = dataloader(root=args.data_dir, train=False, download=False, transform=transform_test) - test_data_loader = data.DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j) - - # network - net = OnlineSpikingVGG(neuron_model=OnlineLIFNode, - neuron_pars=dict(tau=args.tau, - neuron_dropout=args.drop_rate, - f_surrogate=bm.surrogate.sigmoid, - track_rate=True, - v_reset=None), - weight_standardization=True, - num_classes=num_classes, - grad_with_rate=True, - fc_hw=1, - c_in=3) - print('Total Parameters: %.2fM' % ( - sum(p.size for p in net.vars().subset(bm.TrainVar).unique().values()) / 1000000.0)) - - # path - out_dir = os.path.join(args.out_dir, f'{args.dataset}_T_{args.T}_{args.opt}_lr_{args.lr}_') - if args.lr_scheduler == 'CosALR': - out_dir += f'CosALR_{args.T_max}' - elif args.lr_scheduler == 'StepLR': - out_dir += f'StepLR_{args.step_size}_{args.gamma}' - else: - raise NotImplementedError(args.lr_scheduler) - if args.online_update: - out_dir += '_online' - os.makedirs(out_dir, exist_ok=True) - with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: - args_txt.write(str(args)) - - t_step = args.T - - @bm.to_object(child_objs=net) - def single_step(x, y, fit=True): - out = net({'fit': fit}, x) - if args.loss_lambda > 0.0: - y = bm.one_hot(y, 10, dtype=bm.float_) - l = bp.losses.mean_squared_error(out, y) * args.loss_lambda - l += (1 - args.loss_lambda) * bp.losses.cross_entropy_loss(out, y) - l /= t_step - else: - l = bp.losses.cross_entropy_loss(out, y) / t_step - return l, out - - @bm.jit - @bm.to_object(child_objs=net) - def inference_fun(x, y): - l, out = bm.for_loop(lambda _: single_step(x, y, False), - jnp.arange(t_step), - child_objs=net) - out = out.sum(0) - n = jnp.sum(jnp.argmax(out, axis=1) == y) - return l.sum(), n, out - - grad_fun = bm.grad(single_step, grad_vars=net.train_vars().unique(), return_value=True, has_aux=True) - - if args.lr_scheduler == 'StepLR': - lr = bp.optim.StepLR(args.lr, step_size=args.step_size, gamma=args.gamma) - elif args.lr_scheduler == 'CosALR': - lr = bp.optim.CosineAnnealingLR(args.lr, T_max=args.T_max) - else: - raise NotImplementedError(args.lr_scheduler) - - if args.opt == 'SGD': - optimizer = bp.optim.Momentum(lr, net.train_vars().unique(), momentum=args.momentum, weight_decay=args.weight_decay) - elif args.opt == 'Adam': - optimizer = bp.optim.AdamW(lr, net.train_vars().unique(), weight_decay=args.weight_decay) - else: - raise NotImplementedError(args.opt) - - @bm.jit - @bm.to_object(child_objs=(optimizer, grad_fun)) - def train_fun(x, y): - if args.online_update: - final_loss, final_out = 0., 0. - for _ in range(t_step): - grads, l, out = grad_fun(x, y) - optimizer.update(grads) - final_loss += l - final_out += out - else: - final_grads, final_loss, final_out = grad_fun(x, y) - for _ in range(t_step - 1): - grads, l, out = grad_fun(x, y) - final_grads = jax.tree_util.tree_map(lambda a, b: a + b, final_grads, grads) - final_loss += l - final_out += out - optimizer.update(final_grads) - n = jnp.sum(jnp.argmax(final_out, axis=1) == y) - return final_loss, n, final_out - - start_epoch = 0 - max_test_acc = 0 - if args.resume: - checkpoint = bp.checkpoints.load(args.resume) - net.load_state_dict(checkpoint['net']) - optimizer.load_state_dict(checkpoint['optimizer']) - start_epoch = checkpoint['epoch'] + 1 - max_test_acc = checkpoint['max_test_acc'] - - train_samples = len(train_data_loader) - test_samples = len(test_data_loader) - for epoch in range(start_epoch, args.epochs): - start_time = time.time() - - batch_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - end = time.time() - - train_loss = 0 - train_acc = 0 - pbar = tqdm.tqdm(total=train_samples) - for frame, label in train_data_loader: - frame = jnp.asarray(frame).transpose(0, 2, 3, 1) - label = jnp.asarray(label) - net.reset_state(frame.shape[0]) - batch_loss, n, total_fr = train_fun(frame, label) - prec1, prec5 = accuracy(total_fr, label, topk=(1, 5)) - train_loss += batch_loss * label.size - train_acc += n - losses.update(batch_loss, frame.shape[0]) - top1.update(prec1.item(), frame.shape[0]) - top5.update(prec5.item(), frame.shape[0]) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - pbar.update(1) - pbar.set_description( - 'Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - bt=batch_time.avg, loss=losses.avg, top1=top1.avg, top5=top5.avg, - ) - ) - pbar.close() - - train_loss /= train_samples - train_acc /= train_samples - optimizer.lr.step_epoch() - - batch_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - end = time.time() - - test_loss = 0 - test_acc = 0 - pbar = tqdm.tqdm(total=test_samples) - for frame, label in test_data_loader: - frame = jnp.asarray(frame).transpose(0, 2, 3, 1) - label = jnp.asarray(label) - net.reset_state(frame.shape[0]) - total_loss, n, out = inference_fun(frame, label) - test_loss += total_loss * label.size - test_acc += n - prec1, prec5 = accuracy(out, label, topk=(1, 5)) - losses.update(total_loss, frame.shape[0]) - top1.update(prec1.item(), frame.shape[0]) - top5.update(prec5.item(), frame.shape[0]) - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - pbar.update(1) - pbar.set_description( - 'Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - bt=batch_time.avg, loss=losses.avg, top1=top1.avg, top5=top5.avg, - ) - ) - pbar.close() - - test_loss /= test_samples - test_acc /= test_samples - - if test_acc > max_test_acc: - max_test_acc = test_acc - checkpoint = { - 'net': net.state_dict(), - 'optimizer': optimizer.state_dict(), - 'epoch': epoch, - 'max_test_acc': max_test_acc - } - bp.checkpoints.save(out_dir, checkpoint, max_test_acc) - - total_time = time.time() - start_time - print(f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, ' - f'test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, ' - f'total_time={total_time}') - - -if __name__ == '__main__': - classify_cifar() diff --git a/examples/training_snn_models/SurrogateGrad_lif-ANN-style.py b/examples/training_snn_models/SurrogateGrad_lif-ANN-style.py deleted file mode 100644 index fbffdb529..000000000 --- a/examples/training_snn_models/SurrogateGrad_lif-ANN-style.py +++ /dev/null @@ -1,139 +0,0 @@ -# -*- coding: utf-8 -*- - - -""" -Reproduce the results of the``spytorch`` tutorial 1: - -- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial1.ipynb - -""" - -import time - -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec - -import brainpy as bp -import brainpy.math as bm - - -class SNN(bp.Network): - def __init__(self, num_in, num_rec, num_out): - super(SNN, self).__init__() - - # parameters - self.num_in = num_in - self.num_rec = num_rec - self.num_out = num_out - - # neuron groups - self.i = bp.neurons.InputGroup(num_in) - self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.) - self.o = bp.neurons.LeakyIntegrator(num_out, tau=5) - - # synapse: i->r - self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), tau=10., - output=bp.synouts.CUBA(target_var=None), - g_max=bp.init.KaimingNormal(scale=20.)) - # synapse: r->o - self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), tau=10., - output=bp.synouts.CUBA(target_var=None), - g_max=bp.init.KaimingNormal(scale=20.)) - - # whole model - self.model = bp.Sequential(self.i, self.i2r, self.r, self.r2o, self.o) - - def update(self, tdi, spike): - self.model(tdi, spike) - return self.o.V.value - - -def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): - gs = GridSpec(*dim) - mem = 1. * mem - if spk is not None: - mem[spk > 0.0] = spike_height - mem = bm.as_numpy(mem) - for i in range(np.prod(dim)): - if i == 0: - a0 = ax = plt.subplot(gs[i]) - else: - ax = plt.subplot(gs[i], sharey=a0) - ax.plot(mem[i]) - plt.tight_layout() - plt.show() - - -def print_classification_accuracy(output, target): - """ Dirty little helper function to compute classification accuracy. """ - m = bm.max(output, axis=1) # max over time - am = bm.argmax(m, axis=1) # argmax over output units - acc = bm.mean(target == am) # compare to labels - print("Accuracy %.3f" % acc) - - -with bm.environment(mode=bm.training_mode): - net = SNN(100, 4, 2) - -num_step = 2000 -num_sample = 256 -freq = 5 # Hz -mask = bm.random.rand(num_sample, num_step, net.num_in) -x_data = bm.zeros((num_sample, num_step, net.num_in)) -x_data[mask < freq * bm.get_dt() / 1000.] = 1.0 -y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_) -rng = bm.random.RandomState() - -# Before training -runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) -out = runner.run(inputs=x_data, reset_state=True) -plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) -plot_voltage_traces(out) -print_classification_accuracy(out, y_data) - - -@bm.to_object(child_objs=net, dyn_vars=rng) # add nodes and vars used in this function -def loss(): - key = rng.split_key() - X = bm.random.permutation(x_data, key=key) - Y = bm.random.permutation(y_data, key=key) - looper = bp.DSRunner(net, numpy_mon_after_run=False, progress_bar=False) - predictions = looper.run(inputs=X, reset_state=True) - predictions = jnp.max(predictions, axis=1) - return bp.losses.cross_entropy_loss(predictions, Y) - - -grad = bm.grad(loss, grad_vars=loss.train_vars().unique(), return_value=True) -optimizer = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique()) - - -@bm.to_object(child_objs=(grad, optimizer)) # add nodes and vars used in this function -def train(_): - grads, l = grad() - optimizer.update(grads) - return l - - -# train the network -net.reset_state(num_sample) -train_losses = [] -for i in range(0, 3000, 100): - t0 = time.time() - ls = bm.for_loop(train, operands=bm.arange(i, i + 100, 1)) - print(f'Train {i + 100} epoch, loss = {jnp.mean(ls):.4f}, used time {time.time() - t0:.4f} s') - train_losses.append(ls) - -# visualize the training losses -plt.plot(bm.as_numpy(jnp.concatenate(train_losses))) -plt.xlabel("Epoch") -plt.ylabel("Training Loss") -plt.show() - -# predict the output according to the input data -runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) -out = runner.run(inputs=x_data, reset_state=True) -plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) -plot_voltage_traces(out) -print_classification_accuracy(out, y_data) diff --git a/examples/training_snn_models/SurrogateGrad_lif.py b/examples/training_snn_models/SurrogateGrad_lif.py deleted file mode 100644 index c4af45567..000000000 --- a/examples/training_snn_models/SurrogateGrad_lif.py +++ /dev/null @@ -1,143 +0,0 @@ -# -*- coding: utf-8 -*- - - -""" -Reproduce the results of the``spytorch`` tutorial 1: - -- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial1.ipynb - -""" - -import time - -import jax.numpy as jnp -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.gridspec import GridSpec - -import brainpy as bp -import brainpy.math as bm - - -class SNN(bp.Network): - def __init__(self, num_in, num_rec, num_out): - super(SNN, self).__init__() - - # parameters - self.num_in = num_in - self.num_rec = num_rec - self.num_out = num_out - - # neuron groups - self.i = bp.neurons.InputGroup(num_in, mode=bm.training_mode) - self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1., mode=bm.training_mode) - self.o = bp.neurons.LeakyIntegrator(num_out, tau=5, mode=bm.training_mode) - - # synapse: i->r - self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), - output=bp.synouts.CUBA(), tau=10., - g_max=bp.init.KaimingNormal(scale=20.), - mode=bm.training_mode) - # synapse: r->o - self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), - output=bp.synouts.CUBA(), tau=10., - g_max=bp.init.KaimingNormal(scale=20.), - mode=bm.training_mode) - - def update(self, tdi, spike): - self.i2r(tdi, spike) - self.r2o(tdi) - self.r(tdi) - self.o(tdi) - return self.o.V.value - - -def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): - gs = GridSpec(*dim) - mem = 1. * mem - if spk is not None: - mem[spk > 0.0] = spike_height - mem = bm.as_numpy(mem) - for i in range(np.prod(dim)): - if i == 0: - a0 = ax = plt.subplot(gs[i]) - else: - ax = plt.subplot(gs[i], sharey=a0) - ax.plot(mem[i]) - plt.tight_layout() - plt.show() - - -def print_classification_accuracy(output, target): - """ Dirty little helper function to compute classification accuracy. """ - m = bm.max(output, axis=1) # max over time - am = bm.argmax(m, axis=1) # argmax over output units - acc = bm.mean(target == am) # compare to labels - print("Accuracy %.3f" % acc) - - -with bm.environment(mode=bm.training_mode): - net = SNN(100, 4, 2) - -num_step = 2000 -num_sample = 256 -freq = 5 # Hz -mask = bm.random.rand(num_sample, num_step, net.num_in) -x_data = bm.zeros((num_sample, num_step, net.num_in)) -x_data[mask < freq * bm.get_dt() / 1000.] = 1.0 -y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_) -rng = bm.random.RandomState(123) - - -# Before training -runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) -out = runner.run(inputs=x_data.value, reset_state=True) -plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) -plot_voltage_traces(out) -print_classification_accuracy(out, y_data) - - -@bm.to_object(child_objs=net, dyn_vars=rng) # add nodes and vars used here -def loss(): - key = rng.split_key() - X = bm.random.permutation(x_data, key=key) - Y = bm.random.permutation(y_data, key=key) - looper = bp.DSRunner(net, numpy_mon_after_run=False, progress_bar=False) - predictions = looper.run(inputs=X, reset_state=True) - predictions = jnp.max(predictions, axis=1) - return bp.losses.cross_entropy_loss(predictions, Y) - - -grad = bm.grad(loss, grad_vars=loss.train_vars().unique(), return_value=True) -optimizer = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique()) - - -@bm.to_object(child_objs=(grad, optimizer)) # add nodes and vars used here -def train(_): - grads, l = grad() - optimizer.update(grads) - return l - - -# train the network -net.reset_state(num_sample) -train_losses = [] -b = 100 -for i in range(0, 3000, b): - t0 = time.time() - ls = bm.for_loop(train, operands=bm.arange(i, i + b, 1)) - print(f'Train {i + b} epoch, loss = {jnp.mean(ls):.4f}, used time {time.time() - t0:.4f} s') - train_losses.append(ls) - -# visualize the training losses -plt.plot(bm.as_numpy(jnp.concatenate(train_losses))) -plt.xlabel("Epoch") -plt.ylabel("Training Loss") -plt.show() - -# predict the output according to the input data -runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) -out = runner.run(inputs=x_data, reset_state=True) -plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) -plot_voltage_traces(out) -print_classification_accuracy(out, y_data) diff --git a/examples/training_snn_models/SurrogateGrad_lif_fashion_mnist.py b/examples/training_snn_models/SurrogateGrad_lif_fashion_mnist.py deleted file mode 100644 index 5237e903c..000000000 --- a/examples/training_snn_models/SurrogateGrad_lif_fashion_mnist.py +++ /dev/null @@ -1,244 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Reproduce the results of the``spytorch`` tutorial 2 & 3: - -- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb -- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial3.ipynb - -""" - -import brainpy_datasets as bd -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.gridspec import GridSpec - -import brainpy as bp -import brainpy.math as bm - -bm.set_environment(bm.training_mode) - - -class SNN(bp.Network): - """ - This class implements a spiking neural network model with three layers: - - i >> r >> o - - Each two layers are connected through the exponential synapse model. - """ - - def __init__(self, num_in, num_rec, num_out): - super(SNN, self).__init__() - - # parameters - self.num_in = num_in - self.num_rec = num_rec - self.num_out = num_out - - # neuron groups - self.i = bp.neurons.InputGroup(num_in) - self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.) - self.o = bp.neurons.LeakyIntegrator(num_out, tau=5) - - # synapse: i->r - self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), - output=bp.synouts.CUBA(target_var=None), tau=10., - g_max=bp.init.KaimingNormal(scale=2.)) - # synapse: r->o - self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), - output=bp.synouts.CUBA(target_var=None), tau=10., - g_max=bp.init.KaimingNormal(scale=2.)) - - self.model = bp.Sequential( - self.i, self.i2r, self.r, self.r2o, self.o - ) - - def update(self, shared, spike): - self.model(shared, spike) - return self.o.V.value - - -def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): - gs = GridSpec(*dim) - mem = 1. * mem - if spk is not None: - mem[spk > 0.0] = spike_height - mem = bm.as_numpy(mem) - for i in range(np.prod(dim)): - if i == 0: - a0 = ax = plt.subplot(gs[i]) - else: - ax = plt.subplot(gs[i], sharey=a0) - ax.plot(mem[i]) - ax.axis("off") - plt.tight_layout() - plt.show() - - -def print_classification_accuracy(output, target): - """ Dirty little helper function to compute classification accuracy. """ - m = jnp.max(output, axis=1) # max over time - am = jnp.argmax(m, axis=1) # argmax over output units - acc = jnp.mean(target == am) # compare to labels - print("Accuracy %.3f" % acc) - - -def current2firing_time(x, tau=20., thr=0.2, tmax=1.0, epsilon=1e-7): - """Computes first firing time latency for a current input x - assuming the charge time of a current based LIF neuron. - - Args: - x -- The "current" values - - Keyword args: - tau -- The membrane time constant of the LIF neuron to be charged - thr -- The firing threshold value - tmax -- The maximum time returned - epsilon -- A generic (small) epsilon > 0 - - Returns: - Time to first spike for each "current" x - """ - x = np.clip(x, thr + epsilon, 1e9) - T = tau * np.log(x / (x - thr)) - T = np.where(x < thr, tmax, T) - return T - - -def sparse_data_generator(X, y, batch_size, nb_steps, nb_units, shuffle=True): - """ This generator takes datasets in analog format and - generates spiking network input as sparse tensors. - - Args: - X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples - y: The labels - """ - - labels_ = np.array(y, dtype=bm.int_) - sample_index = np.arange(len(X)) - - # compute discrete firing times - tau_eff = 2. / bm.get_dt() - unit_numbers = np.arange(nb_units) - firing_times = np.array(current2firing_time(X, tau=tau_eff, tmax=nb_steps), dtype=bm.int_) - - if shuffle: - np.random.shuffle(sample_index) - - counter = 0 - number_of_batches = len(X) // batch_size - while counter < number_of_batches: - batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)] - all_batch, all_times, all_units = [], [], [] - for bc, idx in enumerate(batch_index): - c = firing_times[idx] < nb_steps - times, units = firing_times[idx][c], unit_numbers[c] - batch = bc * np.ones(len(times), dtype=bm.int_) - all_batch.append(batch) - all_times.append(times) - all_units.append(units) - all_batch = np.concatenate(all_batch).flatten() - all_times = np.concatenate(all_times).flatten() - all_units = np.concatenate(all_units).flatten() - x_batch = bm.zeros((batch_size, nb_steps, nb_units)) - x_batch[all_batch, all_times, all_units] = 1. - y_batch = jnp.asarray(labels_[batch_index]) - yield x_batch, y_batch - counter += 1 - - -def train(model, x_data, y_data, lr=1e-3, nb_epochs=10, batch_size=128, nb_steps=128, nb_inputs=28 * 28): - def loss_fun(predicts, targets): - predicts, mon = predicts - # Here we set up our regularizer loss - # The strength paramters here are merely a guess and - # there should be ample room for improvement by - # tuning these paramters. - l1_loss = 1e-5 * jnp.sum(mon['r.spike']) # L1 loss on total number of spikes - l2_loss = 1e-5 * jnp.mean(jnp.sum(jnp.sum(mon['r.spike'], axis=0), axis=0) ** 2) # L2 loss on spikes per neuron - # predictions - predicts = jnp.max(predicts, axis=1) - loss = bp.losses.cross_entropy_loss(predicts, targets) - return loss + l2_loss + l1_loss - - trainer = bp.BPTT( - model, - loss_fun, - optimizer=bp.optim.Adam(lr=lr), - monitors={'r.spike': net.r.spike}, - ) - trainer.fit(lambda: sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs), - num_epoch=nb_epochs) - return trainer.get_hist_metric('fit') - - -def compute_classification_accuracy(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): - """ Computes classification accuracy on supplied data in batches. """ - accs = [] - runner = bp.DSRunner(model, progress_bar=False) - for x_local, y_local in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False): - output = runner.predict(inputs=x_local, reset_state=True) - m = jnp.max(output, 1) # max over time - am = jnp.argmax(m, 1) # argmax over output units - tmp = jnp.mean(y_local == am) # compare to labels - accs.append(tmp) - return jnp.mean(bm.asarray(accs)) - - -def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): - runner = bp.DSRunner(model, - monitors={'r.spike': model.r.spike}, - progress_bar=False) - data = sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False) - x_local, y_local = next(data) - output = runner.predict(inputs=x_local, reset_state=True) - return output, runner.mon.get('r.spike') - - -num_input = 28 * 28 -net = SNN(num_in=num_input, num_rec=100, num_out=10) - -# load the dataset -root = r"D:\data" -train_dataset = bd.vision.FashionMNIST(root, split='train', download=True) -test_dataset = bd.vision.FashionMNIST(root, split='test', download=True) - -# Standardize data -x_train = np.array(train_dataset.data, dtype=bm.float_) -x_train = x_train.reshape(x_train.shape[0], -1) / 255 -y_train = np.array(train_dataset.targets, dtype=bm.int_) -x_test = np.array(test_dataset.data, dtype=bm.float_) -x_test = x_test.reshape(x_test.shape[0], -1) / 255 -y_test = np.array(test_dataset.targets, dtype=bm.int_) - -# training -train_losses = train(net, x_train, y_train, lr=1e-3, nb_epochs=30, batch_size=256, nb_steps=100, nb_inputs=28 * 28) - -plt.figure(figsize=(3.3, 2), dpi=150) -plt.plot(train_losses) -plt.xlabel("Epoch") -plt.ylabel("Loss") -plt.show() - -print("Training accuracy: %.3f" % (compute_classification_accuracy(net, x_train, y_train, batch_size=512))) -print("Test accuracy: %.3f" % (compute_classification_accuracy(net, x_test, y_test, batch_size=512))) - -outs, spikes = get_mini_batch_results(net, x_train, y_train) -# Let's plot the hidden layer spiking activity for some input stimuli -fig = plt.figure(dpi=100) -plot_voltage_traces(outs) -plt.show() - -nb_plt = 4 -gs = GridSpec(1, nb_plt) -plt.figure(figsize=(7, 3), dpi=150) -for i in range(nb_plt): - plt.subplot(gs[i]) - plt.imshow(bm.as_numpy(spikes[i]).T, cmap=plt.cm.gray_r, origin="lower") - if i == 0: - plt.xlabel("Time") - plt.ylabel("Units") -plt.tight_layout() -plt.show() diff --git a/examples/training_snn_models/fashion_mnist_conv_lif.py b/examples/training_snn_models/fashion_mnist_conv_lif.py deleted file mode 100644 index 6b87ee834..000000000 --- a/examples/training_snn_models/fashion_mnist_conv_lif.py +++ /dev/null @@ -1,261 +0,0 @@ -# -*- coding: utf-8 -*- - -import argparse -import os -import sys -import time -from functools import partial - -import brainpy_datasets as bd -from jax import lax -import jax.numpy as jnp - -import brainpy as bp -import brainpy.math as bm -from brainpy.tools import DotDict - -bm.set_environment(mode=bm.training_mode, dt=1.) - - -class ConvLIF(bp.DynamicalSystem): - def __init__(self, n_time: int, n_channel: int, tau: float = 5.): - super().__init__() - self.n_time = n_time - - lif_par = dict(keep_size=True, V_rest=0., V_reset=0., V_th=1., - tau=tau, spike_fun=bm.surrogate.arctan) - - self.block1 = bp.Sequential( - bp.layers.Conv2d(1, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None), - bp.layers.BatchNorm2d(n_channel, momentum=0.9), - bp.neurons.LIF((28, 28, n_channel), **lif_par) - ) - self.block2 = bp.Sequential( - bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 14 * 14 - bp.layers.Conv2d(n_channel, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None), - bp.layers.BatchNorm2d(n_channel, momentum=0.9), - bp.neurons.LIF((14, 14, n_channel), **lif_par), - ) - self.block3 = bp.Sequential( - bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 7 * 7 - bp.layers.Flatten(), - bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4, b_initializer=None), - bp.neurons.LIF(4 * 4 * n_channel, **lif_par), - ) - self.block4 = bp.Sequential( - bp.layers.Dense(n_channel * 4 * 4, 10, b_initializer=None), - bp.neurons.LIF(10, **lif_par), - ) - - def update(self, sha, x): - self.block1(sha, x) # x.shape = [B, H, W, C] - self.block2(sha, self.block1[-1].spike.value) - self.block3(sha, self.block2[-1].spike.value) - self.block4(sha, self.block3[-1].spike.value) - return self.block4[-1].spike.value - - -class IFNode(bp.DynamicalSystem): - """The Integrate-and-Fire neuron. The voltage of the IF neuron will - not decay as that of the LIF neuron. The sub-threshold neural dynamics - of it is as followed: - - .. math:: - V[t] = V[t-1] + X[t] - """ - - def __init__(self, size: tuple, v_threshold: float = 1., v_reset: float = 0., - spike_fun=bm.surrogate.arctan, mode=None, reset_mode='soft'): - super().__init__(mode=mode) - bp.check.is_subclass(self.mode, bm.TrainingMode) - - self.size = bp.check.is_sequence(size, elem_type=int, allow_none=False) - self.reset_mode = bp.check.is_string(reset_mode, candidates=['hard', 'soft']) - self.v_threshold = bp.check.is_float(v_threshold) - self.v_reset = bp.check.is_float(v_reset) - self.spike_fun = bp.check.is_callable(spike_fun) - - # variables - self.V = bm.Variable(jnp.zeros((1,) + size, dtype=bm.float_), batch_axis=0) - - def reset_state(self, batch_size): - self.V.value = jnp.zeros((batch_size,) + self.size, dtype=bm.float_) - - def update(self, s, x): - self.V.value += x - spike = self.spike_fun(self.V - self.v_threshold) - # s = lax.stop_gradient(spike) - s = spike - if self.reset_mode == 'hard': - one = lax.convert_element_type(1., bm.float_) - self.V.value = self.v_reset * s + (one - s) * self.V - else: - self.V -= s * self.v_threshold - return spike - - -class ConvIF(bp.DynamicalSystem): - def __init__(self, n_time: int, n_channel: int): - super().__init__() - self.n_time = n_time - - self.block1 = bp.Sequential( - bp.layers.Conv2d(1, n_channel, kernel_size=3, padding=(1, 1), ), - bp.layers.BatchNorm2d(n_channel, momentum=0.9), - IFNode((28, 28, n_channel), spike_fun=bm.surrogate.arctan) - ) - self.block2 = bp.Sequential( - bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 14 * 14 - bp.layers.Conv2d(n_channel, n_channel, kernel_size=3, padding=(1, 1), ), - bp.layers.BatchNorm2d(n_channel, momentum=0.9), - IFNode((14, 14, n_channel), spike_fun=bm.surrogate.arctan), - ) - self.block3 = bp.Sequential( - bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 7 * 7 - bp.layers.Flatten(), - bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4,), - IFNode((4 * 4 * n_channel,), spike_fun=bm.surrogate.arctan), - ) - self.block4 = bp.Sequential( - bp.layers.Dense(n_channel * 4 * 4, 10, ), - IFNode((10,), spike_fun=bm.surrogate.arctan), - ) - - def update(self, sha, x): - x = self.block1(sha, x) # x.shape = [B, H, W, C] - x = self.block2(sha, x) - x = self.block3(sha, x) - x = self.block4(sha, x) - return x - - -def main(): - parser = argparse.ArgumentParser(description='Classify Fashion-MNIST') - parser.add_argument('-platform', default='cpu', help='platform') - parser.add_argument('-model', default='lif', help='Neuron model to use') - parser.add_argument('-n_time', default=4, type=int, help='simulating time-steps') - parser.add_argument('-tau', default=5., type=float, help='LIF time constant') - parser.add_argument('-batch', default=128, type=int, help='batch size') - parser.add_argument('-n_channel', default=128, type=int, help='channels of ConvLIF') - parser.add_argument('-n_epoch', default=64, type=int, metavar='N', help='number of total epochs to run') - parser.add_argument('-data-dir', default='./data', type=str, help='root dir of Fashion-MNIST dataset') - parser.add_argument('-out-dir', default='./logs', type=str, help='root dir for saving logs and checkpoint') - parser.add_argument('-lr', default=0.1, type=float, help='learning rate') - args = parser.parse_args() - print(args) - - bm.set_platform(args.platform) - - # net - if args.model == 'if': - net = ConvIF(n_time=args.n_time, n_channel=args.n_channel) - out_dir = os.path.join(args.out_dir, - f'{args.model}_T{args.n_time}_b{args.batch}_' - f'lr{args.lr}_c{args.n_channel}') - elif args.model == 'lif': - net = ConvLIF(n_time=args.n_time, n_channel=args.n_channel, tau=args.tau) - out_dir = os.path.join(args.out_dir, - f'{args.model}_T{args.n_time}_b{args.batch}_' - f'lr{args.lr}_c{args.n_channel}_tau{args.tau}') - else: - raise ValueError - - # prediction function - def inference_fun(X, fit=True): - net.reset_state(X.shape[0]) - return bm.for_loop(lambda sha: net(sha.update(dt=bm.dt, fit=fit), X), - DotDict(t=jnp.arange(args.n_time, dtype=bm.float_), - i=jnp.arange(args.n_time, dtype=bm.int_)), - child_objs=net) - - # loss function - @bm.to_object(child_objs=net) - def loss_fun(X, Y, fit=True): - fr = jnp.max(inference_fun(X, fit), axis=0) - ys_onehot = bm.one_hot(Y, 10, dtype=bm.float_) - l = bp.losses.mean_squared_error(fr, ys_onehot) - n = jnp.sum(fr.argmax(1) == Y) - return l, n - - predict_loss_fun = bm.jit(partial(loss_fun, fit=True), child_objs=loss_fun) - - grad_fun = bm.grad(loss_fun, grad_vars=net.train_vars().unique(), has_aux=True, return_value=True) - - # optimizer - optimizer = bp.optim.Adam(bp.optim.ExponentialDecay(0.2, 1, 0.9999), - train_vars=net.train_vars().unique()) - - @bm.jit - @bm.to_object(child_objs=(grad_fun, optimizer)) - def train_fun(X, Y): - grads, l, n = grad_fun(X, Y) - optimizer.update(grads) - return l, n - - # dataset - train_set = bd.vision.FashionMNIST(root=args.data_dir, split='train', download=True) - test_set = bd.vision.FashionMNIST(root=args.data_dir, split='test', download=True) - x_train = jnp.asarray(train_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1)) - y_train = jnp.asarray(train_set.targets, dtype=bm.int_) - x_test = jnp.asarray(test_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1)) - y_test = jnp.asarray(test_set.targets, dtype=bm.int_) - - os.makedirs(out_dir, exist_ok=True) - with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: - args_txt.write(str(args)) - args_txt.write('\n') - args_txt.write(' '.join(sys.argv)) - - max_test_acc = -1 - for epoch_i in range(0, args.n_epoch): - start_time = time.time() - loss, train_acc = [], 0. - for i in range(0, x_train.shape[0], args.batch): - xs = x_train[i: i + args.batch] - ys = y_train[i: i + args.batch] - l, n = train_fun(xs, ys) - loss.append(l) - train_acc += n - train_acc /= x_train.shape[0] - train_loss = jnp.mean(jnp.asarray(loss)) - optimizer.lr.step_epoch() - - loss, test_acc = [], 0. - for i in range(0, x_test.shape[0], args.batch): - xs = x_test[i: i + args.batch] - ys = y_test[i: i + args.batch] - l, n = predict_loss_fun(xs, ys) - loss.append(l) - test_acc += n - test_acc /= x_test.shape[0] - test_loss = jnp.mean(jnp.asarray(loss)) - - t = (time.time() - start_time) / 60 - print(f'epoch {epoch_i}, used {t:.3f} min, ' - f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, ' - f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}') - - if max_test_acc < test_acc: - max_test_acc = test_acc - states = { - 'net': net.state_dict(), - 'optimizer': optimizer.state_dict(), - 'epoch_i': epoch_i, - 'train_acc': train_acc, - 'test_acc': test_acc, - } - bp.checkpoints.save(out_dir, states, epoch_i) - - # inference - state_dict = bp.checkpoints.load(out_dir) - net.load_state_dict(state_dict['net']) - correct_num = 0 - for i in range(0, x_test.shape[0], 512): - xs = x_test[i: i + 512] - ys = y_test[i: i + 512] - correct_num += predict_loss_fun(xs, ys)[1] - print('Max test accuracy: ', correct_num / x_test.shape[0]) - - -if __name__ == '__main__': - main() diff --git a/examples/training_snn_models/mnist_lif_readout.py b/examples/training_snn_models/mnist_lif_readout.py deleted file mode 100644 index 5a6580453..000000000 --- a/examples/training_snn_models/mnist_lif_readout.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- coding: utf-8 -*- - -import time -import argparse -import os.path -import sys - -import brainpy_datasets as bd - -import jax.numpy as jnp - -import brainpy as bp -import brainpy.math as bm - -parser = argparse.ArgumentParser(description='LIF MNIST Training') -parser.add_argument('-T', default=100, type=int, help='simulating time-steps') -parser.add_argument('-platform', default='cpu', help='device') -parser.add_argument('-batch', default=64, type=int, help='batch size') -parser.add_argument('-epochs', default=15, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') -parser.add_argument('-lr', default=1e-3, type=float, help='learning rate') -parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron') - -args = parser.parse_args() -print(args) - -out_dir = os.path.join(args.out_dir, f'T{args.T}_b{args.batch}_lr{args.lr}') -if not os.path.exists(out_dir): - os.makedirs(out_dir) -with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: - args_txt.write(str(args)) - args_txt.write('\n') - args_txt.write(' '.join(sys.argv)) - -bm.set_platform(args.platform) -bm.set_environment(mode=bm.training_mode, dt=1.) - - -class SNN(bp.DynamicalSystem): - def __init__(self, tau): - super().__init__() - - self.layer = bp.Sequential( - bp.layers.Dense(28 * 28, 10, b_initializer=None), - bp.neurons.LIF(10, V_rest=0., V_reset=0., V_th=1., tau=tau, spike_fun=bm.surrogate.arctan), - ) - - def update(self, p, x): - self.layer(p, x) - return self.layer[-1].spike.value - - -net = SNN(args.tau) - -# data -train_data = bd.vision.MNIST(r'D:/data', split='train', download=True) -test_data = bd.vision.MNIST(r'D:/data', split='test', download=True) -x_train = bm.asarray(train_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28) -y_train = bm.asarray(train_data.targets, dtype=bm.int_) -x_test = bm.asarray(test_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28) -y_test = bm.asarray(test_data.targets, dtype=bm.int_) - -# loss -encoder = bp.encoding.PoissonEncoder(min_val=0., max_val=1.) - - -@bm.to_object(child_objs=(net, encoder)) -def loss_fun(xs, ys): - net.reset_state(batch_size=xs.shape[0]) - xs = encoder(xs, num_step=args.T) - # shared arguments for looping over time - shared = bm.shared_args_over_time(num_step=args.T) - outs = bm.for_loop(net, (shared, xs)) - out_fr = jnp.mean(outs, axis=0) - ys_onehot = bm.one_hot(ys, 10, dtype=bm.float_) - l = bp.losses.mean_squared_error(out_fr, ys_onehot) - n = jnp.sum(out_fr.argmax(1) == ys) - return l, n - - -# gradient -grad_fun = bm.grad(loss_fun, grad_vars=net.train_vars().unique(), has_aux=True, return_value=True) - -# optimizer -optimizer = bp.optim.Adam(lr=args.lr, train_vars=net.train_vars().unique()) - - -# train -@bm.jit -@bm.to_object(child_objs=(grad_fun, optimizer)) -def train(xs, ys): - grads, l, n = grad_fun(xs, ys) - optimizer.update(grads) - return l, n - - -max_test_acc = 0. - -# computing -for epoch_i in range(args.epochs): - bm.random.shuffle(x_train, key=123) - bm.random.shuffle(y_train, key=123) - - t0 = time.time() - loss, train_acc = [], 0. - for i in range(0, x_train.shape[0], args.batch): - X = x_train[i: i + args.batch] - Y = y_train[i: i + args.batch] - l, correct_num = train(X, Y) - loss.append(l) - train_acc += correct_num - train_acc /= x_train.shape[0] - train_loss = jnp.mean(jnp.asarray(loss)) - optimizer.lr.step_epoch() - - loss, test_acc = [], 0. - for i in range(0, x_test.shape[0], args.batch): - X = x_test[i: i + args.batch] - Y = y_test[i: i + args.batch] - l, correct_num = loss_fun(X, Y) - loss.append(l) - test_acc += correct_num - test_acc /= x_test.shape[0] - test_loss = jnp.mean(jnp.asarray(loss)) - - t = (time.time() - t0) / 60 - print(f'epoch {epoch_i}, used {t:.3f} min, ' - f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, ' - f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}') - - if max_test_acc < test_acc: - max_test_acc = test_acc - states = { - 'net': net.state_dict(), - 'optimizer': optimizer.state_dict(), - 'epoch_i': epoch_i, - 'train_acc': train_acc, - 'test_acc': test_acc, - } - bp.checkpoints.save(out_dir, states, epoch_i) - -# inference -state_dict = bp.checkpoints.load(out_dir) -net.load_state_dict(state_dict['net']) - -runner = bp.DSRunner(net, data_first_axis='T') -correct_num = 0 -for i in range(0, x_test.shape[0], 512): - X = encoder(x_test[i: i + 512], num_step=args.T) - Y = y_test[i: i + 512] - out_fr = jnp.mean(runner.predict(inputs=X, reset_state=True), axis=0) - correct_num += jnp.sum(out_fr.argmax(1) == Y) - -print('Max test accuracy: ', correct_num / x_test.shape[0]) diff --git a/examples/training_snn_models/readme.md b/examples/training_snn_models/readme.md new file mode 100644 index 000000000..23f6487a6 --- /dev/null +++ b/examples/training_snn_models/readme.md @@ -0,0 +1,3 @@ + +See [brainpy-example](https://github.com/brainpy/examples/blob/main/brain_inspired_computing) for more examples for brain-inspired computing models. + From 3f84f3c95f25fb43bce886e02af46f4b01966c4c Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 15 Jan 2023 11:11:20 +0800 Subject: [PATCH 13/13] update CI --- .github/workflows/Sync_branches.yml | 4 +- changes.md | 462 +--------------------------- 2 files changed, 17 insertions(+), 449 deletions(-) diff --git a/.github/workflows/Sync_branches.yml b/.github/workflows/Sync_branches.yml index 00ff74b68..4a4192425 100644 --- a/.github/workflows/Sync_branches.yml +++ b/.github/workflows/Sync_branches.yml @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@master - - name: Merge master -> brainpy-2.2.x + - name: Merge master -> brainpy-2.3.x uses: devmasx/merge-branch@master with: type: now from_branch: master - target_branch: brainpy-2.2.x + target_branch: brainpy-2.3.x github_token: ${{ github.token }} \ No newline at end of file diff --git a/changes.md b/changes.md index 9c2b1211c..d072e7470 100644 --- a/changes.md +++ b/changes.md @@ -1,462 +1,30 @@ -# Change from Version 2.3.0 to Version 2.3.1 - - - -This release (under the release branch of ``brainpy=2.3.x``) continues to add supports for brain-inspired computation. - - - -```python -import brainpy as bp -import brainpy.math as bm -``` - - - -## Backwards Incompatible Changes - - - -#### 1. Error: module 'brainpy' has no attribute 'datasets' - -``brainpy.datasets`` module is now published as an independent package ``brainpy_datasets``. - -Please change your dataset access from - -```python -bp.datasets.xxxxx -``` - -to - -```python -import brainpy_datasets as bp_data - -bp_data.chaos.XXX -bp_data.vision.XXX -``` - -For a chaotic data series, - -```python -# old version -data = bp.datasets.double_scroll_series(t_warmup + t_train + t_test, dt=dt) -x_var = data['x'] -y_var = data['y'] -z_var = data['z'] - -# new version -data = bd.chaos.DoubleScrollEq(t_warmup + t_train + t_test, dt=dt) -x_var = data.xs -y_var = data.ys -z_var = data.zs -``` - -For a vision dataset, - -```python -# old version -dataset = bp.datasets.FashionMNIST(root, train=True, download=True) - -# new version -dataset = bd.vision.FashionMNIST(root, split='train', download=True) -``` - - - -#### 2. Error: DSTrainer must receive an instance with BatchingMode - -This error will happen when using ``brainpy.OnlineTrainer`` , ``brainpy.OfflineTrainer``, ``brainpy.BPTT`` , ``brainpy.BPFF``. - -From version 2.3.1, BrainPy explicitly consider the computing mode of each model. For trainers, all training target should be a model with ``BatchingMode`` or ``TrainingMode``. - -If you are training model with ``OnlineTrainer`` or ``OfflineTrainer``, - -```python -# old version -class NGRC(bp.DynamicalSystem): - def __init__(self, num_in): - super(NGRC, self).__init__() - self.r = bp.layers.NVAR(num_in, delay=2, order=3) - self.di = bp.layers.Dense(self.r.num_out, num_in) - - def update(self, sha, x): - di = self.di(sha, self.r(sha, x)) - return x + di - - -# new version -bm.set_enviroment(mode=bm.batching_mode) - -class NGRC(bp.DynamicalSystem): - def __init__(self, num_in): - super(NGRC, self).__init__() - self.r = bp.layers.NVAR(num_in, delay=2, order=3) - self.di = bp.layers.Dense(self.r.num_out, num_in, mode=bm.training_mode) - - def update(self, sha, x): - di = self.di(sha, self.r(sha, x)) - return x + di -``` - - If you are training models with ``BPTrainer``, adding the following line at the top of the script, - -```python -bm.set_enviroment(mode=bm.training_mode) -``` - - - -#### 3. Error: inputs_are_batching is no longer supported. - -This is because if the training target is in ``batching`` mode, this has already indicated that the inputs should be batching. - -Simple remove the ``inputs_are_batching`` from your functional call of ``.predict()`` will solve the issue. +# Change from Version 2.3.1 to Version 2.3.2 +This release (under the branch of ``brainpy=2.3.x``) continues to add supports for brain-inspired computation. ## New Features +### 1. New package structure for stable API release -### 1. ``brainpy.math`` module upgrade - -#### ``brainpy.math.surrogate`` module for surrogate gradient functions. - -Currently, we support - -- `brainpy.math.surrogate.arctan` -- `brainpy.math.surrogate.erf` -- `brainpy.math.surrogate.gaussian_grad` -- `brainpy.math.surrogate.inv_square_grad` -- `brainpy.math.surrogate.leaky_relu` -- `brainpy.math.surrogate.log_tailed_relu` -- `brainpy.math.surrogate.multi_gaussian_grad` -- `brainpy.math.surrogate.nonzero_sign_log` -- `brainpy.math.surrogate.one_input` -- `brainpy.math.surrogate.piecewise_exp` -- `brainpy.math.surrogate.piecewise_leaky_relu` -- `brainpy.math.surrogate.piecewise_quadratic` -- `brainpy.math.surrogate.q_pseudo_spike` -- `brainpy.math.surrogate.relu_grad` -- `brainpy.math.surrogate.s2nn` -- `brainpy.math.surrogate.sigmoid` -- `brainpy.math.surrogate.slayer_grad` -- `brainpy.math.surrogate.soft_sign` -- `brainpy.math.surrogate.squarewave_fourier_series` - - - -#### New transformation function ``brainpy.math.to_dynsys`` - -New transformation function ``brainpy.math.to_dynsys`` supports to transform a pure Python function into a ``DynamicalSystem``. This will be useful when running a `DynamicalSystem` with arbitrary customized inputs. - -```python -import brainpy.math as bm - -hh = bp.neurons.HH(1) - -@bm.to_dynsys(child_objs=hh) -def run_hh(tdi, x=None): - if x is not None: - hh.input += x - -runner = bp.DSRunner(run_hhh, monitors={'v': hh.V}) -runner.run(inputs=bm.random.uniform(3, 6, 1000)) -``` - - - -#### Default data types - -Default data types `brainpy.math.int_`, `brainpy.math.float_` and `brainpy.math.complex_` are initialized according to the default `x64` settings. Then, these data types can be set or get by `brainpy.math.set_*` or `brainpy.math.get_*` syntaxes. - -Take default integer type ``int_`` as an example, - -```python -# set the default integer type -bm.set_int_(jax.numpy.int64) - -# get the default integer type -a1 = bm.asarray([1], dtype=bm.int_) -a2 = bm.asarray([1], dtype=bm.get_int()) # equivalent -``` - -Default data types are changed according to the `x64` setting of JAX. For instance, - -```python -bm.enable_x64() -assert bm.int_ == jax.numpy.int64 -bm.disable_x64() -assert bm.int_ == jax.numpy.int32 -``` - -``brainpy.math.float_`` and ``brainpy.math.complex_`` behaves similarly with ``brainpy.math.int_``. - - - -#### Environment context manager - -This release introduces a new concept ``computing environment`` in BrainPy. Computing environment is a default setting for current computation jobs, including the default data type (``int_``, ``float_``, ``complex_``), the default numerical integration precision (``dt``), the default computing mode (``mode``). All models, arrays, and computations using the default setting will be carried out under the environment setting. - -Users can set a default environment through - -```python -brainpy.math.set_environment(mode, dt, x64) -``` - -However, ones can also construct models or perform computation through a temporal environment context manager, this can be implemented through: - -```python -# constructing a HH model with dt=0.1 and x64 precision -with bm.environment(mode, dt=0.1, x64=True): - hh1 = bp.neurons.HH(1) - -# constructing a HH model with dt=0.05 and x32 precision -with bm.environment(mode, dt=0.05, x64=False): - hh2 = bp.neuron.HH(1) -``` - -Usually, users construct models for either brain-inspired computing (``training mode``) or brain simulation (``nonbatching mode``), therefore, there are shortcut context manager for setting a training environment or batching environment: - -```python -with bm.training_environment(dt, x64): - pass - -with bm.batching_environment(dt, x64): - pass -``` - - - -### 2. ``brainpy.dyn`` module - - - -#### ``brainpy.dyn.transfom`` module for transforming a ``DynamicalSystem`` instance to a callable ``BrainPyObject``. - -Specifically, we provide - -- `LoopOverTime` for unrolling a dynamical system over time. -- `NoSharedArg` for removing the dependency of shared arguments. - - - - - -### 3. Running supports in BrainPy - - - -#### All ``brainpy.Runner`` now are subclasses of ``BrainPyObject`` - -This means that all ``brainpy.Runner`` can be used as a part of the high-level program or transformation. - - - -#### Enable the continuous running of a differential equation (ODE, SDE, FDE, DDE, etc.) with `IntegratorRunner`. - -For example, - -```python -import brainpy as bp - -# differential equation -a, b, tau = 0.7, 0.8, 12.5 -dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext -dw = lambda w, t, V: (V + a - b * w) / tau -fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) - -# differential integrator runner -runner = bp.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.]) - -# run 1 -Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 200, 200], return_length=True) -runner.run(duration, dyn_args=dict(Iext=Iext)) -bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') - -# run 2 -Iext, duration = bp.inputs.section_input([0.5], [200], return_length=True) -runner.run(duration, dyn_args=dict(Iext=Iext)) -bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V-run2', show=True) - -``` - - - -#### Enable call a customized function during fitting of ``brainpy.BPTrainer``. - -This customized function (provided through ``fun_after_report``) will be useful to save a checkpoint during the training. For instance, - -```python -class CheckPoint: - def __init__(self, path='path/to/directory/'): - self.max_acc = 0. - self.path = path - - def __call__(self, idx, metrics, phase): - if phase == 'test' and metrics['acc'] > self.max_acc: - self.max_acc = matrics['acc'] - bp.checkpoints.save(self.path, net.state_dict(), idx) - -trainer = bp.BPTT() -trainer.fit(..., fun_after_report=CheckPoint()) -``` - - - -#### Enable data with ``data_first_axis`` format when predicting or fitting in a ``brainpy.DSRunner`` and ``brainpy.DSTrainer``. - -Previous version of BrainPy only supports data with the batch dimension at the first axis. Currently, ``brainpy.DSRunner`` and ``brainpy.DSTrainer`` can support the data with the time dimension at the first axis. This can be set through ``data_first_axis='T'`` when initializing a runner or trainer. - -```python -runner = bp.DSRunner(..., data_first_axis='T') -trainer = bp.DSTrainer(..., data_first_axis='T') -``` - - - -### 4. Utility in BrainPy - - - -#### ``brainpy.encoding`` module for encoding rate values into spike trains - - Currently, we support - -- `brainpy.encoding.LatencyEncoder` -- `brainpy.encoding.PoissonEncoder` -- `brainpy.encoding.WeightedPhaseEncoder` - - - -#### ``brainpy.checkpoints`` module for model state serialization. - -This version of BrainPy supports to save a checkpoint of the model into the physical disk. Inspired from the Flax API, we provide the following checkpoint APIs: - -- ``brainpy.checkpoints.save()`` for saving a checkpoint of the model. -- ``brainpy.checkpoints.multiprocess_save()`` for saving a checkpoint of the model in multi-process environment. -- ``brainpy.checkpoints.load()`` for loading the last or best checkpoint from the given checkpoint path. -- ``brainpy.checkpoints.load_latest()`` for retrieval the path of the latest checkpoint in a directory. - - - - - -## Deprecations - - - -### 1. Deprecations in the running supports of BrainPy - -#### ``func_monitors`` is no longer supported in all ``brainpy.Runner`` subclasses. - -We will remove its supports since version 2.4.0. Instead, monitoring with a dict of callable functions can be set in ``monitors``. For example, - - - ```python - # old version - - runner = bp.DSRunner(model, - monitors={'sps': model.spike, 'vs': model.V}, - func_monitors={'sp10': model.spike[10]}) - ``` - - ```python - # new version - runner = bp.DSRunner(model, - monitors={'sps': model.spike, - 'vs': model.V, - 'sp10': model.spike[10]}) - ``` - - - -#### ``func_inputs`` is no longer supported in all ``brainpy.Runner`` subclasses. - - Instead, giving inputs with a callable function should be done with ``inputs``. - -```python -# old version - -net = EINet() - -def f_input(tdi): - net.E.input += 10. - -runner = bp.DSRunner(net, fun_inputs=f_input, inputs=('I.input', 10.)) -``` - -```python -# new version - -def f_input(tdi): - net.E.input += 10. - net.I.input += 10. -runner = bp.DSRunner(net, inputs=f_input) -``` - - - -#### ``inputs_are_batching`` is deprecated. - -``inputs_are_batching`` is deprecated in ``predict()``/``.run()`` of all ``brainpy.Runner`` subclasses. - - - -#### ``args`` and ``dyn_args`` are now deprecated in ``IntegratorRunner``. - -Instead, users should specify ``args`` and ``dyn_args`` when using ``IntegratorRunner.run()`` function. - -```python -dV = lambda V, t, w, I: V - V * V * V / 3 - w + I -dw = lambda w, t, V, a, b: (V + a - b * w) / 12.5 -integral = bp.odeint(bp.JointEq([dV, dw]), method='exp_auto') - -# old version -runner = bp.IntegratorRunner( - integral, - monitors=['V', 'w'], - inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)}, - args={'a': 1., 'b': 1.}, # CHANGE - dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}, # CHANGE -) -runner.run(100.,) - -``` - -```python -# new version -runner = bp.IntegratorRunner( - integral, - monitors=['V', 'w'], - inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)}, -) -runner.run(100., - args={'a': 1., 'b': 1.}, - dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}) -``` - - - -### 2. Deprecations in ``brainpy.math`` module - -#### `ditype()` and `dftype()` are deprecated. +Unstable APIs are all hosted in ``brainpy._src`` module. +Other APIs are stable, and will be maintained in a long time. -`brainpy.math.ditype()` and `brainpy.math.dftype()` are deprecated. Using `brainpy.math.int_` and `brainpy.math.float()` instead. +### 2. New schedulers +- `brainpy.optim.CosineAnnealingWarmRestarts` +- `brainpy.optim.CosineAnnealingLR` +- `brainpy.optim.ExponentialLR` +- `brainpy.optim.MultiStepLR` +- `brainpy.optim.StepLR` -#### ``brainpy.modes`` module is now moved into ``brainpy.math`` -The correspondences are listed as the follows: +### 3. Others -- ``brainpy.modes.Mode`` => ``brainpy.math.Mode`` -- ``brainpy.modes.NormalMode `` => ``brainpy.math.NonBatchingMode`` -- ``brainpy.modes.BatchingMode `` => ``brainpy.math.BatchingMode`` -- ``brainpy.modes.TrainingMode `` => ``brainpy.math.TrainingMode`` -- ``brainpy.modes.normal `` => ``brainpy.math.nonbatching_mode`` -- ``brainpy.modes.batching `` => ``brainpy.math.batching_mode`` -- ``brainpy.modes.training `` => ``brainpy.math.training_mode`` +- support `static_argnums` in `brainpy.math.jit` +- fix bugs of `reset_state()` and `clear_input()` in `brainpy.channels` +- fix jit error checking