diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index 3e0bb71a2..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -name: 'Feature Request' -about: 'Suggest a new idea or improvement for Brainpy' -labels: 'enhancement' ---- - -Please: - -- [ ] Check for duplicate requests. -- [ ] Describe your goal, and if possible provide a code snippet with a motivating example. \ No newline at end of file diff --git a/.github/workflows/Linux_CI.yml b/.github/workflows/Linux_CI.yml index 3feb46635..72cce8617 100644 --- a/.github/workflows/Linux_CI.yml +++ b/.github/workflows/Linux_CI.yml @@ -36,7 +36,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | pytest brainpy/ diff --git a/brainpy/__init__.py b/brainpy/__init__.py index c90954f5c..28221aaae 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -29,7 +29,13 @@ # toolboxes -from . import connect, initialize, optimizers, measure, losses, datasets, inputs +from . import (connect, # synaptic connection + initialize, # weight initialization + optimizers, # gradient descent optimizers + losses, # loss functions + measure, # methods for data analysis + datasets, # methods for generating data + inputs) # methods for generating input currents # numerical integrators @@ -45,6 +51,14 @@ # dynamics simulation from . import dyn +from .dyn import (channels, # channel models + layers, # ANN layers + networks, # network models + neurons, # neuron groups + rates, # rate models + synapses, # synaptic dynamics + synouts, # synaptic output + synplast) # synaptic plasticity # dynamics training @@ -63,10 +77,6 @@ from .visualization import visualize -# compatible interface -from .compat import * # compat - - # convenient access conn = connect init = initialize diff --git a/brainpy/train/algorithms/__init__.py b/brainpy/algorithms/__init__.py similarity index 77% rename from brainpy/train/algorithms/__init__.py rename to brainpy/algorithms/__init__.py index 00215dc48..fd8341d6e 100644 --- a/brainpy/train/algorithms/__init__.py +++ b/brainpy/algorithms/__init__.py @@ -2,3 +2,4 @@ from .offline import * from .online import * +from . import utils diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py new file mode 100644 index 000000000..d85c382b2 --- /dev/null +++ b/brainpy/algorithms/offline.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- + +import warnings + +import numpy as np +from jax.lax import while_loop + +import brainpy.math as bm +from brainpy.base import Base +from brainpy.types import Tensor +from .utils import (Sigmoid, + Regularization, L1Regularization, L1L2Regularization, L2Regularization, + polynomial_features, normalize) + +__all__ = [ + # base class for offline training algorithm + 'OfflineAlgorithm', + + # training methods + 'LinearRegression', + 'RidgeRegression', + 'LassoRegression', + 'LogisticRegression', + 'PolynomialRegression', + 'PolynomialRidgeRegression', + 'ElasticNetRegression', + + # general supports + 'get_supported_offline_methods', + 'register_offline_method', +] + +name2func = dict() + + +class OfflineAlgorithm(Base): + """Base class for offline training algorithm.""" + + def __init__(self, name=None): + super(OfflineAlgorithm, self).__init__(name=name) + + def __call__(self, targets, inputs, outputs) -> Tensor: + """The training procedure. + + Parameters + ---------- + inputs: JaxArray, jax.numpy.ndarray, numpy.ndarray + The 3d input data with the shape of `(num_batch, num_time, num_input)`, + or, the 2d input data with the shape of `(num_time, num_input)`. + + targets: JaxArray, jax.numpy.ndarray, numpy.ndarray + The 3d target data with the shape of `(num_batch, num_time, num_output)`, + or the 2d target data with the shape of `(num_time, num_output)`. + + outputs: JaxArray, jax.numpy.ndarray, numpy.ndarray + The 3d output data with the shape of `(num_batch, num_time, num_output)`, + or the 2d output data with the shape of `(num_time, num_output)`. + + Returns + ------- + weight: JaxArray + The weights after fit. + """ + raise NotImplementedError('Must implement the __call__ function by the subclass itself.') + + def __repr__(self): + return self.__class__.__name__ + + def initialize(self, identifier, *args, **kwargs): + raise NotImplementedError('Must implement the initialize() ' + 'function by the subclass itself.') + + +def _check_data_2d_atls(x): + if x.ndim < 2: + raise ValueError(f'Data must be a 2d tensor. But we got {x.ndim}d: {x.shape}.') + if x.ndim != 2: + return x.reshape((-1, x.shape[-1])) + else: + return x + + +class RegressionAlgorithm(OfflineAlgorithm): + """ Base regression model. Models the relationship between a scalar dependent variable y and the independent + variables X. + + Parameters + ---------- + max_iter: int + The number of training iterations the algorithm will tune the weights for. + learning_rate: float + The step length that will be used when updating the weights. + """ + + def __init__( + self, + max_iter: int = None, + learning_rate: float = None, + regularizer: Regularization = None, + name: str = None + ): + super(RegressionAlgorithm, self).__init__(name=name) + self.max_iter = max_iter + self.learning_rate = learning_rate + self.regularizer = regularizer + + def initialize(self, identifier, *args, **kwargs): + pass + + def init_weights(self, n_features): + """ Initialize weights randomly [-1/N, 1/N] """ + limit = 1 / np.sqrt(n_features) + return bm.random.uniform(-limit, limit, (n_features,)) + + def gradient_descent_solve(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + + # initialize weights + w = self.init_weights(n_features=inputs.shape[1]) + + def cond_fun(a): + i, par_old, par_new = a + return bm.logical_and(bm.logical_not(bm.allclose(par_old, par_new)), + i < self.max_iter).value + + def body_fun(a): + i, par_old, par_new = a + # Gradient of regularization loss w.r.t w + y_pred = inputs.dot(w) + grad_w = -(targets - y_pred).dot(inputs) + self.regularizer.grad(par_new) + # Update the weights + par_new2 = par_new - self.learning_rate * grad_w + return i + 1, par_new, par_new2 + + # Tune parameters for n iterations + r = while_loop(cond_fun, body_fun, (0, w, w + 1.)) + return r[-1] + + def predict(self, W, X): + return X.dot(W) + + +class LinearRegression(RegressionAlgorithm): + """Training algorithm of least-square regression. + + Parameters + ---------- + name: str + The name of the algorithm. + """ + + def __init__( + self, + name: str = None, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = False, + ): + super(LinearRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=Regularization(0.)) + self.gradient_descent = gradient_descent + + def __call__(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + + # solving + if self.gradient_descent: + return self.gradient_descent_solve(targets, inputs) + else: + weights = bm.linalg.lstsq(inputs, targets) + return weights[0] + + +name2func['linear'] = LinearRegression +name2func['lstsq'] = LinearRegression + + +class RidgeRegression(RegressionAlgorithm): + """Training algorithm of ridge regression. + + Parameters + ---------- + alpha: float + The regularization coefficient. + + .. versionadded:: 2.2.0 + + beta: float + The regularization coefficient. + + .. deprecated:: 2.2.0 + Please use `alpha` to set regularization factor. + + name: str + The name of the algorithm. + """ + + def __init__( + self, + alpha: float = 1e-7, + beta: float = None, + name: str = None, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = False, + ): + if beta is not None: + warnings.warn(f"Please use 'alpha' to set regularization factor. " + f"'beta' has been deprecated since version 2.2.0.", + UserWarning) + alpha = beta + super(RidgeRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=L2Regularization(alpha=alpha)) + self.gradient_descent = gradient_descent + + def __call__(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + + # solving + if self.gradient_descent: + return self.gradient_descent_solve(targets, inputs) + else: + temp = inputs.T @ inputs + if self.regularizer.alpha > 0.: + temp += self.regularizer.alpha * bm.eye(inputs.shape[-1]) + weights = bm.linalg.pinv(temp) @ (inputs.T @ targets) + return weights + + def __repr__(self): + return f'{self.__class__.__name__}(beta={self.regularizer.alpha})' + + +name2func['ridge'] = RidgeRegression + + +class LassoRegression(RegressionAlgorithm): + """Lasso regression method for offline training. + + Parameters + ---------- + alpha: float + Constant that multiplies the L1 term. Defaults to 1.0. + `alpha = 0` is equivalent to an ordinary least square. + max_iter: int + The maximum number of iterations. + degree: int + The degree of the polynomial that the independent variable X will be transformed to. + name: str + The name of the algorithm. + """ + + def __init__( + self, + alpha: float = 1.0, + degree: int = 2, + add_bias: bool = False, + name: str = None, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(LassoRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=L1Regularization(alpha=alpha)) + self.gradient_descent = gradient_descent + self.add_bias = add_bias + assert gradient_descent + self.degree = degree + + def __call__(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + + # solving + inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)) + super(LassoRegression, self).gradient_descent_solve(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.asarray(X)) + X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) + return super(LassoRegression, self).predict(W, X) + + +name2func['lasso'] = LassoRegression + + +class LogisticRegression(RegressionAlgorithm): + """Logistic regression method for offline training. + + Parameters + ---------- + learning_rate: float + The step length that will be taken when following the negative gradient during + training. + gradient_descent: boolean + True or false depending on if gradient descent should be used when training. If + false then we use batch optimization by least squares. + max_iter: int + The number of iteration to optimize the parameters. + name: str + The name of the algorithm. + """ + + def __init__( + self, + learning_rate: float = .1, + gradient_descent: bool = True, + max_iter: int = 4000, + name: str = None, + ): + super(LogisticRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate) + self.gradient_descent = gradient_descent + self.sigmoid = Sigmoid() + + def __call__(self, targets, inputs, outputs=None) -> Tensor: + # prepare data + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + if targets.shape[-1] != 1: + raise ValueError(f'Target must be a scalar, but got multiple variables: {targets.shape}. ') + targets = targets.flatten() + + # initialize parameters + param = self.init_weights(inputs.shape[1]) + + def cond_fun(a): + i, par_old, par_new = a + return bm.logical_and(bm.logical_not(bm.allclose(par_old, par_new)), + i < self.max_iter).value + + def body_fun(a): + i, par_old, par_new = a + # Make a new prediction + y_pred = self.sigmoid(inputs.dot(par_new)) + if self.gradient_descent: + # Move against the gradient of the loss function with + # respect to the parameters to minimize the loss + par_new2 = par_new - self.learning_rate * (y_pred - targets).dot(inputs) + else: + gradient = self.sigmoid.grad(inputs.dot(par_new)) + diag_grad = bm.zeros((gradient.size, gradient.size)) + diag = bm.arange(gradient.size) + diag_grad[diag, diag] = gradient + par_new2 = bm.linalg.pinv(inputs.T.dot(diag_grad).dot(inputs)).dot(inputs.T).dot( + diag_grad.dot(inputs).dot(par_new) + targets - y_pred) + return i + 1, par_new, par_new2 + + # Tune parameters for n iterations + r = while_loop(cond_fun, body_fun, (0, param+1., param)) + return r[-1] + + def predict(self, W, X): + return self.sigmoid(X @ W) + + +name2func['logistic'] = LogisticRegression + + +class PolynomialRegression(LinearRegression): + def __init__( + self, + degree: int = 2, + name: str = None, + add_bias: bool = False, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(PolynomialRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + gradient_descent=gradient_descent) + self.degree = degree + self.add_bias = add_bias + + def __call__(self, targets, inputs, outputs=None): + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRegression, self).__call__(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.asarray(X)) + X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRegression, self).predict(W, X) + + +name2func['polynomial'] = PolynomialRegression + + +class PolynomialRidgeRegression(RidgeRegression): + def __init__( + self, + alpha: float = 1.0, + degree: int = 2, + name: str = None, + add_bias: bool = False, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(PolynomialRidgeRegression, self).__init__(alpha=alpha, + name=name, + max_iter=max_iter, + learning_rate=learning_rate, + gradient_descent=gradient_descent) + self.degree = degree + self.add_bias = add_bias + + def __call__(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRidgeRegression, self).__call__(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.asarray(X)) + X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRidgeRegression, self).predict(W, X) + + +name2func['polynomial_ridge'] = PolynomialRidgeRegression + + +class ElasticNetRegression(RegressionAlgorithm): + """ + + Parameters: + ----------- + degree: int + The degree of the polynomial that the independent variable X will be transformed to. + reg_factor: float + The factor that will determine the amount of regularization and feature + shrinkage. + l1_ration: float + Weighs the contribution of l1 and l2 regularization. + n_iterations: float + The number of training iterations the algorithm will tune the weights for. + learning_rate: float + The step length that will be used when updating the weights. + """ + + def __init__( + self, + alpha: float = 1.0, + degree: int = 2, + l1_ratio: float = 0.5, + name: str = None, + add_bias: bool = False, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(ElasticNetRegression, self).__init__( + name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=L1L2Regularization(alpha=alpha, l1_ratio=l1_ratio) + ) + self.degree = degree + self.add_bias = add_bias + self.gradient_descent = gradient_descent + assert gradient_descent + + def __call__(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.asarray(inputs)) + targets = _check_data_2d_atls(bm.asarray(targets)) + # solving + inputs = normalize(polynomial_features(inputs, degree=self.degree)) + super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.asarray(X)) + X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) + return super(ElasticNetRegression, self).predict(W, X) + + +name2func['elastic_net'] = ElasticNetRegression + + +def get_supported_offline_methods(): + """Get all supported offline training methods.""" + return tuple(name2func.keys()) + + +def register_offline_method(name: str, method: OfflineAlgorithm): + """Register a new offline learning method. + + Parameters + ---------- + name: str + The method name. + method: OfflineAlgorithm + The function method. + """ + if name in name2func: + raise ValueError(f'"{name}" has been registered in offline training methods.') + if not isinstance(method, OfflineAlgorithm): + raise ValueError(f'"method" must be an instance {OfflineAlgorithm.__name__}, but we got {type(method)}') + name2func[name] = method + + +def get(name: str) -> OfflineAlgorithm: + """Get the training function according to the training method name.""" + if name not in name2func: + raise ValueError(f'All offline methods are: {get_supported_offline_methods()}.\n' + f'But we got {name}.') + return name2func[name] diff --git a/brainpy/train/algorithms/online.py b/brainpy/algorithms/online.py similarity index 91% rename from brainpy/train/algorithms/online.py rename to brainpy/algorithms/online.py index 011e60a2c..a2a34da57 100644 --- a/brainpy/train/algorithms/online.py +++ b/brainpy/algorithms/online.py @@ -135,7 +135,7 @@ def get_supported_online_methods(): return tuple(name2func.keys()) -def register_online_method(name, method): +def register_online_method(name: str, method: OnlineAlgorithm): """Register a new oneline learning method. Parameters @@ -146,14 +146,13 @@ def register_online_method(name, method): The function method. """ if name in name2func: - raise ValueError(f'"{name}" has been registered in offline training methods.') - if not callable(method): - raise ValueError(f'"method" must be an instance of callable ' - f'function, but we got {type(method)}') + raise ValueError(f'"{name}" has been registered in online training methods. Please change another name.') + if not isinstance(method, OnlineAlgorithm): + raise ValueError(f'"method" must be an instance of {OnlineAlgorithm.__name__}, but we got {type(method)}') name2func[name] = method -def get(name): +def get(name: str): """Get the training function according to the training method name.""" if name not in name2func: raise ValueError(f'All online methods are: {get_supported_online_methods()}.\n' diff --git a/brainpy/algorithms/utils.py b/brainpy/algorithms/utils.py new file mode 100644 index 000000000..2828db854 --- /dev/null +++ b/brainpy/algorithms/utils.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- + +import brainpy.math as bm + +from itertools import combinations_with_replacement + +__all__ = [ + 'Sigmoid', + 'Regularization', + 'L1Regularization', + 'L2Regularization', + 'L1L2Regularization', + + 'polynomial_features', + 'normalize', +] + + +class Sigmoid(object): + def __call__(self, x): + return 1 / (1 + bm.exp(-x)) + + def grad(self, x): + exp = bm.exp(-x) + return exp / (1 + exp) ** 2 + + +class Regularization(object): + def __init__(self, alpha): + self.alpha = alpha + + def __call__(self, x): + return 0 + + def grad(self, x): + return 0 + + +class L1Regularization(Regularization): + """L1 Regularization.""" + + def __init__(self, alpha): + super(L1Regularization, self).__init__(alpha=alpha) + + def __call__(self, w): + return self.alpha * bm.linalg.norm(w) + + def grad(self, w): + return self.alpha * bm.sign(w) + + +class L2Regularization(Regularization): + """L2 Regularization.""" + + def __init__(self, alpha): + super(L2Regularization, self).__init__(alpha=alpha) + + def __call__(self, w): + return self.alpha * 0.5 * w.T.dot(w) + + def grad(self, w): + return self.alpha * w + + +class L1L2Regularization(Regularization): + """L1 and L2 Regularization.""" + + def __init__(self, alpha, l1_ratio=0.5): + super(L1L2Regularization, self).__init__(alpha=alpha) + self.l1_ratio = l1_ratio + + def __call__(self, w): + l1_contr = self.l1_ratio * bm.linalg.norm(w) + l2_contr = (1 - self.l1_ratio) * 0.5 * w.T.dot(w) + return self.alpha * (l1_contr + l2_contr) + + def grad(self, w): + l1_contr = self.l1_ratio * bm.sign(w) + l2_contr = (1 - self.l1_ratio) * w + return self.alpha * (l1_contr + l2_contr) + + +def index_combinations(n_features, degree): + combs = [combinations_with_replacement(range(n_features), i) for i in range(2, degree + 1)] + flat_combs = [item for sublist in combs for item in sublist] + return flat_combs + + +def polynomial_features(X, degree: int, add_bias: bool = True): + n_samples, n_features = X.shape + combinations = index_combinations(n_features, degree) + if len(combinations) == 0: + return bm.insert(X, 0, 1, axis=1) if add_bias else X + if add_bias: + n_features += 1 + X_new = bm.zeros((n_samples, 1 + n_features + len(combinations))) + if add_bias: + X_new[:, 0] = 1 + X_new[:, 1:n_features] = X + else: + X_new[:, :n_features] = X + for i, index_combs in enumerate(combinations): + X_new[:, n_features + i] = bm.prod(X[:, index_combs], axis=1) + return X_new + + +def normalize(X, axis=-1, order=2): + """ Normalize the dataset X """ + l2 = bm.atleast_1d(bm.linalg.norm(X, order, axis)) + l2 = bm.where(l2 == 0, 1, l2) + return X / bm.expand_dims(l2, axis) diff --git a/brainpy/analysis/highdim/__init__.py b/brainpy/analysis/highdim/__init__.py index 0d082af2c..07787bb60 100644 --- a/brainpy/analysis/highdim/__init__.py +++ b/brainpy/analysis/highdim/__init__.py @@ -1,3 +1,3 @@ # -*- coding: utf-8 -*- -from .slow_points import * \ No newline at end of file +from .slow_points import * diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py index 50bf8098b..98a73c4f1 100644 --- a/brainpy/analysis/highdim/slow_points.py +++ b/brainpy/analysis/highdim/slow_points.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- +import math import time import warnings -from typing import Callable, Union, Dict, Optional, Sequence +from typing import Callable, Union, Dict, Optional, Sequence, Tuple import jax.numpy as jnp import numpy as np @@ -18,6 +19,8 @@ from brainpy.dyn.runners import build_inputs, check_and_format_inputs from brainpy.errors import AnalyzerError, UnsupportedError from brainpy.types import Tensor +from brainpy.tools.others.dicts import DotDict + __all__ = [ 'SlowPointFinder', @@ -65,31 +68,41 @@ class SlowPointFinder(base.BrainPyAnalyzer): .. versionadded:: 2.2.0 t: float - The time to evaluate the fixed points. + Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`. + The time to evaluate the fixed points. Default is 0. + + .. versionadded:: 2.2.0 + + dt: float + Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`. + The numerical integration step, which can be used when . + The default is given by `brainpy.math.get_dt()`. .. versionadded:: 2.2.0 inputs: sequence + Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`. Same as ``inputs`` in :py:class:`~.DSRunner`. .. versionadded:: 2.2.0 - excluded_vars: sequence - The excluded variables (can be a sequence of `Variable` instances), - when ``f_cell`` is an instance of :py:class:`~.DynamicalSystem`. + excluded_vars: sequence, dict + Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`. + The excluded variables (can be a sequence of `Variable` instances). These variables will not be included for optimization of fixed points. .. versionadded:: 2.2.0 included_vars: dict - The target variables (can be a dict of `Variable` instances), - when ``f_cell`` is an instance of :py:class:`~.DynamicalSystem`. + Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`. + The target variables (can be a dict of `Variable` instances). These variables will be included for optimization of fixed points. The candidate points later provided should have same keys as in ``included_vars``. .. versionadded:: 2.2.0 f_loss_batch : callable, function + Parameter for `f_cell` is instance of :py:class:`~.DynamicalSystem`. The function to compute the loss. .. deprecated:: 2.2.0 @@ -102,15 +115,24 @@ def __init__( f_cell: Union[Callable, DynamicalSystem], f_type: str = None, f_loss: Callable = None, - inputs: Sequence = None, - t: float = 0., verbose: bool = True, - f_loss_batch: Callable = None, + args: Tuple = (), + + # parameters for `f_cell` is DynamicalSystem instance + inputs: Sequence = None, + t: float = None, + dt: float = None, included_vars: Dict[str, bm.Variable] = None, - excluded_vars: Sequence[bm.Variable] = (), + excluded_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, + + # deprecated + f_loss_batch: Callable = None, ): super(SlowPointFinder, self).__init__() + # static arguments + self.args = args + # update function if included_vars is None: self.included_vars = TensorCollector() @@ -118,6 +140,9 @@ def __init__( if not isinstance(included_vars, dict): raise TypeError(f'"included_vars" must be a dict but we got {type(included_vars)}') self.included_vars = TensorCollector(included_vars) + excluded_vars = () if excluded_vars is None else excluded_vars + if isinstance(excluded_vars, dict): + excluded_vars = tuple(excluded_vars.values()) if not isinstance(excluded_vars, (tuple, list)): raise TypeError(f'"excluded_vars" must be a sequence but we got {type(excluded_vars)}') for v in excluded_vars: @@ -125,14 +150,14 @@ def __init__( raise TypeError(f'"excluded_vars" must be a sequence of Variable, ' f'but we got {type(v)}') self.excluded_vars = {f'_exclude_v{i}': v for i, v in enumerate(excluded_vars)} - self.target = f_cell - if len(self.included_vars) > 0 and len(self.excluded_vars) > 0: - raise ValueError + raise ValueError('"included_vars" and "excluded_vars" cannot be provided simultaneously.') + self.target = f_cell if isinstance(f_cell, DynamicalSystem): # included variables all_vars = f_cell.vars(method='relative', level=-1, include_self=True).unique() + # exclude variables if len(self.included_vars) > 0: _all_ids = [id(v) for v in self.included_vars.values()] @@ -146,20 +171,31 @@ def __init__( for key, val in tuple(self.included_vars.items()): if id(val) in excluded_vars: self.included_vars.pop(key) + # input function if inputs is not None: inputs = check_and_format_inputs(host=self.target, inputs=inputs) - _input_step, _i = build_inputs(inputs) - if _i is not None: + _input_step, _has_iter = build_inputs(inputs) + if _has_iter: raise UnsupportedError(f'Do not support iterable inputs when using fixed point finder.') else: _input_step = None + + # check included variables + for var in self.included_vars.values(): + if var.batch_axis is not None: + if var.shape[var.batch_axis] != 1: + raise ValueError(f'Batched variables should has only one batch. ' + f'But we got {var.shape[var.batch_axis]}. Maybe ' + f'you need to call ".reset_state(batch_size=1)" ' + f'for your system.') + # update function self.f_cell = self._generate_ds_cell_function(self.target, self.included_vars, self.excluded_vars, - t, - _input_step) + t, dt, _input_step) + # check function type if f_type is not None: if f_type != constants.DISCRETE: @@ -167,12 +203,30 @@ def __init__( f'is instance of {DynamicalSystem.__name__}') f_type = constants.DISCRETE + # original data + self.included_data = {k: v.value for k, v in self.included_vars.items()} + self.excluded_data = {k: v.value for k, v in self.excluded_vars.items()} + elif callable(f_cell): - self.f_cell = f_cell + if len(self.args) > 0: + self.f_cell = lambda x: f_cell(x, *self.args) + else: + self.f_cell = f_cell if inputs is not None: raise UnsupportedError('Do not support "inputs" when "f_cell" is not instance of ' f'{DynamicalSystem.__name__}') - + if t is not None: + raise UnsupportedError('Do not support "t" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + if dt is not None: + raise UnsupportedError('Do not support "dt" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + if included_vars is not None: + raise UnsupportedError('Do not support "included_vars" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + if len(excluded_vars) > 0: + raise UnsupportedError('Do not support "excluded_vars" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') else: raise ValueError(f'Unknown type of "f_type": {type(f_cell)}') if f_type not in [constants.DISCRETE, constants.CONTINUOUS]: @@ -189,35 +243,15 @@ def __init__( f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square self.f_loss = f_loss - # functions - self._opt_functions = dict() - - # functions - if self.f_type == constants.DISCRETE: - # evaluate losses of a batch of inputs - self.f_eval_loss = bm.jit(lambda h: self.f_loss(h, vmap(self.f_cell)(h), axis=1)) - else: - # evaluate losses of a batch of inputs - self.f_eval_loss = bm.jit(lambda h: self.f_loss(vmap(self.f_cell)(h), axis=1)) - # evaluate Jacobian matrix of a batch of inputs - - # if f_type == constants.DISCRETE: - # # overall loss function for fixed points optimization - # self.f_loss = bm.jit(lambda h: f_loss(h, f_cell(h))) - # # evaluate losses of a batch of inputs - # self.f_loss_batch = bm.jit(lambda h: f_loss(h, vmap(f_cell)(h), axis=1)) - # elif f_type == constants.CONTINUOUS: - # # overall loss function for fixed points optimization - # self.f_loss = bm.jit(lambda h: f_loss(f_cell(h))) - # # evaluate losses of a batch of inputs - # self.f_loss_batch = bm.jit(lambda h: f_loss(vmap(f_cell)(h), axis=1)) - # essential variables self._losses = None self._fixed_points = None self._selected_ids = None self._opt_losses = None + # functions + self._opt_functions = dict() + @property def opt_losses(self) -> np.ndarray: """The optimization losses.""" @@ -261,7 +295,6 @@ def selected_ids(self) -> np.ndarray: def selected_ids(self, val): raise UnsupportedError('Do not support set "selected_ids" by users.') - def find_fps_with_gd_method( self, candidates: Union[Tensor, Dict[str, Tensor]], @@ -299,7 +332,6 @@ def find_fps_with_gd_method( .. versionadded:: 2.1.2 """ - # optimization settings if opt_setting is None: if optimizer is None: @@ -333,11 +365,11 @@ def find_fps_with_gd_method( raise ValueError('Candidates must be instance of JaxArray or dict of JaxArray.') leaves, tree = tree_flatten(candidates, is_leaf=lambda x: isinstance(x, bm.JaxArray)) fixed_points = tree_unflatten(tree, [bm.TrainVar(leaf) for leaf in leaves]) + f_eval_loss = self._get_f_eval_loss() def f_loss(): - return self.f_eval_loss(tree_map(lambda a: a.value, - fixed_points, - is_leaf=lambda x: isinstance(x, bm.JaxArray))).mean() + return f_eval_loss(tree_map(lambda a: a.value, fixed_points, + is_leaf=lambda x: isinstance(x, bm.JaxArray))).mean() grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True) optimizer.register_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points}) @@ -381,13 +413,18 @@ def batch_train(start_i, n_batch): f'is below tolerance {tolerance:0.10f}.') self._opt_losses = bm.concatenate(opt_losses) - self._losses = self.f_eval_loss(tree_map(lambda a: a.value, - fixed_points, - is_leaf=lambda x: isinstance(x, bm.JaxArray))) + self._losses = f_eval_loss(tree_map(lambda a: a.value, fixed_points, + is_leaf=lambda x: isinstance(x, bm.JaxArray))) self._fixed_points = tree_map(lambda a: a.value, fixed_points, is_leaf=lambda x: isinstance(x, bm.JaxArray)) self._selected_ids = jnp.arange(num_candidate) + if isinstance(self.target, DynamicalSystem): + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.included_vars.items(): + v.value = self.included_data[k] + def find_fps_with_opt_solver( self, candidates: Union[Tensor, Dict[str, Tensor]], @@ -402,23 +439,20 @@ def find_fps_with_opt_solver( opt_solver: str The solver of the optimization. """ - # optimization function num_candidate = self._check_candidates(candidates) for var in self.included_vars.values(): if bm.ndim(var) != 1: raise ValueError('Cannot use opt solver.') if self._opt_functions.get(F_OPT_SOLVER, None) is None: - self._opt_functions[F_OPT_SOLVER] = self._get_f_for_opt_solver( - candidates, SUPPORTED_OPT_SOLVERS[opt_solver]) + self._opt_functions[F_OPT_SOLVER] = self._get_f_for_opt_solver(candidates, SUPPORTED_OPT_SOLVERS[opt_solver]) f_opt = self._opt_functions[F_OPT_SOLVER] if self.verbose: print(f"Optimizing with {opt_solver} to find fixed points:") # optimizing - res = f_opt(tree_map(lambda a: a.value, - candidates, + res = f_opt(tree_map(lambda a: a.value, candidates, is_leaf=lambda a: isinstance(a, bm.JaxArray))) # results @@ -524,31 +558,75 @@ def exclude_outliers(self, tolerance: float = 1e0): f"Kept {keep_ids.shape[0]}/{num_fps} fixed points " f"with within outlier tolerance {tolerance}.") - def compute_jacobians(self, points, stack_vars=True): - """Compute the jacobian matrices at the points. + def compute_jacobians( + self, + points: Union[Tensor, Dict[str, Tensor]], + stack_dict_var: bool = True, + plot: bool = False, + num_col: int = 4, + len_col: int = 3, + len_row: int = 2, + ): + """Compute the Jacobian matrices at the points. Parameters ---------- points: np.ndarray, bm.JaxArray, jax.ndarray The fixed points with the shape of (num_point, num_dim). - stack_vars: bool + stack_dict_var: bool + Stack dictionary variables to calculate Jacobian matrix? + plot: bool + Plot the decomposition results of the Jacobian matrix. + num_col: int + The number of the figure column. + len_col: int + The length of each column. + len_row: int + The length of each row. """ - ndim = np.unique([l.ndim for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.JaxArray))[0]]) - if len(ndim) != 1: - raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}') + # check data + info = np.asarray([(l.ndim, l.shape[0]) + for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.JaxArray))[0]]) + ndim = np.unique(info[:, 0]) + if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}') if ndim[0] == 1: points = tree_map(lambda a: bm.asarray([a]), points) + num_point = 1 elif ndim[0] == 2: - pass + nsize = np.unique(info[:, 1]) + if len(nsize) != 1: raise ValueError(f'Number of the evaluated points are mis-matched. {nsize}') + num_point = nsize[0] else: raise ValueError('Only support points of 1D: (num_feature,) or 2D: (num_point, num_feature)') - - if isinstance(points, dict) and stack_vars: + if isinstance(points, dict) and stack_dict_var: points = bm.hstack(points.values()).value - return self._get_f_jocabian(stack_vars)(points) + + # get Jacobian matrix + jacobian = self._get_f_jocabian(stack_dict_var)(points) + + # visualization + if plot: + import matplotlib.pyplot as plt + from brainpy.visualization import visualize + jacobian = bm.as_numpy(jacobian) + + num_col = min(num_col, num_point) + num_row = int(math.ceil(num_point / num_col)) + fig, gs = visualize.get_figure(num_row, num_col, len_row, len_col) + for i in range(num_point): + eigval, eigvec = np.linalg.eig(np.asarray(jacobian[i])) + ax = fig.add_subplot(gs[i // num_col, i % num_col]) + ax.scatter(np.real(eigval), np.imag(eigval)) + ax.plot([1, 1] if self.f_type == constants.DISCRETE else [0, 0], [-1, 1], '--') + ax.set_xlabel('Real') + ax.set_ylabel('Imaginary') + ax.set_title(f'Point {i}') + plt.show() + + return jacobian @staticmethod - def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=True): + def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False): """Compute the eigenvalues of the matrices. Parameters @@ -587,75 +665,94 @@ def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=True): 'L': L}) return decompositions - def _get_f_for_opt_solver(self, candidates, opt_method): - # update function - if isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)): - f_cell = self.f_cell + def _get_f_eval_loss(self, ): + name = 'f_eval_loss' + if name not in self._opt_functions: + self._opt_functions[name] = self._generate_f_eval_loss() + return self._opt_functions[name] - elif isinstance(candidates, dict): - indices = [0] - for v in self.included_vars.values(): - indices.append(v.shape[0]) - indices = np.cumsum(indices) - keys = tuple(self.included_vars.keys()) + def _generate_f_eval_loss(self): + # functions + if self.f_type == constants.DISCRETE: + # evaluate losses of a batch of inputs + f_eval_loss = bm.jit(lambda h: self.f_loss(h, vmap(self.f_cell)(h), axis=1)) + else: + # evaluate losses of a batch of inputs + f_eval_loss = bm.jit(lambda h: self.f_loss(vmap(self.f_cell)(h), axis=1)) - def f_cell(x): - x = {keys[i]: x[indices[i]: indices[i + 1]] for i in range(len(keys))} - r = self.f_cell(x) + if isinstance(self.target, DynamicalSystem): + def loss_func(h): + r = f_eval_loss(h) + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.included_vars.items(): + v.value = self.included_data[k] return r + return loss_func else: - raise ValueError(f'Only supports tensor or a dict of tensors. But we got {type(candidates)}') + return f_eval_loss + def _get_f_for_opt_solver(self, candidates, opt_method): # loss function if self.f_type == constants.DISCRETE: # overall loss function for fixed points optimization if isinstance(candidates, dict): + keys = tuple(self.included_vars.keys()) + indices = [0] + for v in self.included_vars.values(): + indices.append(v.shape[0]) + indices = np.cumsum(indices) + def f_loss(h): - return bm.as_device_array( - self.f_loss({key: h[indices[i]: indices[i + 1]] for i, key in enumerate(self.included_vars.keys())}, - {k: v for k, v in f_cell(h).items() if k in self.included_vars}) - ) + h = {key: h[indices[i]: indices[i + 1]] for i, key in enumerate(keys)} + return bm.as_device_array(self.f_loss(h, self.f_cell(h))) else: def f_loss(h): - return bm.as_device_array(self.f_loss(h, f_cell(h))) + return bm.as_device_array(self.f_loss(h, self.f_cell(h))) else: # overall loss function for fixed points optimization def f_loss(h): - return self.f_loss(f_cell(h)) - - excluded_data = {k: v.value for k, v in self.excluded_vars.items()} + return self.f_loss(self.f_cell(h)) @bm.jit @vmap def f_opt(x0): for k, v in self.included_vars.items(): - v.value = x0[k] + v.value = x0[k] if v.batch_axis is None else bm.expand_dims(x0[k], axis=v.batch_axis) for k, v in self.excluded_vars.items(): - v.value = excluded_data[k] + v.value = self.excluded_data[k] if isinstance(x0, dict): x0 = bm.concatenate(tuple(x0.values())).value return opt_method(f_loss, x0) - return f_opt + def call_opt(x): + r = f_opt(x) + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.included_vars.items(): + v.value = self.included_data[k] + return r - def _generate_ds_cell_function(self, - ds_instance, - included_vars: Dict, - excluded_vars: Dict, - t=0., - f_input=None): + return call_opt if isinstance(self.target, DynamicalSystem) else f_opt - excluded_data = {k: v.value for k, v in excluded_vars.items()} + def _generate_ds_cell_function(self, ds_instance, included_vars: Dict, excluded_vars: Dict, + t: float = None, dt: float = None, f_input: Callable = None): + if dt is None: dt = bm.get_dt() + if t is None: t = 0. + shared = DotDict(t=t, dt=dt, i=0) def f_cell(h: Dict): for k, v in included_vars.items(): - v.value = bm.asarray(h[k], dtype=v.dtype) + v.value = (bm.asarray(h[k], dtype=v.dtype) + if v.batch_axis is None else + bm.asarray(bm.expand_dims(h[k], axis=v.batch_axis), dtype=v.dtype)) for k, v in excluded_vars.items(): - v.value = excluded_data[k] + v.value = self.excluded_data[k] if f_input is not None: - f_input(t, bm.get_dt()) - ds_instance.update(t, bm.get_dt()) + f_input(shared) + args = (shared, ) + self.args + ds_instance.update(*args) return {k: v.value for k, v in included_vars.items()} return f_cell @@ -680,7 +777,20 @@ def jacob(x0): else: jacob = self.f_cell - return bm.jit(vmap(bm.jacobian(jacob))) + f_jac = bm.jit(vmap(bm.jacobian(jacob))) + + if isinstance(self.target, DynamicalSystem): + def jacobian_func(x): + r = f_jac(x) + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.included_vars.items(): + v.value = self.included_data[k] + return r + + return jacobian_func + else: + return f_jac def _check_candidates(self, candidates): if isinstance(self.target, DynamicalSystem): @@ -700,7 +810,13 @@ def _check_candidates(self, candidates): raise KeyError(f'"{key}" is defined in required variables ' f'for fixed point optimization of {self.target}. ' f'Please provide its initial values.') - + for key, value in candidates.items(): + if value.ndim != self.included_vars[key].ndim + 1: + raise ValueError(f'"{key}" is defined in the required variables for fixed ' + f'point optimization of {self.target}. \n' + f'We expect the provided candidate has a batch size, ' + f'but we got {value.shape} for variable with shape of ' + f'{self.included_vars[key].shape}') if isinstance(candidates, dict): num_candidate = np.unique([leaf.shape[0] for leaf in candidates.values()]) if len(num_candidate) != 1: @@ -709,4 +825,4 @@ def _check_candidates(self, candidates): num_candidate = num_candidate[0] else: num_candidate = candidates.shape[0] - return num_candidate \ No newline at end of file + return num_candidate diff --git a/brainpy/analysis/highdim/tests/test_slow_points.py b/brainpy/analysis/highdim/tests/test_slow_points.py index 6a9961607..1ecc7f323 100644 --- a/brainpy/analysis/highdim/tests/test_slow_points.py +++ b/brainpy/analysis/highdim/tests/test_slow_points.py @@ -59,7 +59,8 @@ def dV(self, V, t, m, h, n, Iext): dVdt = (- I_Na - I_K - I_leak + Iext) / self.C return dVdt - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi.t, tdi.dt m = self.int_m(self.m, t, self.V, dt=dt) h = self.int_h(self.h, t, self.V, dt=dt) n = self.int_n(self.n, t, self.V, dt=dt) diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py index 401f68a38..85b04040f 100644 --- a/brainpy/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/analysis/lowdim/lowdim_analyzer.py @@ -157,7 +157,8 @@ def __init__( elif isinstance(resolutions, float): warnings.warn('The `resolutions` is specified to all parameters and variables. ' 'Analysis computation may occupy too much memory if `resolutions` is small. ' - 'Please specify `resolutions` by dict, such as resolutions={"V": 0.1}.', + 'Please specify `resolutions` for each parameter and variable by dict, ' + 'such as resolutions={"V": 0.1}.', category=UserWarning) for key, lim in self.target_vars.items(): self.resolutions[key] = bm.arange(*lim, resolutions) @@ -258,9 +259,9 @@ def F_fx(self): >>> self.F_fx(v1, v2, p1, p2) """ if C.F_fx not in self.analyzed_results: - _, arguments = utils.get_args(self.model.F[self.x_var]) + _, arguments = utils.get_args(self.model.f_derivatives[self.x_var]) wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) - f = wrapper(self.model.F[self.x_var]) + f = wrapper(self.model.f_derivatives[self.x_var]) f = partial(f, **(self.pars_update + self.fixed_vars)) f = utils.f_without_jaxarray_return(f) f = utils.remove_return_shape(f) @@ -419,9 +420,9 @@ def F_fy(self): >>> self.F_fy(v1, v2, p1, p2) """ if C.F_fy not in self.analyzed_results: - variables, arguments = utils.get_args(self.model.F[self.y_var]) + variables, arguments = utils.get_args(self.model.f_derivatives[self.y_var]) wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) - f = wrapper(self.model.F[self.y_var]) + f = wrapper(self.model.f_derivatives[self.y_var]) f = partial(f, **(self.pars_update + self.fixed_vars)) f = utils.f_without_jaxarray_return(f) f = utils.remove_return_shape(f) @@ -431,18 +432,18 @@ def F_fy(self): @property def F_int_x(self): if C.F_int_x not in self.analyzed_results: - wrap_x = utils.std_derivative(utils.get_args(self.model.F[self.x_var])[1], + wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.x_var])[1], self.target_var_names, self.target_par_names) - init_x = partial(wrap_x(self.model.INTG[0]), **(self.pars_update + self.fixed_vars)) + init_x = partial(wrap_x(self.model.f_integrals[0]), **(self.pars_update + self.fixed_vars)) self.analyzed_results[C.F_int_x] = init_x return self.analyzed_results[C.F_int_x] @property def F_int_y(self): if C.F_int_y not in self.analyzed_results: - wrap_x = utils.std_derivative(utils.get_args(self.model.F[self.y_var])[1], + wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.y_var])[1], self.target_var_names, self.target_par_names) - init_x = partial(wrap_x(self.model.INTG[1]), **(self.pars_update + self.fixed_vars)) + init_x = partial(wrap_x(self.model.f_integrals[1]), **(self.pars_update + self.fixed_vars)) self.analyzed_results[C.F_int_y] = init_x return self.analyzed_results[C.F_int_y] @@ -1028,9 +1029,9 @@ def __init__(self, *args, **kwargs): def F_fz(self): """The function to evaluate :math:`f_y(*\mathrm{vars}, *\mathrm{pars})`.""" if C.F_fz not in self.analyzed_results: - variables, arguments = utils.get_args(self.model.F[self.z_var]) + variables, arguments = utils.get_args(self.model.f_derivatives[self.z_var]) wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) - f = wrapper(self.model.F[self.z_var]) + f = wrapper(self.model.f_derivatives[self.z_var]) f = partial(f, **(self.pars_update + self.fixed_vars)) self.analyzed_results[C.F_fz] = bm.jit(f, device=self.jit_device) return self.analyzed_results[C.F_fz] diff --git a/brainpy/analysis/lowdim/tests/test_phase_plane.py b/brainpy/analysis/lowdim/tests/test_phase_plane.py index 735029623..f93c0bc4d 100644 --- a/brainpy/analysis/lowdim/tests/test_phase_plane.py +++ b/brainpy/analysis/lowdim/tests/test_phase_plane.py @@ -26,6 +26,7 @@ def int_x(x, t, Iext): analyzer.plot_vector_field() analyzer.plot_fixed_point() plt.show(block=block) + plt.close() bp.math.disable_x64() def test_2d_decision_making_model(self): @@ -74,4 +75,5 @@ def int_s2(s2, t, s1): analyzer.plot_nullcline(coords=dict(s2='s2-s1')) analyzer.plot_fixed_point() plt.show(block=block) + plt.close() bp.math.disable_x64() diff --git a/brainpy/analysis/utils/model.py b/brainpy/analysis/utils/model.py index 9410fc011..a877a2e52 100644 --- a/brainpy/analysis/utils/model.py +++ b/brainpy/analysis/utils/model.py @@ -4,11 +4,12 @@ import jax.numpy as jnp import brainpy.math as bm -from brainpy import errors from brainpy.dyn.base import DynamicalSystem from brainpy.dyn.runners import DSRunner +from brainpy.errors import AnalyzerError, UnsupportedError +from brainpy.integrators.base import Integrator from brainpy.integrators.joint_eq import JointEq -from brainpy.integrators.ode.base import ODEIntegrator +from brainpy.integrators.ode import ODEIntegrator, odeint __all__ = [ 'model_transform', @@ -17,63 +18,69 @@ ] +def _check_model(model): + if isinstance(model, Integrator): + if not isinstance(model, ODEIntegrator): + raise AnalyzerError(f'Must be the instance of {ODEIntegrator.__name__}, but got {model}.') + elif callable(model): + model = odeint(model) + else: + raise ValueError(f'Please provide derivative function or integral function. But we got {model}') + if isinstance(model.f, JointEq): + return [type(model)(eq, var_type=model.var_type, dt=model.dt) for eq in model.f.eqs] + else: + return [model] + + def model_transform(model): - # check integrals - if isinstance(model, NumDSWrapper): + # check model + if isinstance(model, DynamicalSystem): + model = tuple(model.nodes(level=-1).subset(ODEIntegrator).unique().values()) + elif isinstance(model, NumDSWrapper): return model elif isinstance(model, ODEIntegrator): # model = [model] - - # check model types + elif callable(model): + model = [model] + all_models = [] if isinstance(model, (list, tuple)): if len(model) == 0: - raise errors.AnalyzerError(f'Found no integrators: {model}') - model = tuple(model) - for intg in model: - if not isinstance(intg, ODEIntegrator): - raise errors.AnalyzerError(f'Must be the instance of {ODEIntegrator}, but got {intg}.') + raise AnalyzerError(f'Found no derivative/integral functions: {model}') + for fun in tuple(model): + all_models.extend(_check_model(fun)) elif isinstance(model, dict): if len(model) == 0: - raise errors.AnalyzerError(f'Found no integrators: {model}') - model = tuple(model.values()) - for intg in model: - if not isinstance(intg, ODEIntegrator): - raise errors.AnalyzerError(f'Must be the instance of {ODEIntegrator}, but got {intg}') - elif isinstance(model, DynamicalSystem): - model = tuple(model.nodes(level=-1).subset(ODEIntegrator).unique().values()) + raise AnalyzerError(f'Found no derivative/integral functions: {model}') + for fun in tuple(model.values()): + all_models.extend(_check_model(fun)) else: - raise errors.UnsupportedError(f'Dynamics analysis by symbolic approach only supports ' - f'list/tuple/dict of {ODEIntegrator} or {DynamicalSystem}, ' - f'but we got: {type(model)}: {str(model)}') - - new_model = [] - for intg in model: - if isinstance(intg.f, JointEq): - new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt) for eq in intg.f.eqs]) - else: - new_model.append(intg) + raise UnsupportedError(f'Dynamics analysis by symbolic approach only supports ' + f'derivative/integral functions or {DynamicalSystem.__name__}, ' + f'but we got: {type(model)}: {str(model)}') # pars to update pars_update = set() - for intg in new_model: - pars_update.update(intg.parameters[1:]) + for fun in all_models: + pars_update.update(fun.parameters[1:]) + # variables and parameters all_variables = set() all_parameters = set() - for integral in new_model: + for integral in all_models: + # variable if len(integral.variables) != 1: - raise errors.AnalyzerError(f'Only supports one {ODEIntegrator.__name__} one variable, ' - f'but we got {len(integral.variables)} variables in {integral}.') + raise AnalyzerError(f'Only supports one {ODEIntegrator.__name__} one variable, ' + f'but we got {len(integral.variables)} variables in {integral}.') var = integral.variables[0] if var in all_variables: - raise errors.AnalyzerError(f'Variable name {var} has been defined before. ' - f'Please change another name.') + raise AnalyzerError(f'Variable name {var} has been defined before. ' + f'Please change another name.') all_variables.add(var) - # parameters + # parameter all_parameters.update(integral.parameters[1:]) # form a dynamic model - return NumDSWrapper(integrals=new_model, + return NumDSWrapper(integrals=all_models, variables=list(all_variables), parameters=list(all_parameters), pars_update=pars_update) @@ -87,14 +94,17 @@ def __init__(self, variables, parameters, pars_update=None): - self.INTG = integrals # all integrators - self.F = {intg.variables[0]: intg.f for intg in integrals} # all integrators + self.f_integrals = integrals # all integrators + self.f_derivatives = {intg.variables[0]: intg.f for intg in integrals} # all integrators self.variables = variables # all variables self.parameters = parameters # all parameters self.pars_update = pars_update # the parameters to update self.name2integral = {intg.variables[0]: intg for intg in integrals} self.name2derivative = {intg.variables[0]: intg.f for intg in integrals} + def __repr__(self): + return f'{self.__class__.__name__}(variables={self.variables}, parameters={self.parameters})' + class TrajectModel(DynamicalSystem): def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None): diff --git a/brainpy/analysis/utils/others.py b/brainpy/analysis/utils/others.py index eb4fe9028..ef0ccffab 100644 --- a/brainpy/analysis/utils/others.py +++ b/brainpy/analysis/utils/others.py @@ -46,7 +46,7 @@ def check_initials(initials, target_var_names): assert isinstance(initials, dict) for p in target_var_names: assert p in initials - initials = {p: bm.asarray(initials[p], dtype=bm.get_dfloat()) for p in target_var_names} + initials = {p: bm.asarray(initials[p], dtype=bm.dftype()) for p in target_var_names} len_of_init = [] for v in initials.values(): assert isinstance(v, (tuple, list, np.ndarray, jnp.ndarray, bm.ndarray)) diff --git a/brainpy/base/base.py b/brainpy/base/base.py index 25f41a102..ea82463e6 100644 --- a/brainpy/base/base.py +++ b/brainpy/base/base.py @@ -150,6 +150,13 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p if _paths is None: _paths = set() gather = Collector() + if include_self: + if method == 'absolute': + gather[self.name] = self + elif method == 'relative': + gather[''] = self + else: + raise ValueError(f'No support for the method of "{method}".') if (level > -1) and (_lid >= level): return gather if method == 'absolute': @@ -168,13 +175,14 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p gather[node.name] = node nodes.append(node) for v in nodes: - gather.update(v._find_nodes(method=method, level=level, _lid=_lid + 1, _paths=_paths, + gather.update(v._find_nodes(method=method, + level=level, + _lid=_lid + 1, + _paths=_paths, include_self=include_self)) - if include_self: gather[self.name] = self elif method == 'relative': nodes = [] - if include_self: gather[''] = self for k, v in self.__dict__.items(): if isinstance(v, Base): path = (id(self), id(v)) @@ -189,8 +197,11 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p gather[key] = node nodes.append((key, node)) for k1, v1 in nodes: - for k2, v2 in v1._find_nodes(method=method, _paths=_paths, _lid=_lid + 1, - level=level, include_self=include_self).items(): + for k2, v2 in v1._find_nodes(method=method, + _paths=_paths, + _lid=_lid + 1, + level=level, + include_self=include_self).items(): if k2: gather[f'{k1}.{k2}'] = v2 else: diff --git a/brainpy/base/collector.py b/brainpy/base/collector.py index ea425ad43..e0a6095ba 100644 --- a/brainpy/base/collector.py +++ b/brainpy/base/collector.py @@ -27,7 +27,6 @@ def replace(self, key, new_value): """Replace the original key with the new value.""" self.pop(key) self[key] = new_value - # dict.__setitem__(self, key, new_value) def update(self, other, **kwargs): assert isinstance(other, dict) diff --git a/brainpy/base/tests/test_base.py b/brainpy/base/tests/test_base.py index 6c127c72a..5599d8336 100644 --- a/brainpy/base/tests/test_base.py +++ b/brainpy/base/tests/test_base.py @@ -28,7 +28,7 @@ def __init__(self): net = bp.dyn.Network(a1=A(), a2=A()) print(net.nodes(level=2)) - self.assertTrue(len(net.nodes(level=0)) == 0) + self.assertTrue(len(net.nodes(level=0)) == 1) self.assertTrue(len(net.nodes(level=0, include_self=False)) == 0) self.assertTrue(len(net.nodes(level=1)) == (1 + 2)) self.assertTrue(len(net.nodes(level=1, include_self=False)) == 2) diff --git a/brainpy/compat/__init__.py b/brainpy/compat/__init__.py deleted file mode 100644 index c77b01528..000000000 --- a/brainpy/compat/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - - -__all__ = [ - # modules - 'brainobjects', 'layers', 'nn', - - # brainobjects - 'DynamicalSystem', 'Container', 'Network', - 'ConstantDelay', 'NeuGroup', 'TwoEndConn', - - # integrators - 'set_default_odeint', 'set_default_sdeint', - 'get_default_odeint', 'get_default_sdeint', - - # runners - 'IntegratorRunner', 'DSRunner', 'StructRunner', 'ReportRunner' -] - -from . import brainobjects, layers, nn -from .brainobjects import * -from .integrators import * -from .runners import * diff --git a/brainpy/compat/brainobjects.py b/brainpy/compat/brainobjects.py deleted file mode 100644 index 075464574..000000000 --- a/brainpy/compat/brainobjects.py +++ /dev/null @@ -1,198 +0,0 @@ -# -*- coding: utf-8 -*- - -import math as pm -import warnings - -import brainpy.math as bm -from brainpy import dyn -from brainpy import tools -from brainpy.errors import ModelBuildError - -__all__ = [ - 'DynamicalSystem', - 'Container', - 'Network', - 'ConstantDelay', - 'NeuGroup', - 'TwoEndConn', -] - - -class DynamicalSystem(dyn.DynamicalSystem): - """Dynamical System. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.DynamicalSystem" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.DynamicalSystem" instead. ' - '"brainpy.DynamicalSystem" is deprecated since ' - 'version 2.0.3', DeprecationWarning) - super(DynamicalSystem, self).__init__(*args, **kwargs) - - -class Container(dyn.Container): - """Container. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.Container" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.Container" instead. ' - '"brainpy.Container" is deprecated since ' - 'version 2.0.3', DeprecationWarning) - super(Container, self).__init__(*args, **kwargs) - - -class Network(dyn.Network): - """Network. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.Network" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.Network" instead. ' - '"brainpy.Network" is deprecated since ' - 'version 2.0.3', DeprecationWarning) - super(Network, self).__init__(*args, **kwargs) - - -class ConstantDelay(dyn.DynamicalSystem): - """Class used to model constant delay variables. - - This class automatically supports batch size on the last axis. For example, if - you run batch with the size of (10, 100), where `100` are batch size, then this - class can automatically support your batched data. - For examples, - - >>> import brainpy as bp - >>> bp.dyn.ConstantDelay(size=(10, 100), delay=10.) - - This class also support nonuniform delays. - - >>> bp.dyn.ConstantDelay(size=100, delay=bp.math.random.random(100) * 4 + 10) - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.ConstantDelay" instead. - - Parameters - ---------- - size : int, list of int, tuple of int - The delay data size. - delay : int, float, function, ndarray - The delay time. With the unit of `dt`. - dt: float, optional - The time precision. - name : optional, str - The name of the dynamic system. - """ - - def __init__(self, size, delay, dtype=None, dt=None, **kwargs): - warnings.warn('Please use "brainpy.dyn.ConstantDelay" instead. ' - '"brainpy.ConstantDelay" is deprecated since ' - 'version 2.0.3', DeprecationWarning) - - # dt - self.dt = bm.get_dt() if dt is None else dt - self.dtype = dtype - - # data size - if isinstance(size, int): size = (size,) - if not isinstance(size, (tuple, list)): - raise ModelBuildError(f'"size" must a tuple/list of int, but we got {type(size)}: {size}') - self.size = tuple(size) - - # delay time length - self.delay = delay - - # data and operations - if isinstance(delay, (int, float)): # uniform delay - self.uniform_delay = True - self.num_step = int(pm.ceil(delay / self.dt)) + 1 - self.out_idx = bm.Variable(bm.array([0], dtype=bm.uint32)) - self.in_idx = bm.Variable(bm.array([self.num_step - 1], dtype=bm.uint32)) - self.data = bm.Variable(bm.zeros((self.num_step,) + self.size, dtype=dtype)) - self.num = 1 - - else: # non-uniform delay - self.uniform_delay = False - if not len(self.size) == 1: - raise NotImplementedError(f'Currently, BrainPy only supports 1D heterogeneous ' - f'delays, while we got the heterogeneous delay with ' - f'{len(self.size)}-dimensions.') - self.num = tools.size2num(size) - if bm.ndim(delay) != 1: - raise ModelBuildError(f'Only support a 1D non-uniform delay. ' - f'But we got {delay.ndim}D: {delay}') - if delay.shape[0] != self.size[0]: - raise ModelBuildError(f"The first shape of the delay time size must " - f"be the same with the delay data size. But " - f"we got {delay.shape[0]} != {self.size[0]}") - delay = bm.around(delay / self.dt) - self.diag = bm.array(bm.arange(self.num)) - self.num_step = bm.array(delay, dtype=bm.uint32) + 1 - self.in_idx = bm.Variable(self.num_step - 1) - self.out_idx = bm.Variable(bm.zeros(self.num, dtype=bm.uint32)) - self.data = bm.Variable(bm.zeros((self.num_step.max(),) + size, dtype=dtype)) - - super(ConstantDelay, self).__init__(**kwargs) - - def reset(self): - """Reset the variables.""" - self.in_idx[:] = self.num_step - 1 - self.out_idx[:] = 0 - self.data[:] = 0 - - @property - def oldest(self): - return self.pull() - - @property - def latest(self): - if self.uniform_delay: - return self.data[self.in_idx[0]] - else: - return self.data[self.in_idx, self.diag] - - def pull(self): - if self.uniform_delay: - return self.data[self.out_idx[0]] - else: - return self.data[self.out_idx, self.diag] - - def push(self, value): - if self.uniform_delay: - self.data[self.in_idx[0]] = value - else: - self.data[self.in_idx, self.diag] = value - - def update(self, t=None, dt=None, **kwargs): - """Update the delay index.""" - self.in_idx[:] = (self.in_idx + 1) % self.num_step - self.out_idx[:] = (self.out_idx + 1) % self.num_step - - -class NeuGroup(dyn.NeuGroup): - """Neuron group. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.NeuGroup" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.NeuGroup" instead. ' - '"brainpy.NeuGroup" is deprecated since ' - 'version 2.0.3', DeprecationWarning) - super(NeuGroup, self).__init__(*args, **kwargs) - - -class TwoEndConn(dyn.TwoEndConn): - """Two-end synaptic connection. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.TwoEndConn" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.TwoEndConn" instead. ' - '"brainpy.TwoEndConn" is deprecated since ' - 'version 2.0.3', DeprecationWarning) - super(TwoEndConn, self).__init__(*args, **kwargs) diff --git a/brainpy/compat/integrators.py b/brainpy/compat/integrators.py deleted file mode 100644 index 3980ad446..000000000 --- a/brainpy/compat/integrators.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -from brainpy.integrators import ode, sde - -__all__ = [ - 'set_default_odeint', - 'set_default_sdeint', - 'get_default_odeint', - 'get_default_sdeint', -] - - -def set_default_odeint(method): - """Set default ode integrator. - - .. deprecated:: 2.1.0 - Please use "brainpy.ode.set_default_odeint" instead. - """ - warnings.warn('Please use "brainpy.ode.set_default_odeint" instead. ' - '"brainpy.set_default_odeint" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - ode.set_default_odeint(method) - - -def get_default_odeint(): - """Get default ode integrator. - - .. deprecated:: 2.1.0 - Please use "brainpy.ode.get_default_odeint" instead. - """ - warnings.warn('Please use "brainpy.ode.get_default_odeint" instead. ' - '"brainpy.get_default_odeint" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - ode.get_default_odeint() - - -def set_default_sdeint(method): - """Set default sde integrator. - - .. deprecated:: 2.1.0 - Please use "brainpy.ode.set_default_sdeint" instead. - """ - warnings.warn('Please use "brainpy.sde.set_default_sdeint" instead. ' - '"brainpy.set_default_sdeint" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - sde.set_default_sdeint(method) - - -def get_default_sdeint(): - """Get default sde integrator. - - .. deprecated:: 2.1.0 - Please use "brainpy.ode.get_default_sdeint" instead. - """ - warnings.warn('Please use "brainpy.sde.get_default_sdeint" instead. ' - '"brainpy.get_default_sdeint" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - sde.get_default_sdeint() diff --git a/brainpy/compat/layers.py b/brainpy/compat/layers.py deleted file mode 100644 index 23a17727e..000000000 --- a/brainpy/compat/layers.py +++ /dev/null @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -import jax.numpy as jnp -import numpy as onp - -import brainpy.math as bm -from brainpy.base.base import Base - -__all__ = [ - 'Module', -] - - -def _check_args(args): - if args is None: - return tuple() - elif isinstance(args, tuple): - return args - else: - return (args,) - - -class Module(Base): - """Basic module class. - - .. deprecated:: 2.1.0 - """ - - @staticmethod - def get_param(param, size): - return bm.TrainVar(Module.init_param(param, size)) - - @staticmethod - def init_param(param, size): - if param is None: - return None - if callable(param): - param = param(size) - elif isinstance(param, onp.ndarray): - param = bm.asarray(param) - elif isinstance(param, (bm.JaxArray, jnp.ndarray)): - pass - else: - raise ValueError(f'Unknown param type {type(param)}: {param}') - assert param.shape == size, f'"param.shape" is not the required size {size}' - return param - - def __init__(self, name=None): # initialize parameters - warnings.warn('Please use "brainpy.rnns.Module" instead. ' - '"brainpy.layers.Module" is deprecated since ' - 'version 2.1.0.', DeprecationWarning) - super(Module, self).__init__(name=name) - - def __call__(self, *args, **kwargs): # initialize variables - return self.call(*args, **kwargs) - - def call(self, *args, **kwargs): - raise NotImplementedError - diff --git a/brainpy/compat/nn/__init__.py b/brainpy/compat/nn/__init__.py deleted file mode 100644 index 0eb39fdf2..000000000 --- a/brainpy/compat/nn/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Neural Networks (nn)""" - -from .base import * -from .datatypes import * -from .graph_flow import * -from .nodes import * -from .graph_flow import * -from .operations import * -from .utils import * -from .runners import * -from . import algorithms - diff --git a/brainpy/compat/nn/algorithms/__init__.py b/brainpy/compat/nn/algorithms/__init__.py deleted file mode 100644 index 00215dc48..000000000 --- a/brainpy/compat/nn/algorithms/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -*- coding: utf-8 -*- - -from .offline import * -from .online import * diff --git a/brainpy/compat/nn/algorithms/offline.py b/brainpy/compat/nn/algorithms/offline.py deleted file mode 100644 index bd07d4184..000000000 --- a/brainpy/compat/nn/algorithms/offline.py +++ /dev/null @@ -1,184 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy.math as bm -from brainpy.base import Base - -__all__ = [ - # base class for offline training algorithm - 'OfflineAlgorithm', - - # training methods - 'RidgeRegression', - 'LinearRegression', - - # general supports - 'get_supported_offline_methods', - 'register_offline_method', -] - -name2func = dict() - - -class OfflineAlgorithm(Base): - """Base class for offline training algorithm.""" - - def __init__(self, name=None): - super(OfflineAlgorithm, self).__init__(name=name) - - def __call__(self, targets, inputs, outputs): - """The training procedure. - - Parameters - ---------- - inputs: JaxArray, jax.numpy.ndarray, numpy.ndarray - The 3d input data with the shape of `(num_batch, num_time, num_input)`, - or, the 2d input data with the shape of `(num_time, num_input)`. - - targets: JaxArray, jax.numpy.ndarray, numpy.ndarray - The 3d target data with the shape of `(num_batch, num_time, num_output)`, - or the 2d target data with the shape of `(num_time, num_output)`. - - outputs: JaxArray, jax.numpy.ndarray, numpy.ndarray - The 3d output data with the shape of `(num_batch, num_time, num_output)`, - or the 2d output data with the shape of `(num_time, num_output)`. - - Returns - ------- - weight: JaxArray - The weights after fit. - """ - raise NotImplementedError('Must implement the __call__ function by the subclass itself.') - - def __repr__(self): - return self.__class__.__name__ - - -class RidgeRegression(OfflineAlgorithm): - """Training algorithm of ridge regression. - - Parameters - ---------- - beta: float - The regularization coefficient. - """ - - def __init__(self, beta=1e-7, name=None): - super(RidgeRegression, self).__init__(name=name) - self.beta = beta - - def __call__(self, targets, inputs, outputs=None): - # checking - inputs = bm.asarray(inputs).reshape((-1, inputs.shape[2])) - targets = bm.asarray(targets).reshape((-1, targets.shape[2])) - # solving - temp = inputs.T @ inputs - if self.beta > 0.: - temp += self.beta * bm.eye(inputs.shape[-1]) - weights = bm.linalg.pinv(temp) @ (inputs.T @ targets) - return weights - - def __repr__(self): - return f'{self.__class__.__name__}(beta={self.beta})' - - -name2func['ridge'] = RidgeRegression - - -class LinearRegression(OfflineAlgorithm): - """Training algorithm of least-square regression.""" - - def __init__(self, name=None): - super(LinearRegression, self).__init__(name=name) - - def __call__(self, targets, inputs, outputs=None): - inputs = bm.asarray(inputs).reshape((-1, inputs.shape[2])) - targets = bm.asarray(targets).reshape((-1, targets.shape[2])) - weights = bm.linalg.lstsq(inputs, targets) - return weights[0] - - -name2func['linear'] = LinearRegression -name2func['lstsq'] = LinearRegression - - -class LassoRegression(OfflineAlgorithm): - """Lasso regression method for offline training. - - Parameters - ---------- - alpha: float - Constant that multiplies the L1 term. Defaults to 1.0. - `alpha = 0` is equivalent to an ordinary least square. - max_iter: int - The maximum number of iterations. - """ - - def __init__(self, alpha=1.0, max_iter=1000, name=None): - super(LassoRegression, self).__init__(name=name) - self.alpha = alpha - self.max_iter = max_iter - - def __call__(self, *args, **kwargs): - pass - - -# name2func['lasso'] = LassoRegression - - -def elastic_net_regression(x, y, train_pars): - pass - - -# name2func['elastic_net'] = elastic_net_regression - - -def logistic_regression(x, y, train_pars): - pass - - -# name2func['logistic'] = logistic_regression - - -def polynomial_regression(x, y, train_pars): - pass - - -# name2func['polynomial'] = polynomial_regression - - -def stepwise_regression(x, y, train_pars): - pass - - -# name2func['stepwise'] = stepwise_regression - - -def get_supported_offline_methods(): - """Get all supported offline training methods.""" - return tuple(name2func.keys()) - - -def register_offline_method(name, method): - """Register a new offline learning method. - - Parameters - ---------- - name: str - The method name. - method: callable - The function method. - """ - if name in name2func: - raise ValueError(f'"{name}" has been registered in offline training methods.') - if not callable(method): - raise ValueError(f'"method" must be an instance of callable ' - f'function, but we got {type(method)}') - name2func[name] = method - - -def get(name): - """Get the training function according to the training method name.""" - if name not in name2func: - raise ValueError(f'All offline methods are: {get_supported_offline_methods()}.\n' - f'But we got {name}.') - return name2func[name] diff --git a/brainpy/compat/nn/algorithms/online.py b/brainpy/compat/nn/algorithms/online.py deleted file mode 100644 index 0793345b7..000000000 --- a/brainpy/compat/nn/algorithms/online.py +++ /dev/null @@ -1,161 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy.math as bm -from brainpy.base import Base - -__all__ = [ - # base class - 'OnlineAlgorithm', - - # online learning algorithms - 'ForceLearning', - 'RLS', - 'LMS', - - # generic methods - 'get_supported_online_methods', - 'register_online_method', -] - -name2func = dict() - - -class OnlineAlgorithm(Base): - """Base class for online training algorithm.""" - - def __init__(self, name=None): - super(OnlineAlgorithm, self).__init__(name=name) - - def __call__(self, name, target, input, output): - """The training procedure. - - Parameters - ---------- - name: str - The variable name. - target: JaxArray, ndarray - The 2d target data with the shape of `(num_batch, num_output)`. - input: JaxArray, ndarray - The 2d input data with the shape of `(num_batch, num_input)`. - output: JaxArray, ndarray - The 2d output data with the shape of `(num_batch, num_output)`. - - Returns - ------- - weight: JaxArray - The weights after fit. - """ - return self.call(name, target, input, output) - - def initialize(self, name, *args, **kwargs): - raise NotImplementedError('Must implement the initialize() ' - 'function by the subclass itself.') - - def call(self, name, target, input, output): - """The training procedure. - - Parameters - ---------- - name: str - The variable name. - target: JaxArray, ndarray - The 2d target data with the shape of `(num_batch, num_output)`. - input: JaxArray, ndarray - The 2d input data with the shape of `(num_batch, num_input)`. - output: JaxArray, ndarray - The 2d output data with the shape of `(num_batch, num_output)`. - - Returns - ------- - weight: JaxArray - The weights after fit. - """ - raise NotImplementedError('Must implement the call() function by the subclass itself.') - - def __repr__(self): - return self.__class__.__name__ - - -class RLS(OnlineAlgorithm): - """The recursive least squares (RLS).""" - - postfix = '.rls.P' - - def __init__(self, alpha=0.1, name=None): - super(RLS, self).__init__(name=name) - self.alpha = alpha - - def initialize(self, name, feature_in, feature_out=None): - name = name + self.postfix - self.implicit_vars[name] = bm.Variable(bm.eye(feature_in) * self.alpha) - - def call(self, name, target, input, output): - name = name + self.postfix - P = self.implicit_vars[name] - # update the inverse correlation matrix - k = bm.dot(P, input.T) # (num_input, num_batch) - hPh = bm.dot(input, k) # (num_batch, num_batch) - c = bm.sum(1.0 / (1.0 + hPh)) # () - P -= c * bm.dot(k, k.T) # (num_input, num_input) - # update weights - e = output - target # (num_batch, num_output) - dw = -c * bm.dot(k, e) # (num_input, num_output) - return dw - - -name2func['rls'] = RLS - - -class ForceLearning(RLS): - postfix = '.force.P' - - -name2func['force'] = ForceLearning - - -class LMS(OnlineAlgorithm): - """The least mean squares (LMS). """ - - def __init__(self, alpha=0.1, name=None): - super(LMS, self).__init__(name=name) - self.alpha = alpha - - def initialize(self, name, *args, **kwargs): - pass - - def call(self, name, target, input, output): - return -self.alpha * bm.dot(output - target, output) - - -name2func['lms'] = LMS - - -def get_supported_online_methods(): - """Get all supported online training methods.""" - return tuple(name2func.keys()) - - -def register_online_method(name, method): - """Register a new oneline learning method. - - Parameters - ---------- - name: str - The method name. - method: callable - The function method. - """ - if name in name2func: - raise ValueError(f'"{name}" has been registered in offline training methods.') - if not callable(method): - raise ValueError(f'"method" must be an instance of callable ' - f'function, but we got {type(method)}') - name2func[name] = method - - -def get(name): - """Get the training function according to the training method name.""" - if name not in name2func: - raise ValueError(f'All online methods are: {get_supported_online_methods()}.\n' - f'But we got {name}.') - return name2func[name] diff --git a/brainpy/compat/nn/base.py b/brainpy/compat/nn/base.py deleted file mode 100644 index c3038c1a2..000000000 --- a/brainpy/compat/nn/base.py +++ /dev/null @@ -1,1600 +0,0 @@ -# -*- coding: utf-8 -*- - - -""" -This module provide basic Node class for whole ``brainpy.nn`` system. - -- ``brainpy.nn.Node``: The fundamental class representing the node or the element. -- ``brainpy.nn.RecurrentNode``: The recurrent node which has a self-connection. -- ``brainpy.nn.Network``: The network model which is composed of multiple node elements. - Once the Network instance receives a node operation, the wrapped elements, the new - elements, and their connection edges will be formed as another Network instance. - This means ``brainpy.nn.Network`` is only used to pack element nodes. It will be - never be an element node. -- ``brainpy.nn.FrozenNetwork``: The whole network which can be represented as a basic - elementary node when composing a larger network (TODO). -""" - -from copy import copy, deepcopy -from typing import (Dict, Sequence, Tuple, Union, Optional, Any, Callable) - -import jax.numpy as jnp - -from brainpy import tools, math as bm -from brainpy.base import Base, Collector -from brainpy.errors import (UnsupportedError, - PackageMissingError, - ModelBuildError, - MathError) -from brainpy.compat.nn.algorithms.offline import OfflineAlgorithm -from brainpy.compat.nn.algorithms.online import OnlineAlgorithm -from brainpy.compat.nn.datatypes import (DataType, SingleData, MultipleData) -from brainpy.compat.nn.graph_flow import (find_senders_and_receivers, - find_entries_and_exits, - detect_cycle, - detect_path) -from brainpy.tools.checking import (check_dict_data, - check_shape_except_batch, - check_integer) -from brainpy.types import Tensor - -operations = None - -__all__ = [ - 'Node', 'Network', - 'RecurrentNode', # a marker for recurrent node - 'FrozenNetwork', # a marker for frozen network -] - -NODE_STATES = ['inputs', 'feedbacks', 'state', 'output'] - -SUPPORTED_LAYOUTS = ['shell_layout', - 'multipartite_layout', - 'spring_layout', - 'spiral_layout', - 'spectral_layout', - 'random_layout', - 'planar_layout', - 'kamada_kawai_layout', - 'circular_layout'] - - -def not_implemented(fun: Callable) -> Callable: - """Marks the given module method is not implemented. - - Methods wrapped in @not_implemented can define submodules directly within the method. - - For instance:: - - @not_implemented - init_fb(self): - ... - - @not_implemented - def feedback(self): - ... - """ - fun.not_implemented = True - return fun - - -class Node(Base): - """Basic Node class for neural network building in BrainPy.""" - - '''Support multiple types of data pass, including "PassOnlyOne" (by default), - "PassSequence", "PassNameDict", etc. and user-customized type which inherits - from basic "SingleData" or "MultipleData". - - This setting will change the feedforward/feedback input data which pass into - the "call()" function and the sizes of the feedforward/feedback input data.''' - data_pass = SingleData() - - '''Offline fitting method.''' - offline_fit_by: Union[Callable, OfflineAlgorithm] - - '''Online fitting method.''' - online_fit_by: OnlineAlgorithm - - def __init__( - self, - name: Optional[str] = None, - input_shape: Optional[Union[Sequence[int], int]] = None, - trainable: bool = True - ): - - # initialize parameters - self._feedforward_shapes = None # input shapes - self._output_shape = None # output size - self._feedback_shapes = None # feedback shapes - self._is_ff_initialized = False - self._is_fb_initialized = False - self._is_state_initialized = False - self._is_fb_state_initialized = False - self._trainable = trainable - self._state = None # the state of the current node - self._fb_output = None # the feedback output of the current node - # data pass - if not isinstance(self.data_pass, DataType): - raise ValueError(f'Unsupported data pass type {type(self.data_pass)}. ' - f'Only support {DataType.__class__}') - - # super initialization - super(Node, self).__init__(name=name) - - # parameters - if input_shape is not None: - self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)} - - def __repr__(self): - return (f"{type(self).__name__}(name={self.name}, " - f"forwards={self.feedforward_shapes}, " - f"feedbacks={self.feedback_shapes}, " - f"output={self.output_shape})") - - def __call__(self, *args, **kwargs) -> Tensor: - """The main computation function of a Node. - - Parameters - ---------- - ff: dict, sequence, JaxArray, ndarray - The feedforward inputs. - fb: optional, dict, sequence, JaxArray, ndarray - The feedback inputs. - forced_states: optional, dict - The fixed state for the nodes in the network. - forced_feedbacks: optional, dict - The fixed feedback for the nodes in the network. - monitors: optional, sequence - Can be used to monitor the state or the attribute of a node in the network. - **kwargs - Other parameters which will be parsed into every node. - - Returns - ------- - Tensor - A output tensor value, or a dict of output tensors. - """ - return self._call(*args, **kwargs) - - def __rshift__(self, other): # "self >> other" - global operations - if operations is None: from . import operations - return operations.ff_connect(self, other) - - def __rrshift__(self, other): # "other >> self" - global operations - if operations is None: from . import operations - return operations.ff_connect(other, self) - - def __irshift__(self, other): # "self >>= other" - raise ValueError('Only Network objects support inplace feedforward connection.') - - def __lshift__(self, other): # "self << other" - global operations - if operations is None: from . import operations - return operations.fb_connect(other, self) - - def __rlshift__(self, other): # "other << self" - global operations - if operations is None: from . import operations - return operations.fb_connect(self, other) - - def __ilshift__(self, other): # "self <<= other" - raise ValueError('Only Network objects support inplace feedback connection.') - - def __and__(self, other): # "self & other" - global operations - if operations is None: from . import operations - return operations.merge(self, other) - - def __rand__(self, other): # "other & self" - global operations - if operations is None: from . import operations - return operations.merge(other, self) - - def __iand__(self, other): - raise ValueError('Only Network objects support inplace merging.') - - def __getitem__(self, item): # like "[:10]" - if isinstance(item, str): - raise ValueError('Node only supports slice, not retrieve by the name.') - else: - global operations - if operations is None: from . import operations - return operations.select(self, item) - - @property - def state(self) -> Optional[Tensor]: - """Node current internal state.""" - if self._is_ff_initialized: - return self._state - return None - - @state.setter - def state(self, value: Tensor): - raise NotImplementedError('Please use "set_state()" to reset the node state, ' - 'or use "self.state.value" to change the state content.') - - def set_state(self, state): - """ - Safely set the state of the node. - - This method allows the maximum flexibility to change the - node state. It can set a new data (same shape, same dtype) - to the state. It can also set a new data with the different - shape. We highly recommend the user to use this function. - instead of using ``self.state.value``. - """ - if self.state is None: - if self.output_shape is not None: - check_shape_except_batch(self.output_shape, state.shape) - self._state = bm.Variable(state) if not isinstance(state, bm.Variable) else state - else: - check_shape_except_batch(self.state.shape, state.shape) - if self.state.dtype != state.dtype: - raise MathError('Cannot set the state, because the dtype is not consistent: ' - f'{self.state.dtype} != {state.dtype}') - self.state._value = bm.as_device_array(state) - - @property - def fb_output(self) -> Optional[Tensor]: - return self._fb_output - - @fb_output.setter - def fb_output(self, value: Tensor): - raise NotImplementedError('Please use "set_fb_output()" to reset the node feedback state, ' - 'or use "self.fb_output.value" to change the state content.') - - def set_fb_output(self, state: Tensor): - """ - Safely set the feedback state of the node. - - This method allows the maximum flexibility to change the - node state. It can set a new data (same shape, same dtype) - to the state. It can also set a new data with the different - shape. We highly recommend the user to use this function. - instead of using ``self.fb_output.value``. - """ - if self.fb_output is None: - if self.output_shape is not None: - check_shape_except_batch(self.output_shape, state.shape) - self._fb_output = bm.Variable(state) if not isinstance(state, bm.Variable) else state - else: - check_shape_except_batch(self.fb_output.shape, state.shape) - if self.fb_output.dtype != state.dtype: - raise MathError('Cannot set the feedback state, because the dtype is ' - f'not consistent: {self.fb_output.dtype} != {state.dtype}') - self.fb_output._value = bm.as_device_array(state) - - @property - def trainable(self) -> bool: - """Returns if the Node can be trained.""" - return self._trainable - - @property - def is_initialized(self) -> bool: - if self._is_ff_initialized and self._is_state_initialized: - if self.feedback_shapes is not None: - if self._is_fb_initialized and self._is_fb_state_initialized: - return True - else: - return False - else: - return True - else: - return False - - @trainable.setter - def trainable(self, value: bool): - """Freeze or unfreeze the Node. If set to False, - learning is stopped.""" - assert isinstance(value, bool), 'Must be a boolean.' - self._trainable = value - - @property - def feedforward_shapes(self): - """Input data size.""" - return self.data_pass.filter(self._feedforward_shapes) - - @feedforward_shapes.setter - def feedforward_shapes(self, size): - self.set_feedforward_shapes(size) - - def set_feedforward_shapes(self, feedforward_shapes: Dict): - if not self._is_ff_initialized: - check_dict_data(feedforward_shapes, - key_type=(Node, str), - val_type=(list, tuple), - name='feedforward_shapes') - self._feedforward_shapes = feedforward_shapes - else: - if self.feedforward_shapes is not None: - sizes1 = sorted(list(self._feedforward_shapes.values())) - sizes2 = sorted(list(feedforward_shapes.values())) - if sizes1 != sizes2: - raise ValueError(f"Impossible to reset the input shapes of {self.name}. " - f"Because this Node has the input shapes {sizes1}. " - f"While we got input shapes of {sizes2}") - self._feedforward_shapes = feedforward_shapes - - @property - def feedback_shapes(self): - """Output data size.""" - return self.data_pass.filter(self._feedback_shapes) - - @feedback_shapes.setter - def feedback_shapes(self, size): - self.set_feedback_shapes(size) - - def set_feedback_shapes(self, fb_shapes: Dict): - if not self._is_fb_initialized: - check_dict_data(fb_shapes, - key_type=(Node, str), - val_type=(tuple, list), - name='fb_shapes') - self._feedback_shapes = fb_shapes - else: - if self.feedback_shapes is not None: - sizes1 = sorted(list(self._feedback_shapes.values())) - sizes2 = sorted(list(fb_shapes.values())) - if sizes1 != sizes2: - raise ValueError(f"Impossible to reset the feedback shapes of {self.name}. " - f"Because this Node has the feedback shapes {sizes1}. " - f"While we got feedback shapes of {sizes2}") - self._feedback_shapes = fb_shapes - - @property - def output_shape(self) -> Optional[Tuple[int]]: - """Output data size.""" - return self._output_shape - - @output_shape.setter - def output_shape(self, size): - self.set_output_shape(size) - - @property - def is_feedback_input_supported(self): - if hasattr(self.init_fb_conn, 'not_implemented'): - if self.init_fb_conn.not_implemented: - return False - return True - - @property - def is_feedback_supported(self): - if self.fb_output is None: - return False - else: - return True - - def set_output_shape(self, shape: Sequence[int]): - if not self._is_ff_initialized: - if not isinstance(shape, (tuple, list)): - raise ValueError(f'Must be a sequence of int, but got {shape}') - self._output_shape = tuple(shape) - else: - check_shape_except_batch(shape, self.output_shape) - - def nodes(self, method='absolute', level=1, include_self=True): - return super(Node, self).nodes(method=method, level=level, include_self=include_self) - - def vars(self, method='absolute', level=1, include_self=True): - return super(Node, self).vars(method=method, level=level, include_self=include_self) - - def train_vars(self, method='absolute', level=1, include_self=True): - return super(Node, self).train_vars(method=method, level=level, include_self=include_self) - - def copy(self, - name: str = None, - shallow: bool = False): - """Returns a copy of the Node. - - Parameters - ---------- - name : str - Name of the Node copy. - shallow : bool, default to False - If False, performs a deep copy of the Node. - - Returns - ------- - Node - A copy of the Node. - """ - if shallow: - new_obj = copy(self) - else: - new_obj = deepcopy(self) - new_obj.name = self.unique_name(name or (self.name + '_copy')) - return new_obj - - def _init_ff_conn(self): - if not self._is_ff_initialized: - self.init_ff_conn() - if self.output_shape is None: - raise ValueError(f'Please set the output shape when implementing ' - f'"init_ff_conn()" of the node {self.name}') - self._is_ff_initialized = True - - def _init_fb_conn(self): - if not self._is_fb_initialized: - try: - self.init_fb_conn() - except Exception as e: - raise ModelBuildError(f"{self.name} initialization failed.") from e - self._is_fb_initialized = True - - @not_implemented - def init_fb_conn(self): - """Initialize the feedback connections. - This function will be called only once.""" - raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.') - - def init_ff_conn(self): - """Initialize the feedforward connections. - This function will be called only once.""" - raise NotImplementedError('Please implement the feedforward initialization.') - - def _init_state(self, num_batch=1): - state = self.init_state(num_batch) - if state is not None: - self.set_state(state) - self._is_state_initialized = True - - def _init_fb_output(self, num_batch=1): - output = self.init_fb_output(num_batch) - if output is not None: - self.set_fb_output(output) - self._is_fb_state_initialized = True - - def init_state(self, num_batch=1) -> Optional[Tensor]: - """Set the initial node state. - - This function can be called multiple times.""" - pass - - def init_fb_output(self, num_batch=1) -> Optional[Tensor]: - """Set the initial node feedback state. - - This function can be called multiple times. However, - it is only triggered when the node has feedback connections. - """ - return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_) - - def initialize(self, num_batch: int = 1): - """ - Initialize the node. This function must be called before applying JIT. - - This function is useful, because it is independent of the __call__ function. - We can use this function before we apply JIT to __call__ function. - """ - - # feedforward initialization - if self.feedforward_shapes is None: - raise ValueError('Cannot initialize this node, because we detect ' - 'both "feedforward_shapes" is None. ' - 'Two ways can solve this problem:\n\n' - '1. Connecting an instance of "brainpy.nn.Input()" to this node. \n' - '2. Providing the "input_shape" when initialize the node.') - check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False) - self._init_ff_conn() - - # initialize state - self._init_state(num_batch) - - if self.feedback_shapes is not None: - # feedback initialization - self._init_fb_conn() - # initialize feedback state - self._init_fb_output(num_batch) - - def _check_inputs(self, ff, fb=None): - # check feedforward inputs - if isinstance(ff, (bm.ndarray, jnp.ndarray)): - ff = {self.name: ff} - if not isinstance(ff, dict): - raise ValueError(f'"ff" must be a dict or a tensor, got {type(ff)}: {ff}') - if self.name not in ff: - raise ValueError(f'Cannot find input for this node {self} when given "ff" {ff}') - for k, size in self._feedforward_shapes.items(): - if k not in ff: - raise ValueError(f"The required key {k} is not provided in feedforward inputs.") - check_shape_except_batch(size, ff[k].shape) - if self.state is not None: - for inp in ff.values(): - if self.state.shape[0] != inp.shape[0]: - raise ValueError(f'The batch size of the input data {inp.shape[0]} is not ' - f'equal to the batch size of the node state {self.state.shape[0]}. ' - f'Maybe you need to reinitialize the data with the desired ' - f'batch size by ".initialize(num_batch)", or change the data ' - f'consistent with the data batch size {self.state.shape[0]}.') - - # check feedback inputs - if fb is not None: - if not isinstance(fb, dict): - raise ValueError(f'"fb" must be a dict, got {type(fb)}: {fb}') - # check feedback consistency - for k, size in self._feedback_shapes.items(): - if k not in fb: - raise ValueError(f"The required key {k} is not provided in feedback inputs.") - check_shape_except_batch(size, fb[k].shape) - if self.state is not None: - for inp in fb.values(): - if self.state.shape[0] != inp.shape[0]: - raise ValueError(f'The batch size of the feedback data {inp.shape[0]} is not ' - f'equal to the batch size of the node state {self.state.shape[0]}. ' - f'Maybe you need to reinitialize the data with the desired ' - f'batch size by ".initialize(num_batch)", or change the data ' - f'consistent with the data batch size {self.state.shape[0]}.') - # data - ff = self.data_pass.filter(ff) - fb = self.data_pass.filter(fb) - return ff, fb - - def _call(self, - ff: Union[Tensor, Dict[Any, Tensor]], - fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - monitors=None, - **kwargs) -> Union[Tensor, Tuple[Tensor, Dict]]: - if not self.is_initialized: - raise ValueError('Please initialize the Node first by calling "initialize()" function.') - - # initialize the forced data - if forced_states is None: - forced_states = dict() - if isinstance(forced_states, (bm.ndarray, jnp.ndarray)): - forced_states = {self.name: forced_states} - check_dict_data(forced_states, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) - if forced_feedbacks is not None: - if len(forced_feedbacks) != 0: - raise ValueError('Single instance of brainpy.nn.Node do ' - 'not support "forced_feedbacks"') - # monitors - need_return_monitor = True - if monitors is None: - monitors = tuple() - need_return_monitor = False - attr_monitors: Dict[str, Tensor] = {} - state_monitors: Dict[str, Tensor] = {} - for key in monitors: - splits = key.split('.') - if len(splits) != 2: - raise ValueError(f'Every term in "monitors" must be (node.item), ' - f'while we got {key}') - if splits[0] not in self.implicit_nodes: - raise ValueError(f'Cannot found the node {splits[0]}, this network ' - f'only has {list(self.implicit_nodes.keys())}.') - - if splits[1] not in NODE_STATES: # attribute monitor - if not hasattr(self.implicit_nodes[splits[0]], splits[1]): - raise UnsupportedError(f'Each node can monitor its states (including {NODE_STATES}), ' - f'or its attribute. While {splits[1]} is neither the state nor ' - f'the attribute of node {splits[0]}.') - else: - attr_monitors[key] = getattr(self.implicit_nodes[splits[0]], splits[1]) - else: # state monitor - if splits[1] == 'state': - assert self.implicit_nodes[splits[0]].state is not None, (f'{splits[0]} has no state, while ' - f'the user try to monitor it.') - state_monitors[key] = None - - if not isinstance(key, str): - raise ValueError(f'"extra_returns" must be a sequence of string, ' - f'while we got {type(key)}') - splits = key.split('.') - if len(splits) != 2: - raise ValueError(f'Every term in "monitors" must be (node.item), ' - f'while we got {key}') - if splits[0] != self.name: - raise ValueError(f"Cannot found the node {splits[0]}, this name of " - f"this node is {self.name}.") - if splits[1] not in NODE_STATES: # monitor attributes - if not hasattr(self, key): - raise UnsupportedError(f'Each node can monitor its states (including {NODE_STATES}), ' - f'or its attribute. While {key} is neither the state nor ' - f'the attribute of node \n\n{self}.') - else: - attr_monitors[key] = getattr(self, key) - else: # monitor states - if splits[1] == 'state': - if self.state is None: - raise ValueError(f'{self} \n\nhas no state, while ' - f'the user try to monitor its state.') - state_monitors[key] = None - - # checking - ff, fb = self._check_inputs(ff, fb=fb) - - # monitoring - if f'{self.name}.inputs' in state_monitors: - state_monitors[f'{self.name}.inputs'] = ff - if f'{self.name}.feedbacks' in state_monitors: - state_monitors[f'{self.name}.feedbacks'] = fb - - # forward pass - output = self.forward(ff, fb, **kwargs) - - # monitoring - if f'{self.name}.output' in state_monitors: - state_monitors[f'{self.name}.output'] = output - if f'{self.name}.state' in state_monitors: - state_monitors[f'{self.name}.state'] = self.state - attr_monitors.update(state_monitors) - - # outputs - if need_return_monitor: - return output, attr_monitors - else: - return output - - def forward(self, ff, fb=None, **shared_kwargs): - """The feedforward computation function of a node. - - Parameters - ---------- - ff: tensor, dict, sequence - The feedforward inputs. - fb: optional, tensor, dict, sequence - The feedback inputs. - **shared_kwargs - Other parameters. - - Returns - ------- - Tensor - A output tensor value. - """ - raise NotImplementedError - - def feedback(self, ff_output, **shared_kwargs): - """The feedback computation function of a node. - - Parameters - ---------- - ff_output: JaxArray - The feedforward output when calling ``forward()`` function. - **shared_kwargs - Other global parameters. - - Returns - ------- - Tensor - A feedback output tensor value. - """ - return ff_output - - @not_implemented - def offline_fit(self, targets, ffs, fbs=None): - """Offline training interface.""" - raise ValueError(f'This node \n\n{self} \n\ndoes not support offline training.') - - @not_implemented - def online_init(self): - """Online training initialization interface.""" - raise ValueError(f'This node \n\n{self} \n\ndoes not support online training.') - - @not_implemented - def online_fit(self, target, ff, fb=None): - """Online training fitting interface.""" - raise ValueError(f'This node \n\n{self} \n\ndoes not support online training.') - - -class RecurrentNode(Node): - """ - Basic class for recurrent node. - - The supports for the recurrent node are: - - - Self-connection when using ``plot_node_graph()`` function - - Set trainable state with ``state_trainable=True``. - - Parameters - ---------- - name: str - The name of the node. - input_shape: int, sequence of int - The shape of the input data. - state_trainable: bool - Whether train the model state or not. Default is False. - trainable: bool - Whether train the model or not. Default is True. - - .. versionchanged:: 2.1.8.1 - The faultvalue of ``trainable`` changed from False to True in version 2.1.8.1. - - """ - - def __init__( - self, - name: Optional[str] = None, - input_shape: Optional[Union[Sequence[int], int]] = None, - trainable: bool = True, - state_trainable: bool = False - ): - self._state_trainable = state_trainable - self._train_state = None - super(RecurrentNode, self).__init__(name=name, - input_shape=input_shape, - trainable=trainable) - - @property - def state_trainable(self) -> bool: - """Returns if the Node can be trained.""" - return self._state_trainable - - @property - def train_state(self): - return self._train_state - - def set_state(self, state): - """Safely set the state of the node. - - This method allows the maximum flexibility to change the - node state. It can set a new data (same shape, same dtype) - to the state. It can also set the data with another batch size. - - We highly recommend the user to use this function. - """ - if self.state is None: - if self.output_shape is not None: - check_shape_except_batch(self.output_shape, state.shape) - self._state = bm.Variable(state) if not isinstance(state, bm.Variable) else state - if self.state_trainable: - self._train_state = bm.TrainVar(self._state[0]) # get the first elements as the initial state - self._state[:] = self._train_state # set all batch states the same - else: - check_shape_except_batch(self.state.shape, state.shape) - if self.state.dtype != state.dtype: - raise MathError('Cannot set the state, because the dtype is not consistent: ' - f'{self.state.dtype} != {state.dtype}') - if self.state_trainable: - # get the batch size information - state = bm.repeat(bm.expand_dims(self.train_state, axis=0), state.shape[0], axis=0) - # set the state - self.state._value = bm.as_device_array(state) - else: - self.state._value = bm.as_device_array(state) - - -class Network(Node): - """Basic Network class for neural network building in BrainPy.""" - - data_pass = MultipleData('sequence') - - def __init__(self, - nodes: Optional[Sequence[Node]] = None, - ff_edges: Optional[Sequence[Tuple[Node]]] = None, - fb_edges: Optional[Sequence[Tuple[Node]]] = None, - **kwargs): - super(Network, self).__init__(**kwargs) - # nodes (with tuple/list format) - if nodes is None: - self._nodes = tuple() - else: - self._nodes = tuple(nodes) - # feedforward edges - if ff_edges is None: - self._ff_edges = tuple() - else: - self._ff_edges = tuple(ff_edges) - # feedback edges - if fb_edges is None: - self._fb_edges = tuple() - else: - self._fb_edges = tuple(fb_edges) - # initialize network - self._network_init() - - def _network_init(self): - # detect input and output nodes - self._entry_nodes, self._exit_nodes = find_entries_and_exits(self._nodes, self._ff_edges) - # build feedforward connection graph - self._ff_senders, self._ff_receivers = find_senders_and_receivers(self._ff_edges) - # build feedback connection graph - self._fb_senders, self._fb_receivers = find_senders_and_receivers(self._fb_edges) - # register nodes for brainpy.Base object - self.implicit_nodes = Collector({n.name: n for n in self._nodes}) - # set initialization states - self._is_initialized = False - self._is_fb_initialized = False - - def __repr__(self): - return f"{type(self).__name__}({', '.join([n.name for n in self._nodes])})" - - def __irshift__(self, other): # "self >>= other" - global operations - if operations is None: from . import operations - return operations.ff_connect(self, other, inplace=True) - - def __ilshift__(self, other): # "self <<= other" - global operations - if operations is None: from . import operations - return operations.fb_connect(self, other, inplace=True) - - def __iand__(self, other): - global operations - if operations is None: from . import operations - return operations.merge(self, other, inplace=True) - - def __getitem__(self, item): - if isinstance(item, str): - return self.get_node(item) - else: - global operations - if operations is None: from . import operations - return operations.select(self, item) - - def get_node(self, name): - if name in self.implicit_nodes: - return self.implicit_nodes[name] - else: - raise KeyError(f"No node named '{name}' found in model {self.name}.") - - def nodes(self, method='absolute', level=1, include_self=False): - return super(Node, self).nodes(method=method, level=level, include_self=include_self) - - @property - def trainable(self) -> bool: - """Returns True if at least one Node in the Model is trainable.""" - return any([n.trainable for n in self.lnodes]) - - @trainable.setter - def trainable(self, value: bool): - """Freeze or unfreeze trainable Nodes in the Model.""" - for node in [n for n in self.lnodes]: - node.trainable = value - - @property - def lnodes(self) -> Tuple[Node]: - return self._nodes - - @property - def ff_edges(self) -> Sequence[Tuple[Node]]: - return self._ff_edges - - @property - def fb_edges(self) -> Sequence[Tuple[Node]]: - return self._fb_edges - - @property - def entry_nodes(self) -> Sequence[Node]: - """First Nodes in the graph held by the Model.""" - return self._entry_nodes - - @property - def exit_nodes(self) -> Sequence[Node]: - """Last Nodes in the graph held by the Model.""" - return self._exit_nodes - - @property - def feedback_nodes(self) -> Sequence[Node]: - """Nodes which project feedback connections.""" - return tuple(self._fb_receivers.keys()) - - @property - def nodes_has_feedback(self) -> Sequence[Node]: - """Nodes which receive feedback connections.""" - return tuple(self._fb_senders.keys()) - - @property - def ff_senders(self) -> Dict: - """Nodes which project feedforward connections.""" - return self._ff_senders - - @property - def ff_receivers(self) -> Dict: - """Nodes which receive feedforward connections.""" - return self._ff_receivers - - @property - def fb_senders(self) -> Dict: - """Nodes which project feedback connections.""" - return self._fb_senders - - @property - def fb_receivers(self) -> Dict: - """Nodes which receive feedback connections.""" - return self._fb_receivers - - def update_graph(self, - new_nodes: Sequence[Node], - new_ff_edges: Sequence[Tuple[Node, Node]], - new_fb_edges: Sequence[Tuple[Node, Node]] = None) -> "Network": - """Update current Model's with new nodes and edges, inplace (a copy - is not performed). - - Parameters - ---------- - new_nodes : list of Node - New nodes. - new_ff_edges : list of (Node, Node) - New feedforward edges between nodes. - new_fb_edges : list of (Node, Node) - New feedback edges between nodes. - - Returns - ------- - Network - The updated network. - """ - if new_fb_edges is None: new_fb_edges = tuple() - self._nodes = tuple(set(new_nodes) | set(self.lnodes)) - self._ff_edges = tuple(set(new_ff_edges) | set(self.ff_edges)) - self._fb_edges = tuple(set(new_fb_edges) | set(self.fb_edges)) - # detect cycles in the graph flow - if detect_cycle(self._nodes, self._ff_edges): - raise ValueError('We detect cycles in feedforward connections. ' - 'Maybe you should replace some connection with ' - 'as feedback ones.') - if detect_cycle(self._nodes, self._fb_edges): - raise ValueError('We detect cycles in feedback connections. ') - self._network_init() - return self - - def replace_graph(self, - nodes: Sequence[Node], - ff_edges: Sequence[Tuple[Node, ...]], - fb_edges: Sequence[Tuple[Node, ...]] = None) -> "Network": - if fb_edges is None: fb_edges = tuple() - - # assign nodes and edges - self._nodes = tuple(nodes) - self._ff_edges = tuple(ff_edges) - self._fb_edges = tuple(fb_edges) - self._network_init() - return self - - def set_output_shape(self, shape: Dict[str, Sequence[int]]): - # check shape - if not isinstance(shape, dict): - raise ValueError(f'Must be a dict of , but got {type(shape)}: {shape}') - for key, val in shape.items(): - if not isinstance(val, (tuple, list)): - raise ValueError(f'Must be a sequence of int, but got {val} for key "{key}"') - # for s in val: - # if not (isinstance(s, int) or (s is None)): - # raise ValueError(f'Must be a sequence of int, but got {val}') - - if not self._is_ff_initialized: - if len(self.exit_nodes) == 1: - self._output_shape = tuple(shape.values())[0] - else: - self._output_shape = shape - else: - for val in shape.values(): - check_shape_except_batch(val, self.output_shape) - - def init_ff_conn(self): - """Initialize the feedforward connections of the network. - This function will be called only once.""" - # input shapes of entry nodes - for node in self.entry_nodes: - # set ff shapes - if node.feedforward_shapes is None: - if self.feedforward_shapes is None: - raise ValueError('Cannot find the input size. ' - 'Cannot initialize the network.') - else: - node.set_feedforward_shapes({node.name: self._feedforward_shapes[node.name]}) - # set fb shapes - if node in self.fb_senders: - fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])} - if None not in fb_shapes.values(): - node.set_feedback_shapes(fb_shapes) - # init ff conn - node._init_ff_conn() - - # initialize the data - children_queue = [] - ff_senders, _ = find_senders_and_receivers(self.ff_edges) - - # init shapes of other nodes - for node in self._entry_nodes: - for child in self.ff_receivers.get(node, []): - ff_senders[child].remove(node) - if len(ff_senders.get(child, [])) == 0: - children_queue.append(child) - while len(children_queue): - node = children_queue.pop(0) - # set ff shapes - parent_sizes = {p: p.output_shape for p in self.ff_senders.get(node, [])} - node.set_feedforward_shapes(parent_sizes) - if node in self.fb_senders: - # set fb shapes - fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])} - if None not in fb_shapes.values(): - node.set_feedback_shapes(fb_shapes) - # init ff conn - node._init_ff_conn() - # append children - for child in self.ff_receivers.get(node, []): - ff_senders[child].remove(node) - if len(ff_senders.get(child, [])) == 0: - children_queue.append(child) - - # set output shape - out_sizes = {node: node.output_shape for node in self.exit_nodes} - self.set_output_shape(out_sizes) - - def init_fb_conn(self): - """Initialize the feedback connections of the network. - This function will be called only once.""" - for receiver, senders in self.fb_senders.items(): - fb_sizes = {node: node.output_shape for node in senders} - if None in fb_sizes.values(): - none_size_nodes = [repr(n) for n, v in fb_sizes.items() if v is None] - none_size_nodes = "\n".join(none_size_nodes) - raise ValueError(f'Output shapes of nodes \n\n' - f'{none_size_nodes}\n\n' - f'have not been initialized, ' - f'leading us cannot initialize the ' - f'feedback connection of node \n\n' - f'{receiver}') - receiver.set_feedback_shapes(fb_sizes) - receiver._init_fb_conn() - - def _init_state(self, num_batch=1): - """Initialize the states of all children nodes. - This function can be called multiple times.""" - for node in self.lnodes: - node._init_state(num_batch) - self._is_state_initialized = True - - def _init_fb_output(self, num_batch=1): - """Initialize the node feedback state. - - This function can be called multiple times. However, - it is only triggered when the node has feedback connections. - """ - for node in self.feedback_nodes: - node._init_fb_output(num_batch) - self._is_fb_state_initialized = True - - def initialize(self, num_batch: int = 1): - """ - Initialize the whole network. This function must be called before applying JIT. - - This function is useful, because it is independent of the __call__ function. - We can use this function before we apply JIT to __call__ function. - """ - - # set feedforward shapes - if not self._is_ff_initialized: - # check input and output nodes - if len(self.entry_nodes) <= 0: - raise ValueError(f"We found this network \n\n" - f"{self} " - f"\n\nhas no input nodes.") - if len(self.exit_nodes) <= 0: - raise ValueError(f"We found this network \n\n" - f"{self} " - f"\n\nhas no output nodes.") - - # check whether it has a feedforward path for each feedback pair - ff_edges = [(a.name, b.name) for a, b in self.ff_edges] - for node, receiver in self.fb_edges: - if not detect_path(receiver.name, node.name, ff_edges): - raise ValueError(f'Cannot build a feedback connection from ' - f'\n\n{node} \n\n' - f'to ' - f'\n\n{receiver} \n\n' - f'because there is no feedforward path between them. \n' - f'Maybe you should use "ff_connect" first to establish a ' - f'feedforward connection between them. ') - - # feedforward checking - in_sizes = dict() - for node in self.entry_nodes: - if node.feedforward_shapes is None: - raise ValueError('Cannot initialize this node, because we detect ' - '"feedforward_shapes" is None. ' - 'Maybe you need a brainpy.nn.Input instance ' - 'to instruct the input size.') - in_sizes.update(node._feedforward_shapes) - self.set_feedforward_shapes(in_sizes) - - # feedforward initialization - if self.feedforward_shapes is None: - raise ValueError('Cannot initialize this node, because we detect ' - 'both "feedforward_shapes" is None. ') - check_integer(num_batch, 'num_batch', min_bound=1, allow_none=False) - self._init_ff_conn() - - # initialize state - self._init_state(num_batch) - - # set feedback shapes - if not self._is_fb_initialized: - if len(self.fb_senders) > 0: - fb_sizes = dict() - for sender in self.fb_senders.keys(): - fb_sizes[sender] = sender.output_shape - self.set_feedback_shapes(fb_sizes) - - # feedback initialization - if self.feedback_shapes is not None: - self._init_fb_conn() - - # initialize feedback state - self._init_fb_output(num_batch) - - def _check_inputs(self, ff, fb=None): - # feedforward inputs - if isinstance(ff, (bm.ndarray, jnp.ndarray)): - ff = {self.entry_nodes[0].name: ff} - if not isinstance(ff, dict): - raise ValueError(f'ff must be a dict or a tensor, got {type(ff)}: {ff}') - if len(self.entry_nodes) != len(ff): - raise ValueError(f'This network has {len(self.entry_nodes)} ' - f'entry nodes. While only {len(ff)} input ' - f'data are given.') - for n in self.entry_nodes: - if n.name not in ff: - raise ValueError(f'Cannot find the input of the node: \n{n}') - for k, size in self._feedforward_shapes.items(): - if k not in ff: - raise ValueError(f"The required key {k} is not provided in feedforward inputs.") - if not check_shape_except_batch(size, ff[k].shape, mode='bool'): - raise ValueError(f'Input size {ff[k].shape} is not consistent with ' - f'the input size {size}') - - # feedback inputs - if fb is not None: - if isinstance(fb, (bm.ndarray, jnp.ndarray)): - fb = {self.feedback_nodes[0]: fb} - if not isinstance(fb, dict): - raise ValueError(f'fb must be a dict or a tensor, ' - f'got {type(fb)}: {fb}') - if len(self.feedback_nodes) != len(fb): - raise ValueError(f'This network has {len(self.feedback_nodes)} ' - f'feedback nodes. While only {len(ff)} ' - f'feedback data are given.') - for n in self.feedback_nodes: - if n.name not in fb: - raise ValueError(f'Cannot find the feedback data from the node {n}') - # check feedback consistency - for k, size in self._feedback_shapes.items(): - if k not in fb: - raise ValueError(f"The required key {k} is not provided in feedback inputs.") - check_shape_except_batch(size, fb[k].shape) - - # data transformation - ff = self.data_pass.filter(ff) - fb = self.data_pass.filter(fb) - return ff, fb - - def _call(self, - ff: Union[Tensor, Dict[Any, Tensor]], - fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None, - forced_states: Optional[Dict[str, Tensor]] = None, - forced_feedbacks: Optional[Dict[str, Tensor]] = None, - monitors: Optional[Sequence[str]] = None, - **kwargs): - # initialization - if not self.is_initialized: - raise ValueError('Please initialize the Network first by calling "initialize()" function.') - - # initialize the forced data - if forced_feedbacks is None: forced_feedbacks = dict() - check_dict_data(forced_feedbacks, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) - if forced_states is None: forced_states = dict() - check_dict_data(forced_states, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) - # initialize the monitors - need_return_monitor = True - if monitors is None: - monitors = tuple() - need_return_monitor = False - attr_monitors: Dict[str, Tensor] = {} - state_monitors: Dict[str, Tensor] = {} - for key in monitors: - if not isinstance(key, str): - raise ValueError(f'"extra_returns" must be a sequence of string, ' - f'while we got {type(key)}') - splits = key.split('.') - if len(splits) != 2: - raise ValueError(f'Every term in "extra_returns" must be (node.item), ' - f'while we got {key}') - if splits[0] not in self.implicit_nodes: - raise ValueError(f'Cannot found the node {splits[0]}, this network ' - f'only has {list(self.implicit_nodes.keys())}.') - - if splits[1] not in NODE_STATES: # attribute monitor - if not hasattr(self.implicit_nodes[splits[0]], splits[1]): - raise UnsupportedError(f'Each node can monitor its states (including {NODE_STATES}), ' - f'or its attribute. While {splits[1]} is neither the state nor ' - f'the attribute of node {splits[0]}.') - else: - attr_monitors[key] = getattr(self.implicit_nodes[splits[0]], splits[1]) - else: # state monitor - if splits[1] == 'state': - assert self.implicit_nodes[splits[0]].state is not None, (f'{splits[0]} has no state, while ' - f'the user try to monitor it.') - state_monitors[key] = None - # calling the computation core - ff, fb = self._check_inputs(ff, fb=fb) - output, state_monitors = self.forward(ff, fb, forced_states, forced_feedbacks, state_monitors, **kwargs) - if need_return_monitor: - attr_monitors.update(state_monitors) - return output, attr_monitors - else: - return output - - def _call_a_node(self, node, ff, fb, monitors, forced_states, - parent_outputs, children_queue, ff_senders, - **shared_kwargs): - ff = node.data_pass.filter(ff) - if f'{node.name}.inputs' in monitors: - monitors[f'{node.name}.inputs'] = ff - # get the output results - if len(fb): - fb = node.data_pass.filter(fb) - if f'{node.name}.feedbacks' in monitors: - monitors[f'{node.name}.feedbacks'] = fb - parent_outputs[node] = node.forward(ff, fb, **shared_kwargs) - else: - parent_outputs[node] = node.forward(ff, **shared_kwargs) - # get the feedback state - if node in self.fb_receivers: - node.set_fb_output(node.feedback(parent_outputs[node], **shared_kwargs)) - # forced state - if node.name in forced_states: - node.state.value = forced_states[node.name] - # monitor the values - if f'{node.name}.state' in monitors: - monitors[f'{node.name}.state'] = node.state.value - if f'{node.name}.output' in monitors: - monitors[f'{node.name}.output'] = parent_outputs[node] - # append children nodes - for child in self.ff_receivers.get(node, []): - ff_senders[child].remove(node) - if len(ff_senders.get(child, [])) == 0: - children_queue.append(child) - - def forward(self, - ff, - fb=None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - monitors: Dict = None, - **shared_kwargs): - """The main computation function of a network. - - Parameters - ---------- - ff: dict, sequence - The feedforward inputs. - fb: optional, dict, sequence - The feedback inputs. - forced_states: optional, dict - The fixed state for the nodes in the network. - forced_feedbacks: optional, dict - The fixed feedback for the nodes in the network. - monitors: optional, sequence - Can be used to monitor the state or the attribute of a node in the network. - **shared_kwargs - Other parameters which will be parsed into every node. - - Returns - ------- - Tensor - A output tensor value, or a dict of output tensors. - """ - all_nodes = set([n.name for n in self.lnodes]) - runned_nodes = set() - output_nodes = set([n.name for n in self.exit_nodes]) - - # initialize the feedback - if forced_feedbacks is None: forced_feedbacks = dict() - if monitors is None: monitors = dict() - - # initialize the data - children_queue = [] - ff_senders, _ = find_senders_and_receivers(self.ff_edges) - - # initialize the parent output data - parent_outputs = {} - for i, node in enumerate(self._entry_nodes): - ff_ = {node.name: ff[i]} - fb_ = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output) - for p in self.fb_senders.get(node, [])} - self._call_a_node(node, ff_, fb_, monitors, forced_states, - parent_outputs, children_queue, ff_senders, - **shared_kwargs) - runned_nodes.add(node.name) - - # run the model - while len(children_queue): - node = children_queue.pop(0) - # get feedforward and feedback inputs - ff = {p: parent_outputs[p] for p in self.ff_senders.get(node, [])} - fb = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output) - for p in self.fb_senders.get(node, [])} - # call the node - self._call_a_node(node, ff, fb, monitors, forced_states, - parent_outputs, children_queue, ff_senders, - **shared_kwargs) - - # - remove unnecessary parent outputs - # - needed_parents = [] - runned_nodes.add(node.name) - for child in (all_nodes - runned_nodes): - for parent in self.ff_senders[self.implicit_nodes[child]]: - needed_parents.append(parent.name) - for parent in list(parent_outputs.keys()): - _name = parent.name - if _name not in needed_parents and _name not in output_nodes: - parent_outputs.pop(parent) - - # returns - if len(self.exit_nodes) > 1: - state = {n.name: parent_outputs[n] for n in self.exit_nodes} - else: - state = parent_outputs[self.exit_nodes[0]] - return state, monitors - - def plot_node_graph(self, - fig_size: tuple = (10, 10), - node_size: int = 1000, - arrow_size: int = 20, - layout='shell_layout', - show=True, - legends=None, - ax=None): - """Plot the node graph based on NetworkX package - - Parameters - ---------- - fig_size: tuple, default to (10, 10) - The size of the figure - - .. deprecated:: 2.1.9 - Please use ``ax`` variable. - - node_size: int - The size of the node. default to 1000 - arrow_size:int, default to 20 - The size of the arrow - layout: str - The graph layout. The supported layouts are: - - - "shell_layout" - - "multipartite_layout" - - "spring_layout" - - "spiral_layout" - - "spectral_layout" - - "random_layout" - - "planar_layout" - - "kamada_kawai_layout" - - "circular_layout" - """ - try: - import networkx as nx - except (ModuleNotFoundError, ImportError): - raise PackageMissingError('The node graph plotting currently need package "networkx". ' - 'But it can not be imported. ') - try: - import matplotlib.pyplot as plt - from matplotlib.lines import Line2D - except (ModuleNotFoundError, ImportError): - raise PackageMissingError('The node graph plotting currently need package "matplotlib". ' - 'But it can not be imported. ') - - nodes_trainable = [] - nodes_untrainable = [] - for node in self.lnodes: - if node.trainable: - nodes_trainable.append(node.name) - else: - nodes_untrainable.append(node.name) - - ff_edges = [] - fb_edges = [] - rec_edges = [] - for edge in self.ff_edges: - ff_edges.append((edge[0].name, edge[1].name)) - for edge in self.fb_edges: - fb_edges.append((edge[0].name, edge[1].name)) - for node in self.lnodes: - if isinstance(node, RecurrentNode): - rec_edges.append((node.name, node.name)) - - trainable_color = 'orange' - untrainable_color = 'skyblue' - ff_color = 'green' - fb_color = 'red' - rec_color = 'purple' - G = nx.DiGraph() - mid_nodes = list(set(self.lnodes) - set(self.entry_nodes) - set(self.exit_nodes)) - mid_nodes.sort(key=lambda x: x.name) - index = 0 - for node in list(self.entry_nodes) + mid_nodes + list(self.exit_nodes): - index = index + 1 - G.add_node(node.name, subset=index) - G.add_edges_from(ff_edges) - G.add_edges_from(fb_edges) - G.add_edges_from(rec_edges) - - if layout not in SUPPORTED_LAYOUTS: - raise UnsupportedError(f'Only support layouts: {SUPPORTED_LAYOUTS}') - layout = getattr(nx, layout)(G) - - if ax is None: - from brainpy.visualization.figures import get_figure - fig, gs = get_figure(1, 1, fig_size[1], fig_size[0]) - ax = fig.add_subplot(gs[0, 0]) - nx.draw_networkx_nodes(G, pos=layout, - nodelist=nodes_trainable, - node_color=trainable_color, - node_size=node_size, - ax=ax) - nx.draw_networkx_nodes(G, pos=layout, - nodelist=nodes_untrainable, - node_color=untrainable_color, - node_size=node_size) - - ff_conn_style = "arc3,rad=0." - nx.draw_networkx_edges(G, pos=layout, - edgelist=ff_edges, - edge_color=ff_color, - connectionstyle=ff_conn_style, - arrowsize=arrow_size, - node_size=node_size) - fb_conn_style = "arc3,rad=0.3" - nx.draw_networkx_edges(G, pos=layout, - edgelist=fb_edges, - edge_color=fb_color, - connectionstyle=fb_conn_style, - arrowsize=arrow_size, - node_size=node_size) - rec_conn_style = "arc3,rad=-0.3" - nx.draw_networkx_edges(G, pos=layout, - edgelist=rec_edges, - edge_color=rec_color, - arrowsize=arrow_size, - connectionstyle=rec_conn_style, - node_size=node_size, - node_shape='s') - - nx.draw_networkx_labels(G, pos=layout) - proxie = [] - labels = [] - if len(nodes_trainable): - proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=trainable_color)) - labels.append('Trainable') - if len(nodes_untrainable): - proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=untrainable_color)) - labels.append('Nontrainable') - if len(ff_edges): - proxie.append(Line2D([], [], color=ff_color, linewidth=2)) - labels.append('Feedforward') - if len(fb_edges): - proxie.append(Line2D([], [], color=fb_color, linewidth=2)) - labels.append('Feedback') - if len(rec_edges): - proxie.append(Line2D([], [], color=rec_color, linewidth=2)) - labels.append('Recurrent') - - legends = dict() if legends is None else legends - ax.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best', **legends) - if show: - plt.show() - - -class FrozenNetwork(Network): - """A FrozenNetwork is a Network that can not be linked to other nodes or networks.""" - - def update_graph(self, new_nodes, new_ff_edges, new_fb_edges=None): - raise TypeError(f"Cannot update FrozenModel {self}: " - f"model is frozen and cannot be modified.") - - def replace_graph(self, nodes, ff_edges, fb_edges=None): - raise TypeError(f"Cannot update FrozenModel {self}: " - f"model is frozen and cannot be modified.") - - -class Sequential(Network): - pass - -# def _process_params(G, center, dim): -# # Some boilerplate code. -# import numpy as np -# -# if not isinstance(G, nx.Graph): -# empty_graph = nx.Graph() -# empty_graph.add_nodes_from(G) -# G = empty_graph -# -# if center is None: -# center = np.zeros(dim) -# else: -# center = np.asarray(center) -# -# if len(center) != dim: -# msg = "length of center coordinates must match dimension of layout" -# raise ValueError(msg) -# -# return G, center -# -# -# def multipartite_layout(G, subset_key="subset", align="vertical", scale=1, center=None): -# import numpy as np -# -# if align not in ("vertical", "horizontal"): -# msg = "align must be either vertical or horizontal." -# raise ValueError(msg) -# -# G, center = _process_params(G, center=center, dim=2) -# if len(G) == 0: -# return {} -# -# layers = {} -# for v, data in G.nodes(data=True): -# try: -# layer = data[subset_key] -# except KeyError: -# msg = "all nodes must have subset_key (default='subset') as data" -# raise ValueError(msg) -# layers[layer] = [v] + layers.get(layer, []) -# -# pos = None -# nodes = [] -# -# width = len(layers) -# for i, layer in layers.items(): -# height = len(layer) -# xs = np.repeat(i, height) -# ys = np.arange(0, height, dtype=float) -# offset = ((width - 1) / 2, (height - 1) / 2) -# layer_pos = np.column_stack([xs, ys]) - offset -# if pos is None: -# pos = layer_pos -# else: -# pos = np.concatenate([pos, layer_pos]) -# nodes.extend(layer) -# pos = rescale_layout(pos, scale=scale) + center -# if align == "horizontal": -# pos = np.flip(pos, 1) -# pos = dict(zip(nodes, pos)) -# return pos -# -# -# def rescale_layout(pos, scale=1): -# """Returns scaled position array to (-scale, scale) in all axes. -# -# The function acts on NumPy arrays which hold position information. -# Each position is one row of the array. The dimension of the space -# equals the number of columns. Each coordinate in one column. -# -# To rescale, the mean (center) is subtracted from each axis separately. -# Then all values are scaled so that the largest magnitude value -# from all axes equals `scale` (thus, the aspect ratio is preserved). -# The resulting NumPy Array is returned (order of rows unchanged). -# -# Parameters -# ---------- -# pos : numpy array -# positions to be scaled. Each row is a position. -# -# scale : number (default: 1) -# The size of the resulting extent in all directions. -# -# Returns -# ------- -# pos : numpy array -# scaled positions. Each row is a position. -# -# See Also -# -------- -# rescale_layout_dict -# """ -# # Find max length over all dimensions -# lim = 0 # max coordinate for all axes -# for i in range(pos.shape[1]): -# pos[:, i] -= pos[:, i].mean() -# lim = max(abs(pos[:, i]).max(), lim) -# # rescale to (-scale, scale) in all directions, preserves aspect -# if lim > 0: -# for i in range(pos.shape[1]): -# pos[:, i] *= scale / lim -# return pos diff --git a/brainpy/compat/nn/datatypes.py b/brainpy/compat/nn/datatypes.py deleted file mode 100644 index ad025061f..000000000 --- a/brainpy/compat/nn/datatypes.py +++ /dev/null @@ -1,97 +0,0 @@ -# -*- coding: utf-8 -*- - - -__all__ = [ - # data types - 'DataType', - - # pass rules - 'SingleData', - 'MultipleData', -] - - -class DataType(object): - """Base class for data type.""" - - def filter(self, data): - raise NotImplementedError - - def __repr__(self): - return self.__class__.__name__ - - -class SingleData(DataType): - """Pass the only one data into the node. - If there are multiple data, an error will be raised. """ - - def filter(self, data): - if data is None: - return None - if len(data) > 1: - raise ValueError(f'{self.__class__.__name__} only support one ' - f'feedforward/feedback input. But we got {len(data)}.') - return tuple(data.values())[0] - - def __repr__(self): - return self.__class__.__name__ - - -class MultipleData(DataType): - """Pass a list/tuple of data into the node.""" - - def __init__(self, return_type: str = 'sequence'): - if return_type not in ['sequence', 'name_dict', 'type_dict', 'node_dict']: - raise ValueError(f"Only support return type of 'sequence', 'name_dict', " - f"'type_dict' and 'node_dict'. But we got {return_type}") - self.return_type = return_type - - from brainpy.compat.nn.base import Node - - if return_type == 'sequence': - f = lambda data: tuple(data.values()) - - elif return_type == 'name_dict': - # Pass a dict with into the node. - - def f(data): - _res = dict() - for node, val in data.items(): - if isinstance(node, str): - _res[node] = val - elif isinstance(node, Node): - _res[node.name] = val - else: - raise ValueError(f'Unknown type {type(node)}: node') - return _res - - elif return_type == 'type_dict': - # Pass a dict with into the node. - - def f(data): - _res = dict() - for node, val in data.items(): - if isinstance(node, str): - _res[str] = val - elif isinstance(node, Node): - _res[type(node.name)] = val - else: - raise ValueError(f'Unknown type {type(node)}: node') - return _res - - elif return_type == 'node_dict': - # Pass a dict with into the node. - f = lambda data: data - - else: - raise ValueError - self.return_func = f - - def __repr__(self): - return f'{self.__class__.__name__}(return_type={self.return_type})' - - def filter(self, data): - if data is None: - return None - else: - return self.return_func(data) diff --git a/brainpy/compat/nn/graph_flow.py b/brainpy/compat/nn/graph_flow.py deleted file mode 100644 index bd94a26ff..000000000 --- a/brainpy/compat/nn/graph_flow.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- - - -""" -This module provides basic tool for graphs, including - -- detect the senders and receivers in the network graph, -- find input and output nodes in a given graph, -- detect the cycle in the graph, -- detect the path between two nodes. - -""" - - -from collections import deque, defaultdict - -__all__ = [ - 'find_senders_and_receivers', - 'find_entries_and_exits', - 'detect_cycle', - 'detect_path', -] - - -def find_senders_and_receivers(edges): - """Find all senders and receivers in the given graph.""" - senders = dict() # find parents according to the child - receivers = dict() # find children according to the parent - for edge in edges: - sender, receiver = edge - if receiver not in senders: - senders[receiver] = [sender] - else: - senders[receiver].append(sender) - if sender not in receivers: - receivers[sender] = [receiver] - else: - receivers[sender].append(receiver) - return senders, receivers - - -def find_entries_and_exits(nodes, ff_edges, fb_edges=()): - """Find input nodes and output nodes.""" - nodes = set(nodes) - ff_senders = set([n for n, _ in ff_edges]) - ff_receivers = set([n for _, n in ff_edges]) - fb_senders = set([n for n, _ in fb_edges]) - fb_receivers = set([n for _, n in fb_edges]) - - # # check lonely feedback nodes - # fb_receivers_without_ff = fb_receivers - ff_receivers - ff_senders - # if len(fb_receivers_without_ff) > 0: - # raise ValueError(f'Found feedback nodes do not define feedforward connections: \n\n' - # f'{fb_receivers_without_ff}') - - # check lonely nodes - lonely = nodes - ff_senders - ff_receivers - fb_senders - fb_receivers - # if len(lonely): - # _str_nodes = '\n'.join([str(node) for node in lonely]) - # raise ValueError(f"Found lonely nodes \n\n{_str_nodes} \n\n" - # f"which do not connect with any other.") - - # get input and output nodes - entry_points = (ff_senders | fb_senders) - ff_receivers - lonely - end_points = ff_receivers - ff_senders - lonely - return list(entry_points), list(end_points) - - -def topological_sort(nodes, ff_edges, inputs=None): - if inputs is None: - inputs, _ = find_entries_and_exits(nodes, ff_edges) - parents, children = find_senders_and_receivers(ff_edges) - # using Kahn's algorithm - ordered_nodes = [] - ff_edges = set(ff_edges) - inputs = deque(inputs) - while len(inputs) > 0: - n = inputs.pop() - ordered_nodes.append(n) - for m in children.get(n, ()): - ff_edges.remove((n, m)) - parents[m].remove(n) - if parents.get(m) is None or len(parents[m]) < 1: - inputs.append(m) - if len(ff_edges) > 0: - raise RuntimeError("Model has a cycle: impossible " - "to automatically determine operation " - "order in the model.") - else: - return ordered_nodes - - -def _detect_cycle(v, visited, stacks, graph): - # visited数组元素为true,标记该元素被isCyclicUtil递归调用链处理中,或处理过 - # recStack数组元素为true,表示该元素还在递归函数isCyclicUtil的函数栈中 - visited[v] = True - stacks[v] = True - # 深度遍历所有节点。 - for neighbour in graph[v]: - if not visited[neighbour]: # 如果该节点没有被处理过,那么继续调用递归 - if _detect_cycle(neighbour, visited, stacks, graph): # 如果邻接点neighbour的递归发现了环 - return True # 那么返回真 - elif stacks[neighbour]: # 如果neighbour被处理中(这里强调了不是处理过),且还在递归栈中,说明发现了环 - return True - stacks[v] = False # 函数开始时,V节点进栈。所以函数结束时,V节点出栈。 - return False # v的所有邻接点的递归都没有发现环,则返回假 - - -def detect_cycle(nodes, edges): - """Detect whether a cycle exists in the defined graph. - """ - node2id = {node: i for i, node in enumerate(nodes)} - graph = defaultdict(list) - for s, r in edges: - graph[node2id[s]].append(node2id[r]) - num = len(nodes) - - visited = [False] * num - stacks = [False] * num - for i in range(num): # 分别以每个节点作为起点,然后开始深度遍历 - if not visited[i]: # 这里为真,说明之前的深度遍历已经遍历过该节点了,且那次遍历没有发现环 - if _detect_cycle(i, visited, stacks, graph): # 如果发现环,直接返回 - return True - return False # 如果分别以每个节点作为起点的深度遍历都没有发现环,那肯定是整个图没有环 - - -def _has_path_by_dfs(from_node, to_node, graph): - # queue本质上是堆栈,用来存放需要进行遍历的数据 - # order里面存放的是具体的访问路径 - queue, order = [], [] - # 首先将初始遍历的节点放到queue中,表示将要从这个点开始遍历 - queue.append(from_node) - while len(queue): - # 从queue中pop出点v,然后从v点开始遍历了,所以可以将这个点pop出,然后将其放入order中 - # 这里才是最有用的地方,pop()表示弹出栈顶,由于下面的for循环不断的访问子节点,并将子节点压入堆栈, - # 也就保证了每次的栈顶弹出的顺序是下面的节点 - v = queue.pop() - order.append(v) - # 这里开始遍历v的子节点 - for w in graph[v]: - # w既不属于queue也不属于order,意味着这个点没被访问过,所以讲起放到queue中,然后后续进行访问 - if w not in order and w not in queue: - if w == to_node: - return True - else: - queue.append(w) - return False - - -def _has_path_by_bfs(from_node, to_node, graph): - # queue本质上是堆栈,用来存放需要进行遍历的数据 - # order里面存放的是具体的访问路径 - queue, order = [], [] - # 首先将初始遍历的节点放到queue中,表示将要从这个点开始遍历 - # 由于是广度优先,也就是先访问初始节点的所有的子节点,所以可以 - queue.append(from_node) - order.append(from_node) - while len(queue): - # queue.pop(0)意味着是队列的方式出元素,就是先进先出,而下面的for循环将节点v的所有子节点 - # 放到queue中,所以queue.pop(0)就实现了每次访问都是先将元素的子节点访问完毕,而不是优先叶子节点 - v = queue.pop(0) - for w in graph[v]: - if w not in order: - if w == to_node: - return True - else: - # 这里可以直接order.append(w) 因为广度优先就是先访问节点的所有下级子节点,所以可以 - # 将self.sequense[v]的值直接全部先给到order - order.append(w) - queue.append(w) - return False - - -def detect_path(from_node, to_node, edges, method='dfs'): - """Detect whether there is a path exist in the defined graph - from ``from_node`` to ``to_node``. """ - graph = defaultdict(list) - for s, r in edges: - graph[s].append(r) - if method == 'dfs': - return _has_path_by_dfs(from_node, to_node, graph) - elif method == 'bfs': - return _has_path_by_bfs(from_node, to_node, graph) - else: - raise ValueError(f'Unknown method {method}') diff --git a/brainpy/compat/nn/nodes/ANN/__init__.py b/brainpy/compat/nn/nodes/ANN/__init__.py deleted file mode 100644 index 389ca2d16..000000000 --- a/brainpy/compat/nn/nodes/ANN/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Artificial neural network (ANN) nodes""" - -from .conv import * -from .dropout import * -from .rnn_cells import * -from .pooling import * -from .normalization import * \ No newline at end of file diff --git a/brainpy/compat/nn/nodes/ANN/conv.py b/brainpy/compat/nn/nodes/ANN/conv.py deleted file mode 100644 index de170d0a4..000000000 --- a/brainpy/compat/nn/nodes/ANN/conv.py +++ /dev/null @@ -1,204 +0,0 @@ -# -*- coding: utf-8 -*- - - -import jax.lax -import brainpy.math as bm -from brainpy.initialize import XavierNormal, ZeroInit, init_param -from brainpy.compat.nn.base import Node - -__all__ = [ - 'GeneralConv', - 'Conv1D', - 'Conv2D', - 'Conv3D' -] - - -def _check_tuple(v): - if isinstance(v, (tuple, list)): - return tuple(v) - elif isinstance(v, int): - return (v, v) - else: - raise ValueError - - -def _conv_dimension_numbers(input_shape): - """Computes the dimension numbers based on the input shape.""" - ndim = len(input_shape) - lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) - rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) - out_spec = lhs_spec - return jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) - - -class GeneralConv(Node): - """Applies a convolution to the inputs. - - Args: - out_channels: integer - number of output channels. - kernel_size: sequence[int] - shape of the convolutional kernel. For 1D convolution, - the kernel size can be passed as an integer. For all other cases, it must - be a sequence of integers. - strides: sequence[int] - an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, sequence[int] - either the string `'SAME'`, the string `'VALID'`, the string - `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. A single int is interpeted as applying the same padding - in all dims and passign a single int in a sequence causes the same padding - to be used on both sides. - input_dilation: integer, sequence[int] - an integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - kernel_dilation: integer, sequence[int] - an integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: integer, default 1. - If specified divides the input - features into groups. - kernel_init: brainpy.init.Initializer - initializer for the convolutional kernel. - bias_init: brainpy.init.Initializer - initializer for the bias. - """ - - def __init__(self, out_channels, kernel_size, strides=None, padding='SAME', - input_dilation=None, kernel_dilation=None, groups=1, - w_init=XavierNormal(), b_init=ZeroInit(), **kwargs): - super(GeneralConv, self).__init__(**kwargs) - - self.out_channels = out_channels - self.kernel_size = kernel_size - self.strides = strides - self.padding = padding - self.input_dilation = input_dilation - self.kernel_dilation = kernel_dilation - self.groups = groups - self.w_init = w_init - self.b_init = b_init - self.dimension_numbers = None - self.trainable = True - - if isinstance(padding, str): - assert padding in ['SAME', 'VALID'] - elif isinstance(padding, tuple): - for k in padding: - assert isinstance(k, int) - else: - raise ValueError - - assert out_channels % self.groups == 0, '"nout" should be divisible by groups' - - def _check_input_dim(self): - pass - - def init_ff_conn(self): - input_shapes = self.feedforward_shapes - in_channels = int(input_shapes[-1]) - assert in_channels % self.groups == 0, '"nin" should be divisible by groups' - kernel_shape = _check_tuple(self.kernel_size) + (in_channels // self.groups, self.out_channels) - self.w = init_param(self.w_init, kernel_shape) - self.b = init_param(self.b_init, (1,) * len(self.kernel_size) + (self.out_channels,)) - if self.trainable: - self.w = bm.TrainVar(self.w) - self.b = bm.TrainVar(self.b) - - if self.strides is None: - self.strides = (1,) * (len(input_shapes) - 2) - - output_shapes = jax.lax.conv_transpose_shape_tuple( - input_shapes, kernel_shape, self.strides, self.padding, dimension_numbers=self.dimension_numbers) - self.set_output_shape(output_shapes) - - def init_fb_conn(self): - fb_input_shapes = self.feedback_shapes - ff_input_shapes = self.feedforward_shapes - ff_spatial_axes = ff_input_shapes[1:-1] # only first (batch) and last (channel) dimension are not spatial dims - fb_spatial_axes = fb_input_shapes[1:-1] - assert ff_spatial_axes == fb_spatial_axes, f"Feedback input spatial dimensions {fb_spatial_axes} are not aligned " \ - f"with feedforward input spatial dimensions {ff_spatial_axes}. " - - in_channels = int(ff_input_shapes[-1] + fb_input_shapes[-1]) - assert in_channels % self.groups == 0, '"nin" should be divisible by groups' - kernel_shape = _check_tuple(self.kernel_size) + (in_channels // self.groups, self.out_channels) - self.w = init_param(self.w_init, kernel_shape) - self.b = init_param(self.b_init, (1,) * len(self.kernel_size) + (self.out_channels,)) - if self.trainable: - self.w = bm.TrainVar(self.w) - self.b = bm.TrainVar(self.b) - - if self.strides is None: - self.strides = (1,) * (len(ff_input_shapes) - 2) - - def forward(self, ff, fb=None, **shared_kwargs): - if fb is not None: - data = bm.concatenate((ff, fb), axis=-1) - else: - data = ff - y = jax.lax.conv_general_dilated(lhs=data.value if isinstance(data, bm.JaxArray) else ff, - rhs=self.w.value, - window_strides=self.strides, - padding=self.padding, - lhs_dilation=self.input_dilation, - rhs_dilation=self.kernel_dilation, - feature_group_count=self.groups, - dimension_numbers=self.dimension_numbers) - if self.b is None: - return y - return y + self.b.value - - -class Conv1D(GeneralConv): - def __init__(self, out_channels, kernel_size, **kwargs): - super(Conv1D, self).__init__(out_channels, kernel_size, **kwargs) - - self.dimension_numbers = ('NWC', 'WIO', 'NWC') - - def _check_input_dim(self): - ndim = len(self.feedforward_shapes) - if ndim != 3: - raise ValueError( - "expected 3D input (got {}D input)".format(ndim) - ) - - assert len(self.kernel_size) == 1, "expected 1D kernel size (got {}D input)".format(self.kernel_size) - - -class Conv2D(GeneralConv): - def __init__(self, out_channels, kernel_size, **kwargs): - super(Conv2D, self).__init__(out_channels, kernel_size, **kwargs) - - self.dimension_numbers = ('NHWC', 'HWIO', 'NHWC') - - def _check_input_dim(self): - ndim = len(self.feedforward_shapes) - if ndim != 4: - raise ValueError( - "expected 4D input (got {}D input)".format(ndim) - ) - - assert len(self.kernel_size) == 2, "expected 2D kernel size (got {}D input)".format(self.kernel_size) - - -class Conv3D(GeneralConv): - def __init__(self, out_channels, kernel_size, **kwargs): - super(Conv3D, self).__init__(out_channels, kernel_size, **kwargs) - - self.dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') - - def _check_input_dim(self): - ndim = len(self.feedforward_shapes) - if ndim != 5: - raise ValueError( - "expected 5D input (got {}D input)".format(ndim) - ) - - assert len(self.kernel_size) == 3, "expected 3D kernel size (got {}D input)".format(self.kernel_size) diff --git a/brainpy/compat/nn/nodes/ANN/dropout.py b/brainpy/compat/nn/nodes/ANN/dropout.py deleted file mode 100644 index 65eabe658..000000000 --- a/brainpy/compat/nn/nodes/ANN/dropout.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy.math as bm -from brainpy.compat.nn.base import Node - -__all__ = [ - 'Dropout' -] - - -class Dropout(Node): - """A layer that stochastically ignores a subset of inputs each training step. - - In training, to compensate for the fraction of input values dropped (`rate`), - all surviving values are multiplied by `1 / (1 - rate)`. - - The parameter `shared_axes` allows to specify a list of axes on which - the mask will be shared: we will use size 1 on those axes for dropout mask - and broadcast it. Sharing reduces randomness, but can save memory. - - This layer is active only during training (`mode='train'`). In other - circumstances it is a no-op. - - Parameters - ---------- - prob : float - Probability to keep element of the tensor. - seed : optional, int - The random sampling seed. - name : str, optional - The name of the dynamic system. - - References - ---------- - .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent - neural networks from overfitting." The journal of machine learning - research 15.1 (2014): 1929-1958. - """ - def __init__(self, prob, seed=None, **kwargs): - super(Dropout, self).__init__(**kwargs) - self.prob = prob - self.rng = bm.random.RandomState(seed=seed) - - def init_ff_conn(self): - self.set_output_shape(self.feedforward_shapes) - - def forward(self, ff, **shared_kwargs): - if shared_kwargs.get('train', True): - keep_mask = self.rng.bernoulli(self.prob, ff.shape) - return bm.where(keep_mask, ff / self.prob, 0.) - else: - return ff diff --git a/brainpy/compat/nn/nodes/ANN/normalization.py b/brainpy/compat/nn/nodes/ANN/normalization.py deleted file mode 100644 index 9ea25095d..000000000 --- a/brainpy/compat/nn/nodes/ANN/normalization.py +++ /dev/null @@ -1,382 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union - -import jax.numpy as jnp - -import brainpy.math as bm -from brainpy.compat.nn.base import Node -from brainpy.initialize import ZeroInit, OneInit, Initializer - -__all__ = [ - 'BatchNorm', - 'BatchNorm1d', - 'BatchNorm2d', - 'BatchNorm3d', - 'GroupNorm', - 'LayerNorm', - 'InstanceNorm', -] - - -class BatchNorm(Node): - """Batch Normalization node. - This layer aims to reduce the internal covariant shift of data. It - normalizes a batch of data by fixing the mean and variance of inputs - on each feature (channel). Most commonly, the first axis of the data - is the batch, and the last is the channel. However, users can specify - the axes to be normalized. - - adapted from jax.example_libraries.stax.BatchNorm - https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm - - Parameters - ---------- - axis: int, tuple, list - axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - use_bias: bool - whether to translate data in refactoring. Default: True - use_scale: bool - whether to scale data in refactoring. Default: True - beta_init: brainpy.init.Initializer - an initializer generating the original translation matrix - gamma_init: brainpy.init.Initializer - an initializer generating the original scaling matrix - """ - - def __init__(self, - axis: Union[int, tuple, list], - epsilon: float = 1e-5, - use_bias: bool = True, - use_scale: bool = True, - beta_init: Initializer = ZeroInit(), - gamma_init: Initializer = OneInit(), - **kwargs): - super(BatchNorm, self).__init__(**kwargs) - self.epsilon = epsilon - self.bias = use_bias - self.scale = use_scale - self.beta_init = beta_init if use_bias else () - self.gamma_init = gamma_init if use_scale else () - self.axis = (axis,) if jnp.isscalar(axis) else axis - - def _check_input_dim(self): - pass - - def init_ff_conn(self): - self._check_input_dim() - - input_shape = tuple(d for i, d in enumerate(self.feedforward_shapes) if i not in self.axis) - self.beta = bm.TrainVar(self.beta_init(input_shape)) if self.bias else None - self.gamma = bm.TrainVar(self.gamma_init(input_shape)) if self.scale else None - self.set_output_shape(self.feedforward_shapes) - - def forward(self, ff, **shared_kwargs): - ed = tuple(None if i in self.axis else slice(None) for i in range(jnp.ndim(ff))) - output = bm.normalize(ff, self.axis, epsilon=self.epsilon) - if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed] - if self.bias: return output + self.beta[ed] - if self.scale: return self.gamma[ed] * output - return output - - -class BatchNorm1d(BatchNorm): - """1-D batch normalization. - The data should be of `(b, l, c)`, where `b` is the batch dimension, - `l` is the layer dimension, and `c` is the channel dimension, or of - '(b, c)'. - - Parameters - ---------- - axis: int, tuple, list - axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - use_bias: bool - whether to translate data in refactoring. Default: True - use_scale: bool - whether to scale data in refactoring. Default: True - beta_init: brainpy.init.Initializer - an initializer generating the original translation matrix - gamma_init: brainpy.init.Initializer - an initializer generating the original scaling matrix - """ - def __init__(self, axis=(0, 1), **kwargs): - super(BatchNorm1d, self).__init__(axis=axis, **kwargs) - - def _check_input_dim(self): - ndim = len(self.feedforward_shapes) - if ndim != 2 and ndim != 3: - raise ValueError( - "expected 2D or 3D input (got {}D input)".format(ndim) - ) - if ndim == 2 and len(self.axis) == 2: - self.axis = (0,) - - -class BatchNorm2d(BatchNorm): - """2-D batch normalization. - The data should be of `(b, h, w, c)`, where `b` is the batch dimension, - `h` is the height dimension, `w` is the width dimension, and `c` is the - channel dimension. - - Parameters - ---------- - axis: int, tuple, list - axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - use_bias: bool - whether to translate data in refactoring. Default: True - use_scale: bool - whether to scale data in refactoring. Default: True - beta_init: brainpy.init.Initializer - an initializer generating the original translation matrix - gamma_init: brainpy.init.Initializer - an initializer generating the original scaling matrix - """ - def __init__(self, axis=(0, 1, 2), **kwargs): - super(BatchNorm2d, self).__init__(axis=axis, **kwargs) - - def _check_input_dim(self): - ndim = len(self.feedforward_shapes) - if ndim != 4: - raise ValueError( - "expected 4D input (got {}D input)".format(ndim) - ) - - -class BatchNorm3d(BatchNorm): - """3-D batch normalization. - The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension, - `h` is the height dimension, `w` is the width dimension, `d` is the depth - dimension, and `c` is the channel dimension. - - Parameters - ---------- - axis: int, tuple, list - axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - use_bias: bool - whether to translate data in refactoring. Default: True - use_scale: bool - whether to scale data in refactoring. Default: True - beta_init: brainpy.init.Initializer - an initializer generating the original translation matrix - gamma_init: brainpy.init.Initializer - an initializer generating the original scaling matrix - """ - def __init__(self, axis=(0, 1, 2, 3), **kwargs): - super(BatchNorm3d, self).__init__(axis=axis, **kwargs) - - def _check_input_dim(self): - ndim = len(self.feedforward_shapes) - if ndim != 5: - raise ValueError( - "expected 5D input (got {}D input)".format(ndim) - ) - - -class LayerNorm(Node): - """Layer normalization (https://arxiv.org/abs/1607.06450). - - This layer normalizes data on each example, independently of the batch. More - specifically, it normalizes data of shape (b, d1, d2, ..., c) on the axes of - the data dimensions and the channel (d1, d2, ..., c). Different from batch - normalization, gamma and beta are assigned to each position (elementwise - operation) instead of the whole channel. If users want to assign a single - gamma and beta to a whole example/whole channel, please use GroupNorm/ - InstanceNorm. - - Parameters - ---------- - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - use_bias: bool - whether to translate data in refactoring. Default: True - use_scale: bool - whether to scale data in refactoring. Default: True - beta_init: brainpy.init.Initializer - an initializer generating the original translation matrix - gamma_init: brainpy.init.Initializer - an initializer generating the original scaling matrix - axis: int, tuple, list - axes where the data will be normalized. The batch axis should be excluded. - """ - def __init__(self, - epsilon: float = 1e-5, - use_bias: bool = True, - use_scale: bool = True, - beta_init: Initializer = ZeroInit(), - gamma_init: Initializer = OneInit(), - axis: Union[int, tuple] = None, - **kwargs): - super(LayerNorm, self).__init__(**kwargs) - self.epsilon = epsilon - self.bias = use_bias - self.scale = use_scale - self.beta_init = beta_init if use_bias else () - self.gamma_init = gamma_init if use_scale else () - self.axis = (axis,) if jnp.isscalar(axis) else axis - - def default_axis(self): - # default: the first axis (batch dim) is excluded - return tuple(i for i in range(1, len(self.feedforward_shapes))) - - def init_ff_conn(self): - if self.axis is None: - self.axis = self.default_axis() - # todo: what if elementwise_affine = False? - input_shape = tuple(d for i, d in enumerate(self.feedforward_shapes) if i in self.axis) - self.beta = bm.TrainVar(self.beta_init(input_shape)) if self.bias else None - self.gamma = bm.TrainVar(self.gamma_init(input_shape)) if self.scale else None - self.set_output_shape(self.feedforward_shapes) - - def forward(self, ff, **shared_kwargs): - ed = tuple(None if i not in self.axis else slice(None) for i in range(jnp.ndim(ff))) - output = bm.normalize(ff, self.axis, epsilon=self.epsilon) - if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed] - if self.bias: return output + self.beta[ed] - if self.scale: return self.gamma[ed] * output - return output - - -class GroupNorm(Node): - """Group normalization layer. - - This layer divides channels into groups and normalizes the features within each - group. Its computation is also independent of the batch size. The feature size - must be multiple of the group size. - - The shape of the data should be (b, d1, d2, ..., c), where `d` denotes the batch - size and `c` denotes the feature (channel) size. The `d` and `c` axis should be - excluded in parameter `axis`. - - Parameters - ---------- - num_groups: int - the number of groups. It should be a factor of the number of features. - group_size: int - the group size. It should equal to int(num_features / num_groups). - Either `num_groups` or `group_size` should be specified. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - use_bias: bool - whether to translate data in refactoring. Default: True - use_scale: bool - whether to scale data in refactoring. Default: True - beta_init: brainpy.init.Initializer - an initializer generating the original translation matrix - gamma_init: brainpy.init.Initializer - an initializer generating the original scaling matrix - axis: int, tuple, list - axes where the data will be normalized. Besides the batch axis, the channel - axis should be also excluded, since it will be automatically added to `axis`. - """ - def __init__(self, - num_groups: int = None, - group_size: int = None, - epsilon: float = 1e-5, - use_bias: bool = True, - use_scale: bool = True, - beta_init: Initializer = ZeroInit(), - gamma_init: Initializer = OneInit(), - axis: Union[int, tuple] = None, - **kwargs): - super(GroupNorm, self).__init__(**kwargs) - self.num_groups = num_groups - self.group_size = group_size - self.epsilon = epsilon - self.bias = use_bias - self.scale = use_scale - self.beta_init = beta_init if use_bias else () - self.gamma_init = gamma_init if use_scale else () - self.norm_axis = (axis,) if jnp.isscalar(axis) else axis - - def init_ff_conn(self): - num_channels = self.feedforward_shapes[-1] - self.ndim = len(self.feedforward_shapes) - - # compute num_groups and group_size - if ((self.num_groups is None and self.group_size is None) or - (self.num_groups is not None and self.group_size is not None)): - raise ValueError('Either `num_groups` or `group_size` should be specified. ' - 'Once one is specified, the other will be automatically ' - 'computed.') - - if self.num_groups is None: - assert self.group_size > 0, '`group_size` should be a positive integer.' - if num_channels % self.group_size != 0: - raise ValueError('The number of channels ({}) is not multiple of the ' - 'group size ({}).'.format(num_channels, self.group_size)) - else: - self.num_groups = num_channels // self.group_size - else: # self.num_groups is not None: - assert self.num_groups > 0, '`num_groups` should be a positive integer.' - if num_channels % self.num_groups != 0: - raise ValueError('The number of channels ({}) is not multiple of the ' - 'number of groups ({}).'.format(num_channels, self.num_groups)) - else: - self.group_size = num_channels // self.num_groups - - # axes for normalization - if self.norm_axis is None: - # default: the first axis (batch dim) and the second-last axis (num_group dim) are excluded - self.norm_axis = tuple(i for i in range(1, len(self.feedforward_shapes) - 1)) + (self.ndim,) - - group_shape = self.feedforward_shapes[:-1] + (self.num_groups, self.group_size) - input_shape = tuple(d for i, d in enumerate(group_shape) if i in self.norm_axis) - self.beta = bm.TrainVar(self.beta_init(input_shape)) if self.bias else None - self.gamma = bm.TrainVar(self.gamma_init(input_shape)) if self.scale else None - self.set_output_shape(self.feedforward_shapes) - - def forward(self, ff, **shared_kwargs): - group_shape = ff.shape[:-1] + (self.num_groups, self.group_size) - ff_reshape = ff.reshape(group_shape) - ed = tuple(None if i not in self.norm_axis else slice(None) for i in range(jnp.ndim(ff_reshape))) - output = bm.normalize(ff_reshape, self.norm_axis, epsilon=self.epsilon) - if self.bias and self.scale: - output = self.gamma[ed] * output + self.beta[ed] - elif self.bias: - output = output + self.beta[ed] - elif self.scale: - output = self.gamma[ed] * output - return output.reshape(ff.shape) - - -class InstanceNorm(GroupNorm): - """Instance normalization layer. - - This layer normalizes the data within each feature. It can be regarded as - a group normalization layer in which `group_size` equals to 1. - - Parameters - ---------- - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - use_bias: bool - whether to translate data in refactoring. Default: True - use_scale: bool - whether to scale data in refactoring. Default: True - beta_init: brainpy.init.Initializer - an initializer generating the original translation matrix - gamma_init: brainpy.init.Initializer - an initializer generating the original scaling matrix - axis: int, tuple, list - axes where the data will be normalized. The batch and channel axes - should be excluded. - """ - def __init__(self, - epsilon: float = 1e-5, - use_bias: bool = True, - use_scale: bool = True, - beta_init: Initializer = ZeroInit(), - gamma_init: Initializer = OneInit(), - axis: Union[int, tuple] = None, - **kwargs): - super(InstanceNorm, self).__init__(group_size=1, epsilon=epsilon, use_bias=use_bias, - use_scale=use_scale, beta_init=beta_init, - gamma_init=gamma_init, axis=axis, **kwargs) diff --git a/brainpy/compat/nn/nodes/ANN/pooling.py b/brainpy/compat/nn/nodes/ANN/pooling.py deleted file mode 100644 index aa4391944..000000000 --- a/brainpy/compat/nn/nodes/ANN/pooling.py +++ /dev/null @@ -1,157 +0,0 @@ -# -*- coding: utf-8 -*- - - -import jax.lax -import brainpy.math as bm -from brainpy.compat.nn.base import Node - -__all__ = [ - 'Pool', - 'MaxPool', - 'AvgPool', - 'MinPool' -] - - -class Pool(Node): - def __init__(self, init_v, reduce_fn, window_shape, strides, padding, **kwargs): - """Pooling functions are implemented using the ReduceWindow XLA op. - - Args: - init_v: scalar - the initial value for the reduction - reduce_fn: callable - a reduce function of the form `(T, T) -> T`. - window_shape: tuple - a shape tuple defining the window to reduce over. - strides: sequence[int] - a sequence of `n` integers, representing the inter-window strides. - padding: str, sequence[int] - either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - - Returns: - The output of the reduction for each window slice. - """ - super(Pool, self).__init__(**kwargs) - self.init_v = init_v - self.reduce_fn = reduce_fn - self.window_shape = window_shape - self.strides = strides or (1,) * len(window_shape) - assert len(self.window_shape) == len(self.strides), ( - f"len({self.window_shape}) must equal len({self.strides})") - self.strides = (1,) + self.strides + (1,) - self.dims = (1,) + window_shape + (1,) - self.is_single_input = False - - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert all([len(x) == 2 for x in padding]), ( - f"each entry in padding {padding} must be length 2") - padding = ((0, 0),) + padding + ((0, 0),) - self.padding = padding - - def init_ff_conn(self): - input_shapes = tuple((0,)) + tuple(d for d in self.feedforward_shapes if d is not None) - assert len(input_shapes) == len(self.dims), f"len({len(input_shapes)}) != len({self.dims})" - - padding_vals = jax.lax.padtype_to_pads(input_shapes, self.dims, self.strides, self.padding) - ones = (1,) * len(self.dims) - out_shapes = jax.lax.reduce_window_shape_tuple( - input_shapes, self.dims, self.strides, padding_vals, ones, ones) - - out_shapes = tuple((None,)) + tuple(d for i, d in enumerate(out_shapes) if i != 0) - self.set_output_shape(out_shapes) - - def forward(self, ff, fb=None, **shared_kwargs): - y = jax.lax.reduce_window(ff, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding) - - return y - - -class AvgPool(Pool): - """Pools the input by taking the average over a window. - - Args: - window_shape: tuple - a shape tuple defining the window to reduce over. - strides: sequence[int] - a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`). - padding: str, sequence[int] - either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension (default: `'VALID'`). - - Returns: - The average for each window slice. - """ - - def __init__(self, window_shape, strides=None, padding="VALID"): - super(AvgPool, self).__init__( - init_v=0., - reduce_fn=jax.lax.add, - window_shape=window_shape, - strides=strides, - padding=padding - ) - - def forward(self, ff, fb=None, **shared_kwargs): - y = jax.lax.reduce_window(ff, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding) - y = y / bm.prod(bm.asarray(self.window_shape)) - return y - - -class MaxPool(Pool): - """Pools the input by taking the maximum over a window. - - Args: - window_shape: tuple - a shape tuple defining the window to reduce over. - strides: sequence[int] - a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`). - padding: str, sequence[int] - either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension (default: `'VALID'`). - - Returns: - The maximum for each window slice. - """ - def __init__(self, window_shape, strides=None, padding="VALID"): - super(MaxPool, self).__init__( - init_v=-bm.inf, - reduce_fn=jax.lax.max, - window_shape=window_shape, - strides=strides, - padding=padding - ) - - -class MinPool(Pool): - """Pools the input by taking the minimum over a window. - - Args: - window_shape: tuple - a shape tuple defining the window to reduce over. - strides: sequence[int] - a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`). - padding: str, sequence[int] - either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension (default: `'VALID'`). - - Returns: - The minimum for each window slice. - """ - def __init__(self, window_shape, strides=None, padding="VALID"): - super(MinPool, self).__init__( - init_v=bm.inf, - reduce_fn=jax.lax.min, - window_shape=window_shape, - strides=strides, - padding=padding - ) \ No newline at end of file diff --git a/brainpy/compat/nn/nodes/ANN/rnn_cells.py b/brainpy/compat/nn/nodes/ANN/rnn_cells.py deleted file mode 100644 index 79ab50c7b..000000000 --- a/brainpy/compat/nn/nodes/ANN/rnn_cells.py +++ /dev/null @@ -1,410 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Union, Callable - -import brainpy.math as bm -from brainpy.initialize import (XavierNormal, - ZeroInit, - Uniform, - Orthogonal, - init_param, - Initializer) -from brainpy.compat.nn.base import RecurrentNode -from brainpy.compat.nn.datatypes import MultipleData -from brainpy.tools.checking import (check_integer, - check_initializer, - check_shape_consistency) -from brainpy.types import Tensor - -__all__ = [ - 'VanillaRNN', - 'GRU', - 'LSTM', -] - - -class VanillaRNN(RecurrentNode): - r"""Basic fully-connected RNN core. - - Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the - core computes - - .. math:: - - h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h) - - The output is equal to the new state, :math:`h_t`. - - - Parameters - ---------- - num_unit: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - bias_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - trainable: bool - Whether set the node is trainable. - - """ - data_pass = MultipleData('sequence') - - def __init__( - self, - num_unit: int, - state_initializer: Union[Tensor, Callable, Initializer] = Uniform(), - wi_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), - wh_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), - bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), - activation: str = 'relu', - **kwargs - ): - super(VanillaRNN, self).__init__(**kwargs) - - self.num_unit = num_unit - check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) - self.set_output_shape((None, self.num_unit)) - - # initializers - self._state_initializer = state_initializer - self._wi_initializer = wi_initializer - self._wh_initializer = wh_initializer - self._bias_initializer = bias_initializer - check_initializer(wi_initializer, 'wi_initializer', allow_none=False) - check_initializer(wh_initializer, 'wh_initializer', allow_none=False) - check_initializer(state_initializer, 'state_initializer', allow_none=False) - check_initializer(bias_initializer, 'bias_initializer', allow_none=True) - - # activation function - self.activation = bm.activations.get(activation) - - def init_ff_conn(self): - unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) - assert len(unique_size) == 1, 'Only support data with or without batch size.' - # weights - num_input = sum(free_sizes) - self.Wff = init_param(self._wi_initializer, (num_input, self.num_unit)) - self.Wrec = init_param(self._wh_initializer, (self.num_unit, self.num_unit)) - self.bias = init_param(self._bias_initializer, (self.num_unit,)) - if self.trainable: - self.Wff = bm.TrainVar(self.Wff) - self.Wrec = bm.TrainVar(self.Wrec) - self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - - def init_fb_conn(self): - unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) - assert len(unique_size) == 1, 'Only support data with or without batch size.' - num_feedback = sum(free_sizes) - # weights - self.Wfb = init_param(self._wi_initializer, (num_feedback, self.num_unit)) - if self.trainable: - self.Wfb = bm.TrainVar(self.Wfb) - - def init_state(self, num_batch=1): - return init_param(self._state_initializer, (num_batch, self.num_unit)) - - def forward(self, ff, fb=None, **shared_kwargs): - ff = bm.concatenate(ff, axis=-1) - h = ff @ self.Wff - h += self.state.value @ self.Wrec - if self.bias is not None: - h += self.bias - if fb is not None: - fb = bm.concatenate(fb, axis=-1) - h += fb @ self.Wfb - self.state.value = self.activation(h) - return self.state.value - - -class GRU(RecurrentNode): - r"""Gated Recurrent Unit. - - The implementation is based on (Chung, et al., 2014) [1]_ with biases. - - Given :math:`x_t` and the previous state :math:`h_{t-1}` the core computes - - .. math:: - - \begin{array}{ll} - z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ - r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ - a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ - h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t - \end{array} - - where :math:`z_t` and :math:`r_t` are reset and update gates. - - The output is equal to the new hidden state, :math:`h_t`. - - Warning: Backwards compatibility of GRU weights is currently unsupported. - - Parameters - ---------- - num_unit: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - bias_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - trainable: bool - Whether set the node is trainable. - - References - ---------- - .. [1] Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical - evaluation of gated recurrent neural networks on sequence modeling. - arXiv preprint arXiv:1412.3555. - """ - data_pass = MultipleData('sequence') - - def __init__( - self, - num_unit: int, - wi_initializer: Union[Tensor, Callable, Initializer] = Orthogonal(), - wh_initializer: Union[Tensor, Callable, Initializer] = Orthogonal(), - bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), - state_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), - **kwargs - ): - super(GRU, self).__init__(**kwargs) - - self.num_unit = num_unit - check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) - self.set_output_shape((None, self.num_unit)) - - self._wi_initializer = wi_initializer - self._wh_initializer = wh_initializer - self._bias_initializer = bias_initializer - self._state_initializer = state_initializer - check_initializer(wi_initializer, 'wi_initializer', allow_none=False) - check_initializer(wh_initializer, 'wh_initializer', allow_none=False) - check_initializer(state_initializer, 'state_initializer', allow_none=False) - check_initializer(bias_initializer, 'bias_initializer', allow_none=True) - - def init_ff_conn(self): - # data shape - unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) - assert len(unique_size) == 1, 'Only support data with or without batch size.' - - # weights - num_input = sum(free_sizes) - self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 3)) - self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 3)) - self.bias = init_param(self._bias_initializer, (self.num_unit * 3,)) - if self.trainable: - self.Wi_ff = bm.TrainVar(self.Wi_ff) - self.Wh = bm.TrainVar(self.Wh) - self.bias = bm.TrainVar(self.bias) if (self.bias is not None) else None - - def init_fb_conn(self): - unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) - assert len(unique_size) == 1, 'Only support data with or without batch size.' - num_feedback = sum(free_sizes) - # weights - self.Wi_fb = init_param(self._wi_initializer, (num_feedback, self.num_unit * 3)) - if self.trainable: - self.Wi_fb = bm.TrainVar(self.Wi_fb) - - def init_state(self, num_batch=1): - return init_param(self._state_initializer, (num_batch, self.num_unit)) - - def forward(self, ff, fb=None, **shared_kwargs): - gates_x = bm.matmul(bm.concatenate(ff, axis=-1), self.Wi_ff) - if fb is not None: - gates_x += bm.matmul(bm.concatenate(fb, axis=-1), self.Wi_fb) - zr_x, a_x = bm.split(gates_x, indices_or_sections=[2 * self.num_unit], axis=-1) - w_h_z, w_h_a = bm.split(self.Wh, indices_or_sections=[2 * self.num_unit], axis=-1) - zr_h = bm.matmul(self.state, w_h_z) - zr = zr_x + zr_h - has_bias = (self.bias is not None) - if has_bias: - b_z, b_a = bm.split(self.bias, indices_or_sections=[2 * self.num_unit], axis=0) - zr += bm.broadcast_to(b_z, zr_h.shape) - z, r = bm.split(bm.sigmoid(zr), indices_or_sections=2, axis=-1) - a_h = bm.matmul(r * self.state, w_h_a) - if has_bias: - a = bm.tanh(a_x + a_h + bm.broadcast_to(b_a, a_h.shape)) - else: - a = bm.tanh(a_x + a_h) - next_state = (1 - z) * self.state + z * a - self.state.value = next_state - return next_state - - -class LSTM(RecurrentNode): - r"""Long short-term memory (LSTM) RNN core. - - The implementation is based on (zaremba, et al., 2014) [1]_. Given - :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core - computes - - .. math:: - - \begin{array}{ll} - i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ - f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ - g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ - o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ - c_t = f_t c_{t-1} + i_t g_t \\ - h_t = o_t \tanh(c_t) - \end{array} - - where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and - output gate activations, and :math:`g_t` is a vector of cell updates. - - The output is equal to the new hidden, :math:`h_t`. - - Notes - ----- - - Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0 - to :math:`b_f` after initialization in order to reduce the scale of forgetting in - the beginning of the training. - - - Parameters - ---------- - num_unit: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - bias_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - trainable: bool - Whether set the node is trainable. - - References - ---------- - - .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural - network regularization." arXiv preprint arXiv:1409.2329 (2014). - .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical - exploration of recurrent network architectures." In International conference - on machine learning, pp. 2342-2350. PMLR, 2015. - """ - data_pass = MultipleData('sequence') - - def __init__( - self, - num_unit: int, - wi_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), - wh_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), - bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), - state_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), - **kwargs - ): - super(LSTM, self).__init__(**kwargs) - - self.num_unit = num_unit - check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) - self.set_output_shape((None, self.num_unit,)) - - self._state_initializer = state_initializer - self._wi_initializer = wi_initializer - self._wh_initializer = wh_initializer - self._bias_initializer = bias_initializer - check_initializer(wi_initializer, 'wi_initializer', allow_none=False) - check_initializer(wh_initializer, 'wh_initializer', allow_none=False) - check_initializer(bias_initializer, 'bias_initializer', allow_none=True) - check_initializer(state_initializer, 'state_initializer', allow_none=False) - - def init_ff_conn(self): - # data shape - unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) - assert len(unique_size) == 1, 'Only support data with or without batch size.' - # weights - num_input = sum(free_sizes) - self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 4)) - self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 4)) - self.bias = init_param(self._bias_initializer, (self.num_unit * 4,)) - if self.trainable: - self.Wi_ff = bm.TrainVar(self.Wi_ff) - self.Wh = bm.TrainVar(self.Wh) - self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - - def init_fb_conn(self): - unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) - assert len(unique_size) == 1, 'Only support data with or without batch size.' - num_feedback = sum(free_sizes) - # weights - self.Wi_fb = init_param(self._wi_initializer, (num_feedback, self.num_unit * 4)) - if self.trainable: - self.Wi_fb = bm.TrainVar(self.Wi_fb) - - def init_state(self, num_batch=1): - return init_param(self._state_initializer, (num_batch * 2, self.num_unit)) - - def forward(self, ff, fb=None, **shared_kwargs): - h, c = bm.split(self.state, 2) - gated = bm.concatenate(ff, axis=-1) @ self.Wi_ff - if fb is not None: - gated += bm.concatenate(fb, axis=-1) @ self.Wi_fb - if self.bias is not None: - gated += self.bias - gated += h @ self.Wh - i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1) - c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * bm.tanh(g) - h = bm.sigmoid(o) * bm.tanh(c) - self.state.value = bm.vstack([h, c]) - return h - - @property - def h(self): - """Hidden state.""" - return bm.split(self.state, 2)[0] - - @h.setter - def h(self, value): - if self.state is None: - raise ValueError('Cannot set "h" state. Because the state is not initialized.') - self.state[:self.state.shape[0] // 2, :] = value - - @property - def c(self): - """Memory cell.""" - return bm.split(self.state, 2)[1] - - @c.setter - def c(self, value): - if self.state is None: - raise ValueError('Cannot set "c" state. Because the state is not initialized.') - self.state[self.state.shape[0] // 2:, :] = value - - -class ConvNDLSTM(RecurrentNode): - data_pass = MultipleData('sequence') - - -class Conv1DLSTM(ConvNDLSTM): - data_pass = MultipleData('sequence') - - -class Conv2DLSTM(ConvNDLSTM): - data_pass = MultipleData('sequence') - - -class Conv3DLSTM(ConvNDLSTM): - data_pass = MultipleData('sequence') diff --git a/brainpy/compat/nn/nodes/RC/__init__.py b/brainpy/compat/nn/nodes/RC/__init__.py deleted file mode 100644 index e28d3d4c4..000000000 --- a/brainpy/compat/nn/nodes/RC/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- - - -"""Reservoir computing (RC) nodes""" - -from .linear_readout import * -from .nvar import * -from .reservoir import * - diff --git a/brainpy/compat/nn/nodes/RC/linear_readout.py b/brainpy/compat/nn/nodes/RC/linear_readout.py deleted file mode 100644 index 3153e0a5c..000000000 --- a/brainpy/compat/nn/nodes/RC/linear_readout.py +++ /dev/null @@ -1,109 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp - -import brainpy.math as bm -from brainpy.errors import MathError -from brainpy.initialize import Initializer -from brainpy.compat.nn.datatypes import MultipleData -from brainpy.compat.nn.nodes.base.dense import Dense -from brainpy.tools.checking import check_shape_consistency - -__all__ = [ - 'LinearReadout', -] - - -class LinearReadout(Dense): - """Linear readout node. Different from ``Dense``, this node has its own state. - - Parameters - ---------- - num_unit: int - The number of output features. A positive integer. - weight_initializer: Initializer - The weight initializer. - bias_initializer: Optional, Initializer - The bias initializer. - trainable: bool - Default is true. - """ - data_pass = MultipleData('sequence') - - def __init__(self, num_unit: int, **kwargs): - super(LinearReadout, self).__init__(num_unit=num_unit, **kwargs) - - def init_state(self, num_batch=1): - return bm.zeros((num_batch,) + self.output_shape[1:]) - - def forward(self, ff, fb=None, **shared_kwargs): - h = super(LinearReadout, self).forward(ff, fb=fb, **shared_kwargs) - self.state.value = h - return h - - def online_init(self): - _, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) - num_input = sum(free_shapes) - if self.bias is not None: - num_input += 1 - if self.feedback_shapes is not None: - _, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) - num_input += sum(free_shapes) - self.online_fit_by.initialize(feature_in=num_input, - feature_out=self.num_unit, - name=self.name) - - def online_fit(self, target, ff, fb=None): - if not isinstance(target, (bm.ndarray, jnp.ndarray)): - raise MathError(f'"target" must be a tensor, but got {type(target)}') - ff = bm.concatenate(ff, axis=-1) - if ff.ndim != 2: - raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, ' - f'num_feature), but we got {ff.shape}') - if target.ndim != 2: - raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, ' - f'num_feature), but we got {target.shape}') - if ff.shape[0] != target.shape[0]: - raise ValueError(f'Batch size of the input and target data should be ' - f'the same, while we got {ff.shape[0]} != {target.shape[0]}.') - if target.shape[1] != self.state.shape[1]: - raise MathError(f'The output dimension of output and target data should be ' - f'the same, while we got {target.shape[1]} != {self.state.shape[1]}') - if fb is not None: - fb = bm.concatenate(fb, axis=-1) - if fb.ndim != 2: - raise ValueError(f'"fb" must be a 2D tensor with shape of (num_sample, ' - f'num_feature), but we got {fb.shape}') - if ff.shape[0] != fb.shape[0]: - raise ValueError(f'Batch size of the feedforward and the feedback inputs should be ' - f'the same, while we got {ff.shape[0]} != {fb.shape[0]}.') - - # data - inputs = ff - num_ff_input = ff.shape[1] - if fb is not None: - inputs = bm.concatenate([inputs, fb], axis=-1) - if self.bias is not None: - inputs = bm.concatenate([bm.ones((inputs.shape[0], 1)), inputs], axis=-1) - - # fitting - dW = self.online_fit_by.call(target=target, input=inputs, output=self.state, name=self.name) - - # assign trained weights - if self.bias is None: - if fb is None: - self.Wff += dW - else: - dWff, dWfb = bm.split(dW, [num_ff_input]) - self.Wff += dWff - self.Wfb += dWfb - else: - if fb is None: - db, dWff = bm.split(dW, [1]) - self.bias += db[0] - self.Wff += dWff - else: - db, dWff, dWfb = bm.split(dW, [1, 1 + num_ff_input]) - self.bias += db[0] - self.Wff += dWff - self.Wfb += dWfb diff --git a/brainpy/compat/nn/nodes/RC/nvar.py b/brainpy/compat/nn/nodes/RC/nvar.py deleted file mode 100644 index bb60a2160..000000000 --- a/brainpy/compat/nn/nodes/RC/nvar.py +++ /dev/null @@ -1,205 +0,0 @@ -# -*- coding: utf-8 -*- - -from itertools import combinations_with_replacement -from typing import Union, Sequence - -import numpy as np -import jax.numpy as jnp - -import brainpy.math as bm -from brainpy.compat.nn.base import RecurrentNode -from brainpy.compat.nn.datatypes import MultipleData -from brainpy.tools.checking import (check_shape_consistency, - check_integer, - check_sequence) - -__all__ = [ - 'NVAR' -] - - -def _comb(N, k): - r"""The number of combinations of N things taken k at a time. - - .. math:: - - \frac{N!}{(N-k)! k!} - - """ - if N > k: - val = 1 - for j in range(min(k, N - k)): - val = (val * (N - j)) // (j + 1) - return val - elif N == k: - return 1 - else: - return 0 - - -class NVAR(RecurrentNode): - """Nonlinear vector auto-regression (NVAR) node. - - This class has the following features: - - - it supports batch size, - - it supports multiple orders, - - Parameters - ---------- - delay: int - The number of delay step. - order: int, sequence of int - The nonlinear order. - stride: int - The stride to sample linear part vector in the delays. - constant: optional, float - The constant value. - - References - ---------- - .. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation - reservoir computing. Nat Commun 12, 5564 (2021). - https://doi.org/10.1038/s41467-021-25801-2 - - """ - data_pass = MultipleData('sequence') - - def __init__( - self, - delay: int, - order: Union[int, Sequence[int]] = None, - stride: int = 1, - constant: bool = False, - trainable: bool = False, - **kwargs - ): - super(NVAR, self).__init__(trainable=trainable, **kwargs) - - # parameters - order = tuple() if order is None else order - if not isinstance(order, (tuple, list)): - order = (order,) - self.order = tuple(order) - check_sequence(order, 'order', allow_none=False) - for o in order: check_integer(o, 'delay', allow_none=False, min_bound=2) - check_integer(delay, 'delay', allow_none=False, min_bound=1) - check_integer(stride, 'stride', allow_none=False, min_bound=1) - assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.' - self.delay = delay - self.stride = stride - self.constant = constant - self.num_delay = 1 + (self.delay - 1) * self.stride - - # attributes - self.comb_ids = [] - self.feature_names = [] - self.input_dim = None - self.output_dim = None - self.linear_dim = None - self.nonlinear_dim = None - - # delay variables - self.idx = bm.Variable(jnp.asarray([0])) - self.store = None - - def init_ff_conn(self): - """Initialize feedforward connections.""" - # input dimension - batch_size, free_size = check_shape_consistency(self.feedforward_shapes, -1, True) - self.input_dim = sum(free_size) - assert batch_size == (None,), f'batch_size must be None, but got {batch_size}' - # linear dimension - self.linear_dim = self.delay * self.input_dim - # For each monomial created in the non-linear part, indices - # of the n components involved, n being the order of the - # monomials. Precompute them to improve efficiency. - for order in self.order: - assert order >= 2, f'"order" must be a integer >= 2, while we got {order}.' - idx = np.array(list(combinations_with_replacement(np.arange(self.linear_dim), order))) - self.comb_ids.append(jnp.asarray(idx)) - # number of non-linear components is (d + n - 1)! / (d - 1)! n! - # i.e. number of all unique monomials of order n made from the - # linear components. - self.nonlinear_dim = sum([len(ids) for ids in self.comb_ids]) - # output dimension - self.output_dim = int(self.linear_dim + self.nonlinear_dim) - if self.constant: - self.output_dim += 1 - self.set_output_shape((None, self.output_dim)) - - def init_state(self, num_batch=1): - """Initialize the node state which depends on batch size.""" - # To store the last inputs. - # Note, the batch axis is not in the first dimension, so we - # manually handle the state of NVAR, rather return it. - state = jnp.zeros((self.num_delay, num_batch, self.input_dim)) - if self.store is None: - self.store = bm.Variable(state) - else: - self.store._value = state - - def forward(self, ff, fb=None, **shared_kwargs): - all_parts = [] - # 1. Store the current input - ff = bm.concatenate(ff, axis=-1) - self.store[self.idx[0]] = ff - # 2. Linear part: - # select all previous inputs, including the current, with strides - select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay - linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) - linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1)) - # 3. constant - if self.constant: - constant = jnp.ones((linear_parts.shape[0], 1), dtype=ff.dtype) - all_parts.append(constant) - all_parts.append(linear_parts) - # 3. Nonlinear part: - # select monomial terms and compute them - for ids in self.comb_ids: - all_parts.append(jnp.prod(linear_parts[:, ids], axis=2)) - # 4. Finally - self.idx.value = (self.idx + 1) % self.num_delay - return jnp.concatenate(all_parts, axis=-1) - - def get_feature_names(self, for_plot=False): - """Get output feature names for transformation. - - Returns - ------- - feature_names_out : list of str - Transformed feature names. - """ - if not self.is_initialized: - raise ValueError('Please initialize the node first.') - linear_names = [f'x{i}(t)' for i in range(self.input_dim)] - for di in range(1, self.delay): - linear_names.extend([((f'x{i}_' + r'{t-%d}' % (di * self.stride)) - if for_plot else - f'x{i}(t-{di * self.stride})') - for i in range(self.input_dim)]) - nonlinear_names = [] - for ids in self.comb_ids: - for id_ in np.asarray(ids): - uniques, counts = np.unique(id_, return_counts=True) - nonlinear_names.append(" ".join( - "%s^%d" % (linear_names[ind], exp) if (exp != 1) else linear_names[ind] - for ind, exp in zip(uniques, counts) - )) - if for_plot: - all_names = [f'${n}$' for n in linear_names] + [f'${n}$' for n in nonlinear_names] - else: - all_names = linear_names + nonlinear_names - if self.constant: - all_names = ['1'] + all_names - return all_names - - def get_feature_names_for_plot(self): - """Get output feature names for matplotlib plotting. - - Returns - ------- - feature_names_out : list of str - Transformed feature names. - """ - return self.get_feature_names(for_plot=True) diff --git a/brainpy/compat/nn/nodes/RC/reservoir.py b/brainpy/compat/nn/nodes/RC/reservoir.py deleted file mode 100644 index 976b8d606..000000000 --- a/brainpy/compat/nn/nodes/RC/reservoir.py +++ /dev/null @@ -1,255 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Optional, Union, Callable - -import brainpy.math as bm -from brainpy.initialize import Normal, ZeroInit, Initializer, init_param -from brainpy.compat.nn.base import RecurrentNode -from brainpy.compat.nn.datatypes import MultipleData -from brainpy.tools.checking import (check_shape_consistency, - check_float, - check_initializer, - check_string) -from brainpy.types import Tensor - -__all__ = [ - 'Reservoir', -] - - -class Reservoir(RecurrentNode): - r"""Reservoir node, a pool of leaky-integrator neurons - with random recurrent connections [1]_. - - Parameters - ---------- - num_unit: int - The number of reservoir nodes. - ff_initializer: Initializer - The initialization method for the feedforward connections. - rec_initializer: Initializer - The initialization method for the recurrent connections. - fb_initializer: optional, Tensor, Initializer - The initialization method for the feedback connections. - bias_initializer: optional, Tensor, Initializer - The initialization method for the bias. - leaky_rate: float - A float between 0 and 1. - activation : str, callable, optional - Reservoir activation function. - - If a str, should be a :py:mod:`brainpy.math.activations` function name. - - If a callable, should be an element-wise operator on tensor. - activation_type : str - - If "internal" (default), then leaky integration happens on states transformed - by the activation function: - - .. math:: - - r[n+1] = (1 - \alpha) \cdot r[t] + - \alpha \cdot f(W_{ff} \cdot u[n] + W_{fb} \cdot b[n] + W_{rec} \cdot r[t]) - - - If "external", then leaky integration happens on internal states of - each neuron, stored in an ``internal_state`` parameter (:math:`x` in - the equation below). - A neuron internal state is the value of its state before applying - the activation function :math:`f`: - - .. math:: - - x[n+1] &= (1 - \alpha) \cdot x[t] + - \alpha \cdot f(W_{ff} \cdot u[n] + W_{rec} \cdot r[t] + W_{fb} \cdot b[n]) \\ - r[n+1] &= f(x[n+1]) - ff_connectivity : float, optional - Connectivity of input neurons, i.e. ratio of input neurons connected - to reservoir neurons. Must be in [0, 1], by default 0.1 - rec_connectivity : float, optional - Connectivity of recurrent weights matrix, i.e. ratio of reservoir - neurons connected to other reservoir neurons, including themselves. - Must be in [0, 1], by default 0.1 - fb_connectivity : float, optional - Connectivity of feedback neurons, i.e. ratio of feedabck neurons - connected to reservoir neurons. Must be in [0, 1], by default 0.1 - conn_type: str - The connectivity type, can be "dense" or "sparse". - spectral_radius : float, optional - Spectral radius of recurrent weight matrix, by default None - noise_rec : float, optional - Gain of noise applied to reservoir internal states, by default 0.0 - noise_in : float, optional - Gain of noise applied to feedforward signals, by default 0.0 - noise_fb : float, optional - Gain of noise applied to feedback signals, by default 0.0 - noise_type : optional, str, callable - Distribution of noise. Must be a random variable generator - distribution (see :py:class:`brainpy.math.random.RandomState`), - by default "normal". - seed: optional, int - The seed for random sampling in this node. - - References - ---------- - .. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks." - Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686. - """ - data_pass = MultipleData('sequence') - - def __init__( - self, - num_unit: int, - leaky_rate: float = 0.3, - activation: Union[str, Callable] = 'tanh', - activation_type: str = 'internal', - ff_initializer: Union[Initializer, Callable, Tensor] = Normal(scale=0.1), - rec_initializer: Union[Initializer, Callable, Tensor] = Normal(scale=0.1), - fb_initializer: Optional[Union[Initializer, Callable, Tensor]] = Normal(scale=0.1), - bias_initializer: Optional[Union[Initializer, Callable, Tensor]] = ZeroInit(), - ff_connectivity: float = 0.1, - rec_connectivity: float = 0.1, - fb_connectivity: float = 0.1, - conn_type='dense', - spectral_radius: Optional[float] = None, - noise_ff: float = 0., - noise_rec: float = 0., - noise_fb: float = 0., - noise_type: str = 'normal', - seed: Optional[int] = None, - trainable: bool = False, - **kwargs - ): - super(Reservoir, self).__init__(trainable=trainable, **kwargs) - - # parameters - self.num_unit = num_unit - assert num_unit > 0, f'Must be a positive integer, but we got {num_unit}' - self.leaky_rate = leaky_rate - check_float(leaky_rate, 'leaky_rate', 0., 1.) - self.activation = bm.activations.get(activation) - self.activation_type = activation_type - check_string(activation_type, 'activation_type', ['internal', 'external']) - self.rng = bm.random.RandomState(seed) - check_float(spectral_radius, 'spectral_radius', allow_none=True) - self.spectral_radius = spectral_radius - - # initializations - check_initializer(ff_initializer, 'ff_initializer', allow_none=False) - check_initializer(rec_initializer, 'rec_initializer', allow_none=False) - check_initializer(fb_initializer, 'fb_initializer', allow_none=True) - check_initializer(bias_initializer, 'bias_initializer', allow_none=True) - self.ff_initializer = ff_initializer - self.fb_initializer = fb_initializer - self.rec_initializer = rec_initializer - self.bias_initializer = bias_initializer - - # connectivity - check_float(ff_connectivity, 'ff_connectivity', 0., 1.) - check_float(rec_connectivity, 'rec_connectivity', 0., 1.) - check_float(fb_connectivity, 'fb_connectivity', 0., 1.) - self.ff_connectivity = ff_connectivity - self.rec_connectivity = rec_connectivity - self.fb_connectivity = fb_connectivity - check_string(conn_type, 'conn_type', ['dense', 'sparse']) - self.conn_type = conn_type - - # noises - check_float(noise_ff, 'noise_ff') - check_float(noise_fb, 'noise_fb') - check_float(noise_rec, 'noise_rec') - self.noise_ff = noise_ff - self.noise_fb = noise_fb - self.noise_rec = noise_rec - self.noise_type = noise_type - check_string(noise_type, 'noise_type', ['normal', 'uniform']) - - def init_ff_conn(self): - """Initialize feedforward connections, weights, and variables.""" - unique_shape, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) - self.set_output_shape(unique_shape + (self.num_unit,)) - - # initialize feedforward weights - weight_shape = (sum(free_shapes), self.num_unit) - self.Wff_shape = weight_shape - self.Wff = init_param(self.ff_initializer, weight_shape) - if self.ff_connectivity < 1.: - conn_mat = self.rng.random(weight_shape) > self.ff_connectivity - self.Wff[conn_mat] = 0. - if self.conn_type == 'sparse' and self.ff_connectivity < 1.: - self.ff_pres, self.ff_posts = bm.where(bm.logical_not(conn_mat)) - self.Wff = self.Wff[self.ff_pres, self.ff_posts] - if self.trainable: - self.Wff = bm.TrainVar(self.Wff) - - # initialize recurrent weights - recurrent_shape = (self.num_unit, self.num_unit) - self.Wrec = init_param(self.rec_initializer, recurrent_shape) - if self.rec_connectivity < 1.: - conn_mat = self.rng.random(recurrent_shape) > self.rec_connectivity - self.Wrec[conn_mat] = 0. - if self.spectral_radius is not None: - current_sr = max(abs(bm.linalg.eig(self.Wrec)[0])) - self.Wrec *= self.spectral_radius / current_sr - if self.conn_type == 'sparse' and self.rec_connectivity < 1.: - self.rec_pres, self.rec_posts = bm.where(bm.logical_not(conn_mat)) - self.Wrec = self.Wrec[self.rec_pres, self.rec_posts] - self.bias = init_param(self.bias_initializer, (self.num_unit,)) - if self.trainable: - self.Wrec = bm.TrainVar(self.Wrec) - self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - - # initialize feedback weights - self.Wfb = None - - def init_state(self, num_batch=1): - # initialize internal state - return bm.zeros((num_batch, self.num_unit)) - - def init_fb_conn(self): - """Initialize feedback connections, weights, and variables.""" - if self.feedback_shapes is not None: - unique_shape, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) - fb_shape = (sum(free_shapes), self.num_unit) - self.Wfb_shape = fb_shape - self.Wfb = init_param(self.fb_initializer, fb_shape) - if self.fb_connectivity < 1.: - conn_mat = self.rng.random(fb_shape) > self.fb_connectivity - self.Wfb[conn_mat] = 0. - if self.conn_type == 'sparse' and self.fb_connectivity < 1.: - self.fb_pres, self.fb_posts = bm.where(bm.logical_not(conn_mat)) - self.Wfb = self.Wfb[self.fb_pres, self.fb_posts] - if self.trainable: - self.Wfb = bm.TrainVar(self.Wfb) - - def forward(self, ff, fb=None, **shared_kwargs): - """Feedforward output.""" - # inputs - x = bm.concatenate(ff, axis=-1) - if self.noise_ff > 0: x += self.noise_ff * self.rng.uniform(-1, 1, x.shape) - if self.conn_type == 'sparse' and self.ff_connectivity < 1.: - sparse = {'data': self.Wff, 'index': (self.ff_pres, self.ff_posts), 'shape': self.Wff_shape} - hidden = bm.sparse_matmul(x, sparse) - else: - hidden = bm.dot(x, self.Wff) - # feedback - if self.Wfb is not None: - assert fb is not None, 'Should provide feedback signals, but we got None.' - fb = bm.concatenate(fb, axis=-1) - if self.noise_fb: fb += self.noise_fb * self.rng.uniform(-1, 1, fb.shape) - if self.conn_type == 'sparse' and self.fb_connectivity < 1.: - sparse = {'data': self.Wfb, 'index': (self.fb_pres, self.fb_posts), 'shape': self.Wfb_shape} - hidden += bm.sparse_matmul(fb, sparse) - else: - hidden += bm.dot(fb, self.Wfb) - # recurrent - if self.conn_type == 'sparse' and self.rec_connectivity < 1.: - sparse = {'data': self.Wrec, 'index': (self.rec_pres, self.rec_posts), 'shape': (self.num_unit, self.num_unit)} - hidden += bm.sparse_matmul(self.state, sparse) - else: - hidden += bm.dot(self.state, self.Wrec) - if self.activation_type == 'internal': - hidden = self.activation(hidden) - if self.noise_rec > 0.: hidden += self.noise_rec * self.rng.uniform(-1, -1, self.state.shape) - # new state/output - state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden - if self.activation_type == 'external': - state = self.activation(state) - self.state.value = state - return state diff --git a/brainpy/compat/nn/nodes/__init__.py b/brainpy/compat/nn/nodes/__init__.py deleted file mode 100644 index 162464095..000000000 --- a/brainpy/compat/nn/nodes/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# -*- coding: utf-8 -*- - -from .ANN import * -from .base import * -from .RC import * diff --git a/brainpy/compat/nn/nodes/base/__init__.py b/brainpy/compat/nn/nodes/base/__init__.py deleted file mode 100644 index 9ff36c2b6..000000000 --- a/brainpy/compat/nn/nodes/base/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- - -from .activation import * -from .dense import * -from .io import * -from .ops import * diff --git a/brainpy/compat/nn/nodes/base/activation.py b/brainpy/compat/nn/nodes/base/activation.py deleted file mode 100644 index 9f67776bc..000000000 --- a/brainpy/compat/nn/nodes/base/activation.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Optional, Any - -from brainpy.math import activations -from brainpy.compat.nn.base import Node - -__all__ = [ - 'Activation' -] - - -class Activation(Node): - """Activation node. - - Parameters - ---------- - activation : str - The name of the activation function. - fun_setting : optional, dict - The settings for the activation function. - """ - - def __init__(self, - activation: str = 'relu', - fun_setting: Optional[Dict[str, Any]] = None, - trainable: bool = False, - name: str = None, - **kwargs): - if name is None: - name = self.unique_name(type_=f'{activation}_activation') - super(Activation, self).__init__(name=name, trainable=trainable, **kwargs) - - self._activation = activations.get(activation) - self._fun_setting = dict() if (fun_setting is None) else fun_setting - assert isinstance(self._fun_setting, dict), '"fun_setting" must be a dict.' - - def init_ff_conn(self): - self.set_output_shape(self.feedforward_shapes) - - def forward(self, ff, **shared_kwargs): - return self._activation(ff, **self._fun_setting) diff --git a/brainpy/compat/nn/nodes/base/dense.py b/brainpy/compat/nn/nodes/base/dense.py deleted file mode 100644 index 6a9dd47d3..000000000 --- a/brainpy/compat/nn/nodes/base/dense.py +++ /dev/null @@ -1,223 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Sequence, Optional, Callable, Union - -import jax.numpy as jnp - -from brainpy import math as bm -from brainpy.errors import MathError -from brainpy.initialize import XavierNormal, ZeroInit, Initializer, init_param -from brainpy.compat.nn.base import Node -from brainpy.compat.nn.datatypes import MultipleData -from brainpy.tools.checking import (check_shape_consistency, - check_initializer) -from brainpy.types import Tensor - -__all__ = [ - 'DenseMD', - 'Dense', -] - - -class DenseMD(Node): - r"""A linear transformation applied over the last dimension of the input. - - Mathematically, this node can be defined as: - - .. math:: - - y = x \cdot W + b - - Parameters - ---------- - num_unit: int - The number of the output features. A positive integer. - weight_initializer: optional, Initializer - The weight initialization. - bias_initializer: optional, Initializer - The bias initialization. - trainable: bool - Enable training this node or not. (default True) - """ - - data_pass = MultipleData('sequence') - - def __init__( - self, - num_unit: int, - weight_initializer: Union[Initializer, Callable, Tensor] = XavierNormal(), - bias_initializer: Optional[Union[Initializer, Callable, Tensor]] = ZeroInit(), - trainable: bool = True, - **kwargs - ): - super(DenseMD, self).__init__(trainable=trainable, **kwargs) - - # shape - self.num_unit = num_unit - if num_unit < 0: - raise ValueError(f'Received an invalid value for `num_unit`, expected ' - f'a positive integer. Received: num_unit={num_unit}') - - # weight initializer - self.weight_initializer = weight_initializer - self.bias_initializer = bias_initializer - check_initializer(weight_initializer, 'weight_initializer') - check_initializer(bias_initializer, 'bias_initializer', allow_none=True) - - # weights - self.Wff = None - self.bias = None - self.Wfb = None - - def init_ff_conn(self): - # shapes - other_size, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) - # set output size - self.set_output_shape(other_size + (self.num_unit,)) - - # initialize feedforward weights - self.Wff = init_param(self.weight_initializer, (sum(free_shapes), self.num_unit)) - self.bias = init_param(self.bias_initializer, (self.num_unit,)) - if self.trainable: - self.Wff = bm.TrainVar(self.Wff) - self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - - def init_fb_conn(self): - other_size, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) - - # initialize feedback weights - weight_shapes = (sum(free_shapes), self.num_unit) - if self.trainable: - self.Wfb = bm.TrainVar(init_param(self.weight_initializer, weight_shapes)) - else: - self.Wfb = init_param(self.weight_initializer, weight_shapes) - - def forward(self, ff: Sequence[Tensor], fb=None, **shared_kwargs): - ff = bm.concatenate(ff, axis=-1) - res = ff @ self.Wff - if fb is not None: - fb = bm.concatenate(fb, axis=-1) - res += fb @ self.Wfb - if self.bias is not None: - res += self.bias - return res - - -class Dense(DenseMD): - r"""A linear transformation. - - Different from :py:class:`GeneralDense`, this class only supports 2D input data. - - Mathematically, this node can be defined as: - - .. math:: - - y = x \cdot W+ b - - Parameters - ---------- - num_unit: int - The number of the output features. A positive integer. - weight_initializer: optional, Initializer - The weight initialization. - bias_initializer: optional, Initializer - The bias initialization. - trainable: bool - Enable training this node or not. (default True) - """ - data_pass = MultipleData('sequence') - - def __init__( - self, - num_unit: int, - weight_initializer: Union[Initializer, Callable, Tensor] = XavierNormal(), - bias_initializer: Optional[Union[Initializer, Callable, Tensor]] = ZeroInit(), - **kwargs - ): - super(Dense, self).__init__(num_unit=num_unit, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - **kwargs) - # set output shape - self.set_output_shape((None, self.num_unit)) - - def init_ff_conn(self): - # shapes - other_size, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) - if other_size != (None,): - raise ValueError(f'{self.__class__.__name__} only support 2D inputs, while ' - f'we got {len(other_size) + 1}-D shapes. For >2D inputs, ' - f'you should use brainpy.nn.{DenseMD.__name__} instead. ') - super(Dense, self).init_ff_conn() - - def init_fb_conn(self): - other_size, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) - if other_size != (None,): - raise ValueError(f'{self.__class__.__name__} only support 2D inputs, while ' - f'we got {len(other_size) + 1}-D shapes. For >2D inputs, ' - f'you should use brainpy.nn.{DenseMD.__name__} instead. ') - super(Dense, self).init_fb_conn() - - def offline_fit( - self, - targets: Tensor, - ffs: Sequence[Tensor], - fbs: Optional[Sequence[Tensor]] = None, - ): - """The offline training interface for the Dense node.""" - # data checking - ffs = bm.concatenate(ffs, axis=-1) - if not isinstance(targets, (bm.ndarray, jnp.ndarray)): - raise MathError(f'"targets" must be a tensor, but got {type(targets)}') - if ffs.ndim != 3: - raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, ' - f'num_feature), but we got {ffs.shape}') - if targets.ndim != 3: - raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, ' - f'num_feature), but we got {targets.shape}') - if ffs.shape[0] != targets.shape[0]: - raise ValueError(f'Batch size of the input and target data should be ' - f'the same, while we got {ffs.shape[0]} != {targets.shape[0]}.') - if ffs.shape[1] != targets.shape[1]: - raise MathError(f'The time dimension of input and target data should be ' - f'the same, while we got {ffs.shape[1]} != {targets.shape[1]}') - if fbs is not None: - fbs = bm.concatenate(fbs, axis=-1) - if fbs.ndim != 3: - raise ValueError(f'"fbs" must be a 3D tensor with shape of (num_sample, num_time, ' - f'num_feature), but we got {fbs.shape}') - if ffs.shape[0] != fbs.shape[0]: - raise ValueError(f'Batch size of the feedforward and the feedback inputs should be ' - f'the same, while we got {ffs.shape[0]} != {fbs.shape[0]}.') - if ffs.shape[1] != fbs.shape[1]: - raise MathError(f'The time dimension of feedforward and feedback inputs should be ' - f'the same, while we got {ffs.shape[1]} != {fbs.shape[1]}') - - # get input and target training data - inputs = ffs - num_ff_input = inputs.shape[2] - if self.bias is not None: - inputs = bm.concatenate([bm.ones(ffs.shape[:2] + (1,)), inputs], axis=-1) # (..., 1 + num_ff_input) - if fbs is not None: - inputs = bm.concatenate([inputs, fbs], axis=-1) # (..., 1 + num_ff_input + num_fb_input) - - # solve weights by offline training methods - weights = self.offline_fit_by(targets, inputs) - - # assign trained weights - if self.bias is None: - if fbs is None: - self.Wff.value = weights - else: - self.Wff.value, self.Wfb.value = bm.split(weights, [num_ff_input]) - else: - if fbs is None: - bias, Wff = bm.split(weights, [1]) - self.bias.value = bias[0] - self.Wff.value = Wff - else: - bias, Wff, Wfb = bm.split(weights, [1, 1 + num_ff_input]) - self.bias.value = bias[0] - self.Wff.value = Wff - self.Wfb.value = Wfb diff --git a/brainpy/compat/nn/nodes/base/io.py b/brainpy/compat/nn/nodes/base/io.py deleted file mode 100644 index 1fcc0f74b..000000000 --- a/brainpy/compat/nn/nodes/base/io.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Tuple, Union - -from brainpy.compat.nn.base import Node -from brainpy.tools.others import to_size - -__all__ = [ - 'Input', -] - - -class Input(Node): - """The input node.""" - - def __init__( - self, - input_shape: Union[Tuple[int, ...], int], - trainable: bool = False, - name: str = None, - ): - super(Input, self).__init__(name=name, trainable=trainable, input_shape=input_shape) - self.set_feedforward_shapes({self.name: (None,) + to_size(input_shape)}) - self._init_ff_conn() - - def init_ff_conn(self): - self.set_output_shape(self.feedforward_shapes) - - def forward(self, ff, **shared_kwargs): - return ff diff --git a/brainpy/compat/nn/nodes/base/ops.py b/brainpy/compat/nn/nodes/base/ops.py deleted file mode 100644 index bb091f672..000000000 --- a/brainpy/compat/nn/nodes/base/ops.py +++ /dev/null @@ -1,113 +0,0 @@ -# -*- coding: utf-8 -*- - - -import numpy as np - -from brainpy import math as bm, tools -from brainpy.compat.nn.base import Node -from brainpy.compat.nn.datatypes import MultipleData -from brainpy.tools.checking import check_shape_consistency - -__all__ = [ - 'Concat', 'Select', 'Reshape', 'Summation', -] - - -class Concat(Node): - """ - Concatenate multiple inputs into one. - - Parameters - ---------- - axis : int - The axis of concatenation to perform. - """ - - data_pass = MultipleData('sequence') - - def __init__(self, axis=-1, trainable=False, **kwargs): - super(Concat, self).__init__(trainable=trainable, **kwargs) - self.axis = axis - - def init_ff_conn(self): - unique_shape, free_shapes = check_shape_consistency(self.feedforward_shapes, self.axis) - out_size = list(unique_shape) - out_size.insert(self.axis, sum(free_shapes)) - self.set_output_shape(out_size) - - def forward(self, ff, **shared_kwargs): - return bm.concatenate(ff, axis=self.axis) - - -class Select(Node): - """ - Select a subset of the given input. - """ - - def __init__(self, index, trainable=False, **kwargs): - super(Select, self).__init__(trainable=trainable, **kwargs) - if isinstance(index, int): - self.index = bm.asarray([index]).value - - def init_ff_conn(self): - out_size = bm.zeros(self.feedforward_shapes[1:])[self.index].shape - self.set_output_shape((None,) + out_size) - - def forward(self, ff, **shared_kwargs): - return ff[..., self.index] - - -class Reshape(Node): - """ - Reshape the input tensor to another tensor. - - Parameters - ---------- - shape: int, sequence of int - The reshaped size. This shape does not contain the batch size. - """ - - def __init__(self, shape, trainable=False, **kwargs): - super(Reshape, self).__init__(trainable=trainable, **kwargs) - self.shape = tools.to_size(shape) - assert (None not in self.shape), 'Batch size can not be defined in the reshaped size.' - - def init_ff_conn(self): - in_size = self.feedforward_shapes[1:] - if -1 in self.shape: - assert self.shape.count(-1) == 1, f'Cannot set shape with multiple -1. But got {self.shape}' - length = np.prod(in_size) - out_size = list(self.shape) - m1_idx = out_size.index(-1) - other_shape = out_size[:m1_idx] + out_size[m1_idx + 1:] - m1_length = int(length / np.prod(other_shape)) - out_size[m1_idx] = m1_length - else: - assert np.prod(in_size) == np.prod(self.shape) - out_size = self.shape - self.set_output_shape((None,) + tuple(out_size)) - - def forward(self, ff, **shared_kwargs): - return bm.reshape(ff, self.shape) - - -class Summation(Node): - """ - Sum all input tensors into one. - - All inputs should be broadcast compatible. - """ - data_pass = MultipleData('sequence') - - def __init__(self, trainable=False, **kwargs): - super(Summation, self).__init__(trainable=trainable, **kwargs) - - def init_ff_conn(self): - unique_shape, _ = check_shape_consistency(self.feedforward_shapes, None, True) - self.set_output_shape(list(unique_shape)) - - def forward(self, ff, **shared_kwargs): - res = ff[0] - for v in ff[1:]: - res = res + v - return res diff --git a/brainpy/compat/nn/operations.py b/brainpy/compat/nn/operations.py deleted file mode 100644 index fb0f13cab..000000000 --- a/brainpy/compat/nn/operations.py +++ /dev/null @@ -1,527 +0,0 @@ -# -*- coding: utf-8 -*- - -"""This module provides basic operations for constructing node graphs. - -It supports the following operations: - -1. feedforward connection: ">>", ">>=" -2. feedback connection: "<<", "<<=" -3. merge two nodes: "&", "&=" -4. select subsets of one node: "[:]" -5. concatenate a sequence of nodes: "[node1, node2, ...]", "(node1, node2, ...)" -6. wrap a set of nodes: "{node1, node2, ...}" - -However, all operations should satisfy the following assumptions: - -1. Feedback connection of `(node1, node2)` should have a feedforward path from `node2` to `node1`. -2. Feedforward or feedback connections cannot generate a cycle. -3. Cannot concatenate multiple receiver nodes, e.g., `a >> [b, c]` is forbidden, but `a >> {b, c}` - is allowed. - -""" - -from itertools import product -from typing import Union, Sequence, Set - -from brainpy.compat.nn import graph_flow -from brainpy.compat.nn.base import Node, Network, FrozenNetwork -from brainpy.compat.nn.datatypes import SingleData -from brainpy.compat.nn.nodes.base import Select, Concat -from brainpy.types import Tensor - -__all__ = [ - 'ff_connect', 'fb_connect', 'merge', 'select', 'concatenate', -] - - -def _retrieve_nodes_and_edges(senders: Union[Node, Sequence[Node]], - receivers: Union[Node, Sequence[Node]]): - # check senders - if isinstance(senders, (tuple, list)): - senders = [concatenate(senders)] - elif isinstance(senders, set): - senders = list(senders) - elif isinstance(senders, Node): - senders = [senders] - else: - raise TypeError(f"Impossible to send connection from {senders}: it is not " - f"a Node or a Network instance.") - - # check receivers - if isinstance(receivers, (tuple, list)): - raise TypeError('Cannot concatenate a list/tuple of receivers. ' - 'Please use set to wrap multiple receivers instead.') - elif isinstance(receivers, set): - receivers = list(receivers) - elif isinstance(receivers, Node): - receivers = [receivers] - else: - raise TypeError(f"Impossible to send connection to {receivers}: it is not " - f"a Node or a Network instance.") - - # fetch all nodes in two subgraphs - all_nodes = set() - for node in senders + receivers: - if isinstance(node, FrozenNetwork): - raise TypeError(f"Cannot connect {FrozenNetwork.__name__} to other Nodes.") - if isinstance(node, Network): - all_nodes.update(set(node.lnodes)) - elif isinstance(node, Node): - all_nodes.add(node) - else: - raise TypeError(f"Impossible to link nodes: object {node} is neither a " - f"'brainpy.rnn.Node' nor a 'brainpy.rnn.Network'.") - - # fetch all feedforward edges in two subgraphs - all_ff_edges = set() - for node in senders + receivers: - if isinstance(node, FrozenNetwork): - raise TypeError(f"Cannot connect {FrozenNetwork.__name__} to other Nodes.") - if isinstance(node, Network): - all_ff_edges.update(set(node.ff_edges)) - - # fetch all feedback edges in two subgraphs - all_fb_edges = set() - for node in senders + receivers: - if isinstance(node, FrozenNetwork): - raise TypeError(f"Cannot connect {FrozenNetwork.__name__} to other Nodes.") - if isinstance(node, Network): - all_fb_edges.update(set(node.fb_edges)) - - # create edges between output nodes of the - # subgraph 1 and input nodes of the subgraph 2. - all_senders = set() - for node in senders: - if isinstance(node, Network) and not isinstance(node, FrozenNetwork): - all_senders.update(node.exit_nodes) - else: - all_senders.add(node) - all_receivers = set() - for node in receivers: - if isinstance(node, Network) and not isinstance(node, FrozenNetwork): - all_receivers.update(node.entry_nodes) - else: - all_receivers.add(node) - - return all_nodes, all_ff_edges, all_fb_edges, all_senders, all_receivers - - -def _reorganize_many2one(ff_edges, fb_edges): - """Reorganize the many-to-one connections. - - If some node whose "data_type" is :py:class:`brainpy.nn.datatypes.SingleData` receives - multiple feedforward or feedback connections, we should concatenate all feedforward - inputs (or feedback inputs) into one instance of :py:class:`brainpy.nn.Concat`, then - the new Concat instance feeds into this node. - - """ - from brainpy.compat.nn.nodes.base import Concat - - new_nodes = [] - - # find parents according to the child - ff_senders = dict() - for edge in ff_edges: - sender, receiver = edge - if receiver not in ff_senders: - ff_senders[receiver] = [sender] - else: - ff_senders[receiver].append(sender) - for receiver, senders in ff_senders.items(): - if isinstance(receiver.data_pass, SingleData): - if len(senders) > 1: - concat_nodes = [node for node in senders if isinstance(node, Concat)] - if len(concat_nodes) == 1: - concat = concat_nodes[0] - for sender in senders: - if sender != concat: - ff_edges.remove((sender, receiver)) - ff_edges.add((sender, concat)) - else: - concat = Concat() - for sender in senders: - ff_edges.remove((sender, receiver)) - ff_edges.add((sender, concat)) - ff_edges.add((concat, receiver)) - new_nodes.append(concat) - - # find parents according to the child - fb_senders = dict() - for edge in fb_edges: - sender, receiver = edge - if receiver not in fb_senders: - fb_senders[receiver] = [sender] - else: - fb_senders[receiver].append(sender) - for receiver, senders in fb_senders.items(): - if isinstance(receiver.data_pass, SingleData): - if len(senders) > 1: - concat_nodes = [node for node in senders if isinstance(node, Concat)] - if len(concat_nodes) == 1: - concat = concat_nodes[0] - for sender in senders: - if sender != concat: - fb_edges.remove((sender, receiver)) - ff_edges.add((sender, concat)) - else: - concat = Concat() - for sender in senders: - fb_edges.remove((sender, receiver)) - ff_edges.add((sender, concat)) - fb_edges.add((concat, receiver)) - new_nodes.append(concat) - - return new_nodes, ff_edges, fb_edges - - -def merge( - node: Node, - *other_nodes: Node, - inplace: bool = False, - name: str = None, - need_detect_cycle=True -) -> Network: - """Merge different :py:class:`~.Node` or :py:class:`brainpy.nn.base.Network` - instances into a single :py:class:`brainpy.nn.base.Network` instance. - - :py:class:`~.Node` instances contained in the network to merge will be - gathered in a single network, along with all previously defined connections - between them, if they exists. - - You can also perform this operation using the ``&`` operator:: - - network = (node1 >> node2) & (node1 >> node3)) - - This is equivalent to:: - - network = merge((node1 >> node2), (node1 >> node3)) - - The inplace operator can also be used:: - - network &= other_network - - Parameters - ---------- - node: Network, Node - First node or network to merge. - *other_nodes : Network, Node - All nodes to merge. - inplace: bool, default to False - If `True`, then will update `node` inplace. If `node` is not a Network - instance, this parameter will causes the function to raise an error. - name: str, optional - Name of the resulting Network. - need_detect_cycle: bool - Whether need to detect the cycle defined in the graph. - - Returns - ------- - Network - A new :py:class:`brainpy.nn.base.Network` instance. - """ - # checking - for n in other_nodes + (node,): - if not isinstance(n, Node): - raise TypeError(f"Impossible to merge nodes: object {type(n)} is not a Node instance.") - - # get all node and edges - all_nodes = set() - all_ff_edges = set() - all_fb_edges = set() - for n in other_nodes + (node,): - if isinstance(n, FrozenNetwork): - raise TypeError(f'{FrozenNetwork.__name__} cannot merge with other nodes.') - # fuse models nodes and edges (right side argument) - if isinstance(n, Network): - all_nodes |= set(n.lnodes) - all_ff_edges |= set(n.ff_edges) - all_fb_edges |= set(n.fb_edges) - elif isinstance(n, Node): - all_nodes.add(n) - - # reorganize - new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges) - all_nodes.update(new_nodes) - - # detect cycles in the graph flow - all_nodes = tuple(all_nodes) - all_ff_edges = tuple(all_ff_edges) - all_fb_edges = tuple(all_fb_edges) - if need_detect_cycle: - if graph_flow.detect_cycle(all_nodes, all_ff_edges): - raise ValueError('We detect cycles in feedforward connections. ' - 'Maybe you should replace some connection with ' - 'as feedback ones.') - if graph_flow.detect_cycle(all_nodes, all_fb_edges): - raise ValueError('We detect cycles in feedback connections. ') - - if inplace: - if not isinstance(node, Network) or isinstance(node, FrozenNetwork): - raise ValueError(f"Impossible to merge nodes inplace: " - f"{node} is not a {Network.__name__} instance.") - return node.replace_graph(nodes=all_nodes, - ff_edges=all_ff_edges, - fb_edges=all_fb_edges) - - else: - return Network(nodes=all_nodes, - ff_edges=all_ff_edges, - fb_edges=all_fb_edges, - name=name) - - -def ff_connect( - senders: Union[Node, Sequence[Node], Set[Node]], - receivers: Union[Node, Set[Node]], - inplace: bool = False, - name: str = None, - need_detect_cycle=True -) -> Network: - """Connect two sequences of :py:class:`~.Node` instances to form - a :py:class:`brainpy.nn.base.Network` instance. `senders` output will be used as - input for `receivers` in the created network. This is similar to a - function composition operation: - - .. math:: - - network(x) = (sender \\circ receiver)(x) = receiver(sender(x)) - - You can also perform this operation using the ``>>`` operator:: - - network = sender >> receiver - - Or using this function:: - - network = ff_connect(sender, receiver) - - - `sender` and `receiver` can also be :py:class:`brainpy.nn.base.Network` instances. In this - case, the new :py:class:`brainpy.nn.base.Network` created will contain all nodes previously - contained in all the networks, and link all `node1` outputs to all `node2` - inputs. This allows to chain the ``>>`` operator:: - - step1 = node0 >> node1 # this is a network - step2 = step1 >> node2 # this is another - - - `node1` can finally be lists or tuples of nodes. In this - case, all `node1` outputs will be linked to a :py:class:`~.Concat` node to - concatenate them, and the :py:class:`~.Concat` node will be linked to all - `node2` inputs:: - - # many-concat-to-one - network = [node1, node2, ..., node] >> node_out - - - If you do not want to concatenate all input nodes, you can use `set` to - wrap all input nodes at once. Then, `node2` will receive multiple inputs - defined in `node1`:: - - # many-to-one - network = {node1, node2, ..., node_N} >> node_out - - - In the case of "one-to-many" feedforward connection, `node2` only support - a set of node. Using list or tuple to wrap multiple receivers will concatenate - all nodes in the receiver end. This will cause errors:: - - # wrong operation of one-to-many - network = node_in >> {node1, node2, ..., node_N} - - # correct operation of one-to-many - network = node_in >> {node1, node2, ..., node_N} - - - "many-to-many" connection is also allowed. - - You can still use the ``>>`` operator in this situation, - except for many-to-many nodes connections:: - - # many-to-many - {node1, node2, ..., node} >> {node1, node2, ..., node} - - Parameters - ---------- - senders, receivers : Node, sequence of Node - Nodes or sequence of nodes to connect feedforward connections. - inplace: bool - Whether inplace update the node. - name: str, optional - Name for the chaining Network. - need_detect_cycle: bool - Whether we need to detect cycles exit in the final network. - - Returns - ------- - Network - A :py:class:`brainpy.nn.base.Network` instance chaining the nodes. - - Notes - ----- - - Be careful to how you link the different nodes: `reservoirpy` does not - allow to have circular dependencies between them:: - - network = node1 >> node2 # fine - network = node1 >> node2 >> node1 # raises! data would flow in - # circles forever... - """ - - all_nodes, all_ff_edges, all_fb_edges, ff_senders, ff_receivers = _retrieve_nodes_and_edges(senders, receivers) - new_ff_edges = set(product(ff_senders, ff_receivers)) - - # all outputs from subgraph 1 are connected to - # all inputs from subgraph 2. - all_ff_edges |= new_ff_edges - - # reorganize - new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges) - all_nodes.update(new_nodes) - - # detect cycles in the graph flow - all_nodes = tuple(all_nodes) - all_ff_edges = tuple(all_ff_edges) - all_fb_edges = tuple(all_fb_edges) - if need_detect_cycle: - if graph_flow.detect_cycle(all_nodes, all_ff_edges): - raise ValueError('We detect cycles in feedforward connections. ' - 'Maybe you should replace some connection with ' - 'as feedback ones.') - if graph_flow.detect_cycle(all_nodes, all_fb_edges): - raise ValueError('We detect cycles in feedback connections. ') - - # feedforward - if inplace: - if not isinstance(receivers, Network): - raise TypeError(f'Cannot inplace update the feedback connection of a Node instance: {receivers}') - if name is not None: - raise ValueError('Cannot set name when inplace=True.') - receivers.replace_graph(nodes=all_nodes, - ff_edges=all_ff_edges, - fb_edges=all_fb_edges) - return receivers - else: - return Network(nodes=all_nodes, - ff_edges=all_ff_edges, - fb_edges=all_fb_edges, - name=name) - - -def fb_connect( - senders: Union[Node, Sequence[Node], Set[Node]], - receivers: Union[Node, Set[Node]], - inplace: bool = False, - name: str = None, - need_detect_cycle=True -) -> Node: - """Create a feedback connection from ``sender`` node to ``receiver`` node. - Feedbacks nodes will be called at runtime using data from the previous call. - - You can also perform this operation using the ``<<`` operator. - - Which means that a feedback connection is now created between `node1` and - `node2`. In other words, the forward function of `node1` depends on the - previous output of `node2`: - - .. math:: - \\mathrm{node1}(x_t) = \\mathrm{node1}(x_t, \\mathrm{node2}(x_{t - 1})) - - You can also use this function to define feedback:: - - node1 = fb_connect(node1, node2) - # without copy (node1 is the same object throughout) - node1 = fb_connect(node1, node2, inplace=True, name="n1_copy") - - Parameters - ---------- - receivers : Node - Node receiving feedback. - senders : GenericNode - Node or Network sending feedback - inplace : bool, defaults to False - If `True`, then the function returns a copy of `node`. - name : str, optional - Name of the copy of `node` if `inplace` is `True`. - need_detect_cycle: bool - Whether we need to detect cycles in the defined network. - - Returns - ------- - Network - A network with feedback connections. - """ - - all_nodes, all_ff_edges, all_fb_edges, fb_senders, fb_receivers = _retrieve_nodes_and_edges(senders, receivers) - - # detect whether the node implement its own "init_fb_conn()" function - for node in fb_receivers: - if not node.is_feedback_input_supported: - raise ValueError(f'Establish a feedback connection to \n' - f'{node}\n' - f'is not allowed. Because this node does not ' - f'support feedback connections.') - - # detect feedforward cycle - if need_detect_cycle: - all_nodes1 = list(all_nodes) - all_ff_edges1 = tuple(all_ff_edges) - if graph_flow.detect_cycle(all_nodes1, all_ff_edges1): - raise ValueError('We detect cycles in feedforward connections. ' - 'Maybe you should replace some connection with ' - 'as feedback ones.') - # establish feedback connections - new_fb_edges = set(product(fb_senders, fb_receivers)) - - # all outputs from subgraph 1 are connected to - # all inputs from subgraph 2. - all_fb_edges |= new_fb_edges - - # reorganize - new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges) - all_nodes.update(new_nodes) - - # detect cycles in the graph flow - all_nodes = tuple(all_nodes) - all_ff_edges = tuple(all_ff_edges) - all_fb_edges = tuple(all_fb_edges) - if need_detect_cycle: - if graph_flow.detect_cycle(all_nodes, all_fb_edges): - raise ValueError('We detect cycles in feedback connections. ') - - # feedback - if inplace: - if not isinstance(receivers, Network): - raise TypeError(f'Cannot inplace update the feedback connection of a Node instance: {receivers}') - if name is not None: - raise ValueError('Cannot set name when inplace=True.') - receivers.replace_graph(nodes=all_nodes, - ff_edges=all_ff_edges, - fb_edges=all_fb_edges) - return receivers - else: - return Network(nodes=all_nodes, - ff_edges=all_ff_edges, - fb_edges=all_fb_edges, - name=name) - - -def select( - node: Node, - index: Union[int, Sequence[int], Tensor, slice], - name: str = None -): - if isinstance(node, Network) and len(node.exit_nodes) != 1: - raise ValueError(f'Cannot select subsets of states when Network instance ' - f'"{node}" has multiple output nodes.') - return ff_connect(node, Select(index=index), name=name, need_detect_cycle=False) - - -def concatenate(nodes: Sequence[Node], axis=-1, name=None): - right = Concat(axis=axis) - model = Network(name=name) - for node in nodes: - if isinstance(node, FrozenNetwork): - raise ValueError('Cannot concat a Frozen network.') - if isinstance(node, Network) and len(node.exit_nodes) > 1: - raise ValueError(f'Cannot concatenate network which has {len(node.exit_nodes)} ' - f'output nodes with other nodes.') - model = merge(model, - ff_connect(node, right, need_detect_cycle=False), - inplace=True, - need_detect_cycle=False) - return model diff --git a/brainpy/compat/nn/runners/__init__.py b/brainpy/compat/nn/runners/__init__.py deleted file mode 100644 index 4ede9194f..000000000 --- a/brainpy/compat/nn/runners/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- - - -""" -This module provides various running and training algorithms -for various neural networks. - -The supported training algorithms include - -- offline training methods, like ridge regression, linear regression, etc. -- online training methods, like recursive least squares (RLS, or Force Learning), - least mean squares (LMS), etc. -- back-propagation learning method -- and others - -The supported neural networks include - -- reservoir computing networks, -- artificial recurrent neural networks, -- and others. -""" - - -from .rnn_runner import * -from .rnn_trainer import * -from .online_trainer import * -from .offline_trainer import * -from .back_propagation import * - diff --git a/brainpy/compat/nn/runners/back_propagation.py b/brainpy/compat/nn/runners/back_propagation.py deleted file mode 100644 index 85757e63e..000000000 --- a/brainpy/compat/nn/runners/back_propagation.py +++ /dev/null @@ -1,762 +0,0 @@ -# -*- coding: utf-8 -*- - -import time -from typing import Union, Dict, Callable, Sequence - -import jax.numpy as jnp -import numpy as np -from jax import jit, random as jr -from jax.tree_util import tree_map - -import brainpy.losses as losses -import brainpy.math as bm -import brainpy.optimizers as optim -from brainpy.errors import UnsupportedError -from brainpy.compat.nn.base import Node, Network -from brainpy.compat.nn.utils import check_data_batch_size, serialize_kwargs -from brainpy.tools.checking import check_dict_data, check_float -from brainpy.types import Tensor -from .rnn_trainer import RNNTrainer - -__all__ = [ - 'BPTT', - 'BPFF', -] - - -class BPTT(RNNTrainer): - """ - The trainer implementing back propagation through time (BPTT) - algorithm for recurrent neural networks. - - """ - - def __init__( - self, - target: Node, - - # arguments for BPTT trainer - loss: Union[str, Callable], # loss function - optimizer: optim.Optimizer = None, # optimizer - max_grad_norm=None, - shuffle_data: bool = True, - jit: bool = True, - - # common arguments for RNNTrainer - **kwargs - ): - super(BPTT, self).__init__(target=target, **kwargs) - - # jit settings - if isinstance(jit, bool): - self.jit = {'fit': jit, 'predict': jit, 'loss': jit} - elif isinstance(jit, dict): - jit = {key: val for key, val in jit.items()} - self.jit = {'fit': jit.pop('fit', True), - 'predict': jit.pop('predict', True), - 'loss': jit.pop('loss', True)} - if len(jit): - raise ValueError(f'Unknown jit setting for {jit.keys()}') - else: - raise ValueError(f'Unknown "jit" setting: {jit}') - - # optimizer - if optimizer is None: - lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) - optimizer = optim.Adam(lr=lr) - self.optimizer = optimizer - - # loss - if isinstance(loss, str): - loss = getattr(losses, loss) - elif callable(loss): - loss = loss - else: - raise UnsupportedError(f'Do not support {type(loss)} to specify the loss function. ' - f'We only support str and callable function.') - self.loss_fun = loss - self._train_losses = None - self._test_losses = None - self._f_shuffle = None - - # target/output mapping types - self._mapping_type = None - - # functions - self._f_loss = dict() - self._f_train = dict() - self._f_grad = dict() - - # training parameters - self.max_grad_norm = max_grad_norm # gradient clipping - self.shuffle_data = shuffle_data - - # initialize the optimizer - if not self.target.is_initialized: - raise ValueError('Please initialize the target model first by calling "initialize()" function.') - self.optimizer.register_vars(self.target.vars().subset(bm.TrainVar).unique()) - - def __repr__(self): - name = self.__class__.__name__ - prefix = ' ' * len(name) - return (f'{name}(target={self.target}, \n\t' - f'{prefix}jit={self.jit}, \n\t' - f'{prefix}loss={self.loss_fun}, \n\t' - f'{prefix}optimizer={self.optimizer})') - - def predict( - self, - xs: Union[Tensor, Dict[str, Tensor]], - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - initial_states: Union[Tensor, Dict[str, Tensor]] = None, - initial_feedbacks: Dict[str, Tensor] = None, - reset: bool = True, - shared_kwargs: Dict = None, - **kwargs - ): - """Predict a series of input data with the given target model. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - feedbacks and its output, if users want. - - Parameters - ---------- - xs: Tensor, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. - shared_kwargs: dict - Shared keyword arguments for the given target model. - reset: bool - Whether reset the model states. Default True. - - forced_states: dict - The fixed node states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. Default None. - - .. versionadded:: 2.1.4 - - forced_feedbacks: dict - The fixed feedback states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. Default None. - - .. versionadded:: 2.1.4 - - initial_states: JaxArray, ndarray, dict - The initial states. Each tensor in ``initial_states`` must be a - tensor with the shape of `(num_sample, num_feature)`. - - .. versionadded:: 2.1.4 - - initial_feedbacks: dict - The initial feedbacks for the node in the network model. - Each tensor in ``initial_feedbacks`` must be a - tensor with the shape of `(num_sample, num_feature)`. - - .. versionadded:: 2.1.4 - - Returns - ------- - output: Tensor, dict - The model output. - """ - # check forced states/feedbacks - return super(BPTT, self).predict(xs=xs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks, - initial_states=initial_states, - initial_feedbacks=initial_feedbacks, - reset=reset, - shared_kwargs=shared_kwargs) - - def fit( - self, - train_data: Union[Callable, Sequence], - test_data: Union[Callable, Sequence] = None, - num_batch: int = 32, - num_train: int = 100, - num_report: int = 100, - reset: bool = True, - shared_kwargs: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - initial_states: Union[Tensor, Dict[str, Tensor]] = None, - initial_feedbacks: Dict[str, Tensor] = None, - ): - """ - Fit the target model according to the given training and testing data. - - Parameters - ---------- - train_data: callable, sequence of data - It can be a callable function, or a tuple/list representing `(X, Y)` data. - - Callable. This function should return a pair of `(X, Y)` data - - Sequence. It should be a pair of `(X, Y)` train set. - - ``X``: should be a tensor or a dict of tensors with the shape of - `(num_sample, num_time, num_feature)`, where `num_sample` is - the number of samples, `num_time` is the number of the time step, - and `num_feature` is the number of features. - - ``Y``: Target values. A tensor or a dict of tensors. - - If the shape of each tensor is `(num_sample, num_feature)`, - then we will only fit the model with the only last output. - - If the shape of each tensor is `(num_sample, num_time, num_feature)`, - then the fitting happens on the whole data series. - test_data: callable, sequence of data - Same as the ``train_data``. It can be a callable function, - or a tuple/list representing `(X, Y)` data. - num_batch: int - The batch size. Default 32. This setting is used when users provide - the ``train_data`` and ``test_data`` as a pair of `(X, Y)` data, rather - than a function. - num_train: int - The number of training epoch. Default 100. - num_report: int - The number of step to report the progress. Default 100 training steps. - reset: bool - Whether reset the initial states of the target model. - shared_kwargs: dict - The shared keyword arguments for the target models. - forced_states: dict - The fixed node states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. - - .. versionadded:: 2.1.4 - - forced_feedbacks: dict - The fixed feedback states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. - - .. versionadded:: 2.1.4 - - initial_states: JaxArray, ndarray, dict - The initial states. Each tensor in ``initial_states`` must be a - tensor with the shape of `(num_sample, num_feature)`. - - .. versionadded:: 2.1.4 - - initial_feedbacks: dict - The initial feedbacks for the node in the network model. - Each tensor in ``initial_feedbacks`` must be a - tensor with the shape of `(num_sample, num_feature)`. - - .. versionadded:: 2.1.4 - - """ - # training the model - all_train_losses = [] - all_test_losses = [] - train_i = 0 - t0 = time.time() - for _ in range(num_train): - train_data_ = self._get_train_data(train_data, num_batch) - - # training set - for x, y in train_data_: - self._set_initial_states(initial_states) - self._set_initial_feedbacks(initial_feedbacks) - batch_size = check_data_batch_size(x) - if reset: - self.target.initialize(batch_size) - loss = self.f_train(shared_kwargs)(x, y, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks) - all_train_losses.append(loss) - train_i += 1 - if train_i % num_report == 0: - t1 = time.time() - print(f'Train {train_i} steps, use {t1 - t0:.4f} s, train loss {round(float(loss), 5)}') - t0 = t1 - - # testing set - test_data_ = self._get_test_data(test_data, num_batch) - if test_data_ is not None: - for x, y in test_data_: - batch_size = check_data_batch_size(x) - if reset: - self.target.initialize(batch_size) - loss = self.f_loss(shared_kwargs)(x, y, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks) - all_test_losses.append(loss) - - self._train_losses = bm.asarray(all_train_losses) - self._test_losses = bm.asarray(all_test_losses) - - def f_grad(self, shared_kwargs=None) -> Callable: - """Get gradient function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_grad: - self._f_grad[shared_kwargs_str] = self._make_f_grad(shared_kwargs) - return self._f_grad[shared_kwargs_str] - - def f_loss(self, shared_kwargs=None) -> Callable: - """Get loss function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_loss: - self._f_loss[shared_kwargs_str] = self._make_f_loss(shared_kwargs) - if self.jit['loss']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - self._f_loss[shared_kwargs_str] = bm.jit(self._f_loss[shared_kwargs_str], - dyn_vars=dyn_vars) - return self._f_loss[shared_kwargs_str] - - def f_train(self, shared_kwargs=None) -> Callable: - """Get training function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_train: - self._f_train[shared_kwargs_str] = self._make_f_train(shared_kwargs) - return self._f_train[shared_kwargs_str] - - @property - def train_losses(self): - """Training loss.""" - return self._train_losses - - @property - def mapping_type(self): - """Mapping type for the output and the target.""" - return self._mapping_type - - def _make_f_loss(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - if not isinstance(shared_kwargs, dict): - raise ValueError(f'Only supports dict for "shared_kwargs". ' - f'But got {type(shared_kwargs)}: {shared_kwargs}') - - def loss_fun(inputs, targets, forced_states=None, forced_feedbacks=None): - inputs = self._format_xs(inputs) - targets = self._format_ys(targets) - num_batch, num_step = list(inputs.values())[0].shape[:2] - forced_states = self._check_forced_states(forced_states, num_batch, num_step) - forced_feedbacks = self._check_forced_feedbacks(forced_feedbacks, num_batch, num_step) - inputs = {k: bm.moveaxis(v, 0, 1) for k, v in inputs.items()} - outputs, _ = self._predict(xs=inputs, - shared_kwargs=shared_kwargs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks) - outputs = self._format_ys(outputs) - loss = 0. - for key, output in outputs.items(): - loss += self.loss_fun(output, targets[key]) - return loss - - return loss_fun - - def _make_f_grad(self, shared_kwargs: Dict = None): - _f_loss_internal = self._make_f_loss(shared_kwargs) - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - tran_vars = dyn_vars.subset(bm.TrainVar) - return bm.grad(_f_loss_internal, - dyn_vars=dyn_vars.unique(), - grad_vars=tran_vars.unique(), - return_value=True) - - def _make_f_train(self, shared_kwargs: Dict = None): - if shared_kwargs is None: - shared_kwargs = dict() - elif not isinstance(shared_kwargs, dict): - raise ValueError(f'Only supports dict for "shared_kwargs". ' - f'But got {type(shared_kwargs)}: {shared_kwargs}') - - def train_func(inputs, targets, forced_states=None, forced_feedbacks=None): - inputs = self._format_xs(inputs) - targets = self._format_ys(targets) - grads, loss = self.f_grad(shared_kwargs)(inputs, - targets, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks) - if self.max_grad_norm is not None: - check_float(self.max_grad_norm, 'max_grad_norm', min_bound=0.) - grads = bm.clip_by_norm(grads, self.max_grad_norm) - self.optimizer.update(grads) - return loss - - if self.jit['fit']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - dyn_vars.update(self.optimizer.vars()) - train_func = bm.jit(train_func, dyn_vars=dyn_vars.unique()) - return train_func - - def _format_ys(self, ys): - if isinstance(ys, (bm.ndarray, jnp.ndarray)): - if isinstance(self.target, Network): - if len(self.target.exit_nodes) != 1: - raise ValueError(f'The network {self.target} has ' - f'{len(self.target.exit_nodes)} ' - f'output nodes, while we only got ' - f'one output data.') - ys = {self.target.exit_nodes[0].name: ys} - else: - ys = {self.target.name: ys} - else: - exit_nodes = self.target.exit_nodes if isinstance(self.target, Network) else [self.target] - for node in exit_nodes: - if node.name not in ys: - raise ValueError(f'The network has output node {node.name}, ' - f'however, we did not get the corresponding ' - f'output targets.') - check_dict_data(ys, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) - return ys - - def _get_train_data(self, train_data, num_batch): - # training dataset - if callable(train_data): - train_data = self._get_data_by_method1(train_data, num_batch) - elif isinstance(train_data, (tuple, list)): - if len(train_data) != 2: - raise ValueError(f"Must be (X, Y) pair, but got a sequence with " - f"length {len(train_data)}") - train_data = self._get_data_by_method2(train_data, - num_batch=num_batch, - shuffle=self.shuffle_data) - else: - raise ValueError(f'Train data does not support {type(train_data)}. ') - return train_data - - def _get_test_data(self, test_data, num_batch): - # testing dataset - if test_data is None: - test_data = None - elif callable(test_data): - test_data = self._get_data_by_method1(test_data, num_batch) - elif isinstance(test_data, (tuple, list)): - assert len(test_data) == 2, f"Must be (X, Y) pair, but got a sequence with length {len(test_data)}" - test_data = self._get_data_by_method2(test_data, - num_batch=num_batch, - shuffle=False) - else: - raise ValueError(f'Test data does not support {type(test_data)}. ') - return test_data - - def _get_data_by_method1(self, dataset, num_batch): - for xs, ys in dataset(): - xs = self._format_xs(xs) - ys = self._format_ys(ys) - yield xs, ys - - def _shuffle(self, xs, ys): - key = jr.PRNGKey(seed=np.random.randint(0, 100000)) - if self._f_shuffle is None: - def shuffle(xs, ys, key): - xs = tree_map(lambda x: jr.permutation(key, x, axis=0), xs) - ys = tree_map(lambda y: jr.permutation(key, y, axis=0), ys) - return xs, ys - - self._f_shuffle = jit(shuffle) - return self._f_shuffle(xs, ys, key) - - def _get_data_by_method2(self, dataset, num_batch, shuffle=False, ): - assert isinstance(dataset, (tuple, list)) and len(dataset) == 2 - xs, ys = dataset - xs = self._format_xs(xs) - num_sample = self._get_xs_info(xs) - ys = self._format_ys(ys) - if shuffle: - xs, ys = self._shuffle(xs, ys) - - for data_idx in range(0, num_sample, num_batch): - if (data_idx + num_batch) > num_sample: - inputs = {k: v[data_idx:] for k, v in xs.items()} - targets = {k: v[data_idx:] for k, v in ys.items()} - else: - inputs = {k: v[data_idx: data_idx + num_batch] for k, v in xs.items()} - targets = {k: v[data_idx: data_idx + num_batch] for k, v in ys.items()} - yield inputs, targets - - def _get_xs_info(self, xs): - input_shapes = {} - if isinstance(self.target, Network): - for node in self.target.entry_nodes: - name = self.target.entry_nodes[0].name - input_shapes[name] = node._feedforward_shapes[name] - else: - name = self.target.name - input_shapes[name] = self.target._feedforward_shapes[name] - num_batch_sizes = [] - for key, val in xs.items(): - if key not in input_shapes: - raise ValueError(f'Cannot find {key} in the required inputs. Please check!') - shape = input_shapes[key] - if bm.ndim(val) != len(shape) + 1: - raise ValueError(f'Each tensor in "xs" must be a tensor of shape ' - f'(num_sample, num_time, {str(shape[1:])[1:-1]}). ' - f'But we got {val.shape}.') - num_batch_sizes.append(val.shape[0]) - if len(set(num_batch_sizes)) != 1: - raise ValueError(f'Number of batch size is different across tensors in ' - f'the provided "xs". We got {set(num_batch_sizes)}.') - return num_batch_sizes[0] - - -class BPFF(BPTT): - """ - The trainer implementing back propagation algorithm - for feedforward neural networks. - - """ - - def __init__( - self, target: Node, **kwargs - ): - super(BPFF, self).__init__(target=target, **kwargs) - - def predict( - self, - xs: Union[Tensor, Dict[str, Tensor]], - initial_states: Union[Tensor, Dict[str, Tensor]] = None, - initial_feedbacks: Dict[str, Tensor] = None, - reset: bool = True, - shared_kwargs: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - **kwargs - ): - """Predict a series of input data with the given target model. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - feedbacks and its output. - - Parameters - ---------- - xs: Tensor, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. - forced_states: None - The fixed node states. - forced_feedbacks: None - The fixed feedback states. - initial_states: JaxArray, ndarray, dict - The initial states. Each tensor in ``initial_states`` must be a - tensor with the shape of `(num_sample, num_feature)`. - initial_feedbacks: dict - The initial feedbacks for the node in the network model. - Each tensor in ``initial_feedbacks`` must be a - tensor with the shape of `(num_sample, num_feature)`. - reset: bool - Whether reset the model states. - shared_kwargs: optional, dict - The shared arguments across different layers. - - Returns - ------- - output: Tensor, dict - The model output. - """ - # format input data - xs = self._format_ys(xs) - num_batch = self._get_xs_info(xs) - # get forced data - forced_states = self._check_forced_states(forced_states, num_batch) - forced_feedbacks = self._check_forced_feedbacks(forced_feedbacks, num_batch) - # set initial states - self._set_initial_states(initial_states) - self._set_initial_feedbacks(initial_feedbacks) - # reset the model states - if reset: - self.target.initialize(num_batch) - # init monitor - for key in self.mon.var_names: - self.mon[key] = [] # reshape the monitor items - # prediction - outputs, hists = self._predict(xs=xs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks, - shared_kwargs=shared_kwargs) - # post-running for monitors - for key in hists.keys(): - self.mon[key] = hists[key] - if self.numpy_mon_after_run: - self.mon.ts = np.asarray(self.mon.ts) - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) - return outputs - - def _check_forced_states(self, forced_states, num_batch): - iter_forced_states = dict() - if forced_states is not None: - if isinstance(self.target, Network): - nodes = [node.name for node in self.target.lnodes] - if not isinstance(forced_states, dict): - raise ValueError('"forced_states" must be a dict of (str, Tensor)') - for key, tensor in forced_states.items(): - if not isinstance(key, str): - raise ValueError(f'"forced_states" must be a dict of (str, tensor). ' - f'But got a dict of ({type(key)}, {type(tensor)})') - if key not in nodes: - raise ValueError(f'Node "{key}" is not defined in the target model. ' - f'We only detect: \n{self.target.lnodes}') - if not isinstance(tensor, (bm.ndarray, jnp.ndarray)): - raise ValueError(f'"forced_states" must a dict of (str, tensor), ' - f'while we got ({type(key)}, {type(tensor)})') - if bm.ndim(tensor) != self.target[key].state.ndim: - raise ValueError(f'Must be a tensor with shape of (num_batch, ' - f'{str(self.target[key].state.shape)[1:-1]}), ' - f'but we got {tensor.shape}') - if tensor.shape[0] != num_batch: - raise ValueError(f'The number of the batch size ({tensor.shape[0]}) ' - f'of the forced state of {key} does not ' - f'match with the batch size in inputs {num_batch}.') - if self.target[key].output_shape[1:] != tensor.shape[2:]: - raise UnsupportedError(f'The forced state of {key} has the shape of ' - f'{tensor.shape}, which is not consistent with ' - f'its output shape {self.target[key].output_shape}. ' - f'Each tensor in forced state should have the shape ' - f'of (num_sample, num_time, num_feature) or ' - f'(num_sample, num_feature).') - iter_forced_states[key] = bm.moveaxis(tensor, 0, 1) # shape of (num_time, num_sample, num_feature) - else: - raise UnsupportedError('We do not support forced feedback state ' - 'for a single brainpy.nn.Node instance') - return iter_forced_states - - def _check_forced_feedbacks(self, forced_feedbacks, num_batch): - iter_forced_feedbacks = dict() - if forced_feedbacks is not None: - if isinstance(self.target, Network): - if not isinstance(forced_feedbacks, dict): - raise ValueError('"forced_feedbacks" must be a dict of (str, Tensor)') - feedback_node_names = [node.name for node in self.target.feedback_nodes] - for key, tensor in forced_feedbacks.items(): - if not isinstance(key, str): - raise ValueError(f'"forced_feedbacks" must be a dict of (str, tensor). ' - f'But got a dict of ({type(key)}, {type(tensor)})') - if key not in feedback_node_names: - raise ValueError(f'{self.target} has no feedback node {key}, ' - f'it only has {feedback_node_names}') - if not isinstance(tensor, (bm.ndarray, jnp.ndarray)): - raise ValueError('"forced_feedbacks" must a dict of (str, tensor), ' - 'while we got ({type(key)}, {type(tensor)})') - if bm.ndim(tensor) != self.target[key].fb_output.ndim: - raise ValueError(f'Must be a tensor with shape of (num_batch, ' - f'{str(self.target[key].fb_output.shape)[1:-1]}), ' - f'but we got {tensor.shape}') - if tensor.shape[0] != num_batch: - raise ValueError(f'The number of the batch size ({tensor.shape[0]}) ' - f'of the forced feedback of {key} does not ' - f'match with the batch size in inputs {num_batch}.') - if self.target[key].output_shape[1:] != tensor.shape[2:]: - raise UnsupportedError(f'The forced feedback of {key} has the shape of ' - f'{tensor.shape}, which is not consistent with ' - f'its output shape {self.target[key].output_shape}. ' - f'Each tensor in forced feedback should have the shape ' - f'of (num_sample, num_time, num_feature) or ' - f'(num_sample, num_feature).') - iter_forced_feedbacks[key] = bm.moveaxis(tensor, 0, 1) # shape of (num_time, num_sample, num_feature) - else: - raise UnsupportedError('We do not support forced states for ' - 'a single brainpy.nn.Node instance') - return iter_forced_feedbacks - - def _predict( - self, - xs: Dict[str, Tensor], - shared_kwargs: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - ): - """Predict the output according to the inputs. - - Parameters - ---------- - xs: dict - Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - forced_states: dict - The forced state values. - forced_feedbacks: dict - The forced feedback output values. - shared_kwargs: optional, dict - The shared keyword arguments. - - Returns - ------- - outputs, hists - A tuple of pair of (outputs, hists). - """ - _predict_func = self._get_predict_func(shared_kwargs) - # rune the model - forced_states = dict() if forced_states is None else forced_states - forced_feedbacks = dict() if forced_feedbacks is None else forced_feedbacks - return _predict_func(xs, forced_states, forced_feedbacks) - - def _make_f_loss(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - if not isinstance(shared_kwargs, dict): - raise ValueError(f'Only supports dict for "shared_kwargs". ' - f'But got {type(shared_kwargs)}: {shared_kwargs}') - - def loss_fun(inputs, targets, forced_states=None, forced_feedbacks=None): - inputs = self._format_xs(inputs) - targets = self._format_ys(targets) - num_batch, num_step = list(inputs.values())[0].shape[:2] - forced_states = self._check_forced_states(forced_states, num_batch) - forced_feedbacks = self._check_forced_feedbacks(forced_feedbacks, num_batch) - outputs, _ = self._predict(xs=inputs, - shared_kwargs=shared_kwargs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks) - outputs = self._format_ys(outputs) - loss = 0. - for key, output in outputs.items(): - loss += self.loss_fun(output, targets[key]) - return loss - - return loss_fun - - def _get_predict_func(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._predict_func: - self._predict_func[shared_kwargs_str] = self._make_predict_func(shared_kwargs) - return self._predict_func[shared_kwargs_str] - - def _make_predict_func(self, shared_kwargs: Dict): - if not isinstance(shared_kwargs, dict): - raise ValueError(f'"shared_kwargs" must be a dict, ' - f'but got {type(shared_kwargs)}') - - def run_func(xs, forced_states, forced_feedbacks): - monitors = self.mon.var_names - return self.target(xs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks, - monitors=monitors, - **shared_kwargs) - - if self.jit['predict']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - run_func = bm.jit(run_func, dyn_vars=dyn_vars.unique()) - return run_func - - def _get_xs_info(self, xs): - input_shapes = {} - if isinstance(self.target, Network): - for node in self.target.entry_nodes: - name = self.target.entry_nodes[0].name - input_shapes[name] = node._feedforward_shapes[name] - else: - name = self.target.name - input_shapes[name] = self.target._feedforward_shapes[name] - num_batch_sizes = [] - for key, val in xs.items(): - if key not in input_shapes: - raise ValueError(f'Cannot find {key} in the required inputs. Please check!') - shape = input_shapes[key] - if bm.ndim(val) != len(shape): - raise ValueError(f'Each tensor in "xs" must be a tensor of shape ' - f'(num_sample, {str(shape[1:])[1:-1]}). ' - f'But we got {val.shape}.') - num_batch_sizes.append(val.shape[0]) - if len(set(num_batch_sizes)) != 1: - raise ValueError(f'Number of batch size is different across tensors in ' - f'the provided "xs". We got {set(num_batch_sizes)}.') - return num_batch_sizes[0] diff --git a/brainpy/compat/nn/runners/offline_trainer.py b/brainpy/compat/nn/runners/offline_trainer.py deleted file mode 100644 index f333deffb..000000000 --- a/brainpy/compat/nn/runners/offline_trainer.py +++ /dev/null @@ -1,297 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Sequence, Union, Callable - -import tqdm.auto -from jax.experimental.host_callback import id_tap - -import numpy as np -from brainpy.base import Base -import brainpy.math as bm -from brainpy.errors import NoImplementationError -from brainpy.compat.nn.algorithms.offline import get, RidgeRegression, OfflineAlgorithm -from brainpy.compat.nn.base import Node, Network -from brainpy.compat.nn.utils import serialize_kwargs -from brainpy.types import Tensor -from .rnn_trainer import RNNTrainer - -__all__ = [ - 'OfflineTrainer', - 'RidgeTrainer', -] - - -class OfflineTrainer(RNNTrainer): - """Offline trainer for models with recurrent dynamics. - - Parameters - ---------- - target: Node - The target model to train. - fit_method: OfflineAlgorithm, Callable, dict, str - The fitting method applied to the target model. - - It can be a string, which specify the shortcut name of the training algorithm. - Like, ``fit_method='ridge'`` means using the Ridge regression method. - All supported fitting methods can be obtained through - :py:func:`brainpy.nn.runners.get_supported_offline_methods` - - It can be a dict, whose "name" item specifies the name of the training algorithm, - and the others parameters specify the initialization parameters of the algorithm. - For example, ``fit_method={'name': 'ridge', 'beta': 1e-4}``. - - It can be an instance of :py:class:`brainpy.nn.runners.OfflineAlgorithm`. - For example, ``fit_meth=bp.nn.runners.RidgeRegression(beta=1e-5)``. - - It can also be a callable function, which receives three arguments "targets", "x" and "y". - For example, ``fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0]``. - **kwargs - The other general parameters for RNN running initialization. - """ - - def __init__( - self, - target: Node, - fit_method: Union[OfflineAlgorithm, Callable, Dict, str] = None, - **kwargs - ): - self.true_numpy_mon_after_run = kwargs.get('numpy_mon_after_run', True) - kwargs['numpy_mon_after_run'] = False - super(OfflineTrainer, self).__init__(target=target, **kwargs) - - # training method - if fit_method is None: - fit_method = RidgeRegression(beta=1e-7) - elif isinstance(fit_method, str): - fit_method = get(fit_method)() - elif isinstance(fit_method, dict): - name = fit_method.pop('name') - fit_method = get(name)(**fit_method) - if not callable(fit_method): - raise ValueError(f'"train_method" must be an instance of callable function, ' - f'but we got {type(fit_method)}.') - self.fit_method = fit_method - # check the required interface in the trainable nodes - self._check_interface() - - # set the training method - for node in self.train_nodes: - node.offline_fit_by = fit_method - - # update dynamical variables - if isinstance(self.fit_method, Base): - self.dyn_vars.update(self.fit_method.vars().unique()) - - # add the monitor items which are needed for the training process - self._added_items = self._add_monitor_items() - - # training function - self._f_train = dict() - - def __repr__(self): - name = self.__class__.__name__ - prefix = ' ' * len(name) - return (f'{name}(target={self.target}, \n\t' - f'{prefix}jit={self.jit}, \n\t' - f'{prefix}fit_method={self.fit_method})') - - def fit( - self, - train_data: Sequence, - test_data=None, - reset: bool = False, - shared_kwargs: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - initial_states: Union[Tensor, Dict[str, Tensor]] = None, - initial_feedbacks: Dict[str, Tensor] = None, - ): - """ - Fit the target model according to the given training and testing data. - - Parameters - ---------- - train_data: sequence of data - It should be a pair of `(X, Y)` train set. - - ``X``: should be a tensor or a dict of tensors with the shape of - `(num_sample, num_time, num_feature)`, where `num_sample` is - the number of samples, `num_time` is the number of the time step, - and `num_feature` is the number of features. - - ``Y``: Target values. A tensor or a dict of tensors. - - If the shape of each tensor is `(num_sample, num_feature)`, - then we will only fit the model with the only last output. - - If the shape of each tensor is `(num_sample, num_time, num_feature)`, - then the fitting happens on the whole data series. - test_data: callable, sequence of data - Same as the ``train_data``. It can be a callable function, - or a tuple/list representing `(X, Y)` data. But this argument - is supported in offline trainers. - reset: bool - Whether reset the initial states of the target model. - shared_kwargs: dict - The shared keyword arguments for the target models. - forced_states: dict - The fixed node states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. - - .. versionadded:: 2.1.4 - - forced_feedbacks: dict - The fixed feedback states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. - - .. versionadded:: 2.1.4 - - initial_states: JaxArray, ndarray, dict - The initial states. Each tensor in ``initial_states`` must be a - tensor with the shape of `(num_sample, num_feature)`. - - .. versionadded:: 2.1.4 - - initial_feedbacks: dict - The initial feedbacks for the node in the network model. - Each tensor in ``initial_feedbacks`` must be a - tensor with the shape of `(num_sample, num_feature)`. - - .. versionadded:: 2.1.4 - - """ - # checking training and testing data - if not isinstance(train_data, (list, tuple)): - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a {type(train_data)}.") - if len(train_data) != 2: - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a sequence with length {len(train_data)}") - if test_data is not None: - raise ValueError(f'{self.__class__.__name__} does not support testing data.') - xs, ys = train_data - - # set initial states - self._set_initial_states(initial_states) - self._set_initial_feedbacks(initial_feedbacks) - - # prediction, get all needed data - _ = self.predict(xs=xs, - reset=reset, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks) - - # get all input data - xs, num_step, num_batch = self._check_xs(xs, move_axis=False) - if isinstance(self.target, Network): - for node in self.target.entry_nodes: - if node in self.train_nodes: - inputs = node.data_pass_func({node.name: xs[node.name]}) - self.mon[f'{node.name}.inputs'] = inputs - self._added_items.add(f'{node.name}.inputs') - elif isinstance(self.target, Node): - if self.target in self.train_nodes: - inputs = self.target.data_pass_func({self.target.name: xs[self.target.name]}) - self.mon[f'{self.target.name}.inputs'] = inputs - self._added_items.add(f'{self.target.name}.inputs') - - # format target data - ys = self._check_ys(ys, num_batch=num_batch, num_step=num_step, move_axis=False) - - # init progress bar - if self.progress_bar: - self._pbar = tqdm.auto.tqdm(total=len(self.train_nodes)) - self._pbar.set_description(f"Train {len(self.train_nodes)} nodes: ", refresh=True) - - # training - monitor_data = dict() - for node in self.train_nodes: - monitor_data[f'{node.name}.inputs'] = self.mon.get(f'{node.name}.inputs', None) - monitor_data[f'{node.name}.feedbacks'] = self.mon.get(f'{node.name}.feedbacks', None) - self.f_train(shared_kwargs)(monitor_data, ys) - - # close the progress bar - if self.progress_bar: - self._pbar.close() - - # final things - for key in self._added_items: - self.mon.pop(key) - if self.true_numpy_mon_after_run: - for key in self.mon.keys(): - if key != 'var_names': - self.mon[key] = np.asarray(self.mon[key]) - - def f_train(self, shared_kwargs: Dict = None) -> Callable: - """Get training function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_train: - self._f_train[shared_kwargs_str] = self._make_fit_func(shared_kwargs) - return self._f_train[shared_kwargs_str] - - def _make_fit_func(self, shared_kwargs): - shared_kwargs = dict() if shared_kwargs is None else shared_kwargs - - def train_func(monitor_data: Dict[str, Tensor], target_data: Dict[str, Tensor]): - for node in self.train_nodes: - ff = monitor_data[f'{node.name}.inputs'] - fb = monitor_data.get(f'{node.name}.feedbacks', None) - targets = target_data[node.name] - if fb is None: - node.offline_fit(targets, ff, **shared_kwargs) - else: - node.offline_fit(targets, ff, fb, **shared_kwargs) - if self.progress_bar: - id_tap(lambda *args: self._pbar.update(), ()) - - if self.jit['fit']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - train_func = bm.jit(train_func, dyn_vars=dyn_vars.unique()) - return train_func - - def _add_monitor_items(self): - added_items = set() - if isinstance(self.target, Network): - for node in self.train_nodes: - if node not in self.target.entry_nodes: - if f'{node.name}.inputs' not in self.mon.var_names: - self.mon.var_names += (f'{node.name}.inputs', ) - self.mon[f'{node.name}.inputs'] = [] - added_items.add(f'{node.name}.inputs') - if node in self.target.fb_senders: - if f'{node.name}.feedbacks' not in self.mon.var_names: - self.mon.var_names += (f'{node.name}.feedbacks',) - self.mon[f'{node.name}.feedbacks'] = [] - added_items.add(f'{node.name}.feedbacks') - else: - # brainpy.nn.Node instance does not need to monitor its inputs - pass - return added_items - - def _check_interface(self): - for node in self.train_nodes: - if hasattr(node.offline_fit, 'not_implemented'): - if node.offline_fit.not_implemented: - raise NoImplementationError( - f'The node \n\n{node}\n\n' - f'is set to be trainable with {self.__class__.__name__} method. ' - f'However, it does not implement the required training ' - f'interface "offline_fit()" function. ' - ) - - -class RidgeTrainer(OfflineTrainer): - """ - Trainer of ridge regression, also known as regression with Tikhonov regularization. - - Parameters - ---------- - target: Node - The target model. - beta: float - The regularization coefficient. - **kwarg - Other common parameters for :py:class:`brainpy.nn.RNNTrainer``. - """ - - def __init__(self, target, beta=1e-7, **kwargs): - super(RidgeTrainer, self).__init__(target=target, - fit_method=dict(name='ridge', beta=beta), - **kwargs) diff --git a/brainpy/compat/nn/runners/online_trainer.py b/brainpy/compat/nn/runners/online_trainer.py deleted file mode 100644 index 279052fbb..000000000 --- a/brainpy/compat/nn/runners/online_trainer.py +++ /dev/null @@ -1,299 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Sequence, Union, Callable - -import tqdm.auto -from jax.experimental.host_callback import id_tap -from jax.tree_util import tree_map - -import numpy as np -from brainpy.base import Base -import brainpy.math as bm -from brainpy.errors import NoImplementationError -from brainpy.compat.nn.algorithms.online import get, OnlineAlgorithm, RLS -from brainpy.compat.nn.base import Node -from brainpy.compat.nn.utils import (serialize_kwargs, - check_data_batch_size, - check_rnn_data_time_step) -from brainpy.types import Tensor -from .rnn_trainer import RNNTrainer - -__all__ = [ - 'OnlineTrainer', - 'ForceTrainer', -] - - -class OnlineTrainer(RNNTrainer): - """Online trainer for models with recurrent dynamics. - - Parameters - ---------- - target: Node - The target model to train. - fit_method: OnlineAlgorithm, Callable, dict, str - The fitting method applied to the target model. - - It can be a string, which specify the shortcut name of the training algorithm. - Like, ``fit_method='ridge'`` means using the RLS method. - All supported fitting methods can be obtained through - :py:func:`brainpy.nn.runners.get_supported_online_methods` - - It can be a dict, whose "name" item specifies the name of the training algorithm, - and the others parameters specify the initialization parameters of the algorithm. - For example, ``fit_method={'name': 'ridge', 'beta': 1e-4}``. - - It can be an instance of :py:class:`brainpy.nn.runners.OnlineAlgorithm`. - For example, ``fit_meth=bp.nn.runners.RLS(alpha=1e-5)``. - - It can also be a callable function. - **kwargs - The other general parameters for RNN running initialization. - """ - - def __init__( - self, - target: Node, - fit_method: Union[OnlineAlgorithm, Callable, Dict, str] = None, - **kwargs - ): - super(OnlineTrainer, self).__init__(target=target, **kwargs) - - # training method - if fit_method is None: - fit_method = RLS(alpha=1e-7) - elif isinstance(fit_method, str): - fit_method = get(fit_method)() - elif isinstance(fit_method, dict): - name = fit_method.pop('name') - fit_method = get(name)(**fit_method) - self.fit_method = fit_method - if not callable(fit_method): - raise ValueError(f'"train_method" must be an instance of callable function, ' - f'but we got {type(fit_method)}.') - - # check the required interface in the trainable nodes - self._check_interface() - - # set the training method - for node in self.train_nodes: - node.online_fit_by = fit_method - - # initialize the fitting method - for node in self.train_nodes: - node.online_init() - - # update dynamical variables - if isinstance(self.fit_method, Base): - self.dyn_vars.update(self.fit_method.vars().unique()) - - # training function - self._f_train = dict() - - def __repr__(self): - name = self.__class__.__name__ - prefix = ' ' * len(name) - return (f'{name}(target={self.target}, \n\t' - f'{prefix}jit={self.jit}, \n\t' - f'{prefix}fit_method={self.fit_method})') - - def fit( - self, - train_data: Sequence, - test_data=None, - reset: bool = False, - shared_kwargs: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - initial_states: Dict[str, Tensor] = None, - initial_feedbacks: Dict[str, Tensor] = None, - ): - # checking training and testing data - if not isinstance(train_data, (list, tuple)): - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a {type(train_data)}.") - if len(train_data) != 2: - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a sequence with length {len(train_data)}") - if test_data is not None: - raise ValueError(f'{self.__class__.__name__} does not support testing data.') - xs, ys = train_data - - # format input data - xs, num_step, num_batch = self._check_xs(xs, move_axis=True) - - # format target data - ys = self._check_ys(ys, num_batch=num_batch, num_step=num_step, move_axis=True) - - # set initial states - self._set_initial_states(initial_states) - self._set_initial_feedbacks(initial_feedbacks) - - # get forced data - forced_states = self._check_forced_states(forced_states, num_batch, num_step) - forced_feedbacks = self._check_forced_feedbacks(forced_feedbacks, num_batch, num_step) - - # reset the model states - if reset: - self.target.initialize(num_batch) - - # init monitor - for key in self.mon.var_names: - self.mon[key] = [] # reshape the monitor items - - # init progress bar - if self.progress_bar: - if num_step is None: - num_step = check_rnn_data_time_step(xs) - self._pbar = tqdm.auto.tqdm(total=num_step) - self._pbar.set_description(f"Train {num_step} steps: ", refresh=True) - - # prediction - hists = self._fit(xs=xs, - ys=ys, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks, - shared_kwargs=shared_kwargs) - - # close the progress bar - if self.progress_bar: - self._pbar.close() - - # post-running for monitors - for key in hists.keys(): - self.mon[key] = hists[key] - if self.numpy_mon_after_run: - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) - - def _fit( - self, - xs: Dict[str, Tensor], - ys: Dict[str, Tensor], - shared_kwargs: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - ): - """Predict the output according to the inputs. - - Parameters - ---------- - xs: dict - Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - ys: dict - Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - forced_states: dict - The forced state values. - forced_feedbacks: dict - The forced feedback output values. - shared_kwargs: optional, dict - The shared keyword arguments. - - Returns - ------- - outputs, hists - A tuple of pair of (outputs, hists). - """ - _predict_func = self._get_fit_func(shared_kwargs) - # rune the model - forced_states = dict() if forced_states is None else forced_states - forced_feedbacks = dict() if forced_feedbacks is None else forced_feedbacks - hists = _predict_func([xs, ys, forced_states, forced_feedbacks]) - f1 = lambda x: bm.moveaxis(x, 0, 1) - f2 = lambda x: isinstance(x, bm.JaxArray) - hists = tree_map(f1, hists, is_leaf=f2) - return hists - - def _get_fit_func(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_train: - self._f_train[shared_kwargs_str] = self._make_fit_func(shared_kwargs) - return self._f_train[shared_kwargs_str] - - def _make_fit_func(self, shared_kwargs: Dict): - if not isinstance(shared_kwargs, dict): - raise ValueError(f'"shared_kwargs" must be a dict, ' - f'but got {type(shared_kwargs)}') - add_monitors = self._add_monitor_items() - - def _step_func(all_inputs): - xs, ys, forced_states, forced_feedbacks = all_inputs - monitors = tuple(self.mon.var_names) - - _, outs = self.target(xs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks, - monitors=monitors + add_monitors, - **shared_kwargs) - for node in self.train_nodes: - ff = outs[f'{node.name}.inputs'] - fb = outs[f'{node.name}.feedbacks'] - target = ys[node.name] - node.online_fit(target, ff, fb=fb) - for key in add_monitors: - outs.pop(key) - - if self.progress_bar and (self._pbar is not None): - id_tap(lambda *args: self._pbar.update(), ()) - return outs - - if self.jit['fit']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True) - return lambda all_inputs: f(all_inputs)[1] - - else: - def run_func(all_inputs): - xs, ys, forced_states, forced_feedbacks = all_inputs - monitors = {key: [] for key in self.mon.var_names} - num_step = check_data_batch_size(xs) - for i in range(num_step): - one_xs = {key: tensor[i] for key, tensor in xs.items()} - one_ys = {key: tensor[i] for key, tensor in ys.items()} - one_forced_states = {key: tensor[i] for key, tensor in forced_states.items()} - one_forced_feedbacks = {key: tensor[i] for key, tensor in forced_feedbacks.items()} - mon = _step_func([one_xs, one_ys, one_forced_states, one_forced_feedbacks]) - for key, value in mon.items(): - monitors[key].append(value) - for key, value in monitors.items(): - monitors[key] = bm.asarray(value) - return monitors - return run_func - - def _add_monitor_items(self): - added_items = set() - for node in self.train_nodes: - if f'{node.name}.inputs' not in self.mon.var_names: - added_items.add(f'{node.name}.inputs') - if f'{node.name}.feedbacks' not in self.mon.var_names: - added_items.add(f'{node.name}.feedbacks') - return tuple(added_items) - - def _check_interface(self): - for node in self.train_nodes: - if hasattr(node.online_fit, 'not_implemented'): - if node.online_fit.not_implemented: - raise NoImplementationError( - f'The node \n\n{node}\n\n' - f'is set to be trainable with {self.__class__.__name__} method. ' - f'However, it does not implement the required training ' - f'interface "online_fit()" function. ' - ) - if hasattr(node.online_init, 'not_implemented'): - if node.online_init.not_implemented: - raise NoImplementationError( - f'The node \n\n{node}\n\n' - f'is set to be trainable with {self.__class__.__name__} method. ' - f'However, it does not implement the required training ' - f'interface "online_init()" function. ' - ) - - -class ForceTrainer(OnlineTrainer): - """Force learning.""" - - def __init__(self, target, alpha=1., **kwargs): - fit_method = RLS(alpha=alpha) - super(ForceTrainer, self).__init__(target=target, - fit_method=fit_method, - **kwargs) diff --git a/brainpy/compat/nn/runners/rnn_runner.py b/brainpy/compat/nn/runners/rnn_runner.py deleted file mode 100644 index 846f272cf..000000000 --- a/brainpy/compat/nn/runners/rnn_runner.py +++ /dev/null @@ -1,456 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Union - -import jax.numpy as jnp -import numpy as np -import tqdm.auto -from jax.experimental.host_callback import id_tap -from jax.tree_util import tree_map - -from brainpy import math as bm -from brainpy.errors import UnsupportedError -from brainpy.compat.nn.base import Node, Network -from brainpy.compat.nn.utils import (check_rnn_data_time_step, - check_data_batch_size, - serialize_kwargs) -from brainpy.running.runner import Runner -from brainpy.tools.checking import check_dict_data -from brainpy.types import Tensor - -__all__ = [ - 'RNNRunner', -] - - -class RNNRunner(Runner): - """Structural Runner for Recurrent Neural Networks. - - Parameters - ---------- - target: Node - The target model for simulation. - monitors: None, list of str, tuple of str, Monitor - Variables to monitor. - jit: bool - Whether we use JIT compilation to accelerate the model simulation. - progress_bar: bool - Whether we use progress bar to report the simulation progress. - dyn_vars: Optional, dict - The dynamically changed variables. - numpy_mon_after_run : bool - Change the monitored iterm into NumPy arrays. - """ - - target: Node - - def __init__(self, target: Node, jit=True, **kwargs): - super(RNNRunner, self).__init__(target=target, **kwargs) - assert isinstance(self.target, Node), '"target" must be an instance of brainpy.nn.Node.' - - # jit settings - if isinstance(jit, bool): - self.jit = {'fit': jit, 'predict': jit} - elif isinstance(jit, dict): - jit = {key: val for key, val in jit.items()} - self.jit = {'fit': jit.pop('fit', True), - 'predict': jit.pop('predict', True)} - if len(jit): - raise ValueError(f'Unknown jit setting for {jit.keys()}') - else: - raise ValueError(f'Unknown "jit" setting: {jit}') - - # function for prediction - self._predict_func = dict() - - def __repr__(self): - name = self.__class__.__name__ - prefix = ' ' * len(name) - return (f'{name}(target={self.target}, \n\t' - f'{prefix}jit={self.jit})') - - def predict( - self, - xs: Union[Tensor, Dict[str, Tensor]], - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - initial_states: Union[Tensor, Dict[str, Tensor]] = None, - initial_feedbacks: Dict[str, Tensor] = None, - reset: bool = False, - shared_kwargs: Dict = None, - progress_bar: bool = True, - ): - """Predict a series of input data with the given target model. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - feedbacks and its output. - - Parameters - ---------- - xs: Tensor, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. - forced_states: dict - The fixed node states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. - forced_feedbacks: dict - The fixed feedback states. Similar with ``xs``, each tensor in - ``forced_states`` must be a tensor with the shape of - `(num_sample, num_time, num_feature)`. - initial_states: JaxArray, ndarray, dict - The initial states. Each tensor in ``initial_states`` must be a - tensor with the shape of `(num_sample, num_feature)`. - initial_feedbacks: dict - The initial feedbacks for the node in the network model. - Each tensor in ``initial_feedbacks`` must be a - tensor with the shape of `(num_sample, num_feature)`. - reset: bool - Whether reset the model states. - shared_kwargs: optional, dict - The shared arguments across different layers. - progress_bar: bool - Whether report the progress of the simulation using progress bar. - - Returns - ------- - output: Tensor, dict - The model output. - """ - # format input data - xs, num_step, num_batch = self._check_xs(xs) - # set initial states - self._set_initial_states(initial_states) - self._set_initial_feedbacks(initial_feedbacks) - # get forced data - forced_states = self._check_forced_states(forced_states, num_batch, num_step) - forced_feedbacks = self._check_forced_feedbacks(forced_feedbacks, num_batch, num_step) - # reset the model states - if reset: - self.target.initialize(num_batch) - # init monitor - for key in self.mon.var_names: - self.mon[key] = [] # reshape the monitor items - # init progress bar - if self.progress_bar and progress_bar: - if num_step is None: - num_step = check_rnn_data_time_step(xs) - self._pbar = tqdm.auto.tqdm(total=num_step) - self._pbar.set_description(f"Predict {num_step} steps: ", refresh=True) - # prediction - outputs, hists = self._predict(xs=xs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks, - shared_kwargs=shared_kwargs) - # close the progress bar - if self.progress_bar and progress_bar: - self._pbar.close() - # post-running for monitors - for key in hists.keys(): - self.mon[key] = hists[key] - if self.numpy_mon_after_run: - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) - return outputs - - def _predict( - self, - xs: Dict[str, Tensor], - shared_kwargs: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - ): - """Predict the output according to the inputs. - - Parameters - ---------- - xs: dict - Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - forced_states: dict - The forced state values. - forced_feedbacks: dict - The forced feedback output values. - shared_kwargs: optional, dict - The shared keyword arguments. - - Returns - ------- - outputs, hists - A tuple of pair of (outputs, hists). - """ - _predict_func = self._get_predict_func(shared_kwargs) - # rune the model - forced_states = dict() if forced_states is None else forced_states - forced_feedbacks = dict() if forced_feedbacks is None else forced_feedbacks - outputs, hists = _predict_func([xs, forced_states, forced_feedbacks]) - f1 = lambda x: bm.moveaxis(x, 0, 1) - f2 = lambda x: isinstance(x, bm.JaxArray) - outputs = tree_map(f1, outputs, is_leaf=f2) - hists = tree_map(f1, hists, is_leaf=f2) - return outputs, hists - - def _get_predict_func(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._predict_func: - self._predict_func[shared_kwargs_str] = self._make_predict_func(shared_kwargs) - return self._predict_func[shared_kwargs_str] - - def _make_predict_func(self, shared_kwargs: Dict): - if not isinstance(shared_kwargs, dict): - raise ValueError(f'"shared_kwargs" must be a dict, ' - f'but got {type(shared_kwargs)}') - - def _step_func(a_input): - xs, forced_states, forced_feedbacks = a_input - monitors = self.mon.var_names - outs = self.target(xs, - forced_states=forced_states, - forced_feedbacks=forced_feedbacks, - monitors=monitors, - **shared_kwargs) - if self.progress_bar and (self._pbar is not None): - id_tap(lambda *args: self._pbar.update(), ()) - return outs - - if self.jit['predict']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True) - return lambda all_inputs: f(all_inputs)[1] - - else: - def run_func(all_inputs): - xs, forced_states, forced_feedbacks = all_inputs - if isinstance(self.target, Network) and len(self.target.exit_nodes) > 1: - outputs = {node.name: [] for node in self.target.exit_nodes} - output_type = 'network' - else: - outputs = [] - output_type = 'node' - monitors = {key: [] for key in self.mon.var_names} - num_step = check_data_batch_size(xs) - for i in range(num_step): - one_xs = {key: tensor[i] for key, tensor in xs.items()} - one_forced_states = {key: tensor[i] for key, tensor in forced_states.items()} - one_forced_feedbacks = {key: tensor[i] for key, tensor in forced_feedbacks.items()} - output, mon = _step_func([one_xs, one_forced_states, one_forced_feedbacks]) - for key, value in mon.items(): - monitors[key].append(value) - if output_type == 'node': - outputs.append(output) - else: - for key, out in output.items(): - outputs[key].append(out) - if output_type == 'node': - outputs = bm.asarray(outputs) - else: - for key, out in outputs.items(): - outputs[key] = bm.asarray(out) - for key, value in monitors.items(): - monitors[key] = bm.asarray(value) - return outputs, monitors - return run_func - - def _init_target(self, xs): # deprecated - # we need to initialize the node or the network - x = dict() - for key, tensor in xs.items(): - if not isinstance(key, str): - raise ValueError('"xs" must a dict of (str, tensor), while we got ' - f'({type(key)}, {type(tensor)})') - if not isinstance(tensor, (bm.ndarray, jnp.ndarray)): - raise ValueError('"xs" must a dict of (str, tensor), while we got ' - f'({type(key)}, {type(tensor)})') - x[key] = tensor[0] - self.target.initialize(x) - - def _set_initial_states(self, initial_states): - # initial states - if initial_states is not None: - if isinstance(self.target, Network): - if not isinstance(initial_states, dict): - raise ValueError(f'"initial_states" must be a dict when the ' - f'target model is a brainpy.nn.Network instance. ' - f'But we got {type(initial_states)}') - nodes = [node.name for node in self.target.lnodes] - for key, tensor in initial_states.items(): - if not isinstance(key, str): - raise ValueError(f'"initial_states" must be a dict of (str, tensor). ' - f'But got a dict of ({type(key)}, {type(tensor)})') - if key not in nodes: - raise ValueError(f'Node "{key}" is not defined in the target model. ' - f'We only detect: \n{self.target.lnodes}') - if self.target[key].state is None: - raise ValueError(f'The target model {key} has no state. ' - f'We cannot set its initial state.') - self.target[key].state.value = tensor - elif isinstance(self.target, Node): - if self.target.state is None: - raise ValueError(f'The target model {self.target.name} has no state. ' - f'We cannot set its initial state.') - if not isinstance(initial_states, (jnp.ndarray, bm.ndarray)): - raise ValueError('"initial_states" must be a tensor, ' - f'but we got a {type(initial_states)}') - self.target.state.value = initial_states - - def _set_initial_feedbacks(self, initial_feedbacks): - # initial feedback states - if initial_feedbacks is not None: - if isinstance(self.target, Network): - if not isinstance(initial_feedbacks, dict): - raise ValueError('"initial_feedbacks" must be a dict when the ' - 'target model is a brainpy.nn.Network instance. ' - f'But we got {type(initial_feedbacks)}') - nodes = [node.name for node in self.target.lnodes] - for key, tensor in initial_feedbacks.items(): - if not isinstance(key, str): - raise ValueError(f'"initial_feedbacks" must be a dict of (str, tensor). ' - f'But got a dict of ({type(key)}, {type(tensor)})') - if key not in nodes: - raise ValueError(f'Node "{key}" is not defined in the target model. ' - f'We only detect: \n{self.target.lnodes}') - if self.target[key].fb_output is None: - raise ValueError(f'The target model {key} has no feedback connections. ' - f'We cannot set its initial feedback output.') - self.target[key].fb_output.value = tensor - elif isinstance(self.target, Node): - raise UnsupportedError('Do not support feedback in a single instance of brainpy.nn.Node.') - - def _check_forced_states(self, forced_states, num_batch, num_step=None): - iter_forced_states = dict() - if forced_states is not None: - if isinstance(self.target, Network): - nodes = [node.name for node in self.target.lnodes] - if not isinstance(forced_states, dict): - raise ValueError('"forced_states" must be a dict of (str, Tensor)') - for key, tensor in forced_states.items(): - if not isinstance(key, str): - raise ValueError(f'"forced_states" must be a dict of (str, tensor). ' - f'But got a dict of ({type(key)}, {type(tensor)})') - if key not in nodes: - raise ValueError(f'Node "{key}" is not defined in the target model. ' - f'We only detect: \n{self.target.lnodes}') - if not isinstance(tensor, (bm.ndarray, jnp.ndarray)): - raise ValueError(f'"forced_states" must a dict of (str, tensor), ' - f'while we got ({type(key)}, {type(tensor)})') - if bm.ndim(tensor) != self.target[key].state.ndim + 1: - raise ValueError(f'Must be a tensor with shape of (num_batch, num_time, ' - f'{str(self.target[key].state.shape)[1:-1]}), ' - f'but we got {tensor.shape}') - if tensor.shape[0] != num_batch: - raise ValueError(f'The number of the batch size ({tensor.shape[0]}) ' - f'of the forced state of {key} does not ' - f'match with the batch size in inputs {num_batch}.') - if (num_step is not None) and (tensor.shape[1] != num_step): - raise ValueError(f'The number of the time step ({tensor.shape[1]}) ' - f'of the forced state of {key} does not ' - f'match with the time step in inputs {num_step}.') - if self.target[key].output_shape[1:] != tensor.shape[2:]: - raise UnsupportedError(f'The forced state of {key} has the shape of ' - f'{tensor.shape}, which is not consistent with ' - f'its output shape {self.target[key].output_shape}. ' - f'Each tensor in forced state should have the shape ' - f'of (num_sample, num_time, num_feature) or ' - f'(num_sample, num_feature).') - iter_forced_states[key] = bm.moveaxis(tensor, 0, 1) # shape of (num_time, num_sample, num_feature) - else: - raise UnsupportedError('We do not support forced feedback state ' - 'for a single brainpy.nn.Node instance') - return iter_forced_states - - def _check_forced_feedbacks(self, forced_feedbacks, num_batch, num_step): - iter_forced_feedbacks = dict() - if forced_feedbacks is not None: - if isinstance(self.target, Network): - if not isinstance(forced_feedbacks, dict): - raise ValueError('"forced_feedbacks" must be a dict of (str, Tensor)') - feedback_node_names = [node.name for node in self.target.feedback_nodes] - for key, tensor in forced_feedbacks.items(): - if not isinstance(key, str): - raise ValueError(f'"forced_feedbacks" must be a dict of (str, tensor). ' - f'But got a dict of ({type(key)}, {type(tensor)})') - if key not in feedback_node_names: - raise ValueError(f'{self.target} has no feedback node {key}, ' - f'it only has {feedback_node_names}') - if not isinstance(tensor, (bm.ndarray, jnp.ndarray)): - raise ValueError('"forced_feedbacks" must a dict of (str, tensor), ' - 'while we got ({type(key)}, {type(tensor)})') - if bm.ndim(tensor) != self.target[key].fb_output.ndim + 1: - raise ValueError(f'Must be a tensor with shape of (num_batch, num_time, ' - f'{str(self.target[key].fb_output.shape)[1:-1]}), ' - f'but we got {tensor.shape}') - if tensor.shape[0] != num_batch: - raise ValueError(f'The number of the batch size ({tensor.shape[0]}) ' - f'of the forced feedback of {key} does not ' - f'match with the batch size in inputs {num_batch}.') - if tensor.shape[1] != num_step: - raise ValueError(f'The number of the time step ({tensor.shape[1]}) ' - f'of the forced feedback of {key} does not ' - f'match with the time step in inputs {num_step}.') - if self.target[key].output_shape[1:] != tensor.shape[2:]: - raise UnsupportedError(f'The forced feedback of {key} has the shape of ' - f'{tensor.shape}, which is not consistent with ' - f'its output shape {self.target[key].output_shape}. ' - f'Each tensor in forced feedback should have the shape ' - f'of (num_sample, num_time, num_feature) or ' - f'(num_sample, num_feature).') - iter_forced_feedbacks[key] = bm.moveaxis(tensor, 0, 1) # shape of (num_time, num_sample, num_feature) - else: - raise UnsupportedError('We do not support forced states for ' - 'a single brainpy.nn.Node instance') - return iter_forced_feedbacks - - def _format_xs(self, xs): - if isinstance(xs, (bm.ndarray, jnp.ndarray)): - if isinstance(self.target, Network): - if len(self.target.entry_nodes) != 1: - raise ValueError(f'The network {self.target} has {len(self.target.entry_nodes)} ' - f'input nodes, while we only got one input data.') - xs = {self.target.entry_nodes[0].name: xs} - else: - xs = {self.target.name: xs} - if not isinstance(xs, dict): - raise UnsupportedError(f'Unknown data type {type(xs)}, we only support ' - f'tensor or dict with ') - if len(xs) == 0: - raise ValueError('We got no input data.') - check_dict_data(xs, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) - return xs - - def _check_xs(self, xs: Union[Dict, Tensor], move_axis=True): - input_shapes = {} - if isinstance(self.target, Network): - for node in self.target.entry_nodes: - name = self.target.entry_nodes[0].name - input_shapes[name] = node._feedforward_shapes[name] - else: - name = self.target.name - input_shapes[name] = self.target._feedforward_shapes[name] - - xs = self._format_xs(xs) - num_times, num_batch_sizes = [], [] - for key, val in xs.items(): - if key not in input_shapes: - raise ValueError(f'Cannot find {key} in the required inputs. Please check!') - shape = input_shapes[key] - if bm.ndim(val) != len(shape) + 1: - raise ValueError(f'Each tensor in "xs" must be a tensor of shape ' - f'(num_sample, num_time, {str(shape[1:])[1:-1]}). ' - f'But we got {val.shape}.') - num_times.append(val.shape[1]) - num_batch_sizes.append(val.shape[0]) - if len(set(num_times)) != 1: - raise ValueError(f'Number of time step is different across tensors in ' - f'the provided "xs". We got {set(num_times)}.') - if len(set(num_batch_sizes)) != 1: - raise ValueError(f'Number of batch size is different across tensors in ' - f'the provided "xs". We got {set(num_batch_sizes)}.') - num_step = num_times[0] - num_batch = num_batch_sizes[0] - if move_axis: - # change shape to (num_time, num_sample, num_feature) - xs = {k: bm.moveaxis(v, 0, 1) for k, v in xs.items()} - return xs, num_step, num_batch - diff --git a/brainpy/compat/nn/runners/rnn_trainer.py b/brainpy/compat/nn/runners/rnn_trainer.py deleted file mode 100644 index 904e5397b..000000000 --- a/brainpy/compat/nn/runners/rnn_trainer.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Sequence, Any, Union - -import jax.numpy as jnp - -import brainpy.math as bm -from brainpy.errors import UnsupportedError -from brainpy.compat.nn.base import Node, Network -from brainpy.types import Tensor -from brainpy.tools.checking import check_dict_data -from .rnn_runner import RNNRunner - -__all__ = [ - 'RNNTrainer', -] - - -class RNNTrainer(RNNRunner): - """Structural Trainer for Models with Recurrent Dynamics.""" - - train_nodes: Sequence[Node] # need to be initialized by subclass - train_pars: Dict[str, Any] # need to be initialized by subclass - - def __init__(self, target, **kwargs): - super(RNNTrainer, self).__init__(target=target, **kwargs) - - # get all trainable nodes - self.train_nodes = self._get_trainable_nodes() - - def fit( - self, - train_data: Any, - test_data: Any, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - initial_states: Union[Tensor, Dict[str, Tensor]] = None, - initial_feedbacks: Dict[str, Tensor] = None, - reset: bool = False, - shared_kwargs: Dict = None - ): # need to be implemented by subclass - raise NotImplementedError('Must implement the fit function. ') - - def _get_trainable_nodes(self): - # check trainable nodes - if isinstance(self.target, Network): - train_nodes = [node for node in self.target.lnodes if node.trainable] - elif isinstance(self.target, Node): - train_nodes = [self.target] if self.target.trainable else [] - else: - raise UnsupportedError('Must be a brainpy.nn.Node instance, ' - f'while we got {type(self.target)}: {self.target}') - return train_nodes - - def _check_ys(self, ys, num_batch, num_step, move_axis=False): - # output_shapes = {} - # for node in self.train_nodes: - # name = self.target.entry_nodes[0].name - # output_shapes[name] = node.output_shape - - if isinstance(ys, (bm.ndarray, jnp.ndarray)): - if len(self.train_nodes) == 1: - ys = {self.train_nodes[0].name: ys} - else: - raise ValueError(f'The network {self.target} has {len(self.train_nodes)} ' - f'training nodes, while we only got one target data.') - check_dict_data(ys, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) - for key, val in ys.items(): - if val.ndim < 3: - raise ValueError("Targets must be a tensor with shape of " - "(num_sample, num_time, feature_dim, ...), " - f"but we got {val.shape}") - if val.shape[0] != num_batch: - raise ValueError(f'Batch size of the target {key} does not match ' - f'with the input data {val.shape[0]} != {num_batch}') - if val.shape[1] != num_step: - raise ValueError(f'The time step of the target {key} does not match ' - f'with the input data {val.shape[1]} != {num_step})') - if move_axis: - # change shape to (num_time, num_sample, num_feature) - ys = {k: bm.moveaxis(v, 0, 1) for k, v in ys.items()} - return ys - - diff --git a/brainpy/compat/nn/tests/test_graph_flow.py b/brainpy/compat/nn/tests/test_graph_flow.py deleted file mode 100644 index 3b9898c01..000000000 --- a/brainpy/compat/nn/tests/test_graph_flow.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- - - -import unittest -from brainpy.compat.nn.graph_flow import find_entries_and_exits -from brainpy.compat.nn.graph_flow import detect_cycle - - -class TestGraphFlow(unittest.TestCase): - def test_ff1(self): - nodes = (1, 2, 3, 4, 5) - ff_edges = ((1, 2), (2, 3), (3, 4), (4, 5)) - inputs, outputs = find_entries_and_exits(nodes, ff_edges) - print() - print(inputs, outputs) - - ff_edges = ((1, 2), (2, 3), (3, 4)) - inputs, outputs = find_entries_and_exits(nodes, ff_edges) - print(inputs, outputs) - - def test_fb1(self): - nodes = (1, 2, 3, 4, 5) - ff_edges = ((1, 2), (2, 3), (3, 4), (4, 5)) - fb_edges = ((5, 2), (4, 2)) - inputs, outputs = find_entries_and_exits(nodes, ff_edges, fb_edges) - print() - print(inputs, outputs) - - def test_fb2(self): - nodes = (1, 2, 3, 4, 5) - ff_edges = ((1, 2), (2, 3), (3, 4)) - fb_edges = ((3, 2), (4, 5)) - # with self.assertRaises(ValueError): - find_entries_and_exits(nodes, ff_edges, fb_edges) - - def test_fb3(self): - nodes = (1, 2, 3, 4, 5) - ff_edges = ((1, 2), (2, 3), (3, 4)) - fb_edges = ((5, 2), ) - inputs, outputs = find_entries_and_exits(nodes, ff_edges, fb_edges) - print() - print(inputs, outputs) - - def test_fb4(self): - # 1 -> 2 -> 3 -> 4 -> 5 -> 6 - # ^ |^ | - # ∟------------- ∟---- - nodes = (1, 2, 3, 4, 5, 6) - ff_edges = ((1, 2), (2, 3), (3, 4), (4, 5), (5, 6)) - fb_edges = ((5, 2), (6, 5)) - inputs, outputs = find_entries_and_exits(nodes, ff_edges, fb_edges) - print() - print(inputs, outputs) - - def test_fb5(self): - # 1 -> 2 -> 3 -> 4 -> 5 -> 6 - # ^ |^ | - # ∟------------------- ∟---- - nodes = (1, 2, 3, 4, 5, 6) - ff_edges = ((1, 2), (2, 3), (3, 4), (4, 5), (5, 6)) - fb_edges = ((5, 1), (6, 5)) - inputs, outputs = find_entries_and_exits(nodes, ff_edges, fb_edges) - print() - print(inputs, outputs) - - -class TestDetectCycle(unittest.TestCase): - def test1(self): - nodes = [0, 1, 2, 3] - edges = [(0, 1), (0, 2), (1, 2), (2, 0), (2, 3), (3, 3)] - print(detect_cycle(nodes, edges)) diff --git a/brainpy/compat/nn/tests/test_operations.py b/brainpy/compat/nn/tests/test_operations.py deleted file mode 100644 index 9f40c9b4a..000000000 --- a/brainpy/compat/nn/tests/test_operations.py +++ /dev/null @@ -1,186 +0,0 @@ -# -*- coding: utf-8 -*- - -from unittest import TestCase - -import brainpy as bp - - -class TestFF(TestCase): - def test_one2one(self): - i = bp.nn.Input(1) - r = bp.nn.Reservoir(10) - model = i >> r - print(model.lnodes) - self.assertTrue(model.ff_senders[r][0] == i) - self.assertTrue(model.ff_receivers[i][0] == r) - - def test_many2one1(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - r = bp.nn.Reservoir(10) - model = [i1, i2, i3] >> r - self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) - - def test_many2one2(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - r = bp.nn.Reservoir(10) - model = (i1, i2, i3) >> r - self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) - - def test_many2one3(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - r = bp.nn.Reservoir(10) - model = {i1, i2, i3} >> r - self.assertTrue(model.ff_receivers[i1][0] == r) - self.assertTrue(model.ff_receivers[i2][0] == r) - self.assertTrue(model.ff_receivers[i3][0] == r) - - def test_one2many1(self): - i = bp.nn.Input(1) - o1 = bp.nn.Dense(3) - o2 = bp.nn.Dense(4) - o3 = bp.nn.Dense(5) - with self.assertRaises(TypeError): - model = i >> [o1, o2, o3] - - def test_one2many2(self): - i = bp.nn.Input(1) - o1 = bp.nn.Dense(3) - o2 = bp.nn.Dense(4) - o3 = bp.nn.Dense(5) - with self.assertRaises(TypeError): - model = i >> (o1, o2, o3) - - def test_one2many3(self): - i = bp.nn.Input(1) - o1 = bp.nn.Dense(3) - o2 = bp.nn.Dense(4) - o3 = bp.nn.Dense(5) - model = i >> {o1, o2, o3} - # model.plot_node_graph() - self.assertTrue(model.ff_senders[o1][0] == i) - self.assertTrue(model.ff_senders[o2][0] == i) - self.assertTrue(model.ff_senders[o3][0] == i) - - def test_many2many1(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - - o1 = bp.nn.Dense(3) - o2 = bp.nn.Dense(4) - o3 = bp.nn.Dense(5) - - model = bp.nn.ff_connect([i1, i2, i3], {o1, o2, o3}) - - self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) - - self.assertTrue(isinstance(model.ff_senders[o1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_senders[o2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_senders[o3][0], bp.nn.Concat)) - - def test_many2many2(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - - o1 = bp.nn.Dense(3) - o2 = bp.nn.Dense(4) - o3 = bp.nn.Dense(5) - - model = bp.nn.ff_connect((i1, i2, i3), {o1, o2, o3}) - - self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) - - self.assertTrue(isinstance(model.ff_senders[o1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_senders[o2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_senders[o3][0], bp.nn.Concat)) - - def test_many2many3(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - - o1 = bp.nn.Dense(3) - o2 = bp.nn.Dense(4) - o3 = bp.nn.Dense(5) - - model = bp.nn.ff_connect({i1, i2, i3}, {o1, o2, o3}) - model.plot_node_graph() - - self.assertTrue(len(model.ff_receivers[i1]) == 3) - self.assertTrue(len(model.ff_receivers[i2]) == 3) - self.assertTrue(len(model.ff_receivers[i3]) == 3) - - self.assertTrue(len(model.ff_senders[o1]) == 3) - self.assertTrue(len(model.ff_senders[o2]) == 3) - self.assertTrue(len(model.ff_senders[o3]) == 3) - - def test_many2one4(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - - ii = bp.nn.Input(3) - - model = {i1, i2, i3} >> ii - model.plot_node_graph() - - self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) - - def test_many2one5(self): - i1 = bp.nn.Input(1) - i2 = bp.nn.Input(2) - i3 = bp.nn.Input(3) - ii = bp.nn.Input(3) - - model = (i1 >> ii) & (i2 >> ii) - # model.plot_node_graph() - self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) - self.assertTrue(len(model.ff_senders[ii]) == 1) - self.assertTrue(isinstance(model.ff_senders[ii][0], bp.nn.Concat)) - - model = model & (i3 >> ii) - # model.plot_node_graph() - self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) - self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) - self.assertTrue(len(model.ff_senders[ii]) == 1) - self.assertTrue(isinstance(model.ff_senders[ii][0], bp.nn.Concat)) - - -class TestFB(TestCase): - def test_many2one(self): - class FBNode(bp.nn.Node): - def init_fb_conn(self): - pass - - i1 = FBNode() - i2 = FBNode() - i3 = FBNode() - i4 = FBNode() - - model = (i1 >> i2 >> i3) & (i1 << i2) & (i1 << i3) - model.plot_node_graph() - - model = model & (i3 >> i4) & (i1 << i4) - model.plot_node_graph() - - - diff --git a/brainpy/compat/nn/tests/test_vis.py b/brainpy/compat/nn/tests/test_vis.py deleted file mode 100644 index 09d0d503a..000000000 --- a/brainpy/compat/nn/tests/test_vis.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- - - -import unittest -import brainpy as bp - - -class TestVisualize(unittest.TestCase): - def test(self): - model = ( - bp.nn.Input(3) - >> - bp.nn.Reservoir(100, name='I') - >> - bp.nn.Reservoir(100) - >> - bp.nn.Reservoir(100, name='l1') - >> - bp.nn.LinearReadout(3, weight_initializer=bp.init.Normal()) - >> - bp.nn.Reservoir(100) - >> - bp.nn.Reservoir(100) - >> - bp.nn.LinearReadout(3, weight_initializer=bp.init.Normal(), name='output') - ) - model &= (model['l1'] << model['output']) - model &= (model['I'] << model['output']) - - # model = - # print(model.trainable) - print() - - model.plot_node_graph(fig_size=(10, 5), node_size=100) diff --git a/brainpy/compat/nn/utils.py b/brainpy/compat/nn/utils.py deleted file mode 100644 index 039bddeb2..000000000 --- a/brainpy/compat/nn/utils.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings -from typing import Union, Sequence, Dict, Any, Callable, Optional - -import jax.numpy as jnp - -import brainpy.math as bm -from brainpy.initialize import Initializer, init_param as true_init_param -from brainpy.tools.checking import check_dict_data -from brainpy.types import Tensor, Shape - -__all__ = [ - 'tensor_sum', - 'init_param', - 'check_data_batch_size', - 'check_rnn_data_time_step', - 'serialize_kwargs', -] - - -def tensor_sum(values: Union[Sequence[Tensor], Dict[Any, Tensor], Tensor]): - if isinstance(values, (bm.ndarray, jnp.ndarray)): - return values - if isinstance(values, dict): - values = list(values.values()) - elif isinstance(values, (tuple, list)): - values = list(values) - else: - raise ValueError('Unknown types of tensors.') - res = values[0] - for v in values[1:]: - res = res + v - return res - - -def init_param(param: Union[Callable, Initializer, bm.ndarray, jnp.ndarray], - size: Shape): - """Initialize parameters. - - .. deprecated:: 2.1.2 - Please use "brainpy.init.init_param" instead. - - Parameters - ---------- - param: callable, Initializer, bm.ndarray, jnp.ndarray - The initialization of the parameter. - - If it is None, the created parameter will be None. - - If it is a callable function :math:`f`, the ``f(size)`` will be returned. - - If it is an instance of :py:class:`brainpy.init.Initializer``, the ``f(size)`` will be returned. - - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``. - size: int, sequence of int - The shape of the parameter. - """ - warnings.warn('Please use "brainpy.init.init_param" instead. ' - '"brainpy.nn.init_param" is deprecated since version 2.1.2. ', - DeprecationWarning) - return true_init_param(param, size) - - -def check_data_batch_size(data: Dict, num_batch=None): - if len(data) == 1: - batch_size = list(data.values())[0].shape[0] - else: - batches = [] - for key, val in data.items(): - batches.append(val.shape[0]) - if len(set(batches)) != 1: - raise ValueError('Batch sizes are not consistent among the given data. ' - f'Got {set(batches)}. We expect only one batch size.') - batch_size = batches[0] - if (num_batch is not None) and batch_size != num_batch: - raise ValueError(f'Batch size is not consistent with the expected {batch_size} != {num_batch}') - return batch_size - - -def check_rnn_data_time_step(data: Dict, num_step=None): - if len(data) == 1: - time_step = list(data.values())[0].shape[1] - else: - steps = [] - for key, val in data.items(): - steps.append(val.shape[1]) - if len(set(steps)) != 1: - raise ValueError('Time steps are not consistent among the given data. ' - f'Got {set(steps)}. We expect only one time step.') - time_step = steps[0] - if (num_step is not None) and time_step != num_step: - raise ValueError(f'Time step is not consistent with the expected {time_step} != {num_step}') - return time_step - - -def serialize_kwargs(shared_kwargs: Optional[Dict]): - """Serialize kwargs.""" - shared_kwargs = dict() if shared_kwargs is None else shared_kwargs - check_dict_data(shared_kwargs, - key_type=str, - val_type=(bool, float, int, complex), - name='shared_kwargs') - shared_kwargs = {key: shared_kwargs[key] for key in sorted(shared_kwargs.keys())} - return str(shared_kwargs) diff --git a/brainpy/compat/runners.py b/brainpy/compat/runners.py deleted file mode 100644 index 6246cc649..000000000 --- a/brainpy/compat/runners.py +++ /dev/null @@ -1,65 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -from brainpy.dyn import runners as dyn_runner -from brainpy.integrators import runner as intg_runner - -__all__ = [ - 'IntegratorRunner', - 'DSRunner', - 'StructRunner', - 'ReportRunner' -] - - -class IntegratorRunner(intg_runner.IntegratorRunner): - """Integrator runner class. - - .. deprecated:: 2.1.0 - Please use "brainpy.integrators.IntegratorRunner" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.integrators.IntegratorRunner" instead. ' - '"brainpy.IntegratorRunner" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - super(IntegratorRunner, self).__init__(*args, **kwargs) - - -class DSRunner(dyn_runner.DSRunner): - """Dynamical system runner class. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.DSRunner" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.DSRunner" instead. ' - '"brainpy.DSRunner" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - super(DSRunner, self).__init__(*args, **kwargs) - - -class StructRunner(dyn_runner.DSRunner): - """Dynamical system runner class. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.StructRunner" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.StructRunner" instead. ' - '"brainpy.StructRunner" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - super(StructRunner, self).__init__(*args, **kwargs) - - -class ReportRunner(dyn_runner.ReportRunner): - """Dynamical system runner class. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.ReportRunner" instead. - """ - def __init__(self, *args, **kwargs): - warnings.warn('Please use "brainpy.dyn.ReportRunner" instead. ' - '"brainpy.ReportRunner" is deprecated since ' - 'version 2.1.0', DeprecationWarning) - super(ReportRunner, self).__init__(*args, **kwargs) diff --git a/brainpy/compat/tests/test_integrator_rnn.py b/brainpy/compat/tests/test_integrator_rnn.py deleted file mode 100644 index 83dcca39b..000000000 --- a/brainpy/compat/tests/test_integrator_rnn.py +++ /dev/null @@ -1,82 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial - -import matplotlib.pyplot as plt - -import brainpy as bp -import brainpy.math as bm - - -block = False -dt = 0.04 -num_step = int(1.0 / dt) -num_batch = 128 - - -@partial(bm.jit, - dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}), - static_argnames=['batch_size']) -def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10): - # Create the white noise input - sample = bm.random.normal(size=(batch_size, 1, 1)) - bias = mean * 2.0 * (sample - 0.5) - samples = bm.random.normal(size=(batch_size, num_step, 1)) - noise_t = scale / dt ** 0.5 * samples - inputs = bias + noise_t - targets = bm.cumsum(inputs, axis=1) - return inputs, targets - - -def train_data(): - for _ in range(10): - yield build_inputs_and_targets(batch_size=num_batch) - - -def test_rnn_training(): - model = ( - bp.nn.Input(1) - >> - bp.nn.VanillaRNN(100, state_trainable=True) - >> - bp.nn.Dense(1) - ) - model.initialize(num_batch=num_batch) - - - # define loss function - def loss(predictions, targets, l2_reg=2e-4): - mse = bp.losses.mean_squared_error(predictions, targets) - l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2 - return mse + l2 - - - # define optimizer - lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) - opt = bp.optim.Adam(lr=lr, eps=1e-1) - - # create a trainer - trainer = bp.nn.BPTT(model, - loss=loss, - optimizer=opt, - max_grad_norm=5.0) - trainer.fit(train_data, - num_batch=num_batch, - num_train=5, - num_report=10) - - plt.plot(trainer.train_losses.numpy()) - plt.show(block=block) - - model.initialize(1) - x, y = build_inputs_and_targets(batch_size=1) - predicts = trainer.predict(x) - - plt.figure(figsize=(8, 2)) - plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth') - plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction') - plt.legend() - plt.show(block=block) - plt.close() - - diff --git a/brainpy/compat/tests/test_ngrc_double_scroll.py b/brainpy/compat/tests/test_ngrc_double_scroll.py deleted file mode 100644 index 4f4495561..000000000 --- a/brainpy/compat/tests/test_ngrc_double_scroll.py +++ /dev/null @@ -1,151 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Implementation of the paper: - -- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir - computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 - -The main task is forecasting the double-scroll system. -""" - - -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp -import brainpy.math as bm - - -block = False - - -def get_subset(data, start, end): - res = {'x': data['x'][start: end], - 'y': data['y'][start: end], - 'z': data['z'][start: end]} - res = bm.hstack([res['x'], res['y'], res['z']]) - return res.reshape((1,) + res.shape) - - -def plot_weights(Wout, coefs, bias=None): - Wout = np.asarray(Wout) - if bias is not None: - bias = np.asarray(bias) - Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0) - coefs.insert(0, 'bias') - x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2] - - fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(131) - ax.grid(axis="y") - ax.set_xlabel("$[W_{out}]_x$") - ax.set_ylabel("Features") - ax.set_yticks(np.arange(len(coefs))) - ax.set_yticklabels(coefs) - ax.barh(np.arange(x_Wout.size), x_Wout) - - ax1 = fig.add_subplot(132) - ax1.grid(axis="y") - ax1.set_yticks(np.arange(len(coefs))) - ax1.set_xlabel("$[W_{out}]_y$") - ax1.barh(np.arange(y_Wout.size), y_Wout) - - ax2 = fig.add_subplot(133) - ax2.grid(axis="y") - ax2.set_yticks(np.arange(len(coefs))) - ax2.set_xlabel("$[W_{out}]_z$") - ax2.barh(np.arange(z_Wout.size), z_Wout) - - plt.show(block=block) - - -def plot_double_scroll(ground_truth, predictions): - fig = plt.figure(figsize=(15, 10)) - ax = fig.add_subplot(121, projection='3d') - ax.set_title("Generated attractor") - ax.set_xlabel("$x$") - ax.set_ylabel("$y$") - ax.set_zlabel("$z$") - ax.grid(False) - ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2]) - - ax2 = fig.add_subplot(122, projection='3d') - ax2.set_title("Real attractor") - ax2.grid(False) - ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2]) - plt.show(block=block) - - -dt = 0.02 -t_warmup = 10. # ms -t_train = 100. # ms -t_test = 800. # ms -num_warmup = int(t_warmup / dt) # warm up NVAR -num_train = int(t_train / dt) -num_test = int(t_test / dt) - - -def test_ngrc_double_scroll(): - bm.enable_x64() - - # Datasets # - # -------- # - data_series = bp.datasets.double_scroll_series(t_warmup + t_train + t_test, dt=dt) - - X_warmup = get_subset(data_series, 0, num_warmup - 1) - Y_warmup = get_subset(data_series, 1, num_warmup) - X_train = get_subset(data_series, num_warmup - 1, num_warmup + num_train - 1) - # Target: Lorenz[t] - Lorenz[t - 1] - dX_train = get_subset(data_series, num_warmup, num_warmup + num_train) - X_train - X_test = get_subset(data_series, - num_warmup + num_train - 1, - num_warmup + num_train + num_test - 1) - Y_test = get_subset(data_series, - num_warmup + num_train, - num_warmup + num_train + num_test) - - # Model # - # ----- # - - i = bp.nn.Input(3) - r = bp.nn.NVAR(delay=2, order=3) - di = bp.nn.LinearReadout(3, trainable=True, name='readout') - o = bp.nn.Summation() - # - # Cannot express the model as - # - # [i >> r >> di, i] >> o - # (i >> r >> di, i) >> o - # because it will concatenate the outputs of "i" and "di", - # then feed into the node "o". This is not the connection - # we want. - model = {i >> r >> di, i} >> o - # model = (i >> r >> di >> o) & (i >> o) - model.plot_node_graph() - model.initialize(num_batch=1) - - # Training # - # -------- # - - # warm-up - trainer = bp.nn.RidgeTrainer(model, beta=1e-5, jit=True) - - # training - outputs = trainer.predict(X_warmup) - print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) - trainer.fit([X_train, {'readout': dX_train}]) - plot_weights(di.Wff, r.get_feature_names_for_plot(), di.bias) - - # prediction - model = bm.jit(model) - outputs = [model(X_test[:, 0])] - for i in range(1, X_test.shape[1]): - outputs.append(model(outputs[i - 1])) - outputs = bm.asarray(outputs).squeeze() - print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) - plot_double_scroll(Y_test.numpy().squeeze(), outputs.numpy()) - plt.close() - - bm.disable_x64() - bp.base.clear_name_cache(True) - diff --git a/brainpy/compat/tests/test_ngrc_lorenz.py b/brainpy/compat/tests/test_ngrc_lorenz.py deleted file mode 100644 index 4475b9825..000000000 --- a/brainpy/compat/tests/test_ngrc_lorenz.py +++ /dev/null @@ -1,151 +0,0 @@ -# # -*- coding: utf-8 -*- -# -# """Implementation of the paper: -# -# - Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir -# computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 -# -# The main task is forecasting the Lorenz63 strange attractor. -# """ -# -# import matplotlib.pyplot as plt -# import numpy as np -# -# import brainpy as bp -# import brainpy.math as bm -# -# block = False -# -# -# def get_subset(data, start, end): -# res = {'x': data['x'][start: end], -# 'y': data['y'][start: end], -# 'z': data['z'][start: end]} -# res = bm.hstack([res['x'], res['y'], res['z']]) -# return res.reshape((1,) + res.shape) -# -# -# def plot_weights(Wout, coefs, bias=None): -# Wout = np.asarray(Wout) -# if bias is not None: -# bias = np.asarray(bias) -# Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0) -# coefs.insert(0, 'bias') -# x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2] -# -# fig = plt.figure(figsize=(10, 10)) -# ax = fig.add_subplot(131) -# ax.grid(axis="y") -# ax.set_xlabel("$[W_{out}]_x$") -# ax.set_ylabel("Features") -# ax.set_yticks(np.arange(len(coefs))) -# ax.set_yticklabels(coefs) -# ax.barh(np.arange(x_Wout.size), x_Wout) -# -# ax1 = fig.add_subplot(132) -# ax1.grid(axis="y") -# ax1.set_yticks(np.arange(len(coefs))) -# ax1.set_xlabel("$[W_{out}]_y$") -# ax1.barh(np.arange(y_Wout.size), y_Wout) -# -# ax2 = fig.add_subplot(133) -# ax2.grid(axis="y") -# ax2.set_yticks(np.arange(len(coefs))) -# ax2.set_xlabel("$[W_{out}]_z$") -# ax2.barh(np.arange(z_Wout.size), z_Wout) -# -# plt.show(block=block) -# -# -# def plot_lorenz(ground_truth, predictions): -# fig = plt.figure(figsize=(15, 10)) -# ax = fig.add_subplot(121, projection='3d') -# ax.set_title("Generated attractor") -# ax.set_xlabel("$x$") -# ax.set_ylabel("$y$") -# ax.set_zlabel("$z$") -# ax.grid(False) -# ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2]) -# -# ax2 = fig.add_subplot(122, projection='3d') -# ax2.set_title("Real attractor") -# ax2.grid(False) -# ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2]) -# plt.show(block=block) -# -# -# dt = 0.01 -# t_warmup = 5. # ms -# t_train = 10. # ms -# t_test = 120. # ms -# num_warmup = int(t_warmup / dt) # warm up NVAR -# num_train = int(t_train / dt) -# num_test = int(t_test / dt) -# -# -# def test_ngrc_lorenz(): -# bm.enable_x64() -# -# # Datasets # -# # -------- # -# lorenz_series = bp.datasets.lorenz_series(t_warmup + t_train + t_test, -# dt=dt, -# inits={'x': 17.67715816276679, -# 'y': 12.931379185960404, -# 'z': 43.91404334248268}) -# -# X_warmup = get_subset(lorenz_series, 0, num_warmup - 1) -# Y_warmup = get_subset(lorenz_series, 1, num_warmup) -# X_train = get_subset(lorenz_series, num_warmup - 1, num_warmup + num_train - 1) -# # Target: Lorenz[t] - Lorenz[t - 1] -# dX_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) - X_train -# X_test = get_subset(lorenz_series, -# num_warmup + num_train - 1, -# num_warmup + num_train + num_test - 1) -# Y_test = get_subset(lorenz_series, -# num_warmup + num_train, -# num_warmup + num_train + num_test) -# -# # Model # -# # ----- # -# -# i = bp.nn.Input(3) -# r = bp.nn.NVAR(delay=2, order=2, constant=True) -# di = bp.nn.LinearReadout(3, bias_initializer=None, trainable=True, name='readout') -# o = bp.nn.Summation() -# # -# # Cannot express the model as -# # -# # [i >> r >> di, i] >> o -# # because it will concatenate the outputs of "i" and "di", -# # then feed into the node "o". This is not the connection -# # we want. -# model = (i >> r >> di >> o) & (i >> o) -# # model.plot_node_graph() -# model.initialize(num_batch=1) -# -# print(r.get_feature_names()) -# -# # Training # -# # -------- # -# -# # warm-up -# trainer = bp.nn.RidgeTrainer(model, beta=2.5e-6) -# -# # training -# outputs = trainer.predict(X_warmup) -# print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) -# trainer.fit([X_train, {'readout': dX_train}]) -# plot_weights(di.Wff, r.get_feature_names_for_plot(), di.bias) -# -# # prediction -# model = bm.jit(model) -# outputs = [model(X_test[:, 0])] -# for i in range(1, X_test.shape[1]): -# outputs.append(model(outputs[i - 1])) -# outputs = bm.asarray(outputs) -# print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) -# plot_lorenz(Y_test.numpy().squeeze(), outputs.numpy().squeeze()) -# plt.close() -# bm.disable_x64() -# bp.base.clear_name_cache(True) diff --git a/brainpy/compat/tests/test_ngrc_lorenz_inference.py b/brainpy/compat/tests/test_ngrc_lorenz_inference.py deleted file mode 100644 index d27f2415f..000000000 --- a/brainpy/compat/tests/test_ngrc_lorenz_inference.py +++ /dev/null @@ -1,176 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Implementation of the paper: - -- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir - computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 - -The main task is forecasting the Lorenz63 strange attractor. -""" - -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp -import brainpy.math as bm - -block = False - - -def get_subset(data, start, end): - res = {'x': data['x'][start: end], - 'y': data['y'][start: end], - 'z': data['z'][start: end]} - X = bm.hstack([res['x'], res['y']]) - X = X.reshape((1,) + X.shape) - Y = res['z'] - Y = Y.reshape((1,) + Y.shape) - return X, Y - - -def plot_lorenz(x, y, true_z, predict_z, linewidth=.8): - fig1 = plt.figure() - fig1.set_figheight(8) - fig1.set_figwidth(12) - - t_all = t_warmup + t_train + t_test - ts = np.arange(0, t_all, dt) - - h = 240 - w = 2 - - # top left of grid is 0,0 - axs1 = plt.subplot2grid(shape=(h, w), loc=(0, 0), colspan=2, rowspan=30) - axs2 = plt.subplot2grid(shape=(h, w), loc=(36, 0), colspan=2, rowspan=30) - axs3 = plt.subplot2grid(shape=(h, w), loc=(72, 0), colspan=2, rowspan=30) - axs4 = plt.subplot2grid(shape=(h, w), loc=(132, 0), colspan=2, rowspan=30) - axs5 = plt.subplot2grid(shape=(h, w), loc=(168, 0), colspan=2, rowspan=30) - axs6 = plt.subplot2grid(shape=(h, w), loc=(204, 0), colspan=2, rowspan=30) - - # training phase x - axs1.set_title('training phase') - axs1.plot(ts[num_warmup:num_warmup + num_train], - x[num_warmup:num_warmup + num_train], - color='b', linewidth=linewidth) - axs1.set_ylabel('x') - axs1.axes.xaxis.set_ticklabels([]) - axs1.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) - axs1.axes.set_ybound(-21., 21.) - axs1.text(-.14, .9, 'a)', ha='left', va='bottom', transform=axs1.transAxes) - - # training phase y - axs2.plot(ts[num_warmup:num_warmup + num_train], - y[num_warmup:num_warmup + num_train], - color='b', linewidth=linewidth) - axs2.set_ylabel('y') - axs2.axes.xaxis.set_ticklabels([]) - axs2.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) - axs2.axes.set_ybound(-26., 26.) - axs2.text(-.14, .9, 'b)', ha='left', va='bottom', transform=axs2.transAxes) - - # training phase z - axs3.plot(ts[num_warmup:num_warmup + num_train], - true_z[num_warmup:num_warmup + num_train], - color='b', linewidth=linewidth) - axs3.plot(ts[num_warmup:num_warmup + num_train], - predict_z[num_warmup:num_warmup + num_train], - color='r', linewidth=linewidth) - axs3.set_ylabel('z') - axs3.set_xlabel('time') - axs3.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) - axs3.axes.set_ybound(3., 48.) - axs3.text(-.14, .9, 'c)', ha='left', va='bottom', transform=axs3.transAxes) - - # testing phase x - axs4.set_title('testing phase') - axs4.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - x[num_warmup + num_train:num_warmup + num_train + num_test], - color='b', linewidth=linewidth) - axs4.set_ylabel('x') - axs4.axes.xaxis.set_ticklabels([]) - axs4.axes.set_ybound(-21., 21.) - axs4.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) - axs4.text(-.14, .9, 'd)', ha='left', va='bottom', transform=axs4.transAxes) - - # testing phase y - axs5.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - y[num_warmup + num_train:num_warmup + num_train + num_test], - color='b', linewidth=linewidth) - axs5.set_ylabel('y') - axs5.axes.xaxis.set_ticklabels([]) - axs5.axes.set_ybound(-26., 26.) - axs5.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) - axs5.text(-.14, .9, 'e)', ha='left', va='bottom', transform=axs5.transAxes) - - # testing phose z - axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - true_z[num_warmup + num_train:num_warmup + num_train + num_test], - color='b', linewidth=linewidth) - axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - predict_z[num_warmup + num_train:num_warmup + num_train + num_test], - color='r', linewidth=linewidth) - axs6.set_ylabel('z') - axs6.set_xlabel('time') - axs6.axes.set_ybound(3., 48.) - axs6.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) - axs6.text(-.14, .9, 'f)', ha='left', va='bottom', transform=axs6.transAxes) - - plt.show(block=block) - - -dt = 0.02 -t_warmup = 10. # ms -t_train = 20. # ms -t_test = 50. # ms -num_warmup = int(t_warmup / dt) # warm up NVAR -num_train = int(t_train / dt) -num_test = int(t_test / dt) - - -def test_ngrc_lorenz_inference(): - bm.enable_x64() - # Datasets # - # -------- # - lorenz_series = bp.datasets.lorenz_series(t_warmup + t_train + t_test, - dt=dt, - inits={'x': 17.67715816276679, - 'y': 12.931379185960404, - 'z': 43.91404334248268}) - - X_warmup, Y_warmup = get_subset(lorenz_series, 0, num_warmup) - X_train, Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) - X_test, Y_test = get_subset(lorenz_series, 0, num_warmup + num_train + num_test) - - # Model # - # ----- # - - i = bp.nn.Input(2) - r = bp.nn.NVAR(delay=4, order=2, stride=5) - o = bp.nn.LinearReadout(1, trainable=True) - model = i >> r >> o - model.plot_node_graph() - model.initialize(num_batch=1) - - # Training # - # -------- # - - trainer = bp.nn.RidgeTrainer(model, beta=0.05) - - # warm-up - outputs = trainer.predict(X_warmup) - print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) - - # training - trainer.fit([X_train, Y_train]) - - # prediction - outputs = trainer.predict(X_test, reset=True) - print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) - - plot_lorenz(x=lorenz_series['x'].flatten().numpy(), - y=lorenz_series['y'].flatten().numpy(), - true_z=lorenz_series['z'].flatten().numpy(), - predict_z=outputs.to_numpy().flatten()) - plt.close() - bm.disable_x64() - bp.base.clear_name_cache(True) diff --git a/brainpy/datasets/__init__.py b/brainpy/datasets/__init__.py index f707afee5..0e6ab84bd 100644 --- a/brainpy/datasets/__init__.py +++ b/brainpy/datasets/__init__.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- from .chaotic_systems import * +from .vision import * diff --git a/brainpy/datasets/_internally_replaced_utils.py b/brainpy/datasets/_internally_replaced_utils.py new file mode 100644 index 000000000..32da97bc7 --- /dev/null +++ b/brainpy/datasets/_internally_replaced_utils.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- + +import ctypes +import errno +import hashlib +import importlib.machinery +import os +import re +import shutil +import sys +import tempfile +import warnings +import zipfile +from urllib.parse import urlparse +from urllib.request import urlopen, Request +from brainpy import math as bm + +from tqdm import tqdm + +ENV_TORCH_HOME = 'BRAINPY_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv(ENV_TORCH_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'brainpy'))) + return torch_home + + +# matches bfd8deac from resnet18-bfd8deac.pth +HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') +_HOME = os.path.join(_get_torch_home(), "datasets", "vision") +_USE_SHARDED_DATASETS = False + + +def _download_file_from_remote_location(fpath: str, url: str) -> None: + pass + + +def _is_remote_location_available() -> bool: + return False + + +def get_dir(): + r""" + Get the Torch Hub cache directory used for storing downloaded models & weights. + + If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where + environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. + ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux + filesystem layout, with a default value ``~/.cache`` if the environment + variable is not set. + """ + # Issue warning to move data if old env is set + return os.path.join(_get_torch_home(), 'hub') + + +def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): + r"""Loads the Torch serialized object at the given URL. + + If downloaded file is a zip file, it will be automatically + decompressed. + + If the object is already present in `model_dir`, it's deserialized and + returned. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + url (string): URL of the object to download + model_dir (string, optional): directory in which to save the object + map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) + progress (bool, optional): whether or not to display a progress bar to stderr. + Default: True + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + Default: False + file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set. + + Example: + >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') + + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + # Directory already exists, ignore. + pass + else: + # Unexpected OSError, re-raise. + raise + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + + if _is_legacy_zip_format(cached_file): + return _legacy_zip_load(cached_file, model_dir, map_location) + return bm.load(cached_file, map_location=map_location) + + +def _legacy_zip_load(filename, model_dir, map_location): + warnings.warn('Falling back to the old format < 1.6. This support will be ' + 'deprecated in favor of default zipfile format introduced in 1.6. ' + 'Please redo torch.save() to save it in the new zipfile format.') + # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. + # We deliberately don't handle tarfile here since our legacy serialization format was in tar. + # E.g. resnet18-5c106cde.pth which is widely used. + with zipfile.ZipFile(filename) as f: + members = f.infolist() + if len(members) != 1: + raise RuntimeError('Only one file(not dir) is allowed in the zipfile') + f.extractall(model_dir) + extraced_name = members[0].filename + extracted_file = os.path.join(model_dir, extraced_name) + return bm.load(extracted_file, map_location=map_location) + + +# Hub used to support automatically extracts from zipfile manually compressed by users. +# The legacy zip format expects only one file from torch.save() < 1.6 in the zip. +# We should remove this support since zipfile is now default zipfile format for torch.save(). +def _is_legacy_zip_format(filename): + if zipfile.is_zipfile(filename): + infolist = zipfile.ZipFile(filename).infolist() + return len(infolist) == 1 and not infolist[0].is_dir() + return False + + +def download_url_to_file(url, dst, hash_prefix=None, progress=True): + r"""Download object at the given URL to a local path. + + Args: + url (string): URL of the object to download + dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file`` + hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. + Default: None + progress (bool, optional): whether or not to display a progress bar to stderr + Default: True + + Example: + >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') + + """ + file_size = None + req = Request(url, headers={"User-Agent": "torch.hub"}) + u = urlopen(req) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after + # download is complete. This prevents a local working checkpoint + # being overridden by a broken download. + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with tqdm(total=file_size, disable=not progress, + unit='B', unit_scale=True, unit_divisor=1024) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError('invalid hash value (expected "{}", got "{}")' + .format(hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def _get_extension_path(lib_name): + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec(lib_name) + if ext_specs is None: + raise ImportError + + return ext_specs.origin diff --git a/brainpy/datasets/base.py b/brainpy/datasets/base.py new file mode 100644 index 000000000..95f275b25 --- /dev/null +++ b/brainpy/datasets/base.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- + + + +import bisect +import warnings +from typing import Any +from typing import Callable, Generic, Iterable, Iterator, List, Optional, Tuple, TypeVar + +T_co = TypeVar('T_co', covariant=True) +T = TypeVar('T') + + +__all__ = [ + 'Dataset', + 'IterableDataset', + 'ChainDataset', + 'StandardTransform' +] + +class Dataset(Generic[T_co]): + r"""An abstract class representing a :class:`Dataset`. + + All datasets that represent a map from keys to data samples should subclass + it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given key. Subclasses could also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~.Sampler` implementations and the default options + of :class:`~.DataLoader`. + + .. note:: + :class:`~.DataLoader` by default constructs a index + sampler that yields integral indices. To make it work with a map-style + dataset with non-integral indices/keys, a custom sampler must be provided. + """ + + def __getitem__(self, index) -> T_co: + raise NotImplementedError + + def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': + return ConcatDataset([self, other]) + + # No `def __len__(self)` default? + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # in pytorch/torch/utils/data/sampler.py + + +class IterableDataset(Dataset[T_co]): + r"""An iterable Dataset. + + All datasets that represent an iterable of data samples should subclass it. + Such form of datasets is particularly useful when data come from a stream. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this dataset. + + When a subclass is used with :class:`~.DataLoader`, each + item in the dataset will be yielded from the :class:`~.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the dataset object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Example 1: splitting workload across all workers in :meth:`__iter__`:: + + >>> class MyIterableDataset(.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... worker_info = .get_worker_info() + ... if worker_info is None: # single-process data loading, return the full iterator + ... iter_start = self.start + ... iter_end = self.end + ... else: # in a worker process + ... # split workload + ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... iter_start = self.start + worker_id * per_worker + ... iter_end = min(iter_start + per_worker, self.end) + ... return iter(range(iter_start, iter_end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + + >>> # Mult-process loading with two worker processes + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> print(list(.DataLoader(ds, num_workers=2))) + [3, 5, 4, 6] + + >>> # With even more workers + >>> print(list(.DataLoader(ds, num_workers=20))) + [3, 4, 5, 6] + + Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: + + >>> class MyIterableDataset(.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + >>> + >>> # Directly doing multi-process loading yields duplicate data + >>> print(list(.DataLoader(ds, num_workers=2))) + [3, 3, 4, 4, 5, 5, 6, 6] + + >>> # Define a `worker_init_fn` that configures each dataset copy differently + >>> def worker_init_fn(worker_id): + ... worker_info = .get_worker_info() + ... dataset = worker_info.dataset # the dataset copy in this worker process + ... overall_start = dataset.start + ... overall_end = dataset.end + ... # configure the dataset to only process the split workload + ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... dataset.start = overall_start + worker_id * per_worker + ... dataset.end = min(dataset.start + per_worker, overall_end) + ... + + >>> # Mult-process loading with the custom `worker_init_fn` + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> print(list(.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) + [3, 5, 4, 6] + + >>> # With even more workers + >>> print(list(.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) + [3, 4, 5, 6] + """ + + def __iter__(self) -> Iterator[T_co]: + raise NotImplementedError + + def __add__(self, other: Dataset[T_co]): + return ChainDataset([self, other]) + + # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + + +class ChainDataset(IterableDataset): + r"""Dataset for chaining multiple :class:`IterableDataset` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ChainDataset, self).__init__() + self.datasets = datasets + + def __iter__(self): + for d in self.datasets: + assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" + for x in d: + yield x + + def __len__(self): + total = 0 + for d in self.datasets: + assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" + total += len(d) + return total + + +class ConcatDataset(Dataset[T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + datasets: List[Dataset[T_co]] + cumulative_sizes: List[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ConcatDataset, self).__init__() + self.datasets = list(datasets) + assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] + for d in self.datasets: + assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + def cummulative_sizes(self): + warnings.warn("cummulative_sizes attribute is renamed to " + "cumulative_sizes", DeprecationWarning, stacklevel=2) + return self.cumulative_sizes + + +class StandardTransform: + def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: + self.transform = transform + self.target_transform = target_transform + + def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: + if self.transform is not None: + input = self.transform(input) + if self.target_transform is not None: + target = self.target_transform(target) + return input, target + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def __repr__(self) -> str: + body = [self.__class__.__name__] + if self.transform is not None: + body += self._format_transform_repr(self.transform, "Transform: ") + if self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, "Target transform: ") + + return "\n".join(body) + diff --git a/brainpy/datasets/chaotic_systems.py b/brainpy/datasets/chaotic_systems.py index 71836a8da..e08b9dbd5 100644 --- a/brainpy/datasets/chaotic_systems.py +++ b/brainpy/datasets/chaotic_systems.py @@ -164,7 +164,7 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65, if inits is None: inits = bm.ones(1) * 1.2 elif isinstance(inits, (float, int)): - inits = bm.asarray([inits], dtype=bm.get_dfloat()) + inits = bm.asarray([inits], dtype=bm.dftype()) else: assert isinstance(inits, (bm.ndarray, jnp.ndarray)) diff --git a/brainpy/datasets/vision/__init__.py b/brainpy/datasets/vision/__init__.py new file mode 100644 index 000000000..410909f87 --- /dev/null +++ b/brainpy/datasets/vision/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .mnist import * + diff --git a/brainpy/datasets/vision/base.py b/brainpy/datasets/vision/base.py new file mode 100644 index 000000000..46d8d7cca --- /dev/null +++ b/brainpy/datasets/vision/base.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +import os +import os.path +from typing import Any +from typing import Callable, List, Optional + +from ..base import Dataset, StandardTransform + +__all__ = [ + 'VisionDataset' +] + + +class VisionDataset(Dataset): + """ + Base Class For making datasets which are compatible with torchvision. + It is necessary to override the ``__getitem__`` and ``__len__`` method. + + Args: + root (string): Root directory of dataset. + transforms (callable, optional): A function/transforms that takes in + an image and a label and returns the transformed versions of both. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + + .. note:: + + :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. + """ + + _repr_indent = 4 + + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + if isinstance(root, (str, bytes)): + root = os.path.expanduser(root) + self.root = root + + has_transforms = transforms is not None + has_separate_transform = transform is not None or target_transform is not None + if has_transforms and has_separate_transform: + raise ValueError("Only transforms or transform/target_transform can be passed as argument") + + # for backwards-compatibility + self.transform = transform + self.target_transform = target_transform + + if has_separate_transform: + transforms = StandardTransform(transform, target_transform) + self.transforms = transforms + + def __getitem__(self, index: int) -> Any: + """ + Args: + index (int): Index + + Returns: + (Any): Sample and meta data, optionally transformed by the respective transforms. + """ + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [f"Number of datapoints: {self.__len__()}"] + if self.root is not None: + body.append(f"Root location: {self.root}") + body += self.extra_repr().splitlines() + if hasattr(self, "transforms") and self.transforms is not None: + body += [repr(self.transforms)] + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def extra_repr(self) -> str: + return "" + diff --git a/brainpy/datasets/vision/mnist.py b/brainpy/datasets/vision/mnist.py new file mode 100644 index 000000000..a72a22660 --- /dev/null +++ b/brainpy/datasets/vision/mnist.py @@ -0,0 +1,561 @@ +# -*- coding: utf-8 -*- + +import codecs +import os +import os.path +import shutil +import string +import sys +import warnings +from typing import Any +from typing import Callable, Dict, List, Optional, Tuple +from urllib.error import URLError +from brainpy.errors import PackageMissingError + +import jax.numpy as jnp +import numpy as np +try: + from PIL import Image +except ImportError: + Image = None + +import brainpy.math as bm +from .base import VisionDataset +from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity + +__all__ = [ + 'MNIST', + 'FashionMNIST', + 'KMNIST', + 'EMNIST', + 'QMNIST', +] + + +class MNIST(VisionDataset): + """`MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` + and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = [ + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", + ] + + resources = [ + ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), + ] + + training_file = "training.pt" + test_file = "test.pt" + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + @property + def train_labels(self): + warnings.warn("train_labels has been renamed targets") + return self.targets + + @property + def test_labels(self): + warnings.warn("test_labels has been renamed targets") + return self.targets + + @property + def train_data(self): + warnings.warn("train_data has been renamed data") + return self.data + + @property + def test_data(self): + warnings.warn("test_data has been renamed data") + return self.data + + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.train = train # training set or test set + + if self._check_legacy_exist(): + self.data, self.targets = self._load_legacy_data() + return + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self.data, self.targets = self._load_data() + + def _check_legacy_exist(self): + processed_folder_exists = os.path.exists(self.processed_folder) + if not processed_folder_exists: + return False + + return all( + check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) + ) + + def _load_legacy_data(self): + # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data + # directly. + data_file = self.training_file if self.train else self.test_file + return jnp.load(os.path.join(self.processed_folder, data_file)) + + def _load_data(self): + image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" + data = read_image_file(os.path.join(self.raw_folder, image_file)) + + label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" + targets = read_label_file(os.path.join(self.raw_folder, label_file)) + + return data, targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + if Image is None: + raise PackageMissingError('Need pillow to read the image, pleas install pillow first.') + img = Image.fromarray(img.numpy(), mode="L") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "raw") + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "processed") + + @property + def class_to_idx(self) -> Dict[str, int]: + return {_class: i for i, _class in enumerate(self.classes)} + + def _check_exists(self) -> bool: + return all( + check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) + for url, _ in self.resources + ) + + def download(self) -> None: + """Download the MNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + # download files + for filename, md5 in self.resources: + for mirror in self.mirrors: + url = f"{mirror}{filename}" + try: + print(f"Downloading {url}") + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + except URLError as error: + print(f"Failed to download (trying next):\n{error}") + continue + finally: + print() + break + else: + raise RuntimeError(f"Error downloading {filename}") + + def extra_repr(self) -> str: + split = "Train" if self.train is True else "Test" + return f"Split: {split}" + + +class FashionMNIST(MNIST): + """`Fashion-MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` + and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), + ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), + ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), + ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), + ] + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] + + +class KMNIST(MNIST): + """`Kuzushiji-MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` + and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), + ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), + ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), + ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), + ] + classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] + + +class EMNIST(MNIST): + """`EMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` + and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. + split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies + which one to use. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" + md5 = "58c8d27c78d21e728a6bc7b3cc06412e" + splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") + # Merged Classes assumes Same structure for both uppercase and lowercase version + _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} + _all_classes = set(string.digits + string.ascii_letters) + classes_split_dict = { + "byclass": sorted(list(_all_classes)), + "bymerge": sorted(list(_all_classes - _merged_classes)), + "balanced": sorted(list(_all_classes - _merged_classes)), + "letters": ["N/A"] + list(string.ascii_lowercase), + "digits": list(string.digits), + "mnist": list(string.digits), + } + + def __init__(self, root: str, split: str, **kwargs: Any) -> None: + self.split = verify_str_arg(split, "split", self.splits) + self.training_file = self._training_file(split) + self.test_file = self._test_file(split) + super().__init__(root, **kwargs) + self.classes = self.classes_split_dict[self.split] + + @staticmethod + def _training_file(split) -> str: + return f"training_{split}.pt" + + @staticmethod + def _test_file(split) -> str: + return f"test_{split}.pt" + + @property + def _file_prefix(self) -> str: + return f"emnist-{self.split}-{'train' if self.train else 'test'}" + + @property + def images_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") + + @property + def labels_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") + + def _load_data(self): + return read_image_file(self.images_file), read_label_file(self.labels_file) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def download(self) -> None: + """Download the EMNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) + gzip_folder = os.path.join(self.raw_folder, "gzip") + for gzip_file in os.listdir(gzip_folder): + if gzip_file.endswith(".gz"): + extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) + shutil.rmtree(gzip_folder) + + +class QMNIST(MNIST): + """`QMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset whose ``raw`` + subdir contains binary files of the datasets. + what (string,optional): Can be 'train', 'test', 'test10k', + 'test50k', or 'nist' for respectively the mnist compatible + training set, the 60k qmnist testing set, the 10k qmnist + examples that match the mnist testing set, the 50k + remaining qmnist testing examples, or all the nist + digits. The default is to select 'train' or 'test' + according to the compatibility argument 'train'. + compat (bool,optional): A boolean that says whether the target + for each example is class number (for compatibility with + the MNIST dataloader) or a torch vector containing the + full qmnist information. Default=True. + download (bool, optional): If True, downloads the dataset from + the internet and puts it in root directory. If dataset is + already downloaded, it is not downloaded again. + transform (callable, optional): A function/transform that + takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform + that takes in the target and transforms it. + train (bool,optional,compatibility): When argument 'what' is + not specified, this boolean decides whether to load the + training set ot the testing set. Default: True. + """ + + subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} + resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] + "train": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", + "ed72d4157d28c017586c42bc6afe6370", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", + "0058f8dd561b90ffdd0f734c6a30e5e4", + ), + ], + "test": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", + "1394631089c404de565df7b7aeaf9412", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", + "5b5b05890a5e13444e108efe57b788aa", + ), + ], + "nist": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", + "7f124b3b8ab81486c9d8c2749c17f834", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", + "5ed0e788978e45d4a8bd4b7caec3d79d", + ), + ], + } + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + def __init__( + self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any + ) -> None: + if what is None: + what = "train" if train else "test" + self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) + self.compat = compat + self.data_file = what + ".pt" + self.training_file = self.data_file + self.test_file = self.data_file + super().__init__(root, train, **kwargs) + + @property + def images_file(self) -> str: + (url, _), _ = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + @property + def labels_file(self) -> str: + _, (url, _) = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def _load_data(self): + data = read_sn3_pascalvincent_tensor(self.images_file) + assert data.dtype == jnp.uint8 + assert data.ndim == 3 + + targets = read_sn3_pascalvincent_tensor(self.labels_file).long() + assert targets.ndimension() == 2 + + if self.what == "test10k": + data = data[0:10000, :, :].clone() + targets = targets[0:10000, :].clone() + elif self.what == "test50k": + data = data[10000:, :, :].clone() + targets = targets[10000:, :].clone() + + return data, targets + + def download(self) -> None: + """Download the QMNIST data if it doesn't exist already. + Note that we only download what has been asked for (argument 'what'). + """ + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + split = self.resources[self.subsets[self.what]] + + for url, md5 in split: + download_and_extract_archive(url, self.raw_folder, md5=md5) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + # redefined to handle the compat flag + img, target = self.data[index], self.targets[index] + if Image is None: + raise PackageMissingError('Need pillow to read the image, pleas install pillow first.') + img = Image.fromarray(img.numpy(), mode="L") + if self.transform is not None: + img = self.transform(img) + if self.compat: + target = int(target[0]) + if self.target_transform is not None: + target = self.target_transform(target) + return img, target + + def extra_repr(self) -> str: + return f"Split: {self.what}" + + +def get_int(b: bytes) -> int: + return int(codecs.encode(b, "hex"), 16) + + +SN3_PASCALVINCENT_TYPEMAP = { + 8: jnp.uint8, + 9: jnp.int8, + 11: jnp.int16, + 12: jnp.int32, + 13: jnp.float32, + 14: jnp.float64, +} + + +def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> jnp.ndarray: + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). + Argument may be a filename, compressed filename, or file object. + """ + # read + with open(path, "rb") as f: + data = f.read() + # parse + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + assert 1 <= nd <= 3 + assert 8 <= ty <= 14 + dtype = SN3_PASCALVINCENT_TYPEMAP[ty] + s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] + + num_bytes_per_value = jnp.iinfo(dtype).bits // 8 + # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, + # we need to reverse the bytes before we can read them with .frombuffer(). + needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 + parsed = jnp.frombuffer(bytearray(data), dtype=dtype, offset=(4 * (nd + 1))) + if needs_byte_reversal: + parsed = jnp.flip(parsed, 0) + assert parsed.shape[0] == np.prod(s) or not strict + return parsed.reshape(*s) + + +def read_label_file(path: str) -> jnp.ndarray: + x = read_sn3_pascalvincent_tensor(path, strict=False) + assert x.dtype == jnp.uint8 + assert x.ndim == 1 + return x.astype(bm.dftype()) + + +def read_image_file(path: str) -> jnp.ndarray: + x = read_sn3_pascalvincent_tensor(path, strict=False) + assert x.dtype == jnp.uint8 + assert x.ndim == 3 + return x diff --git a/brainpy/datasets/vision/utils.py b/brainpy/datasets/vision/utils.py new file mode 100644 index 000000000..04aed62ce --- /dev/null +++ b/brainpy/datasets/vision/utils.py @@ -0,0 +1,479 @@ +# -*- coding: utf-8 -*- + +import bz2 +import gzip +import hashlib +import itertools +import lzma +import os +import os.path +import pathlib +import re +import tarfile +import urllib +import urllib.error +import urllib.request +import zipfile +from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator +from urllib.parse import urlparse + +from brainpy.errors import PackageMissingError + +try: + import requests +except ImportError: + requests = None +from tqdm import tqdm + +from .._internally_replaced_utils import ( + _download_file_from_remote_location, + _is_remote_location_available, +) + +# import torch +# from torch.utils.model_zoo import tqdm + +USER_AGENT = "pytorch/vision" + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def gen_bar_updater() -> Callable[[int, int, int], None]: + pbar = tqdm(total=None) + + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: + md5 = hashlib.md5() + with open(fpath, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def _get_redirect_url(url: str, max_hops: int = 3) -> str: + initial_url = url + headers = {"Method": "HEAD", "User-Agent": USER_AGENT} + + for _ in range(max_hops + 1): + with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: + if response.url == url or response.url is None: + return url + + url = response.url + else: + raise RecursionError( + f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}." + ) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def download_url( + url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 +) -> None: + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the basename of the URL + md5 (str, optional): MD5 checksum of the download. If None, do not check + max_redirect_hops (int, optional): Maximum number of redirect hops allowed + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + if _is_remote_location_available(): + _download_file_from_remote_location(fpath, url) + else: + # expand redirect chain if needed + url = _get_redirect_url(url, max_hops=max_redirect_hops) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def list_dir(root: str, prefix: bool = False) -> List[str]: + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + return directories + + +def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] + if prefix is True: + files = [os.path.join(root, d) for d in files] + return files + + +def _quota_exceeded(first_chunk: bytes) -> bool: + try: + return "Google Drive - Quota exceeded" in first_chunk.decode() + except UnicodeDecodeError: + return False + + +def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): + """Download a Google Drive file from and place it in root. + + Args: + file_id (str): id of file to be downloaded + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the id of the file. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url + + url = "https://docs.google.com/uc?export=download" + + root = os.path.expanduser(root) + if not filename: + filename = file_id + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if os.path.isfile(fpath) and check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + else: + if requests is None: + raise PackageMissingError('Need "requests" package, please install it.') + session = requests.Session() + + response = session.get(url, params={"id": file_id}, stream=True) + token = _get_confirm_token(response) + + if token: + params = {"id": file_id, "confirm": token} + response = session.get(url, params=params, stream=True) + + # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent + # with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517. + # Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding + # the first_chunk of the payload + response_content_generator = response.iter_content(32768) + first_chunk = None + while not first_chunk: # filter out keep-alive new chunks + first_chunk = next(response_content_generator) + + if _quota_exceeded(first_chunk): + msg = ( + f"The daily quota of the file {filename} is exceeded and it " + f"can't be downloaded. This is a limitation of Google Drive " + f"and can only be overcome by trying again later." + ) + raise RuntimeError(msg) + + _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath) + response.close() + + +def _get_confirm_token(response) -> Optional[str]: + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + return value + + return None + + +def _save_response_content( + response_gen: Iterator[bytes], + destination: str, +) -> None: + with open(destination, "wb") as f: + pbar = tqdm(total=None) + progress = 0 + + for chunk in response_gen: + if chunk: # filter out keep-alive new chunks + f.write(chunk) + progress += len(chunk) + pbar.update(progress - pbar.n) + pbar.close() + + +def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: + with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: + tar.extractall(to_path) + + +_ZIP_COMPRESSION_MAP: Dict[str, int] = { + ".bz2": zipfile.ZIP_BZIP2, + ".xz": zipfile.ZIP_LZMA, +} + + +def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: + with zipfile.ZipFile( + from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED + ) as zip: + zip.extractall(to_path) + + +_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { + ".tar": _extract_tar, + ".zip": _extract_zip, +} +_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = { + ".bz2": bz2.open, + ".gz": gzip.open, + ".xz": lzma.open, +} +_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { + ".tbz": (".tar", ".bz2"), + ".tbz2": (".tar", ".bz2"), + ".tgz": (".tar", ".gz"), +} + + +def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: + """Detect the archive type and/or compression of a file. + + Args: + file (str): the filename + + Returns: + (tuple): tuple of suffix, archive type, and compression + + Raises: + RuntimeError: if file has no suffix or suffix is not supported + """ + suffixes = pathlib.Path(file).suffixes + if not suffixes: + raise RuntimeError( + f"File '{file}' has no suffixes that could be used to detect the archive type and compression." + ) + suffix = suffixes[-1] + + # check if the suffix is a known alias + if suffix in _FILE_TYPE_ALIASES: + return (suffix, *_FILE_TYPE_ALIASES[suffix]) + + # check if the suffix is an archive type + if suffix in _ARCHIVE_EXTRACTORS: + return suffix, suffix, None + + # check if the suffix is a compression + if suffix in _COMPRESSED_FILE_OPENERS: + # check for suffix hierarchy + if len(suffixes) > 1: + suffix2 = suffixes[-2] + + # check if the suffix2 is an archive type + if suffix2 in _ARCHIVE_EXTRACTORS: + return suffix2 + suffix, suffix2, suffix + + return suffix, None, suffix + + valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)) + raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") + + +def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + r"""Decompress a file. + + The compression is automatically detected from the file name. + + Args: + from_path (str): Path to the file to be decompressed. + to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the decompressed file. + """ + suffix, archive_type, compression = _detect_file_type(from_path) + if not compression: + raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") + + if to_path is None: + to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] + + with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: + wfh.write(rfh.read()) + + if remove_finished: + os.remove(from_path) + + return to_path + + +def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + """Extract an archive. + + The archive type and a possible compression is automatically detected from the file name. If the file is compressed + but not an archive the call is dispatched to :func:`decompress`. + + Args: + from_path (str): Path to the file to be extracted. + to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is + used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the directory the file was extracted to. + """ + if to_path is None: + to_path = os.path.dirname(from_path) + + suffix, archive_type, compression = _detect_file_type(from_path) + if not archive_type: + return _decompress( + from_path, + os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), + remove_finished=remove_finished, + ) + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + extractor = _ARCHIVE_EXTRACTORS[archive_type] + + extractor(from_path, to_path, compression) + if remove_finished: + os.remove(from_path) + + return to_path + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f"Extracting {archive} to {extract_root}") + extract_archive(archive, extract_root, remove_finished) + + +def iterable_to_str(iterable: Iterable) -> str: + return "'" + "', '".join([str(item) for item in iterable]) + "'" + + +T = TypeVar("T", str, bytes) + + +def verify_str_arg( + value: T, + arg: Optional[str] = None, + valid_values: Iterable[T] = None, + custom_msg: Optional[str] = None, +) -> T: + if not isinstance(value, (str, bytes)): + if arg is None: + msg = "Expected type str, but got type {type}." + else: + msg = "Expected type str for argument {arg}, but got type {type}." + msg = msg.format(type=type(value), arg=arg) + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + if custom_msg is not None: + msg = custom_msg + else: + msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." + msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) + raise ValueError(msg) + + return value diff --git a/brainpy/dyn/__init__.py b/brainpy/dyn/__init__.py index 74dd39430..9e959af67 100644 --- a/brainpy/dyn/__init__.py +++ b/brainpy/dyn/__init__.py @@ -4,14 +4,15 @@ Dynamics simulation module. """ - from .base import * +from .training import * from .neurons.compat import * from .synapses.compat import * -from .utils import * from .runners import * -from . import (channels, neurons, rates, - synapses, synouts, synplast, +from . import (channels, neurons, rates, # neuron related + synapses, synouts, synplast, # synapse related networks, - utils, runners) + layers, # ANN related + runners) + diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 1527121dc..86a047b07 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -import warnings -from typing import Union, Dict, Callable, Sequence, Optional, List, Tuple +import gc +from typing import Union, Dict, Callable, Sequence, Optional, Tuple import jax.numpy as jnp import numpy as np @@ -10,11 +10,10 @@ from brainpy import tools from brainpy.base.base import Base from brainpy.base.collector import Collector -from brainpy.connect import TwoEndConnector, MatConn, IJConn -from brainpy.dyn.utils import init_noise +from brainpy.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All from brainpy.errors import ModelBuildError -from brainpy.initialize import Initializer, init_param, Uniform -from brainpy.integrators import Integrator, odeint, sdeint +from brainpy.initialize import Initializer, parameter, variable, Uniform, noise as init_noise +from brainpy.integrators import odeint, sdeint from brainpy.tools.others import to_size, size2num from brainpy.types import Tensor, Shape @@ -32,8 +31,7 @@ 'NeuGroup', 'CondNeuGroup', # synapse models - 'SynConn', 'SynapseOutput', 'SynapsePlasticity', 'TwoEndConn', - + 'SynConn', 'SynOutput', 'SynSTP', 'SynLTP', 'TwoEndConn', ] @@ -50,25 +48,37 @@ class DynamicalSystem(Base): The name of the dynamic system. """ + """Global delay data, which stores the delay variables and corresponding delay targets. + + This variable is useful when the same target variable is used in multiple mappings, + because it can reduce the duplicate delay variable registration.""" + global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], bm.Variable]] = dict() - """Global delay variables. Useful when the same target - variable is used in multiple mappings.""" - global_delay_vars: Dict[str, bm.LengthDelay] = Collector() - global_delay_targets: Dict[str, bm.Variable] = Collector() - - def __init__(self, name=None): + def __init__( + self, + name: str = None, + trainable: bool = False, + ): super(DynamicalSystem, self).__init__(name=name) # local delay variables self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector() - def __repr__(self): - return f'{self.__class__.__name__}(name={self.name})' + # trainable setting + self._trainable = trainable @property - def steps(self): - warnings.warn('.steps has been deprecated since version 2.0.3.', DeprecationWarning) - return {} + def trainable(self): + return self._trainable + + @trainable.setter + def trainable(self, value): + if not isinstance(value, bool): + raise ValueError(f'Must be a bool value. But we got {type(value)}: {value}') + self._trainable = value + + def __repr__(self): + return f'{self.__class__.__name__}(name={self.name}, trainable={self.trainable})' def __call__(self, *args, **kwargs): """The shortcut to call ``update`` methods.""" @@ -76,7 +86,7 @@ def __call__(self, *args, **kwargs): def register_delay( self, - name: str, + identifier: str, delay_step: Optional[Union[int, Tensor, Callable, Initializer]], delay_target: bm.Variable, initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None, @@ -85,7 +95,7 @@ def register_delay( Parameters ---------- - name: str + identifier: str The delay variable name. delay_step: Optional, int, JaxArray, ndarray, callable, Initializer The number of the steps of the delay. @@ -111,7 +121,7 @@ def register_delay( delay_type = 'heter' delay_step = bm.asarray(delay_step) elif callable(delay_step): - delay_step = init_param(delay_step, delay_target.shape, allow_none=False) + delay_step = parameter(delay_step, delay_target.shape, allow_none=False) delay_type = 'heter' else: raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' @@ -127,36 +137,39 @@ def register_delay( max_delay_step = int(bm.max(delay_step)) # delay target - if not isinstance(delay_target, bm.Variable): - raise ValueError(f'"delay_target" must be an instance of Variable, but we got {type(delay_target)}') + if delay_type != 'none': + if not isinstance(delay_target, bm.Variable): + raise ValueError(f'"delay_target" must be an instance of Variable, but we got {type(delay_target)}') # delay variable - self.global_delay_targets[name] = delay_target if delay_type != 'none': - if name not in self.global_delay_vars: - self.global_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data) - self.local_delay_vars[name] = self.global_delay_vars[name] + if identifier not in self.global_delay_data: + delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data) + self.global_delay_data[identifier] = (delay, delay_target) + self.local_delay_vars[identifier] = delay else: - if self.global_delay_vars[name].num_delay_step - 1 < max_delay_step: - self.global_delay_vars[name].reset(delay_target, max_delay_step, initial_delay_data) + if self.global_delay_data[identifier][0].num_delay_step - 1 < max_delay_step: + self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data) + else: + self.global_delay_data[identifier] = (None, delay_target) self.register_implicit_nodes(self.local_delay_vars) return delay_step def get_delay_data( self, - name: str, + identifier: str, delay_step: Optional[Union[int, bm.JaxArray, jnp.DeviceArray]], - *indices: Union[int, bm.JaxArray, jnp.DeviceArray], + *indices: Union[int, slice, bm.JaxArray, jnp.DeviceArray], ): """Get delay data according to the provided delay steps. Parameters ---------- - name: str + identifier: str The delay variable name. delay_step: Optional, int, JaxArray, ndarray The delay length. - indices: optional, int, JaxArray, ndarray + indices: optional, int, slice, JaxArray, ndarray The indices of the delay. Returns @@ -165,90 +178,100 @@ def get_delay_data( The delay data at the given time. """ if delay_step is None: - return self.global_delay_targets[name] + return self.global_delay_data[identifier][1].value - if name in self.global_delay_vars: + if identifier in self.global_delay_data: if isinstance(delay_step, (int, np.integer)): - return self.global_delay_vars[name](delay_step, *indices) + return self.global_delay_data[identifier][0](delay_step, *indices) else: if len(indices) == 0: indices = (jnp.arange(delay_step.size),) - return self.global_delay_vars[name](delay_step, *indices) + return self.global_delay_data[identifier][0](delay_step, *indices) - elif name in self.local_delay_vars: + elif identifier in self.local_delay_vars: if isinstance(delay_step, (int, np.integer)): - return self.local_delay_vars[name](delay_step) + return self.local_delay_vars[identifier](delay_step) else: if len(indices) == 0: indices = (jnp.arange(delay_step.size),) - return self.local_delay_vars[name](delay_step, *indices) + return self.local_delay_vars[identifier](delay_step, *indices) else: - raise ValueError(f'{name} is not defined in delay variables.') - - def update_delay( - self, - name: str, - delay_data: Union[float, bm.JaxArray, jnp.ndarray] - ): - """Update the delay according to the delay data. + raise ValueError(f'{identifier} is not defined in delay variables.') - Parameters - ---------- - name: str - The name of the delay. - delay_data: float, JaxArray, ndarray - The delay data to update at the current time. - """ - warnings.warn('All registered delays by "register_delay()" will be ' - 'automatically updated in the network model since 2.1.13. ' - 'Explicitly call "update_delay()" has no effect.', - DeprecationWarning) - # if name in self.local_delay_vars: - # return self.local_delay_vars[name].update(delay_data) - # else: - # if name not in self.global_delay_vars: - # raise ValueError(f'{name} is not defined in delay variables.') - - def reset_delay( - self, - name: str, - delay_target: Union[bm.JaxArray, jnp.DeviceArray] - ): - """Reset the delay variable.""" - warnings.warn('All registered delays by "register_delay()" will be ' - 'automatically reset in the network model since 2.1.13. ' - 'Explicitly call "reset_delay()" has no effect.', - DeprecationWarning) - # if name in self.local_delay_vars: - # return self.local_delay_vars[name].reset(delay_target) - # else: - # if name not in self.global_delay_vars: - # raise ValueError(f'{name} is not defined in delay variables.') - - def update(self, t, dt): + def update(self, *args, **kwargs): """The function to specify the updating rule. - Assume any dynamical system depends on the time variable ``t`` and - the time step ``dt``. + + Assume any dynamical system depends on the shared variables (`sha`), + like time variable ``t``, the step precision ``dt``, and the time step `i`. """ raise NotImplementedError('Must implement "update" function by subclass self.') - def reset(self): + def reset(self, batch_size=None): """Reset function which reset the whole variables in the model. """ - raise NotImplementedError('Must implement "reset" function by subclass self.') + self.reset_state(batch_size) - def update_local_delays(self): + def reset_state(self, batch_size=None): + """Reset function which reset the states in the model. + """ + raise NotImplementedError('Must implement "reset_state" function by subclass self.') + + def update_local_delays(self, nodes: Union[Sequence, Dict] = None): + """Update local delay variables. + + Parameters + ---------- + nodes: sequence, dict + The nodes to update their delay variables. + """ # update delays - for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values(): + if nodes is None: + nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values() + elif isinstance(nodes, dict): + nodes = nodes.values() + for node in nodes: for name in node.local_delay_vars.keys(): - self.global_delay_vars[name].update(self.global_delay_targets[name].value) + delay = self.global_delay_data[name][0] + target = self.global_delay_data[name][1] + delay.update(target.value) + + def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): + """Reset local delay variables. - def reset_local_delays(self): + Parameters + ---------- + nodes: sequence, dict + The nodes to Reset their delay variables. + """ # reset delays - for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values(): + if nodes is None: + nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values() + elif isinstance(nodes, dict): + nodes = nodes.values() + for node in nodes: for name in node.local_delay_vars.keys(): - self.global_delay_vars[name].reset(self.global_delay_targets[name]) + delay = self.global_delay_data[name][0] + target = self.global_delay_data[name][1] + delay.reset(target.value) + + def __del__(self): + """Function for handling `del` behavior. + + This function is used to pop out the variables which registered in global delay data. + """ + for key in tuple(self.local_delay_vars.keys()): + val = self.global_delay_data.pop(key) + del val + val = self.local_delay_vars.pop(key) + del val + for key in tuple(self.implicit_nodes.keys()): + del self.implicit_nodes[key] + for key in tuple(self.implicit_vars.keys()): + del self.implicit_vars[key] + for key in tuple(self.__dict__.keys()): + del self.__dict__[key] + gc.collect() class Container(DynamicalSystem): @@ -268,8 +291,8 @@ class Container(DynamicalSystem): The instance of DynamicalSystem with the format of "key=dynamic_system". """ - def __init__(self, *ds_tuple, name=None, **ds_dict): - super(Container, self).__init__(name=name) + def __init__(self, *ds_tuple, name=None, trainable=False, **ds_dict): + super(Container, self).__init__(name=name, trainable=trainable) # children dynamical systems self.implicit_nodes = Collector() @@ -294,7 +317,7 @@ def __repr__(self): children = [f'{key}={str(val)}' for key, val in self.implicit_nodes.items()] return f'{cls_name}({split.join(children)})' - def update(self, t, dt): + def update(self, tdi, *args, **kwargs): """Update function of a container. In this update function, the update functions in children systems are @@ -302,7 +325,7 @@ def update(self, t, dt): """ nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique() for node in nodes.values(): - node.update(t, dt) + node.update(tdi) def __getitem__(self, item): """Wrap the slice access (self['']). """ @@ -338,10 +361,19 @@ class Network(Container): A dict container of dynamical system. """ - def __init__(self, *ds_tuple, name=None, **ds_dict): - super(Network, self).__init__(*ds_tuple, name=name, **ds_dict) + def __init__( + self, + *ds_tuple, + name: str = None, + trainable: bool = False, + **ds_dict + ): + super(Network, self).__init__(*ds_tuple, + name=name, + trainable=trainable, + **ds_dict) - def update(self, t, dt): + def update(self, *args, **kwargs): """Step function of a network. In this update function, the update functions in children systems are @@ -354,44 +386,43 @@ def update(self, t, dt): synapse_groups = nodes.subset(SynConn) other_nodes = nodes - neuron_groups - synapse_groups + # shared arguments + shared = args[0] + # update synapse nodes for node in synapse_groups.values(): - node.update(t, dt) + node.update(shared) # update neuron nodes for node in neuron_groups.values(): - node.update(t, dt) + node.update(shared) # update other types of nodes for node in other_nodes.values(): - node.update(t, dt) + node.update(shared) # update delays - for node in nodes.values(): - for name in node.local_delay_vars.keys(): - self.global_delay_vars[name].update(self.global_delay_targets[name].value) + self.update_local_delays(nodes) - def reset(self): + def reset_state(self, batch_size=None): nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique() neuron_groups = nodes.subset(NeuGroup) synapse_groups = nodes.subset(SynConn) # reset neuron nodes for node in neuron_groups.values(): - node.reset() + node.reset_state(batch_size) # reset synapse nodes for node in synapse_groups.values(): - node.reset() + node.reset_state(batch_size) # reset other types of nodes for node in (nodes - neuron_groups - synapse_groups).values(): - node.reset() + node.reset_state(batch_size) # reset delays - for node in nodes.values(): - for name in node.local_delay_vars.keys(): - self.global_delay_vars[name].reset(self.global_delay_targets[name]) + self.reset_local_delays(nodes) class NeuGroup(DynamicalSystem): @@ -421,7 +452,8 @@ def __init__( self, size: Shape, name: str = None, - keep_size: bool = False + keep_size: bool = False, + trainable: bool = False, ): # size if isinstance(size, (list, tuple)): @@ -443,21 +475,27 @@ def __init__( self.num = tools.size2num(size) # initialize - super(NeuGroup, self).__init__(name=name) + super(NeuGroup, self).__init__(name=name, trainable=trainable) @property - def var_shape(self): - return self.size if self.keep_size else self.num + def varshape(self): + return self.size if self.keep_size else (self.num,) + + def get_batch_shape(self, batch_size=None): + if batch_size is None: + return self.varshape + else: + return (batch_size,) + self.varshape - def update(self, t, dt): + def update(self, tdi, x=None): """The function to specify the updating rule. Parameters ---------- - t : float - The current time. - dt : float - The time step. + tdi : DotDict + The shared arguments, especially time `t`, step `dt`, and iteration `i`. + x: Any + The input for a neuron group. """ raise NotImplementedError(f'Subclass of {self.__class__.__name__} must ' f'implement "update" function.') @@ -484,7 +522,10 @@ def __init__( post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]] = None, name: str = None, + trainable: bool = False, ): + super(SynConn, self).__init__(name=name, trainable=trainable) + # pre or post neuron group # ------------------------ if not isinstance(pre, NeuGroup): @@ -517,10 +558,6 @@ def __init__( else: raise ModelBuildError(f'Unknown "conn" type: {conn}') - # initialize - # ---------- - super(SynConn, self).__init__(name=name) - def check_pre_attrs(self, *attrs): """Check whether pre group satisfies the requirement.""" if not hasattr(self, 'pre'): @@ -541,63 +578,58 @@ def check_post_attrs(self, *attrs): if not hasattr(self.post, attr): raise ModelBuildError(f'{self} need "pre" neuron group has attribute "{attr}".') + def update(self, tdi, pre_spike=None): + """The function to specify the updating rule. + + Assume any dynamical system depends on the shared variables (`sha`), + like time variable ``t``, the step precision ``dt``, and the time step `i`. + """ + raise NotImplementedError('Must implement "update" function by subclass self.') + -class SynapseComponent(DynamicalSystem): +class SynComponent(DynamicalSystem): master: SynConn + def reset_state(self, batch_size=None): + pass + def filter(self, g): - raise NotImplementedError + return g + + def __call__(self, *args, **kwargs): + return self.filter(*args, **kwargs) def register_master(self, master: SynConn): if not isinstance(master, SynConn): raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') + if hasattr(self, 'master') and self.master != master: + raise ValueError(f'master has been registered, but we got another master going to be registered.') self.master = master def __repr__(self): - if hasattr(self, 'master'): - return f'{self.__class__.__name__}(master={self.master})' - else: - return self.__class__.__name__ + return self.__class__.__name__ -class SynapseOutput(SynapseComponent): +class SynOutput(SynComponent): """Base class for synaptic current output.""" - def reset(self): + def update(self, tdi): pass - def update(self, t, dt): - pass +class SynSTP(SynComponent): + """Base class for synaptic short-term plasticity.""" -class _NullSynOut(SynapseOutput): - def update(self, t, dt): + def update(self, tdi, pre_spike): pass - def reset(self): - pass - def filter(self, g): - return g +class SynLTP(SynComponent): + """Base class for synaptic long-term plasticity.""" - -class SynapsePlasticity(SynapseComponent): - """Base class for synaptic plasticity.""" - - def update(self, t, dt, pre_spikes, post_spikes): - raise NotImplementedError - - -class _NullSynPlast(SynapsePlasticity): - def update(self, t, dt, pre_spikes, post_spikes): + def update(self, tdi, pre_spike): pass - def reset(self): - pass - - def filter(self, g): - return g - class TwoEndConn(SynConn): """Base class to model synaptic connections. @@ -610,13 +642,13 @@ class TwoEndConn(SynConn): Post-synaptic neuron group. conn : optional, ndarray, JaxArray, dict, TwoEndConnector The connection method between pre- and post-synaptic groups. - output: SynapseOutput + output: SynOutput The output for the synaptic current. .. versionadded:: 2.1.13 The output component for a two-end connection model. - plasticity: SynapsePlasticity + stp: SynSTP The plasticity model for the synaptic variables. .. versionadded:: 2.1.13 @@ -631,27 +663,95 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]] = None, - output: Optional[SynapseOutput] = None, - plasticity: Optional[SynapsePlasticity] = None, + output: Optional[SynOutput] = None, + stp: Optional[SynSTP] = None, name: str = None, + trainable: bool = False, ): - super(TwoEndConn, self).__init__(pre=pre, post=post, conn=conn, name=name) + super(TwoEndConn, self).__init__(pre=pre, + post=post, + conn=conn, + name=name, + trainable=trainable) # synaptic output - if output is None: - output = _NullSynOut() - if not isinstance(output, SynapseOutput): - raise TypeError(f'output must be instance of {SynapseOutput.__name__}, but we got {type(output)}') - self.output: SynapseOutput = output + if output is None: output = SynOutput() + if not isinstance(output, SynOutput): + raise TypeError(f'output must be instance of {SynOutput.__name__}, but we got {type(output)}') + self.output: SynOutput = output self.output.register_master(master=self) # synaptic plasticity - if plasticity is None: - plasticity = _NullSynPlast() - if not isinstance(plasticity, SynapsePlasticity): - raise TypeError(f'plasticity must be instance of {SynapsePlasticity.__name__}, but we got {type(plasticity)}') - self.plasticity: SynapsePlasticity = plasticity - self.plasticity.register_master(master=self) + if stp is None: stp = SynSTP() + if not isinstance(stp, SynSTP): + raise TypeError(f'plasticity must be instance of {SynSTP.__name__}, but we got {type(stp)}') + self.stp: SynSTP = stp + self.stp.register_master(master=self) + + def init_weights( + self, + weight: Union[float, Tensor, Initializer, Callable], + comp_method: str, + sparse_data: str = 'csr' + ) -> Union[float, Tensor]: + if comp_method not in ['sparse', 'dense']: + raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') + if sparse_data not in ['csr', 'ij']: + raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}') + if self.conn is None: + raise ValueError(f'Must provide "conn" when initialize the model {self.name}') + + # connections and weights + if isinstance(self.conn, One2One): + weight = parameter(weight, (self.pre.num,), allow_none=False) + conn_mask = None + + elif isinstance(self.conn, All2All): + weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False) + conn_mask = None + + else: + if comp_method == 'sparse': + if sparse_data == 'csr': + conn_mask = self.conn.require('pre2post') + elif sparse_data == 'ij': + conn_mask = self.conn.require('post_ids', 'pre_ids') + else: + ValueError(f'Unknown sparse data type: {sparse_data}') + weight = parameter(weight, conn_mask[1].shape, allow_none=False) + elif comp_method == 'dense': + weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False) + conn_mask = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {comp_method}') + + # training weights + if self.trainable: + weight = bm.TrainVar(weight) + return weight, conn_mask + + def syn2post_with_all2all(self, syn_value, syn_weight): + if bm.size(syn_weight) == 1: + if self.trainable: + post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) + else: + post_vs = bm.sum(syn_value) + if not self.conn.include_self: + post_vs = post_vs - syn_value + post_vs = syn_weight * post_vs + else: + post_vs = syn_value @ syn_weight + return post_vs + + def syn2post_with_one2one(self, syn_value, syn_weight): + return syn_value * syn_weight + + def syn2post_with_dense(self, syn_value, syn_weight, conn_mat): + if bm.size(syn_weight) == 1: + post_vs = (syn_weight * syn_value) @ conn_mat + else: + post_vs = syn_value @ (syn_weight * conn_mat) + return post_vs class CondNeuGroup(NeuGroup, Container): @@ -712,22 +812,24 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, + trainable: bool = False, **channels ): - NeuGroup.__init__(self, size, keep_size=keep_size) - Container.__init__(self, **channels, name=name) + NeuGroup.__init__(self, size, keep_size=keep_size, trainable=trainable) + Container.__init__(self, **channels, name=name, trainable=trainable) # parameters for neurons self.C = C self.A = A self.V_th = V_th self._V_initializer = V_initializer - self.noise = init_noise(noise, self.var_shape, num_vars=3) + self.noise = init_noise(noise, self.varshape, num_vars=3) # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape, allow_none=False)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = variable(V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) # function if self.noise is None: @@ -742,16 +844,17 @@ def derivative(self, V, t): Iext = Iext + ch.current(V) return Iext / self.C - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape, allow_none=False) - self.spike[:] = False - self.input[:] = 0 + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) - def update(self, t, dt): - V = self.integral(self.V.value, t, dt) + def update(self, tdi, *args, **kwargs): + V = self.integral(self.V.value, tdi['t'], tdi['dt']) channels = self.nodes(level=1, include_self=False).subset(Channel).unique() for node in channels.values(): - node.update(t, dt, self.V.value) + node.update(tdi, self.V.value) self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) self.input[:] = 0. self.V.value = V @@ -771,8 +874,9 @@ def __init__( size: Union[int, Sequence[int]], name: str = None, keep_size: bool = False, + trainable: bool = False, ): - super(Channel, self).__init__(name=name) + super(Channel, self).__init__(name=name, trainable=trainable) # the geometry size self.size = to_size(size) # the number of elements @@ -781,16 +885,16 @@ def __init__( self.keep_size = keep_size @property - def var_shape(self): + def varshape(self): return self.size if self.keep_size else self.num - def update(self, t, dt, V): + def update(self, tdi, V): raise NotImplementedError('Must be implemented by the subclass.') def current(self, V): raise NotImplementedError('Must be implemented by the subclass.') - def reset(self, V): + def reset_state(self, batch_size=None): raise NotImplementedError('Must be implemented by the subclass.') diff --git a/brainpy/dyn/channels/Ca.py b/brainpy/dyn/channels/Ca.py index d027d2ed3..d4dc9316c 100644 --- a/brainpy/dyn/channels/Ca.py +++ b/brainpy/dyn/channels/Ca.py @@ -5,12 +5,11 @@ """ - from typing import Union, Callable import brainpy.math as bm from brainpy.dyn.base import Channel -from brainpy.initialize import OneInit, Initializer, init_param +from brainpy.initialize import OneInit, Initializer, parameter, variable from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.types import Shape, Tensor @@ -50,25 +49,27 @@ def __init__( C: Union[float, Tensor, Initializer, Callable] = 2.4e-4, method: str = 'exp_auto', name: str = None, + trainable: bool = False, **channels ): super(CalciumFixed, self).__init__(size, keep_size=keep_size, method=method, name=name, + trainable=trainable, **channels) - self.E = init_param(E, self.var_shape, allow_none=False) - self.C = init_param(C, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.C = parameter(C, self.varshape, allow_none=False) - def update(self, t, dt, V): + def update(self, tdi, V): for node in self.implicit_nodes.values(): - node.update(t, dt, V, self.C, self.E) + node.update(tdi, V, self.C, self.E) - def reset(self, V, C_Ca=None, E_Ca=None): + def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): C_Ca = self.C if C_Ca is None else C_Ca E_Ca = self.E if E_Ca is None else E_Ca for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values(): - node.reset(V, C_Ca, E_Ca) + node.reset_state(V, C_Ca, E_Ca, batch_size=batch_size) class CalciumDyna(Calcium): @@ -103,23 +104,26 @@ def __init__( C_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.4e-4), method: str = 'exp_auto', name: str = None, + trainable: bool = False, **channels ): super(CalciumDyna, self).__init__(size, keep_size=keep_size, method=method, name=name, + trainable=trainable, **channels) # parameters - self.C0 = init_param(C0, self.var_shape, allow_none=False) - self.T = init_param(T, self.var_shape, allow_none=False) # temperature + self.C0 = parameter(C0, self.varshape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) # temperature self._C_initializer = C_initializer self._constant = self.R / (2 * self.F) * (273.15 + self.T) # variables - self.C = bm.Variable(init_param(C_initializer, self.var_shape)) # Calcium concentration - self.E = bm.Variable(self._reversal_potential(self.C)) # Reversal potential + self.C = variable(C_initializer, trainable, self.varshape) # Calcium concentration + self.E = bm.Variable(self._reversal_potential(self.C), + batch_axis=0 if trainable else None) # Reversal potential # function self.integral = odeint(self.derivative, method=method) @@ -127,16 +131,16 @@ def __init__( def derivative(self, C, t, V): raise NotImplementedError - def reset(self, V, C_Ca=None, E_Ca=None): - self.C[:] = init_param(self._C_initializer, self.var_shape) if (C_Ca is None) else C_Ca + def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): + self.C.value = variable(self._C_initializer, batch_size, self.varshape) if (C_Ca is None) else C_Ca self.E.value = self._reversal_potential(self.C) for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values(): - node.reset(V, self.C, self.E) + node.reset(V, self.C, self.E, batch_size=batch_size) - def update(self, t, dt, V): + def update(self, tdi, V): for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values(): - node.update(t, dt, V, self.C, self.E) - self.C.value = self.integral(self.C.value, t, V, dt) + node.update(tdi, V, self.C, self.E) + self.C.value = self.integral(self.C.value, tdi['t'], V, tdi['dt']) self.E.value = self._reversal_potential(self.C) def _reversal_potential(self, C): @@ -266,6 +270,7 @@ def __init__( C_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.4e-4), method: str = 'exp_auto', name: str = None, + trainable: bool = False, **channels ): super(CalciumDetailed, self).__init__(size, @@ -275,12 +280,13 @@ def __init__( T=T, C0=C0, C_initializer=C_initializer, + trainable=trainable, **channels) # parameters - self.d = init_param(d, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.C_rest = init_param(C_rest, self.var_shape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.C_rest = parameter(C_rest, self.varshape, allow_none=False) def derivative(self, C, t, V): ICa = self.current(V, C, self.E) @@ -308,6 +314,7 @@ def __init__( C_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.4e-4), method: str = 'exp_auto', name: str = None, + trainable: bool = False, **channels ): super(CalciumFirstOrder, self).__init__(size, @@ -317,11 +324,12 @@ def __init__( T=T, C0=C0, C_initializer=C_initializer, + trainable=trainable, **channels) # parameters - self.alpha = init_param(alpha, self.var_shape, allow_none=False) - self.beta = init_param(beta, self.var_shape, allow_none=False) + self.alpha = parameter(alpha, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) def derivative(self, C, t, V): ICa = self.current(V, C, self.E) @@ -373,18 +381,22 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 3., g_max: Union[float, Tensor, Initializer, Callable] = 2., method: str = 'exp_auto', + trainable: bool = False, name: str = None ): - super(ICa_p2q_ss, self).__init__(size, keep_size=keep_size, name=name) + super(ICa_p2q_ss, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable, ) # parameters - self.phi_p = init_param(phi_p, self.var_shape, allow_none=False) - self.phi_q = init_param(phi_q, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) - self.q = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) + self.q = variable(bm.zeros, trainable, self.varshape) # functions self.integral = odeint(JointEq([self.dp, self.dq]), method=method) @@ -395,15 +407,18 @@ def dp(self, p, t, V): def dq(self, q, t, V): return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - def update(self, t, dt, V, C_Ca, E_Ca): - self.p.value, self.q.value = self.integral(self.p, self.q, t, V, dt) + def update(self, tdi, V, C_Ca, E_Ca): + self.p.value, self.q.value = self.integral(self.p, self.q, tdi['t'], V, tdi['dt']) def current(self, V, C_Ca, E_Ca): return self.g_max * self.p * self.p * self.q * (E_Ca - V) - def reset(self, V, C_Ca, E_Ca): + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): self.p.value = self.f_p_inf(V) self.q.value = self.f_q_inf(V) + if batch_size is not None: + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size def f_p_inf(self, V): raise NotImplementedError @@ -459,18 +474,22 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 3., g_max: Union[float, Tensor, Initializer, Callable] = 2., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - super(ICa_p2q_markov, self).__init__(size, keep_size=keep_size, name=name) + super(ICa_p2q_markov, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.phi_p = init_param(phi_p, self.var_shape, allow_none=False) - self.phi_q = init_param(phi_q, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) - self.q = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) + self.q = variable(bm.zeros, trainable, self.varshape) # functions self.integral = odeint(JointEq([self.dp, self.dq]), method=method) @@ -481,17 +500,20 @@ def dp(self, p, t, V): def dq(self, q, t, V): return self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q) - def update(self, t, dt, V, C_Ca, E_Ca): - self.p.value, self.q.value = self.integral(self.p, self.q, t, V, dt) + def update(self, tdi, V, C_Ca, E_Ca): + self.p.value, self.q.value = self.integral(self.p, self.q, tdi['t'], V, tdi['dt']) def current(self, V, C_Ca, E_Ca): return self.g_max * self.p * self.p * self.q * (E_Ca - V) - def reset(self, V, C_Ca, E_Ca): + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): alpha, beta = self.f_p_alpha(V), self.f_p_beta(V) self.p.value = alpha / (alpha + beta) alpha, beta = self.f_q_alpha(V), self.f_q_beta(V) self.q.value = alpha / (alpha + beta) + if batch_size is not None: + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size def f_p_alpha(self, V): raise NotImplementedError @@ -554,17 +576,21 @@ def __init__( g_max: Union[float, Tensor, Initializer, Callable] = 1., phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - super(ICaN_IS2008, self).__init__(size, keep_size=keep_size, name=name) + super(ICaN_IS2008, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.E = init_param(E, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.phi = init_param(phi, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(self.derivative, method=method) @@ -574,16 +600,18 @@ def derivative(self, p, t, V): p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6 return self.phi * (phi_p - p) / p_inf - def update(self, t, dt, V, C_Ca, E_Ca): - self.p.value = self.integral(self.p, t, V, dt) + def update(self, tdi, V, C_Ca, E_Ca): + self.p.value = self.integral(self.p, tdi['t'], V, tdi['dt']) def current(self, V, C_Ca, E_Ca): M = C_Ca / (C_Ca + 0.2) g = self.g_max * M * self.p return g * (self.E - V) - def reset(self, V, C_Ca, E_Ca): + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): self.p.value = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) + if batch_size is not None: + assert self.p.shape[0] == batch_size class ICaT_HM1992(ICa_p2q_ss): @@ -646,7 +674,8 @@ def __init__( phi_p: Union[float, Tensor, Initializer, Callable] = None, phi_q: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q @@ -656,13 +685,14 @@ def __init__( method=method, g_max=g_max, phi_p=phi_p, - phi_q=phi_q) + phi_q=phi_q, + trainable=trainable) # parameters - self.T = init_param(T, self.var_shape, allow_none=False) - self.T_base_p = init_param(T_base_p, self.var_shape, allow_none=False) - self.T_base_q = init_param(T_base_q, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): return 1. / (1 + bm.exp(-(V + 59. - self.V_sh) / 6.2)) @@ -741,8 +771,9 @@ def __init__( V_sh: Union[float, Tensor, Initializer, Callable] = -3., phi_p: Union[float, Tensor, Initializer, Callable] = None, phi_q: Union[float, Tensor, Initializer, Callable] = None, - method='exp_auto', - name=None + method: str = 'exp_auto', + name: str = None, + trainable: bool = False, ): phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q @@ -752,13 +783,14 @@ def __init__( method=method, g_max=g_max, phi_p=phi_p, - phi_q=phi_q) + phi_q=phi_q, + trainable=trainable) # parameters - self.T = init_param(T, self.var_shape, allow_none=False) - self.T_base_p = init_param(T_base_p, self.var_shape, allow_none=False) - self.T_base_q = init_param(T_base_q, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): return 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4)) @@ -833,7 +865,8 @@ def __init__( g_max: Union[float, Tensor, Initializer, Callable] = 2., V_sh: Union[float, Tensor, Initializer, Callable] = 25., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(ICaHT_HM1992, self).__init__(size, keep_size=keep_size, @@ -841,17 +874,18 @@ def __init__( method=method, g_max=g_max, phi_p=T_base_p ** ((T - 24) / 10), - phi_q=T_base_q ** ((T - 24) / 10)) + phi_q=T_base_q ** ((T - 24) / 10), + trainable=trainable) # parameters - self.T = init_param(T, self.var_shape, allow_none=False) - self.T_base_p = init_param(T_base_p, self.var_shape, allow_none=False) - self.T_base_q = init_param(T_base_q, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) - self.q = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) + self.q = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(JointEq([self.dp, self.dq]), method=method) @@ -938,7 +972,8 @@ def __init__( g_max: Union[float, Tensor, Initializer, Callable] = 1., V_sh: Union[float, Tensor, Initializer, Callable] = 0., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q @@ -948,11 +983,12 @@ def __init__( method=method, g_max=g_max, phi_p=phi_p, - phi_q=phi_q) - self.T = init_param(T, self.var_shape, allow_none=False) - self.T_base_p = init_param(T_base_p, self.var_shape, allow_none=False) - self.T_base_q = init_param(T_base_q, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + phi_q=phi_q, + trainable=trainable) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): temp = -27 - V + self.V_sh @@ -1023,7 +1059,8 @@ def __init__( g_max: Union[float, Tensor, Initializer, Callable] = 1., V_sh: Union[float, Tensor, Initializer, Callable] = 0., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(ICaL_IS2008, self).__init__(size, keep_size=keep_size, @@ -1031,13 +1068,14 @@ def __init__( method=method, g_max=g_max, phi_p=T_base_p ** ((T - 24) / 10), - phi_q=T_base_q ** ((T - 24) / 10)) + phi_q=T_base_q ** ((T - 24) / 10), + trainable=trainable) # parameters - self.T = init_param(T, self.var_shape, allow_none=False) - self.T_base_p = init_param(T_base_p, self.var_shape, allow_none=False) - self.T_base_q = init_param(T_base_q, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): return 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.)) diff --git a/brainpy/dyn/channels/IH.py b/brainpy/dyn/channels/IH.py index 9e4d60936..459e69cf1 100644 --- a/brainpy/dyn/channels/IH.py +++ b/brainpy/dyn/channels/IH.py @@ -9,7 +9,7 @@ from typing import Union, Callable import brainpy.math as bm -from brainpy.initialize import Initializer, init_param +from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Shape, Tensor from .base import IhChannel, CalciumChannel, Calcium @@ -62,17 +62,21 @@ def __init__( E: Union[float, Tensor, Initializer, Callable] = 43., phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - super(Ih_HM1992, self).__init__(size, keep_size=keep_size, name=name) + super(Ih_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.phi = init_param(phi, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.E = init_param(E, self.var_shape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) # variable - self.p = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(self.derivative, method=method) @@ -80,11 +84,13 @@ def __init__( def derivative(self, p, t, V): return self.phi * (self.f_p_inf(V) - p) / self.f_p_tau(V) - def reset(self, V): + def reset_state(self, V, batch_size=None): self.p.value = self.f_p_inf(V) + if batch_size is not None: + assert self.p.shape[0] == batch_size - def update(self, t, dt, V): - self.p.value = self.integral(self.p.value, t, V, dt) + def update(self, tdi, V): + self.p.value = self.integral(self.p.value, tdi['t'], V, tdi['dt']) def current(self, V): return self.g_max * self.p * (self.E - V) @@ -166,32 +172,37 @@ def __init__( T_base: Union[float, Tensor] = 3., phi: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): # IhChannel.__init__(self, size, name=name, keep_size=keep_size) - CalciumChannel.__init__(self, size, keep_size=keep_size, name=name) + CalciumChannel.__init__(self, + size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.T = init_param(T, self.var_shape, allow_none=False) - self.T_base = init_param(T_base, self.var_shape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base = parameter(T_base, self.varshape, allow_none=False) if phi is None: self.phi = self.T_base ** ((self.T - 24.) / 10) else: - self.phi = init_param(phi, self.var_shape, allow_none=False) - self.E = init_param(E, self.var_shape, allow_none=False) - self.k2 = init_param(k2, self.var_shape, allow_none=False) - self.Ca_half = init_param(Ca_half, self.var_shape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.k2 = parameter(k2, self.varshape, allow_none=False) + self.Ca_half = parameter(Ca_half, self.varshape, allow_none=False) self.k1 = self.k2 / self.Ca_half ** 4 - self.k4 = init_param(k4, self.var_shape, allow_none=False) + self.k4 = parameter(k4, self.varshape, allow_none=False) self.k3 = self.k4 / 0.01 - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.g_inc = init_param(g_inc, self.var_shape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.g_inc = parameter(g_inc, self.varshape, allow_none=False) # variable - self.O = bm.Variable(bm.zeros(self.var_shape)) - self.OL = bm.Variable(bm.zeros(self.var_shape)) - self.P1 = bm.Variable(bm.zeros(self.var_shape)) + self.O = variable(bm.zeros, trainable, self.varshape) + self.OL = variable(bm.zeros, trainable, self.varshape) + self.P1 = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(JointEq(self.dO, self.dOL, self.dP1), method=method) @@ -209,20 +220,26 @@ def dOL(self, OL, t, O, P1): def dP1(self, P1, t, C_Ca): return self.k1 * C_Ca ** 4 * (1 - P1) - self.k2 * P1 - def update(self, t, dt, V, C_Ca, E_Ca): - self.O.value = self.integral(self.O.value, self.OL.value, self.P1.value, t, V=V, C_Ca=C_Ca, dt=dt) + def update(self, tdi, V, C_Ca, E_Ca): + self.O.value = self.integral(self.O.value, self.OL.value, self.P1.value, + tdi['t'], V=V, C_Ca=C_Ca, dt=tdi['dt']) def current(self, V, C_Ca, E_Ca): return self.g_max * (self.O + self.g_inc * self.OL) * (self.E - V) - def reset(self, V, C_Ca, E_Ca): - self.P1[:] = self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2) + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): + varshape = self.varshape if (batch_size is None) else ((batch_size,) + self.varshape) + self.P1.value = bm.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape) inf = self.f_inf(V) tau = self.f_tau(V) alpha = inf / tau beta = (1 - inf) / tau self.O.value = alpha / (alpha + alpha * self.k3 * self.P1 / self.k4 + beta) self.OL.value = self.k3 * self.P1 * self.O / self.k4 + if batch_size is not None: + assert self.P1.shape[0] == batch_size + assert self.O.shape[0] == batch_size + assert self.OL.shape[0] == batch_size def f_inf(self, V): return 1 / (1 + bm.exp((V + 75 - self.V_sh) / 5.5)) diff --git a/brainpy/dyn/channels/K.py b/brainpy/dyn/channels/K.py index 9c66e6c41..f42e21f4d 100644 --- a/brainpy/dyn/channels/K.py +++ b/brainpy/dyn/channels/K.py @@ -8,7 +8,7 @@ from typing import Union, Callable, Optional import brainpy.math as bm -from brainpy.initialize import Initializer, init_param +from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Shape, Tensor from .base import PotassiumChannel @@ -74,16 +74,20 @@ def __init__( g_max: Union[float, Tensor, Initializer, Callable] = 10., phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - super(IK_p4_markov, self).__init__(size, keep_size=keep_size, name=name) + super(IK_p4_markov, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) - self.E = init_param(E, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.phi = init_param(phi, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(self.derivative, method=method) @@ -91,16 +95,18 @@ def __init__( def derivative(self, p, t, V): return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) - def update(self, t, dt, V): - self.p.value = self.integral(self.p, t, V, dt=dt) + def update(self, tdi, V): + self.p.value = self.integral(self.p, tdi['t'], V, tdi['dt']) def current(self, V): return self.g_max * self.p ** 4 * (self.E - V) - def reset(self, V): + def reset_state(self, V, batch_size=None): alpha = self.f_p_alpha(V) beta = self.f_p_beta(V) self.p.value = alpha / (alpha + beta) + if batch_size is not None: + assert self.p.shape[0] == batch_size def f_p_alpha(self, V): raise NotImplementedError @@ -167,7 +173,8 @@ def __init__( T: Union[float, Tensor] = 36., phi: Optional[Union[float, Tensor, Initializer, Callable]] = None, method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): phi = T_base ** ((T - 36) / 10) if phi is None else phi super(IKDR_Ba2002, self).__init__(size, @@ -176,12 +183,13 @@ def __init__( method=method, g_max=g_max, phi=phi, - E=E) + E=E, + trainable=trainable) # parameters - self.T = init_param(T, self.var_shape, allow_none=False) - self.T_base = init_param(T_base, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base = parameter(T_base, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): tmp = V - self.V_sh - 15. @@ -240,7 +248,8 @@ def __init__( phi: Union[float, Tensor, Initializer, Callable] = 1., V_sh: Union[int, float, Tensor, Initializer, Callable] = -60., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(IK_TM1991, self).__init__(size, keep_size=keep_size, @@ -248,8 +257,9 @@ def __init__( method=method, phi=phi, E=E, - g_max=g_max) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + g_max=g_max, + trainable=trainable) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): c = 15 - V + self.V_sh @@ -309,7 +319,8 @@ def __init__( phi: Union[float, Tensor, Initializer, Callable] = 1., V_sh: Union[int, float, Tensor, Initializer, Callable] = -45., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(IK_HH, self).__init__(size, keep_size=keep_size, @@ -317,8 +328,9 @@ def __init__( method=method, phi=phi, E=E, - g_max=g_max) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + g_max=g_max, + trainable=trainable) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): temp = V - self.V_sh + 10 @@ -379,19 +391,23 @@ def __init__( phi_p: Union[float, Tensor, Initializer, Callable] = 1., phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - super(IKA_p4q_ss, self).__init__(size, keep_size=keep_size, name=name) + super(IKA_p4q_ss, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.E = init_param(E, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.phi_p = init_param(phi_p, self.var_shape, allow_none=False) - self.phi_q = init_param(phi_q, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) - self.q = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) + self.q = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(JointEq(self.dp, self.dq), method=method) @@ -402,15 +418,19 @@ def dp(self, p, t, V): def dq(self, q, t, V): return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - def update(self, t, dt, V): + def update(self, tdi, V): + t, dt = tdi['t'], tdi['dt'] self.p.value, self.q.value = self.integral(self.p.value, self.q.value, t, V, dt) def current(self, V): return self.g_max * self.p ** 4 * self.q * (self.E - V) - def reset(self, V): + def reset_state(self, V, batch_size=None): self.p.value = self.f_p_inf(V) self.q.value = self.f_q_inf(V) + if batch_size is not None: + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size def f_p_inf(self, V): raise NotImplementedError @@ -487,7 +507,8 @@ def __init__( phi_p: Union[float, Tensor, Initializer, Callable] = 1., phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(IKA1_HM1992, self).__init__(size, keep_size=keep_size, @@ -496,10 +517,11 @@ def __init__( E=E, g_max=g_max, phi_p=phi_p, - phi_q=phi_q) + phi_q=phi_q, + trainable=trainable) # parameters - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) @@ -580,7 +602,8 @@ def __init__( phi_p: Union[float, Tensor, Initializer, Callable] = 1., phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(IKA2_HM1992, self).__init__(size, keep_size=keep_size, @@ -589,10 +612,11 @@ def __init__( E=E, g_max=g_max, phi_q=phi_q, - phi_p=phi_p) + phi_p=phi_p, + trainable=trainable) # parameters - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) @@ -662,19 +686,23 @@ def __init__( phi_p: Union[float, Tensor, Initializer, Callable] = 1., phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - super(IKK2_pq_ss, self).__init__(size, keep_size=keep_size, name=name) + super(IKK2_pq_ss, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.E = init_param(E, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.phi_p = init_param(phi_p, self.var_shape, allow_none=False) - self.phi_q = init_param(phi_q, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) - self.q = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) + self.q = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(JointEq(self.dp, self.dq), method=method) @@ -685,15 +713,19 @@ def dp(self, p, t, V): def dq(self, q, t, V): return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - def update(self, t, dt, V): + def update(self, tdi, V): + t, dt = tdi['t'], tdi['dt'] self.p.value, self.q.value = self.integral(self.p.value, self.q.value, t, V, dt) def current(self, V): return self.g_max * self.p * self.q * (self.E - V) - def reset(self, V): + def reset_state(self, V, batch_size=None): self.p.value = self.f_p_inf(V) self.q.value = self.f_q_inf(V) + if batch_size is not None: + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size def f_p_inf(self, V): raise NotImplementedError @@ -766,7 +798,8 @@ def __init__( phi_p: Union[float, Tensor, Initializer, Callable] = 1., phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(IKK2A_HM1992, self).__init__(size, keep_size=keep_size, @@ -775,10 +808,11 @@ def __init__( phi_p=phi_p, phi_q=phi_q, g_max=g_max, - E=E) + E=E, + trainable=trainable) # parameters - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): raise 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) @@ -855,7 +889,8 @@ def __init__( phi_p: Union[float, Tensor, Initializer, Callable] = 1., phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(IKK2B_HM1992, self).__init__(size, keep_size=keep_size, @@ -864,10 +899,11 @@ def __init__( phi_p=phi_p, phi_q=phi_q, g_max=g_max, - E=E) + E=E, + trainable=trainable) # parameters - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): raise 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) @@ -939,20 +975,24 @@ def __init__( tau_max: Union[float, Tensor, Initializer, Callable] = 4e3, V_sh: Union[float, Tensor, Initializer, Callable] = 0., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - super(IKNI_Ya1989, self).__init__(size, keep_size=keep_size, name=name) + super(IKNI_Ya1989, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.E = init_param(E, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.tau_max = init_param(tau_max, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) - self.phi_p = init_param(phi_p, self.var_shape, allow_none=False) - self.phi_q = init_param(phi_q, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.tau_max = parameter(tau_max, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(self.dp, method=method) @@ -960,14 +1000,18 @@ def __init__( def dp(self, p, t, V): return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - def update(self, t, dt, V): + def update(self, tdi, V): + t, dt = tdi['t'], tdi['dt'] self.p.value = self.integral(self.p.value, t, V, dt) def current(self, V): return self.g_max * self.p * (self.E - V) - def reset(self, V): + def reset_state(self, V, batch_size=None): self.p.value = self.f_p_inf(V) + if batch_size is not None: + assert self.p.shape[0] == batch_size + def f_p_inf(self, V): raise 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) diff --git a/brainpy/dyn/channels/KCa.py b/brainpy/dyn/channels/KCa.py index 2aa9a745f..a94a5b4b1 100644 --- a/brainpy/dyn/channels/KCa.py +++ b/brainpy/dyn/channels/KCa.py @@ -9,7 +9,7 @@ from typing import Union, Callable import brainpy.math as bm -from brainpy.initialize import Initializer, init_param +from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators.ode import odeint from brainpy.types import Shape, Tensor from .base import Calcium, CalciumChannel, PotassiumChannel @@ -81,20 +81,25 @@ def __init__( beta: Union[float, Tensor, Initializer, Callable] = 0.09, phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): - CalciumChannel.__init__(self, size, keep_size=keep_size, name=name) + CalciumChannel.__init__(self, + size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.E = init_param(E, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) - self.n = init_param(n, self.var_shape, allow_none=False) - self.alpha = init_param(alpha, self.var_shape, allow_none=False) - self.beta = init_param(beta, self.var_shape, allow_none=False) - self.phi = init_param(phi, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.n = parameter(n, self.varshape, allow_none=False) + self.alpha = parameter(alpha, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(self.dp, method=method) @@ -104,13 +109,18 @@ def dp(self, p, t, C_Ca): C3 = C2 + self.beta return self.phi * (C2 / C3 - p) * C3 - def update(self, t, dt, V, C_Ca, E_Ca): + def update(self, tdi, V, C_Ca, E_Ca): + t, dt = tdi['t'], tdi['dt'] self.p.value = self.integral(self.p, t, C_Ca=C_Ca, dt=dt) def current(self, V, C_Ca, E_Ca): return self.g_max * self.p * self.p * (self.E - V) - def reset(self, V, C_Ca, E_Ca): + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): C2 = self.alpha * bm.power(C_Ca, self.n) C3 = C2 + self.beta - self.p.value = C2 / C3 + if batch_size is None: + self.p.value = bm.broadcast_to(C2 / C3, self.varshape) + else: + self.p.value = bm.broadcast_to(C2 / C3, (batch_size,) + self.varshape) + assert self.p.shape[0] == batch_size diff --git a/brainpy/dyn/channels/Na.py b/brainpy/dyn/channels/Na.py index dec58de2c..20e8fa877 100644 --- a/brainpy/dyn/channels/Na.py +++ b/brainpy/dyn/channels/Na.py @@ -8,7 +8,7 @@ from typing import Union, Callable import brainpy.math as bm -from brainpy.initialize import Initializer, init_param +from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor, Shape from .base import SodiumChannel @@ -60,28 +60,35 @@ def __init__( phi: Union[int, float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, + trainable: bool = False, ): - super(INa_p3q_markov, self).__init__(size, keep_size=keep_size, name=name) + super(INa_p3q_markov, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.E = init_param(E, self.var_shape, allow_none=False) - self.phi = init_param(phi, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) # variables - self.p = bm.Variable(bm.zeros(self.var_shape)) - self.q = bm.Variable(bm.zeros(self.var_shape)) + self.p = variable(bm.zeros, trainable, self.varshape) + self.q = variable(bm.zeros, trainable, self.varshape) # function self.integral = odeint(JointEq([self.dp, self.dq]), method=method) - def reset(self, V): + def reset_state(self, V, batch_size=None): alpha = self.f_p_alpha(V) beta = self.f_p_beta(V) self.p.value = alpha / (alpha + beta) alpha = self.f_q_alpha(V) beta = self.f_q_beta(V) self.q.value = alpha / (alpha + beta) + if batch_size is not None: + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size def dp(self, p, t, V): return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) @@ -89,7 +96,8 @@ def dp(self, p, t, V): def dq(self, q, t, V): return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q) - def update(self, t, dt, V): + def update(self, tdi, V): + t, dt = tdi['t'], tdi['dt'] p, q = self.integral(self.p, self.q, t, V, dt) self.p.value, self.q.value = p, q @@ -161,7 +169,8 @@ def __init__( g_max: Union[int, float, Tensor, Initializer, Callable] = 90., V_sh: Union[int, float, Tensor, Initializer, Callable] = -50., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(INa_Ba2002, self).__init__(size, keep_size=keep_size, @@ -169,9 +178,10 @@ def __init__( method=method, phi=3 ** ((T - 36) / 10), g_max=g_max, - E=E) - self.T = init_param(T, self.var_shape, allow_none=False) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + E=E, + trainable=trainable) + self.T = parameter(T, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): temp = V - self.V_sh - 13. @@ -246,7 +256,8 @@ def __init__( phi: Union[int, float, Tensor, Initializer, Callable] = 1., V_sh: Union[int, float, Tensor, Initializer, Callable] = -63., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(INa_TM1991, self).__init__(size, keep_size=keep_size, @@ -254,8 +265,9 @@ def __init__( method=method, E=E, phi=phi, - g_max=g_max) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + g_max=g_max, + trainable=trainable) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): temp = 13 - V + self.V_sh @@ -331,7 +343,8 @@ def __init__( phi: Union[int, float, Tensor, Initializer, Callable] = 1., V_sh: Union[int, float, Tensor, Initializer, Callable] = -45., method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): super(INa_HH, self).__init__(size, keep_size=keep_size, @@ -339,8 +352,9 @@ def __init__( method=method, E=E, phi=phi, - g_max=g_max) - self.V_sh = init_param(V_sh, self.var_shape, allow_none=False) + g_max=g_max, + trainable=trainable) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): temp = V - self.V_sh - 5 diff --git a/brainpy/dyn/channels/base.py b/brainpy/dyn/channels/base.py index be673fc96..c13614db1 100644 --- a/brainpy/dyn/channels/base.py +++ b/brainpy/dyn/channels/base.py @@ -23,10 +23,13 @@ class Ion(Channel): '''The type of the master object.''' master_type = CondNeuGroup - def update(self, t, dt, V): + def update(self, tdi, V): raise NotImplementedError('Must be implemented by the subclass.') - def reset(self, V): + def reset(self, V, batch_size=None): + self.reset_state(V, batch_size) + + def reset_state(self, V, batch_size=None): raise NotImplementedError('Must be implemented by the subclass.') def current(self, V): @@ -42,13 +45,16 @@ class IonChannel(Channel): '''The type of the master object.''' master_type = CondNeuGroup - def update(self, t, dt, V): + def update(self, tdi, V): raise NotImplementedError('Must be implemented by the subclass.') def current(self, V): raise NotImplementedError('Must be implemented by the subclass.') - def reset(self, V): + def reset(self, V, batch_size=None): + self.reset_state(V, batch_size) + + def reset_state(self, V, batch_size=None): raise NotImplementedError('Must be implemented by the subclass.') def __repr__(self): @@ -85,20 +91,24 @@ def __init__( keep_size: bool = False, method: str = 'exp_auto', name: str = None, + trainable: bool = False, **channels ): - Ion.__init__(self, size, keep_size=keep_size) - Container.__init__(self, name=name, **channels) + Ion.__init__(self, size, keep_size=keep_size, trainable=trainable) + Container.__init__(self, name=name, trainable=trainable, **channels) self.method = method def current(self, V, C_Ca=None, E_Ca=None): C_Ca = self.C if (C_Ca is None) else C_Ca E_Ca = self.E if (E_Ca is None) else E_Ca nodes = list(self.nodes(level=1, include_self=False).unique().subset(Channel).values()) - current = nodes[0].current(V, C_Ca, E_Ca) - for node in nodes[1:]: - current += node.current(V, C_Ca, E_Ca) - return current + if len(nodes) == 0: + return 0. + else: + current = nodes[0].current(V, C_Ca, E_Ca) + for node in nodes[1:]: + current += node.current(V, C_Ca, E_Ca) + return current def register_implicit_nodes(self, *channels, **named_channels): check_master(type(self), *channels, **named_channels) @@ -111,14 +121,17 @@ class CalciumChannel(IonChannel): '''The type of the master object.''' master_type = Calcium - def update(self, t, dt, V, C_Ca, E_Ca): + def update(self, tdi, V, C_Ca, E_Ca): raise NotImplementedError def current(self, V, C_Ca, E_Ca): raise NotImplementedError - def reset(self, V, C_Ca, E_Ca): - raise NotImplementedError + def reset(self, V, C_Ca, E_Ca, batch_size=None): + self.reset_state(V, C_Ca, E_Ca, batch_size) + + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): + raise NotImplementedError('Must be implemented by the subclass.') class IhChannel(IonChannel): diff --git a/brainpy/dyn/channels/leaky.py b/brainpy/dyn/channels/leaky.py index e009e7b4e..7c979903c 100644 --- a/brainpy/dyn/channels/leaky.py +++ b/brainpy/dyn/channels/leaky.py @@ -7,7 +7,7 @@ from typing import Union, Callable -from brainpy.initialize import Initializer, init_param +from brainpy.initialize import Initializer, parameter from brainpy.types import Tensor, Shape from .base import LeakyChannel @@ -36,17 +36,21 @@ def __init__( E: Union[int, float, Tensor, Initializer, Callable] = -70., method: str = None, name: str = None, + trainable: bool = False, ): - super(IL, self).__init__(size, keep_size=keep_size, name=name) + super(IL, self).__init__(size, + keep_size=keep_size, + name=name, + trainable=trainable) - self.E = init_param(E, self.var_shape, allow_none=False) - self.g_max = init_param(g_max, self.var_shape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) self.method = method - def reset(self, V): + def reset_state(self, V, batch_size=None): pass - def update(self, t, dt, V): + def update(self, tdi, V): pass def current(self, V): @@ -71,7 +75,14 @@ def __init__( keep_size: bool = False, g_max: Union[int, float, Tensor, Initializer, Callable] = 0.005, E: Union[int, float, Tensor, Initializer, Callable] = -90., - method=None, - name=None, + method: str = None, + name: str = None, + trainable: bool = False, ): - super(IKL, self).__init__(size=size, keep_size=keep_size, g_max=g_max, E=E, method=method, name=name) + super(IKL, self).__init__(size=size, + keep_size=keep_size, + g_max=g_max, + E=E, + method=method, + name=name, + trainable=trainable) diff --git a/brainpy/train/layers/__init__.py b/brainpy/dyn/layers/__init__.py similarity index 84% rename from brainpy/train/layers/__init__.py rename to brainpy/dyn/layers/__init__.py index 9bdc431b8..40e970878 100644 --- a/brainpy/train/layers/__init__.py +++ b/brainpy/dyn/layers/__init__.py @@ -4,7 +4,7 @@ from .linear import * from .nvar import * from .reservoir import * -from .recurrents import * +from .rnncells import * from .conv import * diff --git a/brainpy/train/layers/conv.py b/brainpy/dyn/layers/conv.py similarity index 62% rename from brainpy/train/layers/conv.py rename to brainpy/dyn/layers/conv.py index 31e52f145..97d05d1ea 100644 --- a/brainpy/train/layers/conv.py +++ b/brainpy/dyn/layers/conv.py @@ -2,9 +2,10 @@ import jax.lax + import brainpy.math as bm -from brainpy.initialize import XavierNormal, ZeroInit, init_param -from brainpy.train.base import TrainingSystem +from brainpy.dyn.training import TrainingSystem +from brainpy.initialize import XavierNormal, ZeroInit, parameter __all__ = [ 'GeneralConv', @@ -35,47 +36,59 @@ def _conv_dimension_numbers(input_shape): class GeneralConv(TrainingSystem): """Applies a convolution to the inputs. - Args: - in_channels: integer - number of input channels. - out_channels: integer - number of output channels. - kernel_size: sequence[int] - shape of the convolutional kernel. For 1D convolution, - the kernel size can be passed as an integer. For all other cases, it must - be a sequence of integers. - strides: sequence[int] - an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, sequence[int] - either the string `'SAME'`, the string `'VALID'`, the string - `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. A single int is interpeted as applying the same padding - in all dims and passign a single int in a sequence causes the same padding - to be used on both sides. - input_dilation: integer, sequence[int] - an integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - kernel_dilation: integer, sequence[int] - an integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: integer, default 1. - If specified divides the input - features into groups. - kernel_init: brainpy.init.Initializer - initializer for the convolutional kernel. - bias_init: brainpy.init.Initializer - initializer for the bias. + Parameters + ---------- + in_channels: integer + number of input channels. + out_channels: integer + number of output channels. + kernel_size: sequence[int] + shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: sequence[int] + an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding: str, sequence[int] + either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. + input_dilation: integer, sequence[int] + an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + kernel_dilation: integer, sequence[int] + an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups: integer, default 1. + If specified divides the input + features into groups. + w_init: brainpy.init.Initializer + initializer for the convolutional kernel. + b_init: brainpy.init.Initializer + initializer for the bias. """ - def __init__(self, in_channels, out_channels, kernel_size, strides=None, padding='SAME', - input_dilation=None, kernel_dilation=None, groups=1, - w_init=XavierNormal(), b_init=ZeroInit(), - trainable=True, name=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + strides=None, + padding='SAME', + input_dilation=None, + kernel_dilation=None, + groups=1, + w_init=XavierNormal(), + b_init=ZeroInit(), + trainable: bool = True, + name: str = None, + ): super(GeneralConv, self).__init__(name=name, trainable=trainable) self.in_channels = in_channels self.out_channels = out_channels @@ -101,8 +114,8 @@ def __init__(self, in_channels, out_channels, kernel_size, strides=None, padding assert self.in_channels % self.groups == 0, '"nin" should be divisible by groups' kernel_shape = _check_tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels) - self.w = init_param(self.w_init, kernel_shape) - self.b = init_param(self.b_init, (1,) * len(self.kernel_size) + (self.out_channels,)) + self.w = parameter(self.w_init, kernel_shape) + self.b = parameter(self.b_init, (1,) * len(self.kernel_size) + (self.out_channels,)) if self.trainable: self.w = bm.TrainVar(self.w) self.b = bm.TrainVar(self.b) @@ -110,7 +123,7 @@ def __init__(self, in_channels, out_channels, kernel_size, strides=None, padding def _check_input_dim(self, x): pass - def forward(self, x, **shared_kwargs): + def update(self, sha, x): self._check_input_dim(x) if self.strides is None: self.strides = (1,) * (len(x.shape) - 2) @@ -128,7 +141,13 @@ def forward(self, x, **shared_kwargs): class Conv1D(GeneralConv): - def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + **kwargs + ): super(Conv1D, self).__init__(in_channels, out_channels, kernel_size, **kwargs) self.dimension_numbers = ('NWC', 'WIO', 'NWC') @@ -147,7 +166,13 @@ def _check_input_dim(self, x): class Conv2D(GeneralConv): - def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + **kwargs + ): super(Conv2D, self).__init__(in_channels, out_channels, kernel_size, **kwargs) self.dimension_numbers = ('NHWC', 'HWIO', 'NHWC') @@ -166,7 +191,13 @@ def _check_input_dim(self, x): class Conv3D(GeneralConv): - def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + **kwargs + ): super(Conv3D, self).__init__(in_channels, out_channels, kernel_size, **kwargs) self.dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') diff --git a/brainpy/train/layers/dropout.py b/brainpy/dyn/layers/dropout.py similarity index 88% rename from brainpy/train/layers/dropout.py rename to brainpy/dyn/layers/dropout.py index e5d9598ae..07e3a46c1 100644 --- a/brainpy/train/layers/dropout.py +++ b/brainpy/dyn/layers/dropout.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import brainpy.math as bm -from brainpy.train.base import TrainingSystem +from brainpy.dyn.training import TrainingSystem __all__ = [ 'Dropout' @@ -41,9 +41,8 @@ def __init__(self, prob, seed=None, trainable=False, name=None): self.prob = prob self.rng = bm.random.RandomState(seed=seed) - def forward(self, x, shared_args=None): - shared_args = dict() if shared_args is None else shared_args - if shared_args.get('train', True): + def update(self, sha, x): + if sha.get('fit', True): keep_mask = self.rng.bernoulli(self.prob, x.shape) return bm.where(keep_mask, x / self.prob, 0.) else: diff --git a/brainpy/train/layers/linear.py b/brainpy/dyn/layers/linear.py similarity index 90% rename from brainpy/train/layers/linear.py rename to brainpy/dyn/layers/linear.py index 977531fdd..855288811 100644 --- a/brainpy/train/layers/linear.py +++ b/brainpy/dyn/layers/linear.py @@ -7,9 +7,9 @@ from brainpy import math as bm from brainpy.errors import MathError -from brainpy.initialize import XavierNormal, ZeroInit, Initializer, init_param +from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.tools.checking import check_initializer -from brainpy.train.base import TrainingSystem +from brainpy.dyn.training import TrainingSystem from brainpy.types import Tensor __all__ = [ @@ -68,17 +68,24 @@ def __init__( check_initializer(b_initializer, 'bias_initializer', allow_none=True) # parameter initialization - self.W = init_param(self.weight_initializer, (num_in, self.num_out)) - self.b = init_param(self.bias_initializer, (self.num_out,)) + self.W = parameter(self.weight_initializer, (num_in, self.num_out)) + self.b = parameter(self.bias_initializer, (self.num_out,)) if self.trainable: self.W = bm.TrainVar(self.W) self.b = None if (self.b is None) else bm.TrainVar(self.b) - def forward(self, x, shared_args=None): + def update(self, sha, x): res = x @ self.W if self.b is not None: res += self.b - if self.online_fit_by is not None or self.offline_fit_by is not None: + + # online fitting data + if sha.get('fit', False) and self.online_fit_by is not None: + self.fit_record['input'] = x + self.fit_record['output'] = res + + # offline fitting data + if sha.get('fit', False) and self.offline_fit_by is not None: self.fit_record['input'] = x self.fit_record['output'] = res return res @@ -98,8 +105,7 @@ def online_init(self): def online_fit(self, target: Tensor, - fit_record: Dict[str, Tensor], - shared_args: Dict = None): + fit_record: Dict[str, Tensor]): if not isinstance(target, (bm.ndarray, jnp.ndarray)): raise MathError(f'"target" must be a tensor, but got {type(target)}') x = fit_record['input'] @@ -141,8 +147,7 @@ def offline_init(self): def offline_fit(self, target: Tensor, - fit_record: Dict[str, Tensor], - shared_args: Dict = None): + fit_record: Dict[str, Tensor]): """The offline training interface for the Dense node.""" # data checking if not isinstance(target, (bm.ndarray, jnp.ndarray)): diff --git a/brainpy/train/layers/nvar.py b/brainpy/dyn/layers/nvar.py similarity index 98% rename from brainpy/train/layers/nvar.py rename to brainpy/dyn/layers/nvar.py index c0cbd1414..114c9516d 100644 --- a/brainpy/train/layers/nvar.py +++ b/brainpy/dyn/layers/nvar.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy.tools.checking import (check_integer, check_sequence) -from brainpy.train.base import TrainingSystem +from brainpy.dyn.training import TrainingSystem __all__ = [ @@ -126,7 +126,7 @@ def reset_state(self, batch_size=1): # manually handle the state of NVAR, rather return it. self.store._value = jnp.zeros((self.num_delay, batch_size, self.num_in)) - def forward(self, x, shared_args=None): + def update(self, sha, x): all_parts = [] # 1. Store the current input self.store[self.idx[0]] = x diff --git a/brainpy/train/layers/reservoir.py b/brainpy/dyn/layers/reservoir.py similarity index 82% rename from brainpy/train/layers/reservoir.py rename to brainpy/dyn/layers/reservoir.py index 8d4300ff3..26232c31e 100644 --- a/brainpy/train/layers/reservoir.py +++ b/brainpy/dyn/layers/reservoir.py @@ -3,10 +3,10 @@ from typing import Optional, Union, Callable, Tuple import brainpy.math as bm -from brainpy.initialize import Normal, ZeroInit, Initializer, init_param +from brainpy.initialize import Normal, ZeroInit, Initializer, parameter from brainpy.tools.checking import check_float, check_initializer, check_string from brainpy.tools.others import to_size -from brainpy.train.base import TrainingSystem +from brainpy.dyn.training import TrainingSystem from brainpy.types import Tensor __all__ = [ @@ -28,8 +28,6 @@ class Reservoir(TrainingSystem): The initialization method for the feedforward connections. Wrec_initializer: Initializer The initialization method for the recurrent connections. - fb_initializer: optional, Tensor, Initializer - The initialization method for the feedback connections. b_initializer: optional, Tensor, Initializer The initialization method for the bias. leaky_rate: float @@ -58,16 +56,13 @@ class Reservoir(TrainingSystem): x[n+1] &= (1 - \alpha) \cdot x[t] + \alpha \cdot f(W_{ff} \cdot u[n] + W_{rec} \cdot r[t] + W_{fb} \cdot b[n]) \\ r[n+1] &= f(x[n+1]) - ff_connectivity : float, optional + in_connectivity : float, optional Connectivity of input neurons, i.e. ratio of input neurons connected to reservoir neurons. Must be in [0, 1], by default 0.1 rec_connectivity : float, optional Connectivity of recurrent weights matrix, i.e. ratio of reservoir neurons connected to other reservoir neurons, including themselves. Must be in [0, 1], by default 0.1 - fb_connectivity : float, optional - Connectivity of feedback neurons, i.e. ratio of feedabck neurons - connected to reservoir neurons. Must be in [0, 1], by default 0.1 conn_type: str The connectivity type, can be "dense" or "sparse". spectral_radius : float, optional @@ -76,8 +71,6 @@ class Reservoir(TrainingSystem): Gain of noise applied to reservoir internal states, by default 0.0 noise_in : float, optional Gain of noise applied to feedforward signals, by default 0.0 - noise_fb : float, optional - Gain of noise applied to feedback signals, by default 0.0 noise_type : optional, str, callable Distribution of noise. Must be a random variable generator distribution (see :py:class:`brainpy.math.random.RandomState`), @@ -101,28 +94,25 @@ def __init__( Win_initializer: Union[Initializer, Callable, Tensor] = Normal(scale=0.1), Wrec_initializer: Union[Initializer, Callable, Tensor] = Normal(scale=0.1), b_initializer: Optional[Union[Initializer, Callable, Tensor]] = ZeroInit(), - ff_connectivity: float = 0.1, + in_connectivity: float = 0.1, rec_connectivity: float = 0.1, - fb_connectivity: float = 0.1, conn_type='dense', spectral_radius: Optional[float] = None, - noise_ff: float = 0., + noise_in: float = 0., noise_rec: float = 0., - noise_fb: float = 0., noise_type: str = 'normal', seed: Optional[int] = None, trainable: bool = False, - name: str=None + name: str = None ): super(Reservoir, self).__init__(trainable=trainable, name=name) - # parameters input_shape = to_size(input_shape) if input_shape[0] is None: input_shape = input_shape[1:] self.input_shape = input_shape - self.output_shape = input_shape[:-1] + (num_out, ) + self.output_shape = input_shape[:-1] + (num_out,) self.num_unit = num_out assert num_out > 0, f'Must be a positive integer, but we got {num_out}' self.leaky_rate = leaky_rate @@ -143,21 +133,17 @@ def __init__( self._b_initializer = b_initializer # connectivity - check_float(ff_connectivity, 'ff_connectivity', 0., 1.) + check_float(in_connectivity, 'ff_connectivity', 0., 1.) check_float(rec_connectivity, 'rec_connectivity', 0., 1.) - check_float(fb_connectivity, 'fb_connectivity', 0., 1.) - self.ff_connectivity = ff_connectivity + self.ff_connectivity = in_connectivity self.rec_connectivity = rec_connectivity - self.fb_connectivity = fb_connectivity check_string(conn_type, 'conn_type', ['dense', 'sparse']) self.conn_type = conn_type # noises - check_float(noise_ff, 'noise_ff') - check_float(noise_fb, 'noise_fb') + check_float(noise_in, 'noise_ff') check_float(noise_rec, 'noise_rec') - self.noise_ff = noise_ff - self.noise_fb = noise_fb + self.noise_ff = noise_in self.noise_rec = noise_rec self.noise_type = noise_type check_string(noise_type, 'noise_type', ['normal', 'uniform']) @@ -165,7 +151,7 @@ def __init__( # initialize feedforward weights weight_shape = (input_shape[-1], self.num_unit) self.Wff_shape = weight_shape - self.Win = init_param(self._Win_initializer, weight_shape) + self.Win = parameter(self._Win_initializer, weight_shape) if self.ff_connectivity < 1.: conn_mat = self.rng.random(weight_shape) > self.ff_connectivity self.Win[conn_mat] = 0. @@ -177,7 +163,7 @@ def __init__( # initialize recurrent weights recurrent_shape = (self.num_unit, self.num_unit) - self.Wrec = init_param(self._Wrec_initializer, recurrent_shape) + self.Wrec = parameter(self._Wrec_initializer, recurrent_shape) if self.rec_connectivity < 1.: conn_mat = self.rng.random(recurrent_shape) > self.rec_connectivity self.Wrec[conn_mat] = 0. @@ -187,7 +173,7 @@ def __init__( if self.conn_type == 'sparse' and self.rec_connectivity < 1.: self.rec_pres, self.rec_posts = bm.where(bm.logical_not(conn_mat)) self.Wrec = self.Wrec[self.rec_pres, self.rec_posts] - self.bias = init_param(self._b_initializer, (self.num_unit,)) + self.bias = parameter(self._b_initializer, (self.num_unit,)) if self.trainable: self.Wrec = bm.TrainVar(self.Wrec) self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) @@ -202,25 +188,30 @@ def reset(self, batch_size=1): def reset_state(self, batch_size=1): pass - def forward(self, x, shared_args=None): + def update(self, sha, x): """Feedforward output.""" # inputs x = bm.concatenate(x, axis=-1) if self.noise_ff > 0: x += self.noise_ff * self.rng.uniform(-1, 1, x.shape) if self.conn_type == 'sparse' and self.ff_connectivity < 1.: - sparse = {'data': self.Win, 'index': (self.ff_pres, self.ff_posts), 'shape': self.Wff_shape} + sparse = {'data': self.Win, + 'index': (self.ff_pres, self.ff_posts), + 'shape': self.Wff_shape} hidden = bm.sparse_matmul(x, sparse) else: hidden = bm.dot(x, self.Win) # recurrent if self.conn_type == 'sparse' and self.rec_connectivity < 1.: - sparse = {'data': self.Wrec, 'index': (self.rec_pres, self.rec_posts), 'shape': (self.num_unit, self.num_unit)} + sparse = {'data': self.Wrec, + 'index': (self.rec_pres, self.rec_posts), + 'shape': (self.num_unit, self.num_unit)} hidden += bm.sparse_matmul(self.state, sparse) else: hidden += bm.dot(self.state, self.Wrec) if self.activation_type == 'internal': hidden = self.activation(hidden) - if self.noise_rec > 0.: hidden += self.noise_rec * self.rng.uniform(-1, -1, self.state.shape) + if self.noise_rec > 0.: + hidden += self.noise_rec * self.rng.uniform(-1, -1, self.state.shape) # new state/output state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden if self.activation_type == 'external': diff --git a/brainpy/train/layers/recurrents.py b/brainpy/dyn/layers/rnncells.py similarity index 91% rename from brainpy/train/layers/recurrents.py rename to brainpy/dyn/layers/rnncells.py index d48d7a3de..5a44c853f 100644 --- a/brainpy/train/layers/recurrents.py +++ b/brainpy/dyn/layers/rnncells.py @@ -7,10 +7,10 @@ from brainpy.initialize import (XavierNormal, ZeroInit, Orthogonal, - init_param, + parameter, Initializer) from brainpy.tools.checking import (check_integer, check_initializer) -from brainpy.train.base import TrainingSystem +from brainpy.dyn.training import TrainingSystem from brainpy.types import Tensor __all__ = [ @@ -39,16 +39,16 @@ def __init__(self, # state self.state = bm.Variable(bm.zeros((1, self.num_out))) if train_state and self.trainable: - self.state2train = bm.TrainVar(init_param(state_initializer, (self.num_out,), allow_none=False)) + self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False)) self.state[:] = self.state2train def reset(self, batch_size=1): self.reset_state(batch_size) def reset_state(self, batch_size=1): - self.state._value = init_param(self._state_initializer, (batch_size, self.num_out), allow_none=False) + self.state._value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False) if self.train_state: - self.state2train.value = init_param(self._state_initializer, self.num_out, allow_none=False) + self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) self.state[:] = self.state2train @@ -120,15 +120,15 @@ def __init__( self.activation = bm.activations.get(activation) # weights - self.Wi = init_param(self._Wi_initializer, (num_in, self.num_out)) - self.Wh = init_param(self._Wh_initializer, (self.num_out, self.num_out)) - self.b = init_param(self._b_initializer, (self.num_out,)) + self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out)) + self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out)) + self.b = parameter(self._b_initializer, (self.num_out,)) if self.trainable: self.Wi = bm.TrainVar(self.Wi) self.Wh = bm.TrainVar(self.Wh) self.b = None if (self.b is None) else bm.TrainVar(self.b) - def forward(self, x, shared_args=None): + def update(self, sha, x): h = x @ self.Wi h += self.state.value @ self.Wh if self.b is not None: @@ -218,15 +218,15 @@ def __init__( self.activation = bm.activations.get(activation) # weights - self.Wi = init_param(self._Wi_initializer, (num_in, self.num_out * 3)) - self.Wh = init_param(self._Wh_initializer, (self.num_out, self.num_out * 3)) - self.b = init_param(self._b_initializer, (self.num_out * 3,)) + self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 3)) + self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 3)) + self.b = parameter(self._b_initializer, (self.num_out * 3,)) if self.trainable: self.Wi = bm.TrainVar(self.Wi) self.Wh = bm.TrainVar(self.Wh) self.b = bm.TrainVar(self.b) if (self.b is not None) else None - def forward(self, x, shared_args=None): + def update(self, sha, x): gates_x = bm.matmul(x, self.Wi) zr_x, a_x = bm.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1) w_h_z, w_h_a = bm.split(self.Wh, indices_or_sections=[2 * self.num_out], axis=-1) @@ -342,15 +342,15 @@ def __init__( self.activation = bm.activations.get(activation) # weights - self.Wi = init_param(self._Wi_initializer, (num_in, self.num_out * 4)) - self.Wh = init_param(self._Wh_initializer, (self.num_out, self.num_out * 4)) - self.b = init_param(self._b_initializer, (self.num_out * 4,)) + self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 4)) + self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 4)) + self.b = parameter(self._b_initializer, (self.num_out * 4,)) if self.trainable: self.Wi = bm.TrainVar(self.Wi) self.Wh = bm.TrainVar(self.Wh) self.b = None if (self.b is None) else bm.TrainVar(self.b) - def forward(self, x, shared_args=None): + def update(self, sha, x): h, c = bm.split(self.state, 2) gated = x @ self.Wi if self.b is not None: diff --git a/brainpy/train/layers/tests/test_conv.py b/brainpy/dyn/layers/tests/test_conv.py similarity index 61% rename from brainpy/train/layers/tests/test_conv.py rename to brainpy/dyn/layers/tests/test_conv.py index 2405b023d..fa01ece80 100644 --- a/brainpy/train/layers/tests/test_conv.py +++ b/brainpy/dyn/layers/tests/test_conv.py @@ -11,14 +11,14 @@ class TestConv(TestCase): def test_Conv2D_img(self): - class Convnet(bp.train.TrainingSystem): + class Convnet(bp.dyn.TrainingSystem): def __init__(self): super(Convnet, self).__init__() - self.conv = bp.train.layers.Conv2D(in_channels=4, out_channels=32, kernel_size=(3, 3), - strides=(1, 1), padding='SAME', groups=1) + self.conv = bp.layers.Conv2D(in_channels=4, out_channels=32, kernel_size=(3, 3), + strides=(1, 1), padding='SAME', groups=1) - def forward(self, x, shared_args=None): - x = self.conv(x) + def update(self, shared, x): + x = self.conv(shared, x) return x img = jnp.zeros((2, 200, 198, 4)) @@ -29,7 +29,7 @@ def forward(self, x, shared_args=None): img = img.at[1, x:x + 20, y:y + 20, k].set(3.0) net = Convnet() - out = net(img) + out = net(None, img) print("out shape: ", out.shape) # print("First output channel:") # plt.figure(figsize=(10, 10)) @@ -37,19 +37,19 @@ def forward(self, x, shared_args=None): # plt.show() def test_conv1D(self): - class Convnet(bp.train.TrainingSystem): + class Convnet(bp.dyn.TrainingSystem): def __init__(self): super(Convnet, self).__init__() - self.conv = bp.train.layers.Conv1D(in_channels=3, out_channels=32, kernel_size=(3,)) + self.conv = bp.layers.Conv1D(in_channels=3, out_channels=32, kernel_size=(3,)) - def forward(self, x, shared_args=None): - x = self.conv(x) + def update(self, shared, x): + x = self.conv(shared, x) return x model = Convnet() input = bp.math.ones((2, 5, 3)) - out = model(input) + out = model(None, input) print("out shape: ", out.shape) # print("First output channel:") # plt.figure(figsize=(10, 10)) @@ -57,20 +57,20 @@ def forward(self, x, shared_args=None): # plt.show() def test_conv2D(self): - class Convnet(bp.train.TrainingSystem): + class Convnet(bp.dyn.TrainingSystem): def __init__(self): super(Convnet, self).__init__() - self.conv = bp.train.layers.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3)) + self.conv = bp.layers.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3)) - def forward(self, x, shared_args=None): - x = self.conv(x) + def update(self, shared, x): + x = self.conv(shared, x) return x model = Convnet() input = bp.math.ones((2, 5, 5, 3)) - out = model(input) + out = model(None, input) print("out shape: ", out.shape) # print("First output channel:") # plt.figure(figsize=(10, 10)) @@ -78,18 +78,18 @@ def forward(self, x, shared_args=None): # plt.show() def test_conv3D(self): - class Convnet(bp.train.TrainingSystem): + class Convnet(bp.dyn.TrainingSystem): def __init__(self): super(Convnet, self).__init__() - self.conv = bp.train.layers.Conv3D(in_channels=3, out_channels=32, kernel_size=(3, 3, 3)) + self.conv = bp.layers.Conv3D(in_channels=3, out_channels=32, kernel_size=(3, 3, 3)) - def forward(self, x, shared_args=None): - x = self.conv(x) + def update(self, shared, x): + x = self.conv(shared, x) return x model = Convnet() input = bp.math.ones((2, 5, 5, 5, 3)) - out = model(input) + out = model(None, input) print("out shape: ", out.shape) diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index baaba564a..97a60e88c 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -4,8 +4,7 @@ import brainpy.math as bm from brainpy.dyn.base import NeuGroup -from brainpy.dyn.utils import init_noise -from brainpy.initialize import OneInit, Uniform, Initializer, init_param +from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.integrators.sde import sdeint @@ -209,21 +208,27 @@ def __init__( n_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.32), noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - name: str = None + name: str = None, + + # training parameter + trainable: bool = False, ): # initialization - super(HH, self).__init__(size=size, keep_size=keep_size, name=name) + super(HH, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.ENa = init_param(ENa, self.var_shape, allow_none=False) - self.EK = init_param(EK, self.var_shape, allow_none=False) - self.EL = init_param(EL, self.var_shape, allow_none=False) - self.gNa = init_param(gNa, self.var_shape, allow_none=False) - self.gK = init_param(gK, self.var_shape, allow_none=False) - self.gL = init_param(gL, self.var_shape, allow_none=False) - self.C = init_param(C, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=4) + self.ENa = parameter(ENa, self.varshape, allow_none=False) + self.EK = parameter(EK, self.varshape, allow_none=False) + self.EL = parameter(EL, self.varshape, allow_none=False) + self.gNa = parameter(gNa, self.varshape, allow_none=False) + self.gK = parameter(gK, self.varshape, allow_none=False) + self.gL = parameter(gL, self.varshape, allow_none=False) + self.C = parameter(C, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=4) # initializers check_initializer(m_initializer, 'm_initializer', allow_none=False) @@ -236,12 +241,13 @@ def __init__( self._V_initializer = V_initializer # variables - self.m = bm.Variable(init_param(self._m_initializer, self.var_shape)) - self.h = bm.Variable(init_param(self._h_initializer, self.var_shape)) - self.n = bm.Variable(init_param(self._n_initializer, self.var_shape)) - self.V = bm.Variable(init_param(self._V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.m = variable(self._m_initializer, trainable, self.varshape) + self.h = variable(self._h_initializer, trainable, self.varshape) + self.n = variable(self._n_initializer, trainable, self.varshape) + self.V = variable(self._V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + # sp_type = bm.dftype() if trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # integral if self.noise is None: @@ -249,13 +255,14 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.m.value = init_param(self._m_initializer, self.var_shape) - self.h.value = init_param(self._h_initializer, self.var_shape) - self.n.value = init_param(self._n_initializer, self.var_shape) - self.V.value = init_param(self._V_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False + def reset_state(self, batch_size=None): + self.m.value = variable(self._m_initializer, batch_size, self.varshape) + self.h.value = variable(self._h_initializer, batch_size, self.varshape) + self.n.value = variable(self._n_initializer, batch_size, self.varshape) + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + # sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) def dm(self, m, t, V): alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) @@ -286,8 +293,10 @@ def dV(self, V, t, m, h, n, I_ext): def derivative(self): return JointEq([self.dV, self.dm, self.dh, self.dn]) - def update(self, t, dt): - V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, self.input, dt=dt) + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + if x is not None: self.input += x + V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, self.input, dt) self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.m.value = m @@ -394,26 +403,32 @@ def __init__( V_initializer: Union[Callable, Initializer, Tensor] = Uniform(-70., -60.), noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - name: str = None + name: str = None, + + # training parameter + trainable: bool = False, ): # initialization - super(MorrisLecar, self).__init__(size=size, keep_size=keep_size, name=name) + super(MorrisLecar, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # params - self.V_Ca = init_param(V_Ca, self.var_shape, allow_none=False) - self.g_Ca = init_param(g_Ca, self.var_shape, allow_none=False) - self.V_K = init_param(V_K, self.var_shape, allow_none=False) - self.g_K = init_param(g_K, self.var_shape, allow_none=False) - self.V_leak = init_param(V_leak, self.var_shape, allow_none=False) - self.g_leak = init_param(g_leak, self.var_shape, allow_none=False) - self.C = init_param(C, self.var_shape, allow_none=False) - self.V1 = init_param(V1, self.var_shape, allow_none=False) - self.V2 = init_param(V2, self.var_shape, allow_none=False) - self.V3 = init_param(V3, self.var_shape, allow_none=False) - self.V4 = init_param(V4, self.var_shape, allow_none=False) - self.phi = init_param(phi, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=2) + self.V_Ca = parameter(V_Ca, self.varshape, allow_none=False) + self.g_Ca = parameter(g_Ca, self.varshape, allow_none=False) + self.V_K = parameter(V_K, self.varshape, allow_none=False) + self.g_K = parameter(g_K, self.varshape, allow_none=False) + self.V_leak = parameter(V_leak, self.varshape, allow_none=False) + self.g_leak = parameter(g_leak, self.varshape, allow_none=False) + self.C = parameter(C, self.varshape, allow_none=False) + self.V1 = parameter(V1, self.varshape, allow_none=False) + self.V2 = parameter(V2, self.varshape, allow_none=False) + self.V3 = parameter(V3, self.varshape, allow_none=False) + self.V4 = parameter(V4, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=2) # initializers check_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -422,10 +437,11 @@ def __init__( self._V_initializer = V_initializer # variables - self.W = bm.Variable(init_param(W_initializer, self.var_shape)) - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.W = variable(self._W_initializer, trainable, self.varshape) + self.V = variable(self._V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + # sp_type = bm.dftype() if trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # integral if self.noise is None: @@ -433,11 +449,12 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.W.value = init_param(self._W_initializer, self.var_shape) - self.V.value = init_param(self._V_initializer, self.var_shape) - self.input.value = bm.zeros(self.var_shape) - self.spike.value = bm.zeros(self.var_shape, dtype=bool) + def reset_state(self, batch_size=None): + self.W.value = variable(self._W_initializer, batch_size, self.varshape) + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + # sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) def dV(self, V, t, W, I_ext): M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) @@ -457,8 +474,10 @@ def dW(self, W, t, V): def derivative(self): return JointEq([self.dV, self.dW]) - def update(self, t, dt): - V, self.W.value = self.integral(self.V, self.W, t, self.input, dt=dt) + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + if x is not None: self.input += x + V, self.W.value = self.integral(self.V, self.W, t, self.input, dt) spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.spike.value = spike @@ -644,31 +663,35 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, + trainable: bool = False, ): # initialization - super(PinskyRinzelModel, self).__init__(size=size, keep_size=keep_size, name=name) + super(PinskyRinzelModel, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # conductance parameters - self.gAHP = init_param(gAHP, self.var_shape, allow_none=False) - self.gCa = init_param(gCa, self.var_shape, allow_none=False) - self.gNa = init_param(gNa, self.var_shape, allow_none=False) - self.gK = init_param(gK, self.var_shape, allow_none=False) - self.gL = init_param(gL, self.var_shape, allow_none=False) - self.gC = init_param(gC, self.var_shape, allow_none=False) + self.gAHP = parameter(gAHP, self.varshape, allow_none=False) + self.gCa = parameter(gCa, self.varshape, allow_none=False) + self.gNa = parameter(gNa, self.varshape, allow_none=False) + self.gK = parameter(gK, self.varshape, allow_none=False) + self.gL = parameter(gL, self.varshape, allow_none=False) + self.gC = parameter(gC, self.varshape, allow_none=False) # reversal potential parameters - self.ENa = init_param(ENa, self.var_shape, allow_none=False) - self.ECa = init_param(ECa, self.var_shape, allow_none=False) - self.EK = init_param(EK, self.var_shape, allow_none=False) - self.EL = init_param(EL, self.var_shape, allow_none=False) + self.ENa = parameter(ENa, self.varshape, allow_none=False) + self.ECa = parameter(ECa, self.varshape, allow_none=False) + self.EK = parameter(EK, self.varshape, allow_none=False) + self.EL = parameter(EL, self.varshape, allow_none=False) # other neuronal parameters - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.Cm = init_param(Cm, self.var_shape, allow_none=False) - self.gc = init_param(gc, self.var_shape, allow_none=False) - self.p = init_param(p, self.var_shape, allow_none=False) - self.A = init_param(A, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=8) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.Cm = parameter(Cm, self.varshape, allow_none=False) + self.gc = parameter(gc, self.varshape, allow_none=False) + self.p = parameter(p, self.varshape, allow_none=False) + self.A = parameter(A, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=8) # initializers check_initializer(Vs_initializer, 'Vs_initializer', allow_none=False) @@ -679,17 +702,17 @@ def __init__( self._Ca_initializer = Ca_initializer # variables - self.Vs = bm.Variable(init_param(self._Vs_initializer, self.var_shape)) - self.Vd = bm.Variable(init_param(self._Vd_initializer, self.var_shape)) - self.Ca = bm.Variable(init_param(self._Ca_initializer, self.var_shape)) - self.h = bm.Variable(self.inf_h(self.Vs)) - self.n = bm.Variable(self.inf_n(self.Vs)) - self.s = bm.Variable(self.inf_s(self.Vd)) - self.c = bm.Variable(self.inf_c(self.Vd)) - self.q = bm.Variable(self.inf_q(self.Ca)) - self.Id = bm.Variable(bm.zeros(self.var_shape)) # input to soma - self.Is = bm.Variable(bm.zeros(self.var_shape)) # input to dendrite - # self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.Vs = variable(self._Vs_initializer, trainable, self.varshape) + self.Vd = variable(self._Vd_initializer, trainable, self.varshape) + self.Ca = variable(self._Ca_initializer, trainable, self.varshape) + self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if trainable else None) + self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if trainable else None) + self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if trainable else None) + self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if trainable else None) + self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if trainable else None) + self.Id = variable(bm.zeros, trainable, self.varshape) # input to soma + self.Is = variable(bm.zeros, trainable, self.varshape) # input to dendrite + # self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool)) # integral if self.noise is None: @@ -697,32 +720,38 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.Vd.value = init_param(self._Vd_initializer, self.var_shape) - self.Vs.value = init_param(self._Vs_initializer, self.var_shape) - self.Ca.value = init_param(self._Ca_initializer, self.var_shape) - self.h.value = self.inf_h(self.Vs) - self.n.value = self.inf_n(self.Vs) - self.s.value = self.inf_s(self.Vd) - self.c.value = self.inf_c(self.Vd) - self.q.value = self.inf_q(self.Ca) - self.Id[:] = 0 - self.Is[:] = 0 + def reset_state(self, batch_size=None): + self.Vd.value = variable(self._Vd_initializer, batch_size, self.varshape) + self.Vs.value = variable(self._Vs_initializer, batch_size, self.varshape) + self.Ca.value = variable(self._Ca_initializer, batch_size, self.varshape) + batch_axis = 0 if self.trainable else None + self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis) + self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis) + self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis) + self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis) + self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis) + self.Id.value = variable(bm.zeros, batch_size, self.varshape) + self.Is.value = variable(bm.zeros, batch_size, self.varshape) # self.spike[:] = False def dCa(self, Ca, t, s, Vd): ICa = self.gCa * s * s * (Vd - self.ECa) return -0.13 * ICa - 0.075 * Ca - def dh(self, h, t, Vs): return self.alpha_h(Vs) * (1 - h) - self.beta_h(Vs) * h + def dh(self, h, t, Vs): + return self.alpha_h(Vs) * (1 - h) - self.beta_h(Vs) * h - def dn(self, n, t, Vs): return self.alpha_n(Vs) * (1 - n) - self.beta_n(Vs) * n + def dn(self, n, t, Vs): + return self.alpha_n(Vs) * (1 - n) - self.beta_n(Vs) * n - def ds(self, s, t, Vd): return self.alpha_s(Vd) * (1 - s) - self.beta_s(Vd) * s + def ds(self, s, t, Vd): + return self.alpha_s(Vd) * (1 - s) - self.beta_s(Vd) * s - def dc(self, c, t, Vd): return self.alpha_c(Vd) * (1 - c) - self.beta_c(Vd) * c + def dc(self, c, t, Vd): + return self.alpha_c(Vd) * (1 - c) - self.beta_c(Vd) * c - def dq(self, q, t, Ca): return self.alpha_q(Ca) * (1 - q) - self.beta_q(Ca) * q + def dq(self, q, t, Ca): + return self.alpha_q(Ca) * (1 - q) - self.beta_q(Ca) * q def dVs(self, Vs, t, h, n, Vd): I_Na = (self.gNa * self.inf_m(Vs) ** 2 * h) * (Vs - self.ENa) @@ -746,7 +775,8 @@ def dVd(self, Vd, t, s, q, c, Ca, Vs): def derivative(self): return JointEq([self.dVs, self.dVd, self.dCa, self.dh, self.dn, self.ds, self.dc, self.dq]) - def update(self, t, dt): + def update(self, tdi, x=None): + assert x is None Vs, Vd, Ca, h, n, s, c, q = self.integral(Vs=self.Vs.value, Vd=self.Vd.value, Ca=self.Ca.value, @@ -755,8 +785,8 @@ def update(self, t, dt): s=self.s.value, c=self.c.value, q=self.q.value, - t=t, - dt=dt) + t=tdi['t'], + dt=tdi['dt']) self.Vs.value = Vs self.Vd.value = Vd self.Ca.value = Ca @@ -768,36 +798,44 @@ def update(self, t, dt): self.Id[:] = 0. self.Is[:] = 0. - def alpha_m(self, Vs): return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.) + def alpha_m(self, Vs): + return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.) - def beta_m(self, Vs): return 0.28 * ((Vs + 60.) - 40.1) / (bm.exp(((Vs + 60.) - 40.1) / 5.) - 1.) + def beta_m(self, Vs): + return 0.28 * ((Vs + 60.) - 40.1) / (bm.exp(((Vs + 60.) - 40.1) / 5.) - 1.) def inf_m(self, Vs): alpha = self.alpha_m(Vs) beta = self.beta_m(Vs) return alpha / (alpha + beta) - def alpha_n(self, Vs): return 0.016 * (35.1 - (Vs + 60.)) / (bm.exp((35.1 - (Vs + 60.)) / 5) - 1) + def alpha_n(self, Vs): + return 0.016 * (35.1 - (Vs + 60.)) / (bm.exp((35.1 - (Vs + 60.)) / 5) - 1) - def beta_n(self, Vs): return 0.25 * bm.exp(0.5 - 0.025 * (Vs + 60.)) + def beta_n(self, Vs): + return 0.25 * bm.exp(0.5 - 0.025 * (Vs + 60.)) def inf_n(self, Vs): alpha = self.alpha_n(Vs) beta = self.beta_n(Vs) return alpha / (alpha + beta) - def alpha_h(self, Vs): return 0.128 * bm.exp((17. - (Vs + 60.)) / 18.) + def alpha_h(self, Vs): + return 0.128 * bm.exp((17. - (Vs + 60.)) / 18.) - def beta_h(self, Vs): return 4. / (1 + bm.exp((40. - (Vs + 60.)) / 5)) + def beta_h(self, Vs): + return 4. / (1 + bm.exp((40. - (Vs + 60.)) / 5)) def inf_h(self, Vs): alpha = self.alpha_h(Vs) beta = self.beta_h(Vs) return alpha / (alpha + beta) - def alpha_s(self, Vd): return 1.6 / (1 + bm.exp(-0.072 * ((Vd + 60.) - 65.))) + def alpha_s(self, Vd): + return 1.6 / (1 + bm.exp(-0.072 * ((Vd + 60.) - 65.))) - def beta_s(self, Vd): return 0.02 * ((Vd + 60.) - 51.1) / (bm.exp(((Vd + 60.) - 51.1) / 5.) - 1.) + def beta_s(self, Vd): + return 0.02 * ((Vd + 60.) - 51.1) / (bm.exp(((Vd + 60.) - 51.1) / 5.) - 1.) def inf_s(self, Vd): alpha = self.alpha_s(Vd) @@ -818,9 +856,11 @@ def inf_c(self, Vd): beta_c = self.beta_c(Vd) return alpha_c / (alpha_c + beta_c) - def alpha_q(self, Ca): return bm.minimum(2e-5 * Ca, 1e-2) + def alpha_q(self, Ca): + return bm.minimum(2e-5 * Ca, 1e-2) - def beta_q(self, Ca): return 1e-3 + def beta_q(self, Ca): + return 1e-3 def inf_q(self, Ca): alpha = self.alpha_q(Ca) @@ -931,22 +971,23 @@ def __init__( n_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.32), noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - name: str = None + name: str = None, + trainable: bool = False, ): # initialization super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name) # parameters - self.ENa = init_param(ENa, self.var_shape, allow_none=False) - self.EK = init_param(EK, self.var_shape, allow_none=False) - self.EL = init_param(EL, self.var_shape, allow_none=False) - self.gNa = init_param(gNa, self.var_shape, allow_none=False) - self.gK = init_param(gK, self.var_shape, allow_none=False) - self.gL = init_param(gL, self.var_shape, allow_none=False) - self.C = init_param(C, self.var_shape, allow_none=False) - self.phi = init_param(phi, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=3) + self.ENa = parameter(ENa, self.varshape, allow_none=False) + self.EK = parameter(EK, self.varshape, allow_none=False) + self.EL = parameter(EL, self.varshape, allow_none=False) + self.gNa = parameter(gNa, self.varshape, allow_none=False) + self.gK = parameter(gK, self.varshape, allow_none=False) + self.gL = parameter(gL, self.varshape, allow_none=False) + self.C = parameter(C, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=3) # initializers check_initializer(h_initializer, 'h_initializer', allow_none=False) @@ -957,11 +998,11 @@ def __init__( self._V_initializer = V_initializer # variables - self.h = bm.Variable(init_param(self._h_initializer, self.var_shape)) - self.n = bm.Variable(init_param(self._n_initializer, self.var_shape)) - self.V = bm.Variable(init_param(self._V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.h = variable(self._h_initializer, trainable, self.varshape) + self.n = variable(self._n_initializer, trainable, self.varshape) + self.V = variable(self._V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + self.spike = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # integral if self.noise is None: @@ -969,12 +1010,12 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.h.value = init_param(self._h_initializer, self.var_shape) - self.n.value = init_param(self._n_initializer, self.var_shape) - self.V.value = init_param(self._V_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False + def reset_state(self, batch_size=None): + self.h.value = variable(self._h_initializer, batch_size, self.varshape) + self.n.value = variable(self._n_initializer, batch_size, self.varshape) + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) def m_inf(self, V): alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) @@ -1004,8 +1045,10 @@ def dV(self, V, t, h, n, I_ext): def derivative(self): return JointEq([self.dV, self.dh, self.dn]) - def update(self, t, dt): - V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt) + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + if x is not None: self.input += x + V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt) self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.h.value = h diff --git a/brainpy/dyn/neurons/fractional_models.py b/brainpy/dyn/neurons/fractional_models.py index 7530a84e5..8947e85e4 100644 --- a/brainpy/dyn/neurons/fractional_models.py +++ b/brainpy/dyn/neurons/fractional_models.py @@ -4,7 +4,7 @@ import brainpy.math as bm from brainpy.dyn.base import NeuGroup -from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param +from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter from brainpy.integrators.fde import CaputoL1Schema from brainpy.integrators.fde import GLShortMemory from brainpy.integrators.joint_eq import JointEq @@ -103,13 +103,13 @@ def __init__( check_integer(num_memory, 'num_memory', allow_none=False) # parameters - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.c = init_param(c, self.var_shape, allow_none=False) - self.d = init_param(d, self.var_shape, allow_none=False) - self.mu = init_param(mu, self.var_shape, allow_none=False) - self.Vth = init_param(Vth, self.var_shape, allow_none=False) - self.delta = init_param(delta, self.var_shape, allow_none=False) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.mu = parameter(mu, self.varshape, allow_none=False) + self.Vth = parameter(Vth, self.varshape, allow_none=False) + self.delta = parameter(delta, self.varshape, allow_none=False) # initializers check_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -120,22 +120,23 @@ def __init__( self._y_initializer = y_initializer # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.w = bm.Variable(init_param(w_initializer, self.var_shape)) - self.y = bm.Variable(init_param(y_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = bm.Variable(parameter(V_initializer, self.varshape)) + self.w = bm.Variable(parameter(w_initializer, self.varshape)) + self.y = bm.Variable(parameter(y_initializer, self.varshape)) + self.input = bm.Variable(bm.zeros(self.varshape)) + self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool)) # integral function self.integral = GLShortMemory(self.derivative, alpha=alpha, - num_step=num_memory, + num_memory=num_memory, inits=[self.V, self.w, self.y]) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.w.value = init_param(self._w_initializer, self.var_shape) - self.y.value = init_param(self._y_initializer, self.var_shape) + def reset_state(self, batch_size=None): + assert batch_size is None + self.V.value = parameter(self._V_initializer, self.varshape) + self.w.value = parameter(self._w_initializer, self.varshape) + self.y.value = parameter(self._y_initializer, self.varshape) self.input[:] = 0 self.spike[:] = False # integral function reset @@ -154,7 +155,9 @@ def dy(self, y, t, V): def derivative(self): return JointEq([self.dV, self.dw, self.dy]) - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + if x is not None: self.input += x V, w, y = self.integral(self.V, self.w, self.y, t, dt) self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) self.V.value = V @@ -243,16 +246,16 @@ def __init__( # params self.alpha = alpha check_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True) - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.c = init_param(c, self.var_shape, allow_none=False) - self.d = init_param(d, self.var_shape, allow_none=False) - self.f = init_param(f, self.var_shape, allow_none=False) - self.g = init_param(g, self.var_shape, allow_none=False) - self.h = init_param(h, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.R = init_param(R, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.f = parameter(f, self.varshape, allow_none=False) + self.g = parameter(g, self.varshape, allow_none=False) + self.h = parameter(h, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) # initializers check_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -261,10 +264,10 @@ def __init__( self._u_initializer = u_initializer # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.u = bm.Variable(init_param(u_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = bm.Variable(parameter(V_initializer, self.varshape)) + self.u = bm.Variable(parameter(u_initializer, self.varshape)) + self.input = bm.Variable(bm.zeros(self.varshape)) + self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool)) # functions check_integer(num_step, 'num_step', allow_none=False) @@ -273,9 +276,9 @@ def __init__( num_memory=num_step, inits=[self.V, self.u]) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.u.value = init_param(self._u_initializer, self.var_shape) + def reset_state(self, batch_size=None): + self.V.value = parameter(self._V_initializer, self.varshape) + self.u.value = parameter(self._u_initializer, self.varshape) self.input[:] = 0 self.spike[:] = False # integral function reset @@ -293,7 +296,9 @@ def du(self, u, t, V): def derivative(self): return JointEq([self.dV, self.du]) - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + if x is not None: self.input += x V, u = self.integral(self.V, self.u, t=t, I_ext=self.input, dt=dt) spikes = V >= self.V_th self.V.value = bm.where(spikes, self.c, V) diff --git a/brainpy/dyn/neurons/input_groups.py b/brainpy/dyn/neurons/input_groups.py index 7082d0be7..788debf16 100644 --- a/brainpy/dyn/neurons/input_groups.py +++ b/brainpy/dyn/neurons/input_groups.py @@ -7,18 +7,37 @@ import brainpy.math as bm from brainpy.dyn.base import NeuGroup from brainpy.errors import ModelBuildError -from brainpy.initialize import Initializer, init_param +from brainpy.initialize import Initializer, parameter, variable from brainpy.types import Shape, Tensor __all__ = [ - 'SpikeTimeInput', + 'InputGroup', 'SpikeTimeGroup', - - 'PoissonInput', 'PoissonGroup', ] +class InputGroup(NeuGroup): + def __init__( + self, + size: Shape, + keep_size: bool = False, + trainable: bool = False, + name: str = None, + ): + super(InputGroup, self).__init__(name=name, + size=size, + keep_size=keep_size, + trainable=trainable) + self.spike = None + + def update(self, tdi, x=None): + pass + + def reset_state(self, batch_size=None): + pass + + class SpikeTimeGroup(NeuGroup): """The input neuron group characterized by spikes emitting at given times. @@ -54,12 +73,15 @@ def __init__( indices: Union[Sequence, Tensor], need_sort: bool = True, keep_size: bool = False, + trainable: bool = False, name: str = None ): - super(SpikeTimeGroup, self).__init__(size=size, name=name) + super(SpikeTimeGroup, self).__init__(size=size, + name=name, + keep_size=keep_size, + trainable=trainable) # parameters - self.keep_size = keep_size if keep_size: raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}') if len(indices) != len(times): @@ -73,7 +95,7 @@ def __init__( # variables self.i = bm.Variable(bm.zeros(1)) - self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) + self.spike = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) if need_sort: sort_idx = bm.argsort(self.times) self.indices.value = self.indices[sort_idx] @@ -85,36 +107,22 @@ def cond_fun(t): return bm.logical_and(i < self.num_times, t >= self.times[i]) def body_fun(t): - self.spike[self.indices[self.i[0]]] = True + i = self.i[0] + if self.trainable: + self.spike[:, self.indices[i]] = True + else: + self.spike[self.indices[i]] = True self.i += 1 self._run = bm.make_while(cond_fun, body_fun, dyn_vars=self.vars()) - def reset(self): + def reset_state(self, batch_size=None): self.i[0] = 1 - self.spike[:] = False + self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) - def update(self, t, dt): + def update(self, tdi, x=None): self.spike[:] = False - self._run(t) - - -class SpikeTimeInput(SpikeTimeGroup): - """Spike Time Input. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.SpikeTimeGroup" instead. - - Returns - ------- - group: NeuGroup - The neural group. - """ - - def __init__(self, *args, **kwargs): - raise ValueError('Please use "brainpy.dyn.SpikeTimeGroup" instead. ' - '"brainpy.dyn.SpikeTimeInput" is deprecated since ' - 'version 2.1.5') + self._run(tdi['t']) class PoissonGroup(NeuGroup): @@ -124,43 +132,33 @@ class PoissonGroup(NeuGroup): def __init__( self, size: Shape, - freqs: Union[float, jnp.ndarray, bm.JaxArray, Initializer], + freqs: Union[int, float, jnp.ndarray, bm.JaxArray, Initializer], seed: int = None, keep_size: bool = False, + trainable: bool = False, name: str = None ): - super(PoissonGroup, self).__init__(size=size, name=name) + super(PoissonGroup, self).__init__(size=size, + name=name, + keep_size=keep_size, + trainable=trainable) # parameters self.keep_size = keep_size self.seed = seed - self.freqs = init_param(freqs, self.num, allow_none=False) + self.freqs = parameter(freqs, self.num, allow_none=False) # variables - self.spike = bm.Variable(bm.zeros(self.size if keep_size else self.num, dtype=bool)) + self.spike = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) self.rng = bm.random.RandomState(seed=seed) - def update(self, t, dt): - self.spike.update(self.rng.random(self.var_shape) <= (self.freqs * dt / 1000.)) + def update(self, tdi, x=None): + shape = (self.spike.shape[:1] + self.varshape) if self.trainable else self.varshape + self.spike.update(self.rng.random(shape) <= (self.freqs * tdi['dt'] / 1000.)) - def reset(self): - self.spike[:] = False + def reset(self, batch_size=None): self.rng.seed(self.seed) + self.reset_state(batch_size) - -class PoissonInput(PoissonGroup): - """Poisson Group Input. - - .. deprecated:: 2.1.0 - Please use "brainpy.dyn.PoissonGroup" instead. - - Returns - ------- - poisson_group: NeuGroup - The poisson neural group. - """ - - def __init__(self, *args, **kwargs): - raise ValueError('Please use "brainpy.dyn.PoissonGroup" instead. ' - '"brainpy.dyn.PoissonInput" is deprecated since ' - 'version 2.1.5') + def reset_state(self, batch_size=None): + self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) diff --git a/brainpy/dyn/neurons/noise_groups.py b/brainpy/dyn/neurons/noise_groups.py index d61d54817..03836c2b0 100644 --- a/brainpy/dyn/neurons/noise_groups.py +++ b/brainpy/dyn/neurons/noise_groups.py @@ -2,9 +2,9 @@ from typing import Union, Callable -import brainpy.math as bm +from brainpy import math as bm, initialize as init from brainpy.dyn.base import NeuGroup -from brainpy.initialize import init_param, Initializer +from brainpy.initialize import Initializer from brainpy.integrators.sde import sdeint from brainpy.types import Tensor, Shape @@ -49,30 +49,31 @@ def __init__( sigma: Union[float, Tensor, Initializer, Callable] = 1., tau: Union[float, Tensor, Initializer, Callable] = 10., method: str = 'euler', - name: str = None + keep_size: bool = False, + trainable: bool = False, + name: str = None, ): - super(OUProcess, self).__init__(size=size, name=name) + super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size) # parameters - self.mean = init_param(mean, self.num, allow_none=False) - self.sigma = init_param(sigma, self.num, allow_none=False) - self.tau = init_param(tau, self.num, allow_none=False) + self.mean = init.parameter(mean, self.varshape, allow_none=False) + self.sigma = init.parameter(sigma, self.varshape, allow_none=False) + self.tau = init.parameter(tau, self.varshape, allow_none=False) # variables - self.x = bm.Variable(bm.ones(self.num) * mean) + self.x = init.variable(lambda s: bm.ones(s) * self.mean, trainable, self.varshape) # integral functions self.integral = sdeint(f=self.df, g=self.dg, method=method) - def reset(self): - self.x[:] = self.mean + def reset_state(self, batch_size=None): + self.x.value = init.variable(lambda s: bm.ones(s) * self.mean, batch_size, self.varshape) def df(self, x, t): - f_x_ou = (self.mean - x) / self.tau - return f_x_ou + return (self.mean - x) / self.tau def dg(self, x, t): return self.sigma - def update(self, t, dt): - self.x.value = self.integral(self.x, t, dt) + def update(self, tdi, x=None): + self.x.value = self.integral(self.x, tdi['t'], tdi['dt']) diff --git a/brainpy/dyn/neurons/reduced_models.py b/brainpy/dyn/neurons/reduced_models.py index 82ee5ec12..b3c8cca8c 100644 --- a/brainpy/dyn/neurons/reduced_models.py +++ b/brainpy/dyn/neurons/reduced_models.py @@ -2,27 +2,125 @@ from typing import Union, Callable +from jax.lax import stop_gradient + import brainpy.math as bm from brainpy.dyn.base import NeuGroup -from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param +from brainpy.initialize import (ZeroInit, OneInit, Initializer, + parameter, variable, noise as init_noise) from brainpy.integrators import sdeint, odeint, JointEq -from brainpy.tools.checking import check_initializer +from brainpy.tools.checking import check_initializer, check_callable from brainpy.types import Shape, Tensor -from brainpy.dyn.utils import init_noise __all__ = [ + 'LeakyIntegrator', 'LIF', 'ExpIF', 'AdExIF', 'QuaIF', 'AdQuaIF', 'GIF', + 'ALIFBellec2020', 'Izhikevich', 'HindmarshRose', 'FHN', ] +class LeakyIntegrator(NeuGroup): + r"""Leaky Integrator Model. + + **Model Descriptions** + + This class implements a leaky integrator model, in which its dynamics is + given by: + + .. math:: + + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) + + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`\tau` is the time constant, and :math:`R` is the + resistance. + + Parameters + ---------- + size: sequence of int, int + The size of the neuron group. + V_rest: float, JaxArray, ndarray, Initializer, callable + Resting membrane potential. + R: float, JaxArray, ndarray, Initializer, callable + Membrane resistance. + tau: float, JaxArray, ndarray, Initializer, callable + Membrane time constant. + V_initializer: JaxArray, ndarray, Initializer, callable + The initializer of membrane potential. + noise: JaxArray, ndarray, Initializer, callable + The noise added onto the membrane potential + method: str + The numerical integration method. + name: str + The group name. + """ + + def __init__( + self, + # neuron group size + size: Shape, + keep_size: bool = False, + + # neuron parameters + V_rest: Union[float, Tensor, Initializer, Callable] = 0., + R: Union[float, Tensor, Initializer, Callable] = 1., + tau: Union[float, Tensor, Initializer, Callable] = 10., + V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), + noise: Union[float, Tensor, Initializer, Callable] = None, + + # training parameters + trainable: bool = False, + + # other parameter + name: str = None, + method: str = 'exp_auto', + ): + super(LeakyIntegrator, self).__init__(size=size, + trainable=trainable, + keep_size=keep_size, + name=name) + + # parameters + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape) + + # initializers + check_initializer(V_initializer, 'V_initializer') + self._V_initializer = V_initializer + + # variables + self.V = variable(self._V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + def derivative(self, V, t, I_ext): + return (-V + self.V_rest + self.R * I_ext) / self.tau + + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + + def update(self, tdi, x=None): + if x is not None: self.input += x + self.V.value = self.integral(self.V.value, tdi.t, self.input.value, tdi.dt) + self.input[:] = 0. + + class LIF(NeuGroup): r"""Leaky integrate-and-fire neuron model. @@ -82,41 +180,52 @@ class LIF(NeuGroup): def __init__( self, size: Shape, + keep_size: bool = False, + + # other parameter V_rest: Union[float, Tensor, Initializer, Callable] = 0., V_reset: Union[float, Tensor, Initializer, Callable] = -5., V_th: Union[float, Tensor, Initializer, Callable] = 20., R: Union[float, Tensor, Initializer, Callable] = 1., tau: Union[float, Tensor, Initializer, Callable] = 10., - tau_ref: Union[float, Tensor, Initializer, Callable] = 1., + tau_ref: Union[float, Tensor, Initializer, Callable] = None, V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), noise: Union[float, Tensor, Initializer, Callable] = None, - keep_size: bool = False, method: str = 'exp_auto', - name: str = None + name: str = None, + + # training parameter + trainable: bool = False, + spike_fun: Callable = bm.spike_with_sigmoid_grad, ): # initialization - super(LIF, self).__init__(size=size, keep_size=keep_size, name=name) + super(LIF, self).__init__(size=size, + name=name, + keep_size=keep_size, + trainable=trainable) # parameters - self.keep_size = keep_size - self.V_rest = init_param(V_rest, self.var_shape, allow_none=False) - self.V_reset = init_param(V_reset, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.tau_ref = init_param(tau_ref, self.var_shape, allow_none=False) - self.R = init_param(R, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_reset = parameter(V_reset, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.noise = init_noise(noise, self.varshape) + self.spike_fun = check_callable(spike_fun, 'spike_fun') # initializers check_initializer(V_initializer, 'V_initializer') self._V_initializer = V_initializer # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.t_last_spike = bm.Variable(bm.ones(self.var_shape) * -1e7) - self.refractory = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = variable(self._V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if trainable else bool # the gradient of spike is a float + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) + if self.tau_ref is not None: + self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, trainable, self.varshape) + self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # integral if self.noise is None: @@ -127,22 +236,61 @@ def __init__( def derivative(self, V, t, I_ext): return (-V + self.V_rest + self.R * I_ext) / self.tau - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False - self.t_last_spike[:] = -1e7 - self.refractory[:] = False - - def update(self, t, dt): - refractory = (t - self.t_last_spike) <= self.tau_ref - V = self.integral(self.V, t, self.input, dt=dt) - V = bm.where(refractory, self.V, V) - spike = V >= self.V_th - self.t_last_spike.value = bm.where(spike, t, self.t_last_spike) - self.V.value = bm.where(spike, self.V_reset, V) - self.refractory.value = bm.logical_or(refractory, spike) - self.spike.value = spike + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + if self.tau_ref is not None: + self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) + self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + if x is not None: self.input += x + + # integrate membrane potential + V = self.integral(self.V.value, t, self.input.value, dt) + + if self.tau_ref is not None: + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if self.trainable: + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V, V) + + # spike, refractory, spiking time, and membrane potential reset + if self.trainable: + spike = self.spike_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) + V += (self.V_reset - V) * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + refractory = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value) + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + refractory = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike) + self.V.value = V + self.spike.value = spike + self.refractory.value = refractory + self.t_last_spike.value = t_last_spike + + else: + # spike, spiking time, and membrane potential reset + if self.trainable: + spike = self.spike_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) + V += (self.V_reset - V) * spike_no_grad + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + self.V.value = V + self.spike.value = spike + + # reset input self.input[:] = 0. @@ -254,37 +402,43 @@ def __init__( delta_T: Union[float, Tensor, Initializer, Callable] = 3.48, R: Union[float, Tensor, Initializer, Callable] = 1., tau: Union[float, Tensor, Initializer, Callable] = 10., - tau_ref: Union[float, Tensor, Initializer, Callable] = 1.7, + tau_ref: Union[float, Tensor, Initializer, Callable] = None, V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), noise: Union[float, Tensor, Initializer, Callable] = None, keep_size: bool = False, + trainable: bool = False, method: str = 'exp_auto', name: str = None ): # initialize - super(ExpIF, self).__init__(size=size, keep_size=keep_size, name=name) + super(ExpIF, self).__init__(size=size, + name=name, + trainable=trainable, + keep_size=keep_size, ) # parameters - self.V_rest = init_param(V_rest, self.var_shape, allow_none=False) - self.V_reset = init_param(V_reset, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.V_T = init_param(V_T, self.var_shape, allow_none=False) - self.delta_T = init_param(delta_T, self.var_shape, allow_none=False) - self.tau_ref = init_param(tau_ref, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.R = init_param(R, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_reset = parameter(V_reset, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.V_T = parameter(V_T, self.varshape, allow_none=False) + self.delta_T = parameter(delta_T, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape) # initializers check_initializer(V_initializer, 'V_initializer') self._V_initializer = V_initializer # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.refractory = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.t_last_spike = bm.Variable(bm.ones(self.var_shape) * -1e7) + self.V = variable(V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) + self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, trainable, self.varshape) + if self.tau_ref is not None: + self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # integral if self.noise is None: @@ -292,27 +446,40 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False - self.t_last_spike[:] = -1e7 - self.refractory[:] = False + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) + if self.tau_ref is not None: + self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) def derivative(self, V, t, I_ext): exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau return dvdt - def update(self, t, dt): - refractory = (t - self.t_last_spike) <= self.tau_ref - V = self.integral(self.V, t, self.input, dt=dt) - V = bm.where(refractory, self.V, V) - spike = self.V_th <= V - self.t_last_spike.value = bm.where(spike, t, self.t_last_spike) - self.V.value = bm.where(spike, self.V_reset, V) - self.refractory.value = bm.logical_or(refractory, spike) + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + if x is not None: self.input += x + V = self.integral(self.V.value, t, self.input.value, dt) + + if self.tau_ref is not None: + refractory = (t - self.t_last_spike) <= self.tau_ref + V = bm.where(refractory, self.V, V) + spike = self.V_th <= V + t_last_spike = bm.where(spike, t, self.t_last_spike) + V = bm.where(spike, self.V_reset, V) + self.refractory.value = bm.logical_or(refractory, spike) + else: + spike = self.V_th <= V + t_last_spike = bm.where(spike, t, self.t_last_spike) + V = bm.where(spike, self.V_reset, V) + + self.V.value = V self.spike.value = spike + self.t_last_spike.value = t_last_spike self.input[:] = 0. @@ -407,22 +574,26 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, + trainable: bool = False, name: str = None ): - super(AdExIF, self).__init__(size=size, keep_size=keep_size, name=name) + super(AdExIF, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable, ) # parameters - self.V_rest = init_param(V_rest, self.var_shape, allow_none=False) - self.V_reset = init_param(V_reset, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.V_T = init_param(V_T, self.var_shape, allow_none=False) - self.delta_T = init_param(delta_T, self.var_shape, allow_none=False) - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.tau_w = init_param(tau_w, self.var_shape, allow_none=False) - self.R = init_param(R, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=2) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_reset = parameter(V_reset, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.V_T = parameter(V_T, self.varshape, allow_none=False) + self.delta_T = parameter(delta_T, self.varshape, allow_none=False) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.tau_w = parameter(tau_w, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=2) # initializers check_initializer(V_initializer, 'V_initializer') @@ -431,11 +602,11 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.w = bm.Variable(init_param(w_initializer, self.var_shape)) - self.refractory = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = variable(V_initializer, trainable, self.varshape) + self.w = variable(w_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) # functions if self.noise is None: @@ -443,16 +614,16 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.w.value = init_param(self._w_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False - self.refractory[:] = False + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.w.value = variable(self._w_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dV(self, V, t, w, I_ext): - dVdt = (- V + self.V_rest + self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - - self.R * w + self.R * I_ext) / self.tau + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I_ext) / self.tau return dVdt def dw(self, w, t, V): @@ -463,8 +634,10 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - def update(self, t, dt): - V, w = self.integral(self.V, self.w, t, self.input, dt=dt) + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + if x is not None: self.input += x + V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt) spike = V >= self.V_th self.V.value = bm.where(spike, self.V_reset, V) self.w.value = bm.where(spike, w + self.b, w) @@ -549,37 +722,43 @@ def __init__( c: Union[float, Tensor, Initializer, Callable] = .07, R: Union[float, Tensor, Initializer, Callable] = 1., tau: Union[float, Tensor, Initializer, Callable] = 10., - tau_ref: Union[float, Tensor, Initializer, Callable] = 0., + tau_ref: Union[float, Tensor, Initializer, Callable] = None, V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), noise: Union[float, Tensor, Initializer, Callable] = None, keep_size: bool = False, + trainable: bool = False, method: str = 'exp_auto', name: str = None ): # initialization - super(QuaIF, self).__init__(size=size, keep_size=keep_size, name=name) + super(QuaIF, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.V_rest = init_param(V_rest, self.var_shape, allow_none=False) - self.V_reset = init_param(V_reset, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.V_c = init_param(V_c, self.var_shape, allow_none=False) - self.c = init_param(c, self.var_shape, allow_none=False) - self.R = init_param(R, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.tau_ref = init_param(tau_ref, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=1) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_reset = parameter(V_reset, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.V_c = parameter(V_c, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.noise = init_noise(noise, self.varshape, num_vars=1) # initializers check_initializer(V_initializer, '_V_initializer', allow_none=False) self._V_initializer = V_initializer # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.refractory = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.t_last_spike = bm.Variable(bm.ones(self.var_shape) * -1e7) + self.V = variable(V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) + self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, trainable, self.varshape) + if self.tau_ref is not None: + self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # integral if self.noise is None: @@ -587,26 +766,37 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False - self.t_last_spike[:] = -1e7 - self.refractory[:] = False + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) + if self.tau_ref is not None: + self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) def derivative(self, V, t, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau return dVdt - def update(self, t, dt, **kwargs): - refractory = (t - self.t_last_spike) <= self.tau_ref - V = self.integral(self.V, t, self.input, dt=dt) - V = bm.where(refractory, self.V, V) - spike = self.V_th <= V - self.t_last_spike.value = bm.where(spike, t, self.t_last_spike) - self.V.value = bm.where(spike, self.V_reset, V) - self.refractory.value = bm.logical_or(refractory, spike) + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + if x is not None: self.input += x + V = self.integral(self.V.value, t, self.input.value, dt) + if self.tau_ref is not None: + refractory = (t - self.t_last_spike) <= self.tau_ref + V = bm.where(refractory, self.V, V) + spike = self.V_th <= V + t_last_spike = bm.where(spike, t, self.t_last_spike) + V = bm.where(spike, self.V_reset, V) + self.refractory.value = bm.logical_or(refractory, spike) + else: + spike = self.V_th <= V + t_last_spike = bm.where(spike, t, self.t_last_spike) + V = bm.where(spike, self.V_reset, V) + self.V.value = V self.spike.value = spike + self.t_last_spike.value = t_last_spike self.input[:] = 0. @@ -704,21 +894,25 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, + trainable: bool = False, name: str = None ): - super(AdQuaIF, self).__init__(size=size, keep_size=keep_size, name=name) + super(AdQuaIF, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable, ) # parameters - self.V_rest = init_param(V_rest, self.var_shape, allow_none=False) - self.V_reset = init_param(V_reset, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.V_c = init_param(V_c, self.var_shape, allow_none=False) - self.c = init_param(c, self.var_shape, allow_none=False) - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.tau_w = init_param(tau_w, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=2) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_reset = parameter(V_reset, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.V_c = parameter(V_c, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.tau_w = parameter(tau_w, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=2) # initializers check_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -727,11 +921,12 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.w = bm.Variable(init_param(w_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.refractory = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = variable(V_initializer, trainable, self.varshape) + self.w = variable(w_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) + self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # integral if self.noise is None: @@ -739,12 +934,13 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.w.value = init_param(self._w_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False - self.refractory[:] = False + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.w.value = variable(self._w_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) def dV(self, V, t, w, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau @@ -758,8 +954,10 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - def update(self, t, dt): - V, w = self.integral(self.V, self.w, t, self.input, dt=dt) + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + if x is not None: self.input += x + V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt) spike = self.V_th <= V self.V.value = bm.where(spike, self.V_reset, V) self.w.value = bm.where(spike, w + self.b, w) @@ -873,27 +1071,35 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, - name: str = None + name: str = None, + + # parameter for training + trainable: bool = False, + spike_fun: Callable = bm.spike_with_sigmoid_grad, ): # initialization - super(GIF, self).__init__(size=size, keep_size=keep_size, name=name) + super(GIF, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # params - self.V_rest = init_param(V_rest, self.var_shape, allow_none=False) - self.V_reset = init_param(V_reset, self.var_shape, allow_none=False) - self.V_th_inf = init_param(V_th_inf, self.var_shape, allow_none=False) - self.V_th_reset = init_param(V_th_reset, self.var_shape, allow_none=False) - self.R = init_param(R, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.k1 = init_param(k1, self.var_shape, allow_none=False) - self.k2 = init_param(k2, self.var_shape, allow_none=False) - self.R1 = init_param(R1, self.var_shape, allow_none=False) - self.R2 = init_param(R2, self.var_shape, allow_none=False) - self.A1 = init_param(A1, self.var_shape, allow_none=False) - self.A2 = init_param(A2, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=4) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_reset = parameter(V_reset, self.varshape, allow_none=False) + self.V_th_inf = parameter(V_th_inf, self.varshape, allow_none=False) + self.V_th_reset = parameter(V_th_reset, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.k1 = parameter(k1, self.varshape, allow_none=False) + self.k2 = parameter(k2, self.varshape, allow_none=False) + self.R1 = parameter(R1, self.varshape, allow_none=False) + self.R2 = parameter(R2, self.varshape, allow_none=False) + self.A1 = parameter(A1, self.varshape, allow_none=False) + self.A2 = parameter(A2, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=4) + self.spike_fun = check_callable(spike_fun, 'spike_fun') # initializers check_initializer(V_initializer, 'V_initializer') @@ -906,12 +1112,13 @@ def __init__( self._Vth_initializer = Vth_initializer # variables - self.I1 = bm.Variable(init_param(I1_initializer, self.var_shape)) - self.I2 = bm.Variable(init_param(I2_initializer, self.var_shape)) - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.V_th = bm.Variable(init_param(Vth_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.I1 = variable(I1_initializer, trainable, self.varshape) + self.I2 = variable(I2_initializer, trainable, self.varshape) + self.V_th = variable(Vth_initializer, trainable, self.varshape) + self.V = variable(V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) # integral if self.noise is None: @@ -919,13 +1126,14 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.I1.value = init_param(self._I1_initializer, self.var_shape) - self.I2.value = init_param(self._I2_initializer, self.var_shape) - self.V_th.value = init_param(self._Vth_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False + def reset_state(self, batch_size=None): + self.I1.value = variable(self._I1_initializer, batch_size, self.varshape) + self.I2.value = variable(self._I2_initializer, batch_size, self.varshape) + self.V_th.value = variable(self._Vth_initializer, batch_size, self.varshape) + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dI1(self, I1, t): return - self.k1 * I1 @@ -943,19 +1151,198 @@ def dV(self, V, t, I1, I2, I_ext): def derivative(self): return JointEq([self.dI1, self.dI2, self.dVth, self.dV]) - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + + # integral + if x is not None: self.input += x I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, t, self.input, dt=dt) - spike = self.V_th <= V - V = bm.where(spike, self.V_reset, V) - I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) - I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) - reset_th = bm.logical_and(V_th < self.V_th_reset, spike) - V_th = bm.where(reset_th, self.V_th_reset, V_th) + + # spike and resets + if self.trainable: + spike = self.spike_fun(V - self.V_th) + V += (self.V_reset - V) * spike + I1 += spike * (self.R1 * I1 + self.A1 - I1) + I2 += spike * (self.R2 * I2 + self.A2 - I2) + reset_th = self.spike_fun(self.V_th_reset - V_th) * spike + V_th += reset_th * (self.V_th_reset - V_th) + else: + spike = self.V_th <= V + V = bm.where(spike, self.V_reset, V) + I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) + I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) + reset_th = bm.logical_and(V_th < self.V_th_reset, spike) + V_th = bm.where(reset_th, self.V_th_reset, V_th) self.spike.value = spike self.I1.value = I1 self.I2.value = I2 self.V_th.value = V_th self.V.value = V + + # reset input + self.input[:] = 0. + + +class ALIFBellec2020(NeuGroup): + r"""Leaky Integrate-and-Fire model with SFA [1]_. + + This model is similar to the GLIF2 model in the Technical White Paper + on generalized LIF (GLIF) models from AllenInstitute [2]_. + + Formally, this model is given by: + + .. math:: + + \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\ + \tau_a \dot{a} = -a + + Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then + + .. math:: + + V \gets V - V_{\mathrm{th}} \\ + a \gets a + 1 + + + References + ---------- + .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for + recurrent networks of spiking neurons." + Nature communications 11.1 (2020): 1-15. + .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for + Brain Science. Allen Cell Types Database, cell feature search. + Available from: celltypes.brain-map.org/data (2018). + """ + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # model parameters + V_rest: Union[float, Tensor, Initializer, Callable] = -70., + V_th: Union[float, Tensor, Initializer, Callable] = -60., + R: Union[float, Tensor, Initializer, Callable] = 1., + beta: Union[float, Tensor, Initializer, Callable] = 1.6, + tau: Union[float, Tensor, Initializer, Callable] = 20., + tau_a: Union[float, Tensor, Initializer, Callable] = 2000., + tau_ref: Union[float, Tensor, Initializer, Callable] = None, + noise: Union[float, Tensor, Initializer, Callable] = None, + + # initializers + V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-70.), + a_initializer: Union[Initializer, Callable, Tensor] = OneInit(-50.), + + # parameter for training + trainable: bool = False, + spike_fun: Callable = bm.spike_with_relu_grad, + + # other parameters + method: str = 'exp_auto', + name: str = None, + ): + super(ALIFBellec2020, self).__init__(name=name, + size=size, + keep_size=keep_size, + trainable=trainable) + + # parameters + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_th_reset = parameter(V_th, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.tau_a = parameter(tau_a, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.noise = init_noise(noise, self.varshape, num_vars=2) + self.spike_fun = check_callable(spike_fun, 'spike_fun') + + # initializers + check_initializer(V_initializer, 'V_initializer') + check_initializer(a_initializer, 'a_initializer') + self._V_initializer = V_initializer + self._a_initializer = a_initializer + + # variables + self.a = variable(a_initializer, trainable, self.varshape) + self.V = variable(V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) + if self.tau_ref is not None: + self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, trainable, self.varshape) + self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + def dVth(self, a, t): + return -a / self.tau_a + + def dV(self, V, t, I_ext): + return (- (V - self.V_rest) + self.R * I_ext) / self.tau + + @property + def derivative(self): + return JointEq([self.dV, self.dVth]) + + def reset_state(self, batch_size=None): + self.a.value = variable(self._a_initializer, batch_size, self.varshape) + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + if self.tau_ref is not None: + self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) + self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + + # integral + if x is not None: self.input += x + V, a = self.integral(self.V, self.a, t, self.input, dt) + + if self.tau_ref is not None: + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if self.trainable: + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V, V) + # spike and reset + if self.trainable: + spike = self.spike_fun(V - self.V_th_reset - self.beta * self.a) + spike_no_grad = stop_gradient(spike) + V -= self.V_th_reset * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + refractory = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value) + else: + spike = V >= (self.V_th_reset + self.beta * self.a) + refractory = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike) + V -= self.V_th_reset * spike + a += spike + self.refractory.value = refractory + self.t_last_spike.value = t_last_spike + + else: + # spike and reset + if self.trainable: + spike = self.spike_fun(V - self.V_th_reset - self.beta * self.a) + V -= self.V_th_reset * stop_gradient(spike) + else: + spike = V >= (self.V_th_reset + self.beta * self.a) + V -= self.V_th_reset * spike + a += spike + self.spike.value = spike + self.V.value = V + self.a.value = a + + # reset input self.input[:] = 0. @@ -1035,25 +1422,31 @@ def __init__( c: Union[float, Tensor, Initializer, Callable] = -65., d: Union[float, Tensor, Initializer, Callable] = 8., V_th: Union[float, Tensor, Initializer, Callable] = 30., - tau_ref: Union[float, Tensor, Initializer, Callable] = 0., + tau_ref: Union[float, Tensor, Initializer, Callable] = None, V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), u_initializer: Union[Initializer, Callable, Tensor] = OneInit(), noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', + trainable: bool = False, + spike_fun: Callable = bm.spike_with_sigmoid_grad, keep_size: bool = False, name: str = None ): # initialization - super(Izhikevich, self).__init__(size=size, keep_size=keep_size, name=name) + super(Izhikevich, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # params - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.c = init_param(c, self.var_shape, allow_none=False) - self.d = init_param(d, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.tau_ref = init_param(tau_ref, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=2) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.noise = init_noise(noise, self.varshape, num_vars=2) + self.spike_fun = check_callable(spike_fun, 'spike_fun') # initializers check_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -1062,12 +1455,14 @@ def __init__( self._u_initializer = u_initializer # variables - self.u = bm.Variable(init_param(u_initializer, self.var_shape)) - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.refractory = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) - self.t_last_spike = bm.Variable(bm.ones(self.var_shape) * -1e7) + self.u = variable(u_initializer, trainable, self.varshape) + self.V = variable(V_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) + if self.tau_ref is not None: + self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, trainable, self.varshape) + self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) # functions if self.noise is None: @@ -1075,13 +1470,15 @@ def __init__( else: self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.u.value = init_param(self._u_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False - self.refractory[:] = False - self.t_last_spike[:] = -1e7 + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.u.value = variable(self._u_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + if self.tau_ref is not None: + self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) + self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) def dV(self, V, t, u, I_ext): dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext @@ -1091,15 +1488,52 @@ def du(self, u, t, V): dudt = self.a * (self.b * V - u) return dudt - def update(self, t, dt): - V, u = self.integral(self.V, self.u, t, self.input, dt=dt) - refractory = (t - self.t_last_spike) <= self.tau_ref - V = bm.where(refractory, self.V, V) - spike = self.V_th <= V - self.t_last_spike.value = bm.where(spike, t, self.t_last_spike) - self.V.value = bm.where(spike, self.c, V) - self.u.value = bm.where(spike, u + self.d, u) - self.refractory.value = bm.logical_or(refractory, spike) + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + + # integrate membrane potential + if x is not None: self.input += x + V, u = self.integral(self.V, self.u, t, self.input, dt) + + if self.tau_ref is not None: + refractory = (t - self.t_last_spike) <= self.tau_ref + if self.trainable: + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V, V) + + # spike, refractory, and reset membrane potential + if self.trainable: + spike = self.spike_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) + V += spike_no_grad * (self.c - self.V_th) + u += spike_no_grad * self.d + spike_ = spike_no_grad > 0. + refractory = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value) + else: + spike = self.V_th <= V + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) + refractory = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike) + self.refractory.value = refractory + self.t_last_spike.value = t_last_spike + + else: + # spike, refractory, and reset membrane potential + if self.trainable: + spike = self.spike_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) + V += spike_no_grad * (self.c - self.V_th) + u += spike_no_grad * self.d + else: + spike = self.V_th <= V + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) + + # finally + self.V.value = V + self.u.value = u self.spike.value = spike self.input[:] = 0. @@ -1219,21 +1653,29 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, - name: str = None + name: str = None, + + # parameters for training + trainable: bool = False, + spike_fun: Callable = bm.spike2_with_sigmoid_grad, ): # initialization - super(HindmarshRose, self).__init__(size=size, keep_size=keep_size, name=name) + super(HindmarshRose, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.c = init_param(c, self.var_shape, allow_none=False) - self.d = init_param(d, self.var_shape, allow_none=False) - self.r = init_param(r, self.var_shape, allow_none=False) - self.s = init_param(s, self.var_shape, allow_none=False) - self.V_th = init_param(V_th, self.var_shape, allow_none=False) - self.V_rest = init_param(V_rest, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=3) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.r = parameter(r, self.varshape, allow_none=False) + self.s = parameter(s, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=3) + self.spike_fun = check_callable(spike_fun, 'spike_fun') # variables check_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -1244,11 +1686,12 @@ def __init__( self._z_initializer = z_initializer # variables - self.z = bm.Variable(init_param(V_initializer, self.var_shape)) - self.y = bm.Variable(init_param(y_initializer, self.var_shape)) - self.V = bm.Variable(init_param(z_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = variable(self._V_initializer, trainable, self.varshape) + self.y = variable(self._y_initializer, trainable, self.varshape) + self.z = variable(self._z_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) # integral if self.noise is None: @@ -1256,12 +1699,13 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.y.value = init_param(self._y_initializer, self.var_shape) - self.z.value = init_param(self._z_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.y.value = variable(self._y_initializer, batch_size, self.varshape) + self.z.value = variable(self._z_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dV(self, V, t, y, z, I_ext): return y - self.a * V * V * V + self.b * V * V - z + I_ext @@ -1276,9 +1720,14 @@ def dz(self, z, t, V): def derivative(self): return JointEq([self.dV, self.dy, self.dz]) - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + if x is not None: self.input += x V, y, z = self.integral(self.V, self.y, self.z, t, self.input, dt=dt) - self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) + if self.trainable: + self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th) + else: + self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) self.V.value = V self.y.value = y self.z.value = z @@ -1380,17 +1829,25 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, - name: str = None + name: str = None, + + # parameters for training + trainable: bool = False, + spike_fun: Callable = bm.spike2_with_sigmoid_grad, ): # initialization - super(FHN, self).__init__(size=size, keep_size=keep_size, name=name) + super(FHN, self).__init__(size=size, + keep_size=keep_size, + name=name, + trainable=trainable) # parameters - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.Vth = init_param(Vth, self.var_shape, allow_none=False) - self.noise = init_noise(noise, self.var_shape, num_vars=2) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.Vth = parameter(Vth, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=2) + self.spike_fun = check_callable(spike_fun, 'spike_fun') # initializers check_initializer(V_initializer, 'V_initializer') @@ -1399,10 +1856,11 @@ def __init__( self._w_initializer = w_initializer # variables - self.w = bm.Variable(init_param(w_initializer, self.var_shape)) - self.V = bm.Variable(init_param(V_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) - self.spike = bm.Variable(bm.zeros(self.var_shape, dtype=bool)) + self.V = variable(self._V_initializer, trainable, self.varshape) + self.w = variable(self._w_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), trainable, self.varshape) # integral if self.noise is None: @@ -1410,11 +1868,12 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def reset(self): - self.V.value = init_param(self._V_initializer, self.var_shape) - self.w.value = init_param(self._w_initializer, self.var_shape) - self.input[:] = 0 - self.spike[:] = False + def reset_state(self, batch_size=None): + self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.w.value = variable(self._w_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + sp_type = bm.dftype() if self.trainable else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dV(self, V, t, w, I_ext): return V - V * V * V / 3 - w + I_ext @@ -1426,9 +1885,14 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - def update(self, t, dt): - V, w = self.integral(self.V, self.w, t, self.input, dt=dt) - self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) + def update(self, tdi, x=None): + t, dt = tdi.t, tdi.dt + if x is not None: self.input += x + V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt=dt) + if self.trainable: + self.spike.value = self.spike_fun(V - self.Vth, self.V - self.Vth) + else: + self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) self.V.value = V self.w.value = w self.input[:] = 0. diff --git a/brainpy/dyn/neurons/tests/test_reduced_models.py b/brainpy/dyn/neurons/tests/test_reduced_models.py index bd2bcba56..b420165c9 100644 --- a/brainpy/dyn/neurons/tests/test_reduced_models.py +++ b/brainpy/dyn/neurons/tests/test_reduced_models.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -import unittest import brainpy as bp from absl.testing import parameterized from brainpy.dyn.neurons import reduced_models diff --git a/brainpy/dyn/rates/__init__.py b/brainpy/dyn/rates/__init__.py index f860bceee..0dec414f6 100644 --- a/brainpy/dyn/rates/__init__.py +++ b/brainpy/dyn/rates/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- from .populations import * -from .couplings import * diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index f509d52c8..e53b303fa 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -2,11 +2,10 @@ from typing import Union, Callable -import brainpy.math as bm -from brainpy import check +from brainpy import check, math as bm from brainpy.dyn.base import NeuGroup from brainpy.dyn.neurons.noise_groups import OUProcess -from brainpy.initialize import Initializer, Uniform, init_param, ZeroInit +from brainpy.initialize import Initializer, Uniform, parameter, variable, ZeroInit from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.tools.checking import check_float, check_initializer @@ -14,7 +13,7 @@ from brainpy.types import Shape, Tensor __all__ = [ - 'Population', + 'RateModel', 'FHN', 'FeedbackFHN', 'QIF', @@ -24,12 +23,11 @@ ] -class Population(NeuGroup): - def update(self, t, dt): - raise NotImplementedError +class RateModel(NeuGroup): + pass -class FHN(NeuGroup): +class FHN(RateModel): r"""FitzHugh-Nagumo system used in [1]_. .. math:: @@ -90,26 +88,32 @@ def __init__( y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), method: str = 'exp_auto', name: str = None, + + # parameter for training + trainable: bool = False, ): - super(FHN, self).__init__(size=size, name=name, keep_size=keep_size) + super(FHN, self).__init__(size=size, + name=name, + keep_size=keep_size, + trainable=trainable) # model parameters - self.alpha = init_param(alpha, self.var_shape, allow_none=False) - self.beta = init_param(beta, self.var_shape, allow_none=False) - self.gamma = init_param(gamma, self.var_shape, allow_none=False) - self.delta = init_param(delta, self.var_shape, allow_none=False) - self.epsilon = init_param(epsilon, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) + self.alpha = parameter(alpha, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.gamma = parameter(gamma, self.varshape, allow_none=False) + self.delta = parameter(delta, self.varshape, allow_none=False) + self.epsilon = parameter(epsilon, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) # noise parameters - self.x_ou_mean = init_param(x_ou_mean, self.var_shape, allow_none=False) # mV/ms, OU process - self.y_ou_mean = init_param(y_ou_mean, self.var_shape, allow_none=False) # mV/ms, OU process - self.x_ou_sigma = init_param(x_ou_sigma, self.var_shape, allow_none=False) # mV/ms/sqrt(ms), noise intensity - self.y_ou_sigma = init_param(y_ou_sigma, self.var_shape, allow_none=False) # mV/ms/sqrt(ms), noise intensity - self.x_ou_tau = init_param(x_ou_tau, self.var_shape, - allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process - self.y_ou_tau = init_param(y_ou_tau, self.var_shape, - allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity + self.x_ou_tau = parameter(x_ou_tau, self.varshape, + allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process + self.y_ou_tau = parameter(y_ou_tau, self.varshape, + allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process # initializers check_initializer(x_initializer, 'x_initializer') @@ -118,32 +122,38 @@ def __init__( self._y_initializer = y_initializer # variables - self.x = bm.Variable(init_param(x_initializer, self.var_shape)) - self.y = bm.Variable(init_param(y_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) + self.x = variable(x_initializer, trainable, self.varshape) + self.y = variable(y_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + self.input_y = variable(bm.zeros, trainable, self.varshape) # noise variables self.x_ou = self.y_ou = None if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.var_shape, - self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, method=method) if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.var_shape, - self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, method=method) # integral functions self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) - def reset(self): - self.x.value = init_param(self._x_initializer, self.var_shape) - self.y.value = init_param(self._y_initializer, self.var_shape) - self.input[:] = 0 + def reset_state(self, batch_size=None): + self.x.value = variable(self._x_initializer, batch_size, self.varshape) + self.y.value = variable(self._y_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: - self.x_ou.reset() + self.x_ou.reset_state(batch_size) if self.y_ou is not None: - self.y_ou.reset() + self.y_ou.reset_state(batch_size) def dx(self, x, t, y, x_ext): return - self.alpha * x ** 3 + self.beta * x ** 2 + self.gamma * x - y + x_ext @@ -151,21 +161,27 @@ def dx(self, x, t, y, x_ext): def dy(self, y, t, x, y_ext=0.): return (x - self.delta - self.epsilon * y) / self.tau + y_ext - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + + # input + if x is not None: self.input += x if self.x_ou is not None: self.input += self.x_ou.x - self.x_ou.update(t, dt) - y_ext = 0. + self.x_ou.update(tdi) if self.y_ou is not None: - y_ext = self.y_ou.x - self.y_ou.update(t, dt) - x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=y_ext, dt=dt) + self.input_y += self.y_ou.x + self.y_ou.update(tdi) + + # integral + x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y self.input[:] = 0. + self.input_y[:] = 0. -class FeedbackFHN(NeuGroup): +class FeedbackFHN(RateModel): r"""FitzHugh-Nagumo model with recurrent neural feedback. The equation of the feedback FitzHugh-Nagumo model [4]_ is given by @@ -218,8 +234,6 @@ class FeedbackFHN(NeuGroup): y_ou_tau: Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - - References ---------- .. [4] Plant, Richard E. (1981). *A FitzHugh Differential-Difference @@ -252,32 +266,37 @@ def __init__( # other parameters x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), - method: str = 'rk4', - sde_method: str = None, + method: str = 'exp_auto', name: str = None, - dt: float = None + dt: float = None, + + # parameter for training + training: bool = False, ): - super(FeedbackFHN, self).__init__(size=size, name=name, keep_size=keep_size) + super(FeedbackFHN, self).__init__(size=size, + name=name, + keep_size=keep_size, + trainable=training) # dt self.dt = bm.get_dt() if dt is None else dt check_float(self.dt, 'dt', allow_none=False, min_bound=0., allow_int=False) # parameters - self.a = init_param(a, self.var_shape, allow_none=False) - self.b = init_param(b, self.var_shape, allow_none=False) - self.delay = init_param(delay, self.var_shape, allow_none=False) - self.tau = init_param(tau, self.var_shape, allow_none=False) - self.mu = init_param(mu, self.var_shape, allow_none=False) # feedback strength - self.v0 = init_param(v0, self.var_shape, allow_none=False) # resting potential + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.delay = parameter(delay, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.mu = parameter(mu, self.varshape, allow_none=False) # feedback strength + self.v0 = parameter(v0, self.varshape, allow_none=False) # resting potential # noise parameters - self.x_ou_mean = init_param(x_ou_mean, self.var_shape, allow_none=False) - self.y_ou_mean = init_param(y_ou_mean, self.var_shape, allow_none=False) - self.x_ou_sigma = init_param(x_ou_sigma, self.var_shape, allow_none=False) - self.y_ou_sigma = init_param(y_ou_sigma, self.var_shape, allow_none=False) - self.x_ou_tau = init_param(x_ou_tau, self.var_shape, allow_none=False) - self.y_ou_tau = init_param(y_ou_tau, self.var_shape, allow_none=False) + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) # initializers check_initializer(x_initializer, 'x_initializer') @@ -286,36 +305,42 @@ def __init__( self._y_initializer = y_initializer # variables - self.x = bm.Variable(init_param(x_initializer, self.var_shape)) - self.y = bm.Variable(init_param(y_initializer, self.var_shape)) + self.x = variable(x_initializer, training, self.varshape) + self.y = variable(y_initializer, training, self.varshape) self.x_delay = bm.TimeDelay(self.x, self.delay, dt=self.dt, interp_method='round') - self.input = bm.Variable(bm.zeros(self.var_shape)) + self.input = variable(bm.zeros, training, self.varshape) + self.input_y = variable(bm.zeros, training, self.varshape) # noise variables self.x_ou = self.y_ou = None if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.var_shape, - self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, - method=sde_method) + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.var_shape, - self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, - method=sde_method) + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) # integral self.integral = odeint(method=method, f=JointEq([self.dx, self.dy]), state_delays={'V': self.x_delay}) - def reset(self): - self.x.value = init_param(self._x_initializer, self.var_shape) - self.y.value = init_param(self._y_initializer, self.var_shape) + def reset_state(self, batch_size=None): + self.x.value = variable(self._x_initializer, batch_size, self.varshape) + self.y.value = variable(self._y_initializer, batch_size, self.varshape) self.x_delay.reset(self.x, self.delay) - self.input[:] = 0 + self.input = variable(bm.zeros, batch_size, self.varshape) + self.input_y = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: - self.x_ou.reset() + self.x_ou.reset_state(batch_size) if self.y_ou is not None: - self.y_ou.reset() + self.y_ou.reset_state(batch_size) def dx(self, x, t, y, x_ext): return x - x * x * x / 3 - y + x_ext + self.mu * (self.x_delay(t - self.delay) - self.v0) @@ -328,23 +353,28 @@ def _check_dt(self, dt): f'not consistent with the "dt" {self.dt} ' f'used in model definition.') - def update(self, t, dt): + def update(self, tdi, x=None): + t = tdi['t'] + dt = tdi['dt'] if check.is_checking(): check_error_in_jit(not bm.isclose(dt, self.dt), self._check_dt, dt) + + if x is not None: self.input += x if self.x_ou is not None: self.input += self.x_ou.x - self.x_ou.update(t, dt) - y_ext = 0. + self.x_ou.update(tdi) if self.y_ou is not None: - y_ext = self.y_ou.x - self.y_ou.update(t, dt) - x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=y_ext, dt=dt) + self.input_y += self.y_ou.x + self.y_ou.update(tdi) + + x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y self.input[:] = 0. + self.input_y[:] = 0. -class QIF(NeuGroup): +class QIF(RateModel): r"""A mean-field model of a quadratic integrate-and-fire neuron population. **Model Descriptions** @@ -434,26 +464,31 @@ def __init__( y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), method: str = 'exp_auto', name: str = None, - sde_method: str = None, + + # parameter for training + trainable: bool = False, ): - super(QIF, self).__init__(size=size, name=name, keep_size=keep_size) + super(QIF, self).__init__(size=size, + name=name, + keep_size=keep_size, + trainable=trainable) # parameters - self.tau = init_param(tau, self.var_shape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) # the mean of a Lorenzian distribution over the neural excitability in the population - self.eta = init_param(eta, self.var_shape, allow_none=False) + self.eta = parameter(eta, self.varshape, allow_none=False) # the half-width at half maximum of the Lorenzian distribution over the neural excitability - self.delta = init_param(delta, self.var_shape, allow_none=False) + self.delta = parameter(delta, self.varshape, allow_none=False) # the strength of the recurrent coupling inside the population - self.J = init_param(J, self.var_shape, allow_none=False) + self.J = parameter(J, self.varshape, allow_none=False) # noise parameters - self.x_ou_mean = init_param(x_ou_mean, self.var_shape, allow_none=False) - self.y_ou_mean = init_param(y_ou_mean, self.var_shape, allow_none=False) - self.x_ou_sigma = init_param(x_ou_sigma, self.var_shape, allow_none=False) - self.y_ou_sigma = init_param(y_ou_sigma, self.var_shape, allow_none=False) - self.x_ou_tau = init_param(x_ou_tau, self.var_shape, allow_none=False) - self.y_ou_tau = init_param(y_ou_tau, self.var_shape, allow_none=False) + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) # initializers check_initializer(x_initializer, 'x_initializer') @@ -462,32 +497,38 @@ def __init__( self._y_initializer = y_initializer # variables - self.x = bm.Variable(init_param(x_initializer, self.var_shape)) - self.y = bm.Variable(init_param(y_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) + self.x = variable(x_initializer, trainable, self.varshape) + self.y = variable(y_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + self.input_y = variable(bm.zeros, trainable, self.varshape) # noise variables self.x_ou = self.y_ou = None if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.var_shape, - self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, - method=sde_method) + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.var_shape, - self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, - method=sde_method) + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) # functions self.integral = odeint(JointEq([self.dx, self.dy]), method=method) - def reset(self): - self.x.value = init_param(self._x_initializer, self.var_shape) - self.y.value = init_param(self._y_initializer, self.var_shape) - self.input[:] = 0 + def reset_state(self, batch_size=None): + self.x.value = variable(self._x_initializer, batch_size, self.varshape) + self.y.value = variable(self._y_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: - self.x_ou.reset() + self.x_ou.reset_state(batch_size) if self.y_ou is not None: - self.y_ou.reset() + self.y_ou.reset_state(batch_size) def dy(self, y, t, x, y_ext): return (self.delta / (bm.pi * self.tau) + 2. * x * y + y_ext) / self.tau @@ -496,21 +537,25 @@ def dx(self, x, t, y, x_ext): return (x ** 2 + self.eta + x_ext + self.J * y * self.tau - (bm.pi * y * self.tau) ** 2) / self.tau - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + + if x is not None: self.input += x if self.x_ou is not None: self.input += self.x_ou.x - self.x_ou.update(t, dt) - y_ext = 0. + self.x_ou.update(tdi) if self.y_ou is not None: - y_ext = self.y_ou.x - self.y_ou.update(t, dt) - x, y = self.integral(self.x, self.y, t=t, x_ext=self.input, y_ext=y_ext, dt=dt) + self.input_y += self.y_ou.x + self.y_ou.update(tdi) + + x, y = self.integral(self.x, self.y, t=t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y self.input[:] = 0. + self.input_y[:] = 0. -class StuartLandauOscillator(Population): +class StuartLandauOscillator(RateModel): r""" Stuart-Landau model with Hopf bifurcation. @@ -557,24 +602,27 @@ def __init__( x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5), y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5), method: str = 'exp_auto', - sde_method: str = None, name: str = None, + + # parameter for training + trainable: bool = False, ): super(StuartLandauOscillator, self).__init__(size=size, name=name, - keep_size=keep_size) + keep_size=keep_size, + trainable=trainable) # model parameters - self.a = init_param(a, self.var_shape, allow_none=False) - self.w = init_param(w, self.var_shape, allow_none=False) + self.a = parameter(a, self.varshape, allow_none=False) + self.w = parameter(w, self.varshape, allow_none=False) # noise parameters - self.x_ou_mean = init_param(x_ou_mean, self.var_shape, allow_none=False) - self.y_ou_mean = init_param(y_ou_mean, self.var_shape, allow_none=False) - self.x_ou_sigma = init_param(x_ou_sigma, self.var_shape, allow_none=False) - self.y_ou_sigma = init_param(y_ou_sigma, self.var_shape, allow_none=False) - self.x_ou_tau = init_param(x_ou_tau, self.var_shape, allow_none=False) - self.y_ou_tau = init_param(y_ou_tau, self.var_shape, allow_none=False) + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) # initializers check_initializer(x_initializer, 'x_initializer') @@ -583,32 +631,38 @@ def __init__( self._y_initializer = y_initializer # variables - self.x = bm.Variable(init_param(x_initializer, self.var_shape)) - self.y = bm.Variable(init_param(y_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) + self.x = variable(x_initializer, trainable, self.varshape) + self.y = variable(y_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + self.input_y = variable(bm.zeros, trainable, self.varshape) # noise variables self.x_ou = self.y_ou = None if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.var_shape, - self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, - method=sde_method) + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.var_shape, - self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, - method=sde_method) + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) # integral functions self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) - def reset(self): - self.x.value = init_param(self._x_initializer, self.var_shape) - self.y.value = init_param(self._y_initializer, self.var_shape) - self.input[:] = 0 + def reset_state(self, batch_size=None): + self.x.value = variable(self._x_initializer, batch_size, self.varshape) + self.y.value = variable(self._y_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: - self.x_ou.reset() + self.x_ou.reset_state(batch_size) if self.y_ou is not None: - self.y_ou.reset() + self.y_ou.reset_state(batch_size) def dx(self, x, t, y, x_ext, a, w): return (a - x * x - y * y) * x - w * y + x_ext @@ -616,22 +670,32 @@ def dx(self, x, t, y, x_ext, a, w): def dy(self, y, t, x, y_ext, a, w): return (a - x * x - y * y) * y - w * y + y_ext - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + + if x is not None: self.input += x if self.x_ou is not None: self.input += self.x_ou.x - self.x_ou.update(t, dt) - y_ext = 0. + self.x_ou.update(tdi) if self.y_ou is not None: - y_ext = self.y_ou.x - self.y_ou.update(t, dt) - x, y = self.integral(self.x, self.y, t, x_ext=self.input, - y_ext=y_ext, a=self.a, w=self.w, dt=dt) + self.input_y += self.y_ou.x + self.y_ou.update(tdi) + + x, y = self.integral(self.x, + self.y, + t=t, + x_ext=self.input, + y_ext=self.input_y, + a=self.a, + w=self.w, + dt=dt) self.x.value = x self.y.value = y self.input[:] = 0. + self.input_y[:] = 0. -class WilsonCowanModel(Population): +class WilsonCowanModel(RateModel): """Wilson-Cowan population model. @@ -690,32 +754,34 @@ def __init__( y_initializer: Union[Initializer, Callable, Tensor] = Uniform(max_val=0.05), # other parameters - sde_method: str = None, method: str = 'exp_euler_auto', name: str = None, + + # parameter for training + trainable: bool = False, ): super(WilsonCowanModel, self).__init__(size=size, name=name, keep_size=keep_size) # model parameters - self.E_a = init_param(E_a, self.var_shape, allow_none=False) - self.I_a = init_param(I_a, self.var_shape, allow_none=False) - self.E_tau = init_param(E_tau, self.var_shape, allow_none=False) - self.I_tau = init_param(I_tau, self.var_shape, allow_none=False) - self.E_theta = init_param(E_theta, self.var_shape, allow_none=False) - self.I_theta = init_param(I_theta, self.var_shape, allow_none=False) - self.wEE = init_param(wEE, self.var_shape, allow_none=False) - self.wIE = init_param(wIE, self.var_shape, allow_none=False) - self.wEI = init_param(wEI, self.var_shape, allow_none=False) - self.wII = init_param(wII, self.var_shape, allow_none=False) - self.r = init_param(r, self.var_shape, allow_none=False) + self.E_a = parameter(E_a, self.varshape, allow_none=False) + self.I_a = parameter(I_a, self.varshape, allow_none=False) + self.E_tau = parameter(E_tau, self.varshape, allow_none=False) + self.I_tau = parameter(I_tau, self.varshape, allow_none=False) + self.E_theta = parameter(E_theta, self.varshape, allow_none=False) + self.I_theta = parameter(I_theta, self.varshape, allow_none=False) + self.wEE = parameter(wEE, self.varshape, allow_none=False) + self.wIE = parameter(wIE, self.varshape, allow_none=False) + self.wEI = parameter(wEI, self.varshape, allow_none=False) + self.wII = parameter(wII, self.varshape, allow_none=False) + self.r = parameter(r, self.varshape, allow_none=False) # noise parameters - self.x_ou_mean = init_param(x_ou_mean, self.var_shape, allow_none=False) - self.y_ou_mean = init_param(y_ou_mean, self.var_shape, allow_none=False) - self.x_ou_sigma = init_param(x_ou_sigma, self.var_shape, allow_none=False) - self.y_ou_sigma = init_param(y_ou_sigma, self.var_shape, allow_none=False) - self.x_ou_tau = init_param(x_ou_tau, self.var_shape, allow_none=False) - self.y_ou_tau = init_param(y_ou_tau, self.var_shape, allow_none=False) + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) # initializers check_initializer(x_initializer, 'x_initializer') @@ -724,32 +790,38 @@ def __init__( self._y_initializer = y_initializer # variables - self.x = bm.Variable(init_param(x_initializer, self.var_shape)) - self.y = bm.Variable(init_param(y_initializer, self.var_shape)) - self.input = bm.Variable(bm.zeros(self.var_shape)) + self.x = variable(x_initializer, trainable, self.varshape) + self.y = variable(y_initializer, trainable, self.varshape) + self.input = variable(bm.zeros, trainable, self.varshape) + self.input_y = variable(bm.zeros, trainable, self.varshape) # noise variables self.x_ou = self.y_ou = None if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.var_shape, - self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, - method=sde_method) + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.var_shape, - self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, - method=sde_method) + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) # functions self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) - def reset(self): - self.x.value = init_param(self._x_initializer, self.var_shape) - self.y.value = init_param(self._y_initializer, self.var_shape) - self.input[:] = 0 + def reset_state(self, batch_size=None): + self.x.value = variable(self._x_initializer, batch_size, self.varshape) + self.y.value = variable(self._y_initializer, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: - self.x_ou.reset() + self.x_ou.reset_state(batch_size) if self.y_ou is not None: - self.y_ou.reset() + self.y_ou.reset_state(batch_size) def F(self, x, a, theta): return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta)) @@ -762,41 +834,43 @@ def dy(self, y, t, x, y_ext): x = self.wEI * x - self.wII * y + y_ext return (-y + (1 - self.r * y) * self.F(x, self.I_a, self.I_theta)) / self.I_tau - def update(self, t, dt): + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + if x is not None: self.input += x if self.x_ou is not None: self.input += self.x_ou.x - self.x_ou.update(t, dt) - y_ext = 0. + self.x_ou.update(tdi) if self.y_ou is not None: - y_ext = self.y_ou.x - self.y_ou.update(t, dt) - x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=y_ext, dt=dt) + self.input_y += self.y_ou.x + self.y_ou.update(tdi) + x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y self.input[:] = 0. + self.input_y[:] = 0. -class JansenRitModel(Population): +class JansenRitModel(RateModel): pass -class KuramotoOscillator(Population): +class KuramotoOscillator(RateModel): pass -class ThetaNeuron(Population): +class ThetaNeuron(RateModel): pass -class RateQIFWithSFA(Population): +class RateQIFWithSFA(RateModel): pass -class VanDerPolOscillator(Population): +class VanDerPolOscillator(RateModel): pass -class ThresholdLinearModel(Population): +class ThresholdLinearModel(RateModel): r"""A threshold linear rate model. The threshold linear rate model is given by [1]_ @@ -835,46 +909,60 @@ def __init__( i_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), seed: int = None, keep_size: bool = False, - name: str = None + name: str = None, + + # parameter for training + trainable: bool = False, ): - super(ThresholdLinearModel, self).__init__(size, name=name) + super(ThresholdLinearModel, self).__init__(size, + name=name, + keep_size=keep_size, + trainable=trainable) # parameters self.seed = seed - self.tau_e = init_param(tau_e, self.var_shape, False) - self.tau_i = init_param(tau_i, self.var_shape, False) - self.beta_e = init_param(beta_e, self.var_shape, False) - self.beta_i = init_param(beta_i, self.var_shape, False) - self.noise_e = init_param(noise_e, self.var_shape, False) - self.noise_i = init_param(noise_i, self.var_shape, False) + self.tau_e = parameter(tau_e, self.varshape, False) + self.tau_i = parameter(tau_i, self.varshape, False) + self.beta_e = parameter(beta_e, self.varshape, False) + self.beta_i = parameter(beta_i, self.varshape, False) + self.noise_e = parameter(noise_e, self.varshape, False) + self.noise_i = parameter(noise_i, self.varshape, False) self._e_initializer = e_initializer self._i_initializer = i_initializer # variables - self.e = bm.Variable(init_param(e_initializer, self.var_shape)) # Firing rate of excitatory population - self.i = bm.Variable(init_param(i_initializer, self.var_shape)) # Firing rate of inhibitory population - self.Ie = bm.Variable(bm.zeros(self.var_shape)) # Input of excitaory population - self.Ii = bm.Variable(bm.zeros(self.var_shape)) # Input of inhibitory population + self.e = variable(e_initializer, trainable, self.varshape) # Firing rate of excitatory population + self.i = variable(i_initializer, trainable, self.varshape) # Firing rate of inhibitory population + self.Ie = variable(bm.zeros, trainable, self.varshape) # Input of excitaory population + self.Ii = variable(bm.zeros, trainable, self.varshape) # Input of inhibitory population if bm.any(self.noise_e != 0) or bm.any(self.noise_i != 0): self.rng = bm.random.RandomState(self.seed) - def reset(self): + def reset(self, batch_size=None): self.rng.seed(self.seed) - self.e.value = init_param(self._e_initializer, self.var_shape) - self.i.value = init_param(self._i_initializer, self.var_shape) - self.Ie[:] = 0. - self.Ii[:] = 0. + self.reset_state(batch_size) - def update(self, t, dt): + def reset_state(self, batch_size=None): + self.e.value = variable(self._e_initializer, batch_size, self.varshape) + self.i.value = variable(self._i_initializer, batch_size, self.varshape) + self.Ie.value = variable(bm.zeros, batch_size, self.varshape) + self.Ii.value = variable(bm.zeros, batch_size, self.varshape) + + def update(self, tdi, x=None): + t, dt = tdi['t'], tdi['dt'] + + if x is not None: self.Ie += x de = -self.e + self.beta_e * bm.maximum(self.Ie, 0.) if bm.any(self.noise_e != 0.): - de += self.rng.randn(self.var_shape) * self.noise_e + de += self.rng.randn(self.varshape) * self.noise_e de = de / self.tau_e self.e.value = bm.maximum(self.e + de * dt, 0.) + di = -self.i + self.beta_i * bm.maximum(self.Ii, 0.) if bm.any(self.noise_i != 0.): - di += self.rng.randn(self.var_shape) * self.noise_i + di += self.rng.randn(self.varshape) * self.noise_i di = di / self.tau_i self.i.value = bm.maximum(self.i + di * dt, 0.) + self.Ie[:] = 0. self.Ii[:] = 0. diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py index e6d7a21ed..0bc01702a 100644 --- a/brainpy/dyn/runners.py +++ b/brainpy/dyn/runners.py @@ -12,15 +12,15 @@ from jax.tree_util import tree_map, tree_flatten from brainpy import math as bm -from brainpy.base.collector import TensorCollector from brainpy.dyn.base import DynamicalSystem from brainpy.errors import RunningError from brainpy.running.runner import Runner -from brainpy.types import Tensor -from .utils import serialize_kwargs, check_data_batch_size +from brainpy.tools.checking import check_float, serialize_kwargs +from brainpy.tools.others.dicts import DotDict +from brainpy.types import Tensor, Output, Monitor __all__ = [ - 'DSRunner', 'ReportRunner', + 'DSRunner', ] SUPPORTED_INPUT_OPS = ['-', '+', '*', '/', '='] @@ -100,7 +100,7 @@ def check_and_format_inputs(host, inputs): else: raise RunningError(f'For each input, input[0] must be a string to ' f'specify variable of the target, but we got {key}.') - inputs_which_found_target.append((real_target, ) + tuple(one_input[1:])) + inputs_which_found_target.append((real_target,) + tuple(one_input[1:])) # checking 2: relative access # Check whether the input target node is accessible @@ -116,7 +116,7 @@ def check_and_format_inputs(host, inputs): if not hasattr(real_target, key): raise RunningError(f'Input target key "{key}" is not defined in {real_target}.') real_target = getattr(real_target, key) - inputs_which_found_target.append((real_target, ) + tuple(one_input[1:])) + inputs_which_found_target.append((real_target,) + tuple(one_input[1:])) # 3. format inputs # --------- @@ -162,66 +162,60 @@ def check_and_format_inputs(host, inputs): def build_inputs(inputs): - fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - - _has_iter_array = False - for variable, value, type_, op in inputs: - # variable - if not isinstance(variable, bm.Variable): - raise RunningError(f'{variable}\n is not a dynamically changed Variable, ' - f'its value will not change, we think there is no need to ' - f'give its input.') - - # input data - if type_ == 'iter': - if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)): - array_inputs[op].append([variable, bm.asarray(value)]) - _has_iter_array = True - else: - next_inputs[op].append([variable, iter(value)]) - elif type_ == 'func': - func_inputs[op].append([variable, value]) + fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + + _has_iter_array = False + for variable, value, type_, op in inputs: + # variable + if not isinstance(variable, bm.Variable): + raise RunningError(f'{variable}\n is not a dynamically changed Variable, ' + f'its value will not change, we think there is no need to ' + f'give its input.') + + # input data + if type_ == 'iter': + if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)): + array_inputs[op].append([variable, bm.asarray(value)]) + _has_iter_array = True else: - fix_inputs[op].append([variable, value]) - - index = None - if _has_iter_array: - index = bm.Variable(bm.zeros(1, dtype=int)) - - def _f_ops(ops, var, data): - if ops == '=': - var[:] = data - elif ops == '+': - var += data - elif ops == '-': - var -= data - elif ops == '*': - var *= data - elif ops == '/': - var /= data - else: - raise ValueError - - def func(_t, _dt): - for ops, values in fix_inputs.items(): - for var, data in values: - _f_ops(ops, var, data) - for ops, values in array_inputs.items(): - for var, data in values: - _f_ops(ops, var, data[index[0]]) - for ops, values in func_inputs.items(): - for var, data in values: - _f_ops(ops, var, data(_t, _dt)) - for ops, values in next_inputs.items(): - for var, data in values: - _f_ops(ops, var, next(data)) - if _has_iter_array: - index[0] += 1 - - return func, index + next_inputs[op].append([variable, iter(value)]) + elif type_ == 'func': + func_inputs[op].append([variable, value]) + else: + fix_inputs[op].append([variable, value]) + + def _f_ops(ops, var, data): + if ops == '=': + var[:] = data + elif ops == '+': + var += data + elif ops == '-': + var -= data + elif ops == '*': + var *= data + elif ops == '/': + var /= data + else: + raise ValueError(f'Unknown input operation: {ops}') + + def func(tdi): + for ops, values in fix_inputs.items(): + for var, data in values: + _f_ops(ops, var, data) + for ops, values in array_inputs.items(): + for var, data in values: + _f_ops(ops, var, data[tdi['i']]) + for ops, values in func_inputs.items(): + for var, data in values: + _f_ops(ops, var, data(tdi['t'], tdi['dt'])) + for ops, values in next_inputs.items(): + for var, data in values: + _f_ops(ops, var, next(data)) + + return func, _has_iter_array class DSRunner(Runner): @@ -254,6 +248,7 @@ def __init__( target: DynamicalSystem, inputs: Sequence = (), dt: float = None, + t0: Union[float, int] = 0., **kwargs ): if not isinstance(target, DynamicalSystem): @@ -261,6 +256,11 @@ def __init__( f'but we got {type(target)}: {target}') super(DSRunner, self).__init__(target=target, **kwargs) + # t0 and i0 + self._t0 = t0 + self.i0 = 0 + self.t0 = check_float(t0, 't0', allow_none=False, allow_int=True) + # parameters dt = bm.get_dt() if dt is None else dt if not isinstance(dt, (int, float)): @@ -268,47 +268,39 @@ def __init__( self.dt = dt # Build the monitor function - self._monitor_step = self.build_monitors(*self.format_monitors()) + self._mon_info = self.format_monitors() + # self._monitor_step = self.build_monitors(*self.format_monitors()) # Build input function inputs = check_and_format_inputs(host=target, inputs=inputs) - self._input_step, self._i = build_inputs(inputs) - - # start simulation time - self._start_t = None - - # JAX does not support iterator in fori_loop, scan, etc. - # https://github.com/google/jax/issues/3567 - # We use Variable i to index the current input data. - if self._i is not None: # must behind of "self.build_input()" - self.dyn_vars.update({'_i': self._i}) + self._input_step, _ = build_inputs(inputs) # run function - self._predict_func = dict() - - def build_monitors(self, return_without_idx, return_with_idx, flatten=False): - if flatten: - def func(_t, _dt): - res = {k: (v.flatten() if bm.ndim(v) > 1 else v.value) for k, v in return_without_idx.items()} - res.update({k: (v.flatten()[idx] if bm.ndim(v) > 1 else v[idx]) for k, (v, idx) in return_with_idx.items()}) - res.update({k: f(_t, _dt) for k, f in self.fun_monitors.items()}) - return res - else: - def func(_t, _dt): - res = {k: v.value for k, v in return_without_idx.items()} - res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) - res.update({k: f(_t, _dt) for k, f in self.fun_monitors.items()}) - return res + self._f_predict_compiled = dict() + + def build_monitors(self, return_without_idx, return_with_idx, shared_args: dict): + def func(tdi): + res = {k: v.value for k, v in return_without_idx.items()} + res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) + res.update({k: f(tdi) for k, f in self.fun_monitors.items()}) + return res return func + def reset_state(self): + self.i0 = 0 + self.t0 = check_float(self._t0, 't0', allow_none=False, allow_int=True) + def predict( self, - xs: Union[Tensor, Dict[str, Tensor]], + duration: Union[float, int] = None, + inputs: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]] = None, + inputs_are_batching: bool = False, reset_state: bool = False, shared_args: Dict = None, progress_bar: bool = True, - ): + eval_time: bool = False + ) -> Output: """Predict a series of input data with the given target model. This function use the JIT compilation to accelerate the model simulation. @@ -317,62 +309,92 @@ def predict( Parameters ---------- - xs: Tensor, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. + duration: int, float + The simulation time length. + inputs: Tensor, dict of Tensor, sequence of Tensor + The input data. If ``inputs_are_batching=True``, ``inputs`` must be a + PyTree of data with two dimensions: `(num_sample, num_time, ...)`. + Otherwise, the ``inputs`` should be a PyTree of data with one dimension: + `(num_time, ...)`. + inputs_are_batching: bool + Whether the ``inputs`` are batching. If `True`, the batching axis is the + first dimension. reset_state: bool Whether reset the model states. shared_args: optional, dict The shared arguments across different layers. progress_bar: bool Whether report the progress of the simulation using progress bar. + eval_time: bool + Whether ro evaluate the running time. Returns ------- - output: Tensor, dict + output: Tensor, dict, sequence The model output. """ - # format input data - xs, num_step, num_batch = self._check_xs(xs) - times = jax.device_put(jnp.linspace(0., self.dt * (num_step - 1), num_step)) - xs = (times, xs,) - # reset the model states + + # shared arguments + if shared_args is None: shared_args = dict() + shared_args['fit'] = shared_args.get('fit', False) + + # times and inputs + times, indices, xs, num_step, num_batch, duration, description = self._format_xs( + duration, inputs, inputs_are_batching) + + # reset the states of the model and the runner if reset_state: self.target.reset_state(num_batch) - # init monitor + self.reset_state() + indices += self.i0 + times += self.t0 + + # build monitor for key in self.mon.var_names: self.mon[key] = [] # reshape the monitor items + # init progress bar if self.progress_bar and progress_bar: self._pbar = tqdm.auto.tqdm(total=num_step) - self._pbar.set_description(f"Predict {num_step} steps: ", refresh=True) - # prediction - outputs, hists = self._predict(xs=xs, shared_args=shared_args) - outputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), outputs, is_leaf=lambda x: isinstance(x, bm.JaxArray)) - hists = tree_map(lambda x: bm.moveaxis(x, 0, 1), hists, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + self._pbar.set_description(description, refresh=True) + + # running + if eval_time: t0 = time.time() + outputs, hists = self._predict(xs=(times, indices, xs), shared_args=shared_args) + if eval_time: running_time = time.time() - t0 + + # format + if inputs_are_batching: + outputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), outputs, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + hists = tree_map(lambda x: bm.moveaxis(x, 0, 1), hists, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + # close the progress bar if self.progress_bar and progress_bar: self._pbar.close() + # post-running for monitors - for key, val in hists.items(): - self.mon[key] = val + hists['ts'] = times + self.dt if self.numpy_mon_after_run: - self.mon['ts'] = np.asarray(self.mon['ts']) - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) - return outputs + hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.JaxArray)) + for key in hists.keys(): + self.mon[key] = hists[key] + self.i0 += times.shape[0] + self.t0 += duration + return outputs if not eval_time else (running_time, outputs) def _predict( self, - xs: Sequence[Tensor], + xs: Sequence, shared_args: Dict = None, - ): + ) -> Union[Output, Monitor]: """Predict the output according to the inputs. Parameters ---------- xs: sequence - Each tensor should have the shape of `(num_time, num_batch, num_feature)`. + Must be a tuple/list of data, including `(times, indices, inputs)`. + If `inputs` is not None, it should be a tensor with the shape of + :math:`(num_time, ...)`. shared_args: optional, dict The shared keyword arguments. @@ -381,177 +403,182 @@ def _predict( outputs, hists A tuple of pair of (outputs, hists). """ - _predict_func = self._get_predict_func(shared_args) + _predict_func = self.f_predict(shared_args) outputs, hists = _predict_func(xs) return outputs, hists - def run(self, duration, start_t=None, shared_args: Dict = None, eval_time=False): - """The running function. + def run( + self, + duration: Union[float, int] = None, + inputs: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]] = None, + inputs_are_batching: bool = False, + reset_state: bool = False, + shared_args: Dict = None, + progress_bar: bool = True, + eval_time: bool = False + ) -> Output: + """Predict a series of input data with the given target model. + + This function use the JIT compilation to accelerate the model simulation. + Moreover, it can automatically monitor the node variables, states, inputs, + feedbacks and its output. Parameters ---------- - duration : float, int, tuple, list - The running duration. - start_t : float, optional - The start time. - shared_args: dict - The shared arguments across nodes. + duration: int, float + The simulation time length. + inputs: Tensor, dict of Tensor, sequence of Tensor + The input data. If ``inputs_are_batching=True``, ``inputs`` must be a + PyTree of data with two dimensions: `(num_sample, num_time, ...)`. + Otherwise, the ``inputs`` should be a PyTree of data with one dimension: + `(num_time, ...)`. + inputs_are_batching: bool + Whether the ``inputs`` are batching. If `True`, the batching axis is the + first dimension. + reset_state: bool + Whether reset the model states. + shared_args: optional, dict + The shared arguments across different layers. + progress_bar: bool + Whether report the progress of the simulation using progress bar. eval_time: bool - Whether we record the running time? + Whether ro evaluate the running time. + + Returns + ------- + output: Tensor, dict, sequence + The model output. """ - # time step - if start_t is None: - if self._start_t is None: - start_t = 0. - else: - start_t = float(self._start_t) - end_t = float(start_t + duration) - # times - times = jax.device_put(jnp.arange(start_t, end_t, self.dt)) - # build monitor - for key in self.mon.var_names: - self.mon[key] = [] # reshape the monitor items - # running - if self.progress_bar: - self._pbar = tqdm.auto.tqdm(total=times.size) - self._pbar.set_description(f"Running a duration of {round(float(duration), 3)} ({times.size} steps)", - refresh=True) - if eval_time: - t0 = time.time() - outputs, hists = self._predict((times, None), shared_args=shared_args) - if eval_time: - running_time = time.time() - t0 - if self.progress_bar: - self._pbar.close() - # post-running - self.mon.ts = times + self.dt - for key in hists.keys(): - self.mon[key] = bm.asarray(hists[key]) - self._start_t = end_t - if self.numpy_mon_after_run: - self.mon['ts'] = np.asarray(self.mon['ts']) - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) - if eval_time: - return running_time, outputs + return self.predict(duration=duration, + inputs=inputs, + inputs_are_batching=inputs_are_batching, + reset_state=reset_state, + shared_args=shared_args, + progress_bar=progress_bar, + eval_time=eval_time) + + def __call__(self, *args, **kwargs) -> Output: + return self.predict(*args, **kwargs) + + def _format_xs(self, duration, inputs, inputs_are_batching=True, move_axis=True): + if duration is None: + if inputs is None: + raise ValueError('"duration" and "inputs" can not both be None.') + xs, num_step, num_batch = self._check_xs(inputs, + move_axis=move_axis, + inputs_are_batching=inputs_are_batching) + indices = jax.device_put(jnp.arange(num_step)) + times = jax.device_put(indices * self.dt) + description = f'Predict {num_step} steps: ' + duration = num_step * self.dt else: - return outputs - - def _check_xs(self, xs, move_axis=True): + times = jax.device_put(jnp.arange(0, duration, self.dt)) + num_step = times.shape[0] + indices = jax.device_put(jnp.arange(num_step)) + description = f'Running a duration of {round(float(duration), 3)} ({times.shape[0]} steps)' + if inputs is None: + xs, num_batch = None, None + else: + xs, num_step_, num_batch = self._check_xs(inputs, + move_axis=move_axis, + inputs_are_batching=inputs_are_batching) + if num_step != num_step: + raise ValueError('The step numbers of "time" and "inputs" ' + f'do not match: {num_step_} != {num_step}.') + return times, indices, xs, num_step, num_batch, duration, description + + def _check_xs(self, xs, move_axis=True, inputs_are_batching=True): leaves, tree = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + # get information of time step and batch size - num_times, num_batch_sizes = [], [] - for val in leaves: - num_batch_sizes.append(val.shape[0]) - num_times.append(val.shape[1]) + if inputs_are_batching: + num_times, num_batch_sizes = [], [] + for val in leaves: + num_batch_sizes.append(val.shape[0]) + num_times.append(val.shape[1]) + else: + num_times = [val.shape[0] for val in leaves] if len(set(num_times)) != 1: raise ValueError(f'Number of time step is different across tensors in ' f'the provided "xs". We got {set(num_times)}.') - if len(set(num_batch_sizes)) != 1: - raise ValueError(f'Number of batch size is different across tensors in ' - f'the provided "xs". We got {set(num_batch_sizes)}.') num_step = num_times[0] - num_batch = num_batch_sizes[0] + if inputs_are_batching: + if len(set(num_batch_sizes)) != 1: + raise ValueError(f'Number of batch size is different across tensors in ' + f'the provided "xs". We got {set(num_batch_sizes)}.') + num_batch = num_batch_sizes[0] + else: + num_batch = None + # change shape to (num_time, num_sample, num_feature) - if move_axis: - xs = tree_map(lambda x: bm.moveaxis(x, 0, 1), xs) + if move_axis and inputs_are_batching: + xs = tree_map(lambda x: bm.moveaxis(x, 0, 1), xs, + is_leaf=lambda x: isinstance(x, bm.JaxArray)) return xs, num_step, num_batch - def _get_predict_func(self, shared_args: Dict = None): + def f_predict(self, shared_args: Dict = None): if shared_args is None: shared_args = dict() - shared_kwargs_str = serialize_kwargs(shared_args) - if shared_kwargs_str not in self._predict_func: - self._predict_func[shared_kwargs_str] = self._make_predict_func(shared_args) - return self._predict_func[shared_kwargs_str] - - def _make_predict_func(self, shared_args: Dict): - if not isinstance(shared_args, dict): - raise ValueError(f'"shared_kwargs" must be a dict, but got {type(shared_args)}') - - def _step_func(inputs): - t, x = inputs - self._input_step(t, self.dt) - if x is None: - args = (t, self.dt) - else: - args = (t, self.dt, x) - kwargs = dict() - if len(shared_args): - kwargs['shared_args'] = shared_args - out = self.target.update(*args, **kwargs) - if self.progress_bar: - id_tap(lambda *arg: self._pbar.update(), ()) - return out, self._monitor_step(t, self.dt) - - if self.jit['predict']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True) - return lambda all_inputs: f(all_inputs)[1] - - else: - def run_func(xs): - outputs = [] - monitors = {key: [] for key in set(self.mon.item_contents.keys()) | set(self.fun_monitors.keys())} - for i in range(check_data_batch_size(xs)): - x = tree_map(lambda x: x[i], xs) - output, mon = _step_func(x) - outputs.append(output) - for key, value in mon.items(): - monitors[key].append(value) - if outputs[0] is None: - outputs = None - else: - outputs = bm.asarray(outputs) - for key, value in monitors.items(): - monitors[key] = bm.asarray(value) - return outputs, monitors - return run_func - - -class ReportRunner(DSRunner): - """The runner provides convenient interface for debugging. - It is also able to report the running progress. - - .. deprecated:: 2.0.3 - Prefer the use of :py:class:`brainpy.dyn.DSRunner` for dynamical system running. - This runner is deprecated since 2.0.3. - Parameters - ---------- - target : DynamicalSystem - The target model to run. - monitors : None, list of str, tuple of str, Monitor - Variables to monitor. - inputs : list, tuple - The input settings. - """ + shared_kwargs_str = serialize_kwargs(shared_args) + if shared_kwargs_str not in self._f_predict_compiled: + + monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args) + + def _step_func(inputs): + t, i, x = inputs + # input step + shared = DotDict(t=t, i=t, dt=self.dt) + self._input_step(shared) + # dynamics update step + shared.update(shared_args) + args = (shared,) if x is None else (shared, x) + out = self.target(*args) + # monitor step + mon = monitor_func(shared) + # finally + if self.progress_bar: + id_tap(lambda *arg: self._pbar.update(), ()) + return out, mon + + if self.jit['predict']: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True) + run_func = lambda all_inputs: f(all_inputs)[1] - def __init__(self, target, inputs=(), jit=False, dt=None, **kwargs): - super(ReportRunner, self).__init__(target=target, inputs=inputs, dt=dt, jit=False, **kwargs) + else: + def run_func(xs): + # total data + times, indices, xs = xs + + outputs = [] + monitors = {key: [] for key in set(self.mon.var_names) | set(self.fun_monitors.keys())} + for i in range(times.shape[0]): + # data at time i + x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + + # step at the i + output, mon = _step_func((times[i], indices[i], x)) + + # append output and monitor + outputs.append(output) + for key, value in mon.items(): + monitors[key].append(value) + + # final work + if outputs[0] is None: + outputs = None + else: + outputs = bm.asarray(outputs) + for key, value in monitors.items(): + monitors[key] = bm.asarray(value) + return outputs, monitors + self._f_predict_compiled[shared_kwargs_str] = run_func + return self._f_predict_compiled[shared_kwargs_str] + + def __del__(self): + if hasattr(self, '_predict_func'): + for key in tuple(self._f_predict_compiled.keys()): + del self._f_predict_compiled[key] + super(DSRunner, self).__del__() - # Build the update function - if jit: - dyn_vars = TensorCollector() - dyn_vars.update(self.dyn_vars) - dyn_vars.update(self.target.vars().unique()) - self._update_step = bm.jit(self.target.update, dyn_vars=dyn_vars) - else: - self._update_step = self.target.update - - def _run_one_step(self, _t): - self._input_step(_t, self.dt) - self._update_step(_t, self.dt) - if self.progress_bar: - self._pbar.update() - return self._monitor_step(_t, self.dt) - - def build_run_function(self): - def f_run(all_t): - for i in range(all_t.shape[0]): - mon = self._run_one_step(all_t[i]) - for k, v in mon.items(): - self.mon.item_contents[k].append(v) - return None, {} - - return f_run diff --git a/brainpy/dyn/synapses/__init__.py b/brainpy/dyn/synapses/__init__.py index 8f9da3ccb..763f8f22a 100644 --- a/brainpy/dyn/synapses/__init__.py +++ b/brainpy/dyn/synapses/__init__.py @@ -4,7 +4,8 @@ from .biological_models import * from .learning_rules import * from .gap_junction import * -from .others import * +from .couplings import * +# compatible interface from . import compat diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index dd289ccbf..f015c973a 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -1,11 +1,14 @@ # -*- coding: utf-8 -*- + import warnings from typing import Union, Dict, Callable, Optional +from jax import vmap +from jax.lax import stop_gradient import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One -from brainpy.dyn.base import NeuGroup, SynapseOutput, SynapsePlasticity, TwoEndConn -from brainpy.initialize import Initializer, init_param +from brainpy.dyn.base import NeuGroup, SynOutput, SynSTP, TwoEndConn +from brainpy.initialize import Initializer, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor from ..synouts import CUBA, MgBlock @@ -26,12 +29,14 @@ class Delta(TwoEndConn): .. math:: - I_{syn} (t) = \sum_{j\in C} w \delta(t-t_j-D) + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \mathrm{STP} * \delta(t-t_j-D) - where :math:`w` denotes the chemical synaptic strength, :math:`t_j` the spiking - moment of the presynaptic neuron :math:`j`, :math:`C` the set of neurons connected - to the post-synaptic neuron, and :math:`D` the transmission delay of chemical - synapses. For simplicity, the rise and decay phases of post-synaptic currents are + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + :math:`D` the transmission delay of chemical synapses, + and :math:`\mathrm{STP}` the short-term plasticity effect. + For simplicity, the rise and decay phases of post-synaptic currents are omitted in this model. **Model Examples** @@ -66,17 +71,17 @@ class Delta(TwoEndConn): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. delay_step: int, ndarray, JaxArray, Initializer, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. g_max: float, ndarray, JaxArray, Initializer, Callable The synaptic strength. Default is 1. - post_key: str + post_input_key: str The key of the post variable. It should be a string. The key should be the attribute of the post-synaptic neuron group. - post_has_ref: bool + post_ref_key: str Whether the post-synaptic group has refractory period. """ @@ -85,108 +90,86 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: Optional[SynapseOutput] = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'sparse', + output: Optional[SynOutput] = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'sparse', g_max: Union[float, Tensor, Initializer, Callable] = 1., delay_step: Union[float, Tensor, Initializer, Callable] = None, - post_key: str = 'V', - post_has_ref: bool = False, + post_input_key: str = 'V', + post_ref_key: str = None, name: str = None, + + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, ): - super(Delta, self).__init__(pre=pre, + super(Delta, self).__init__(name=name, + pre=pre, post=post, conn=conn, output=CUBA() if output is None else output, - plasticity=plasticity, - name=name) - self.check_pre_attrs('spike') + stp=stp, + trainable=trainable) # parameters - self.post_key = post_key - self.check_post_attrs(post_key) - self.post_has_ref = post_has_ref - if post_has_ref: - self.check_post_attrs('refractory') + self.stop_spike_gradient = stop_spike_gradient + self.post_input_key = post_input_key + self.check_post_attrs(post_input_key) + self.post_ref_key = post_ref_key + if post_ref_key: + self.check_post_attrs(post_ref_key) + self.comp_method = comp_method # connections and weights - self.conn_type = conn_type - if conn_type not in ['sparse', 'dense']: - raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') - if self.conn is None: - raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - if isinstance(self.conn, One2One): - self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif isinstance(self.conn, All2All): - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - if bm.size(self.g_max) != 1: - self.weight_type = 'heter' - bm.fill_diagonal(self.g_max, 0.) - else: - self.weight_type = 'homo' - else: - if conn_type == 'sparse': - self.pre2post = self.conn.require('pre2post') - self.g_max = init_param(g_max, self.pre2post[1].shape, allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif conn_type == 'dense': - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - if self.weight_type == 'homo': - self.conn_mat = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {conn_type}') + self.g_max, self.conn_mask = self.init_weights(g_max, comp_method=comp_method, sparse_data='csr') - # variables - self.delay_step = self.register_delay(f"{self.pre.name}.spike", - delay_step=delay_step, - delay_target=self.pre.spike) + # register delay + self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) - def reset(self): - self.output.reset() - self.plasticity.reset() + def reset_state(self, batch_size=None): + self.output.reset_state(batch_size) + self.stp.reset_state(batch_size) - def update(self, t, dt): - # delayed pre-synaptic spikes - pre_spike = self.get_delay_data(f"{self.pre.name}.spike", delay_step=self.delay_step) + def update(self, tdi, pre_spike=None): + # pre-synaptic spikes + if pre_spike is None: + pre_spike = self.get_delay_data(f"{self.pre.name}.spike", delay_step=self.delay_step) + if self.stop_spike_gradient: + pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike + pre_spike = stop_gradient(pre_spike) # update sub-components - self.output.update(t, dt) - self.plasticity.update(t, dt, pre_spike, self.post.spike) - - # post values - pre_spike = self.plasticity.filter(pre_spike.astype(bm.get_dfloat())) + self.output.update(tdi) + self.stp.update(tdi, pre_spike) - assert self.weight_type in ['homo', 'heter'] - assert self.conn_type in ['sparse', 'dense'] + # synaptic values onto the post if isinstance(self.conn, All2All): - if self.weight_type == 'homo': - post_vs = bm.sum(pre_spike) - if not self.conn.include_self: - post_vs = post_vs - pre_spike - post_vs *= self.g_max - else: - post_vs = pre_spike @ self.g_max + syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_vs = pre_spike * self.g_max + syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: - if self.conn_type == 'sparse': - post_vs = bm.pre2post_event_sum(pre_spike, - self.pre2post, - self.post.num, - self.g_max) + if self.comp_method == 'sparse': + f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max) + if self.trainable: f = vmap(f) + post_vs = f(pre_spike) + # if not isinstance(self.stp, _NullSynSTP): + # raise NotImplementedError() + # stp_value = self.stp(1.) + # f2 = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) + # if self.trainable: f2 = vmap(f2) + # post_vs *= f2(stp_value) else: - if self.weight_type == 'homo': - post_vs = self.g_max * (pre_spike @ self.conn_mat) - else: - post_vs = pre_spike @ self.g_max + syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + if self.post_ref_key: + post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key)) + post_vs = self.output(post_vs) # update outputs - target = getattr(self.post, self.post_key) - if self.post_has_ref: - post_vs = post_vs * bm.logical_not(self.post.refractory) - target += self.output.filter(post_vs) + target = getattr(self.post, self.post_input_key) + target += post_vs class Exponential(TwoEndConn): @@ -212,10 +195,13 @@ class Exponential(TwoEndConn): .. math:: \begin{aligned} - & g_{\mathrm{syn}}(t) = g_{max} g \\ + & g_{\mathrm{syn}}(t) = g_{max} g * \mathrm{STP} \\ & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}). \end{aligned} - + + where :math:`\mathrm{STP}` is used to model the short-term plasticity effect. + + **Model Examples** - `(Brunel & Hakim, 1999) Fast Global Oscillation `_ @@ -258,7 +244,7 @@ class Exponential(TwoEndConn): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. delay_step: int, ndarray, JaxArray, Initializer, Callable @@ -286,113 +272,85 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynapseOutput = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'sparse', + output: SynOutput = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'sparse', g_max: Union[float, Tensor, Initializer, Callable] = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, tau: Union[float, Tensor] = 8.0, name: str = None, method: str = 'exp_auto', + + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, ): super(Exponential, self).__init__(pre=pre, post=post, conn=conn, output=CUBA() if output is None else output, - plasticity=plasticity, - name=name) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - + stp=stp, + name=name, + trainable=trainable) # parameters + self.stop_spike_gradient = stop_spike_gradient + self.comp_method = comp_method self.tau = tau if bm.size(self.tau) != 1: - raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. ' - f'But we got {self.tau}') + raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}') # connections and weights - self.conn_type = conn_type - if conn_type not in ['sparse', 'dense']: - raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') - if self.conn is None: - raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - if isinstance(self.conn, One2One): - self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif isinstance(self.conn, All2All): - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - if bm.size(self.g_max) != 1: - self.weight_type = 'heter' - bm.fill_diagonal(self.g_max, 0.) - else: - self.weight_type = 'homo' - else: - if conn_type == 'sparse': - self.pre2post = self.conn.require('pre2post') - self.g_max = init_param(g_max, self.pre2post[1].shape, allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif conn_type == 'dense': - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - if self.weight_type == 'homo': - self.conn_mat = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {conn_type}') + self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr') # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - self.delay_step = self.register_delay(f"{self.pre.name}.spike", - delay_step, - self.pre.spike) + self.g = variable(bm.zeros, trainable, self.post.num) + self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # function self.integral = odeint(lambda g, t: -g / self.tau, method=method) - def reset(self): - self.g.value = bm.zeros(self.post.num) - self.output.reset() + def reset_state(self, batch_size=None): + self.g.value = variable(bm.zeros, batch_size, self.post.num) + self.output.reset_state(batch_size) + self.stp.reset_state(batch_size) + + def update(self, tdi, pre_spike=None): + t, dt = tdi['t'], tdi['dt'] - def update(self, t, dt): # delays - pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + if pre_spike is None: + pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + if self.stop_spike_gradient: + pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike + pre_spike = stop_gradient(pre_spike) # update sub-components - self.output.update(t, dt) - self.plasticity.update(t, dt, pre_spike, self.post.spike) + self.output.update(tdi) + self.stp.update(tdi, pre_spike) # post values - assert self.weight_type in ['homo', 'heter'] - assert self.conn_type in ['sparse', 'dense'] if isinstance(self.conn, All2All): - pre_spike = pre_spike.astype(bm.get_dfloat()) - if self.weight_type == 'homo': - post_vs = bm.sum(pre_spike) - if not self.conn.include_self: - post_vs = post_vs - pre_spike - post_vs = self.g_max * post_vs - else: - post_vs = pre_spike @ self.g_max + syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - pre_spike = pre_spike.astype(bm.get_dfloat()) - post_vs = pre_spike * self.g_max + syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: - if self.conn_type == 'sparse': - post_vs = bm.pre2post_event_sum(pre_spike, - self.pre2post, - self.post.num, - self.g_max) + if self.comp_method == 'sparse': + f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max) + if self.trainable: f = vmap(f) + post_vs = f(pre_spike) + # if not isinstance(self.stp, _NullSynSTP): + # raise NotImplementedError() else: - pre_spike = pre_spike.astype(bm.get_dfloat()) - if self.weight_type == 'homo': - post_vs = self.g_max * (pre_spike @ self.conn_mat) - else: - post_vs = pre_spike @ self.g_max - + syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) + post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates self.g.value = self.integral(self.g.value, t, dt) + post_vs + g_out = self.output(self.g) # output - self.post.input += self.output.filter(self.g) + self.post.input += g_out class DualExponential(TwoEndConn): @@ -419,13 +377,14 @@ class DualExponential(TwoEndConn): .. math:: \begin{aligned} - &g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g \\ + &g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g * \mathrm{STP} \\ &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\ &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right), \end{aligned} - **Model Examples** + where :math:`\mathrm{STP}` is used to model the short-term plasticity effect of synapses. + **Model Examples** .. plot:: :include-source: True @@ -462,7 +421,7 @@ class DualExponential(TwoEndConn): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. delay_step: int, ndarray, JaxArray, Initializer, Callable @@ -494,26 +453,32 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - plasticity: Optional[SynapsePlasticity] = None, - output: SynapseOutput = None, - conn_type: str = 'dense', + stp: Optional[SynSTP] = None, + output: SynOutput = None, + comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 1., tau_decay: Union[float, Tensor] = 10.0, tau_rise: Union[float, Tensor] = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - name: str = None + name: str = None, + + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, ): super(DualExponential, self).__init__(pre=pre, post=post, conn=conn, output=CUBA() if output is None else output, - plasticity=plasticity, - name=name) - self.check_pre_attrs('spike') - self.check_post_attrs('input') - + stp=stp, + name=name, + trainable=trainable) # parameters + # self.check_pre_attrs('spike') + self.check_post_attrs('input') + self.stop_spike_gradient = stop_spike_gradient + self.comp_method = comp_method self.tau_rise = tau_rise self.tau_decay = tau_decay if bm.size(self.tau_rise) != 1: @@ -524,46 +489,21 @@ def __init__( f'But we got {self.tau_decay}') # connections - self.conn_type = conn_type - if conn_type not in ['sparse', 'dense']: - raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') - if self.conn is None: - raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - if isinstance(self.conn, One2One): - self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif isinstance(self.conn, All2All): - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - if bm.size(self.g_max) != 1: - self.weight_type = 'heter' - bm.fill_diagonal(self.g_max, 0.) - else: - self.weight_type = 'homo' - else: - if conn_type == 'sparse': - self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') - self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif conn_type == 'dense': - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - if self.weight_type == 'homo': - self.conn_mat = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {conn_type}') + self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') # variables - self.h = bm.Variable(bm.zeros(self.pre.num)) - self.g = bm.Variable(bm.zeros(self.pre.num)) + self.h = variable(bm.zeros, trainable, self.pre.num) + self.g = variable(bm.zeros, trainable, self.pre.num) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral self.integral = odeint(method=method, f=JointEq([self.dg, self.dh])) - def reset(self): - self.h.value = bm.zeros(self.pre.num) - self.g.value = bm.zeros(self.pre.num) - self.output.reset() + def reset_state(self, batch_size=None): + self.h.value = variable(bm.zeros, batch_size, self.pre.num) + self.g.value = variable(bm.zeros, batch_size, self.pre.num) + self.output.reset_state(batch_size) + self.stp.reset_state(batch_size) def dh(self, h, t): return -h / self.tau_rise @@ -571,42 +511,41 @@ def dh(self, h, t): def dg(self, g, t, h): return -g / self.tau_decay + h - def update(self, t, dt): - # delays - pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + def update(self, tdi, pre_spike=None): + t, dt = tdi['t'], tdi['dt'] + + # pre-synaptic spikes + if pre_spike is None: + pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + if self.stop_spike_gradient: + pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike + pre_spike = stop_gradient(pre_spike) # update sub-components - self.output.update(t, dt) - self.plasticity.update(t, dt, pre_spike, self.post.spike) + self.output.update(tdi) + self.stp.update(tdi, pre_spike) # update synaptic variables self.g.value, self.h.value = self.integral(self.g, self.h, t, dt) self.h += pre_spike - # post-synaptic values - syn_value = self.plasticity.filter(self.g) - + # post values + syn_value = self.stp(self.g) if isinstance(self.conn, All2All): - if self.weight_type == 'homo': - post_vs = bm.sum(syn_value) - if not self.conn.include_self: - post_vs = post_vs - syn_value - post_vs = self.g_max * post_vs - else: - post_vs = syn_value @ self.g_max + post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_vs = self.g_max * syn_value + post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: - if self.conn_type == 'sparse': - post_vs = bm.pre2post_sum(syn_value, self.post.num, self.post_ids, self.pre_ids) + if self.comp_method == 'sparse': + f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) + if self.trainable: f = vmap(f) + post_vs = f(syn_value) else: - if self.weight_type == 'homo': - post_vs = (self.g_max * syn_value) @ self.conn_mat - else: - post_vs = syn_value @ self.g_max + post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self.output(post_vs) # output - self.post.input += self.output.filter(post_vs) + self.post.input += post_vs class Alpha(DualExponential): @@ -667,7 +606,7 @@ class Alpha(DualExponential): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. delay_step: int, ndarray, JaxArray, Initializer, Callable @@ -694,27 +633,33 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynapseOutput = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'dense', + output: SynOutput = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, tau_decay: Union[float, Tensor] = 10.0, method: str = 'exp_auto', - name: str = None + name: str = None, + + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, ): super(Alpha, self).__init__(pre=pre, post=post, conn=conn, - conn_type=conn_type, + comp_method=comp_method, delay_step=delay_step, g_max=g_max, tau_decay=tau_decay, tau_rise=tau_decay, method=method, output=CUBA() if output is None else output, - plasticity=plasticity, - name=name) + stp=stp, + name=name, + trainable=trainable, + stop_spike_gradient=stop_spike_gradient) class NMDA(TwoEndConn): @@ -810,7 +755,7 @@ class NMDA(TwoEndConn): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `dense`. delay_step: int, ndarray, JaxArray, Initializer, Callable @@ -872,9 +817,9 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: Optional[SynapseOutput] = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'dense', + output: Optional[SynOutput] = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.15, delay_step: Union[int, Tensor, Initializer, Callable] = None, tau_decay: Union[float, Tensor] = 100., @@ -883,6 +828,10 @@ def __init__( method: str = 'exp_auto', name: str = None, + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, + # deprecated alpha=None, beta=None, @@ -925,12 +874,11 @@ def __init__( post=post, conn=conn, output=output, - plasticity=plasticity, - name=name) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - + stp=stp, + name=name, + trainable=trainable) # parameters + # self.check_post_attrs('input', 'V') self.tau_decay = tau_decay self.tau_rise = tau_rise self.a = a @@ -940,39 +888,15 @@ def __init__( raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. But we got {tau_decay}') if bm.size(tau_rise) != 1: raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}') + self.comp_method = comp_method + self.stop_spike_gradient = stop_spike_gradient # connections and weights - self.conn_type = conn_type - if conn_type not in ['sparse', 'dense']: - raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') - if self.conn is None: - raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - if isinstance(self.conn, One2One): - self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif isinstance(self.conn, All2All): - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - if bm.size(self.g_max) != 1: - self.weight_type = 'heter' - bm.fill_diagonal(self.g_max, 0.) - else: - self.weight_type = 'homo' - else: - if conn_type == 'sparse': - self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') - self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif conn_type == 'dense': - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - if self.weight_type == 'homo': - self.conn_mat = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {conn_type}') + self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') # variables - self.g = bm.Variable(bm.zeros(self.pre.num)) - self.x = bm.Variable(bm.zeros(self.pre.num)) + self.g = variable(bm.zeros, trainable, self.pre.num) + self.x = variable(bm.zeros, trainable, self.pre.num) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral @@ -984,46 +908,43 @@ def dg(self, g, t, x): def dx(self, x, t): return -x / self.tau_rise - def reset(self): - self.g[:] = 0 - self.x[:] = 0 - self.output.reset() - self.plasticity.reset() + def reset_state(self, batch_size=None): + self.g.value = variable(bm.zeros, batch_size, self.pre.num) + self.x.value = variable(bm.zeros, batch_size, self.pre.num) + self.output.reset_state(batch_size) + self.stp.reset_state(batch_size) - def update(self, t, dt): + def update(self, tdi, pre_spike=None): + t, dt = tdi['t'], tdi['dt'] # delays - pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + if pre_spike is None: + pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + if self.stop_spike_gradient: + pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike + pre_spike = stop_gradient(pre_spike) # update sub-components - if self.plasticity is not None: - self.plasticity.update(t, dt, pre_spike, self.post.spike) - self.output.update(t, dt) + self.stp.update(tdi, pre_spike) + self.output.update(tdi) # update synapse variables self.g.value, self.x.value = self.integral(self.g, self.x, t, dt=dt) self.x += pre_spike # post-synaptic value - syn_value = self.plasticity.filter(self.g) # x * g, u * x * g - + syn_value = self.stp(self.g) if isinstance(self.conn, All2All): - if self.weight_type == 'homo': - post_g = bm.sum(syn_value) - if not self.conn.include_self: - post_g = post_g - syn_value - post_g = post_g * self.g_max - else: - post_g = syn_value @ self.g_max + post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_g = self.g_max * syn_value + post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: - if self.conn_type == 'sparse': - post_g = bm.pre2post_sum(syn_value, self.post.num, self.post_ids, self.pre_ids) + if self.comp_method == 'sparse': + f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) + if self.trainable: f = vmap(f) + post_vs = f(syn_value) else: - if self.weight_type == 'homo': - post_g = (self.g_max * syn_value) @ self.conn_mat - else: - post_g = syn_value @ self.g_max + post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self.output(post_vs) # output - self.post.input += self.output.filter(post_g) + self.post.input += post_vs diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index 897c1b2f4..cfa6ef26c 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- + import warnings from typing import Union, Dict, Callable, Optional +from jax import vmap +from jax.lax import stop_gradient + import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One -from brainpy.dyn.base import NeuGroup, TwoEndConn, SynapsePlasticity, SynapseOutput -from brainpy.initialize import Initializer, init_param +from brainpy.dyn.base import NeuGroup, TwoEndConn, SynSTP, SynOutput +from brainpy.dyn.synouts import COBA, MgBlock +from brainpy.initialize import Initializer, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor -from ..synouts import COBA, MgBlock __all__ = [ 'AMPA', @@ -94,7 +98,7 @@ class AMPA(TwoEndConn): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `dense`. delay_step: int, ndarray, JaxArray, Initializer, Callable @@ -135,20 +139,24 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynapseOutput = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'dense', + output: SynOutput = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.42, delay_step: Union[int, Tensor, Initializer, Callable] = None, - alpha: Union[float, Tensor] = 0.98, - beta: Union[float, Tensor] = 0.18, - T: Union[float, Tensor] = 0.5, - T_duration: Union[float, Tensor] = 0.5, + alpha: float = 0.98, + beta: float = 0.18, + T: float = 0.5, + T_duration: float = 0.5, method: str = 'exp_auto', name: str = None, + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, + # deprecated - E: Union[float, Tensor] = None, + E: float = None, ): _E = 0. if E is not None: @@ -159,12 +167,13 @@ def __init__( post=post, conn=conn, output=COBA(E=_E) if output is None else output, - plasticity=plasticity, - name=name) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') + stp=stp, + name=name, + trainable=trainable) # parameters + self.stop_spike_gradient = stop_spike_gradient + self.comp_method = comp_method self.alpha = alpha self.beta = beta self.T = T @@ -179,87 +188,64 @@ def __init__( raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}') # connection - self.conn_type = conn_type - if conn_type not in ['sparse', 'dense']: - raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') - if self.conn is None: - raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - if isinstance(self.conn, One2One): - self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif isinstance(self.conn, All2All): - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - if bm.size(self.g_max) != 1: - self.weight_type = 'heter' - bm.fill_diagonal(self.g_max, 0.) - else: - self.weight_type = 'homo' - else: - if conn_type == 'sparse': - self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') - self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif conn_type == 'dense': - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - if self.weight_type == 'homo': - self.conn_mat = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {conn_type}') + self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') # variables - self.g = bm.Variable(bm.zeros(self.pre.num)) - self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7) - self.delay_step = self.register_delay(f"{self.pre.name}.spike", - delay_step=delay_step, - delay_target=self.pre.spike) + self.g = variable(bm.zeros, trainable, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, trainable, self.pre.num) + self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # functions self.integral = odeint(method=method, f=self.dg) - def reset(self): - self.g[:] = 0 - self.output.reset() - self.plasticity.reset() + def reset_state(self, batch_size=None): + self.g = variable(bm.zeros, batch_size, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num) + self.output.reset_state(batch_size) + self.stp.reset_state(batch_size) def dg(self, g, t, TT): dg = self.alpha * TT * (1 - g) - self.beta * g return dg - def update(self, t, dt): + def update(self, tdi, pre_spike=None): + t, dt = tdi['t'], tdi['dt'] + # delays - pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) - self.output.update(t, dt) - self.plasticity.update(t, dt, pre_spike, self.post.spike) + if pre_spike is None: + pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + if self.stop_spike_gradient: + pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike + pre_spike = stop_gradient(pre_spike) + + # update sub-components + self.output.update(tdi) + self.stp.update(tdi, pre_spike) # update synaptic variables self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) + if self.trainable: + self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value) TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T - self.g.value = self.integral(self.g, t, TT, dt=dt) + self.g.value = self.integral(self.g, t, TT, dt) # post-synaptic values - syn_value = self.plasticity.filter(self.g) - if isinstance(self.conn, One2One): - post_g = self.g_max * syn_value - elif isinstance(self.conn, All2All): - if self.weight_type == 'homo': - post_g = bm.sum(syn_value) - if not self.conn.include_self: - post_g = post_g - syn_value - post_g = post_g * self.g_max - else: - post_g = syn_value @ self.g_max + syn_value = self.stp(self.g) + if isinstance(self.conn, All2All): + post_vs = self.syn2post_with_all2all(syn_value, self.g_max) + elif isinstance(self.conn, One2One): + post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: - if self.conn_type == 'sparse': - post_g = bm.pre2post_sum(syn_value, self.post.num, self.post_ids, self.pre_ids) + if self.comp_method == 'sparse': + f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) + if self.trainable: f = vmap(f) + post_vs = f(syn_value) else: - if self.weight_type == 'homo': - post_g = (self.g_max * syn_value) @ self.conn_mat - else: - post_g = syn_value @ self.g_max + post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self.output(post_vs) # output - self.post.input += self.output.filter(post_g) + self.post.input += post_vs class GABAa(AMPA): @@ -295,7 +281,7 @@ class GABAa(AMPA): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `dense`. delay_step: int, ndarray, JaxArray, Initializer, Callable @@ -335,9 +321,9 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynapseOutput = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'dense', + output: SynOutput = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.04, delay_step: Union[int, Tensor, Initializer, Callable] = None, alpha: Union[float, Tensor] = 0.53, @@ -347,6 +333,10 @@ def __init__( method: str = 'exp_auto', name: str = None, + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, + # deprecated E: Union[float, Tensor] = None, ): @@ -359,8 +349,8 @@ def __init__( post=post, conn=conn, output=COBA(E=_E) if output is None else output, - plasticity=plasticity, - conn_type=conn_type, + stp=stp, + comp_method=comp_method, delay_step=delay_step, g_max=g_max, alpha=alpha, @@ -368,7 +358,9 @@ def __init__( T=T, T_duration=T_duration, method=method, - name=name) + name=name, + trainable=trainable, + stop_spike_gradient=stop_spike_gradient,) class BioNMDA(TwoEndConn): @@ -458,7 +450,7 @@ class BioNMDA(TwoEndConn): The post-synaptic neuron group. conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector The synaptic connections. - conn_type: str + comp_method: str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `dense`. delay_step: int, ndarray, JaxArray, Initializer, Callable @@ -473,7 +465,6 @@ class BioNMDA(TwoEndConn): The conversion rate of x from inactive to active. Default 1 ms^-1. beta2: float, JaxArray, ndarray The conversion rate of x from active to inactive. Default 0.5 ms^-1. - name: str The name of this synaptic projection. method: str @@ -498,9 +489,9 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynapseOutput = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'dense', + output: Optional[SynOutput] = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.15, delay_step: Union[int, Tensor, Initializer, Callable] = None, alpha1: Union[float, Tensor] = 2., @@ -511,15 +502,18 @@ def __init__( T_dur: Union[float, Tensor] = 0.5, method: str = 'exp_auto', name: str = None, + + # training parameters + trainable: bool = False, + stop_spike_gradient: bool = False, ): super(BioNMDA, self).__init__(pre=pre, post=post, conn=conn, output=MgBlock(E=0.) if output is None else output, - plasticity=plasticity, - name=name) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') + stp=stp, + name=name, + trainable=trainable) # parameters self.beta1 = beta1 @@ -540,50 +534,27 @@ def __init__( raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}') if bm.size(T_dur) != 1: raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}') + self.comp_method = comp_method + self.stop_spike_gradient = stop_spike_gradient # connections and weights - self.conn_type = conn_type - if conn_type not in ['sparse', 'dense']: - raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') - if self.conn is None: - raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - if isinstance(self.conn, One2One): - self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif isinstance(self.conn, All2All): - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - if bm.size(self.g_max) != 1: - self.weight_type = 'heter' - bm.fill_diagonal(self.g_max, 0.) - else: - self.weight_type = 'homo' - else: - if conn_type == 'sparse': - self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') - self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - elif conn_type == 'dense': - self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) - self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' - if self.weight_type == 'homo': - self.conn_mat = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {conn_type}') + self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') # variables - self.g = bm.Variable(bm.zeros(self.pre.num)) - self.x = bm.Variable(bm.zeros(self.pre.num)) - self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7) + self.g = variable(bm.zeros, trainable, self.pre.num) + self.x = variable(bm.zeros, trainable, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, trainable, self.pre.num) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) - def reset(self): - self.g.value = bm.zeros(self.pre.num) - self.x.value = bm.zeros(self.pre.num) - self.spike_arrival_time.value = bm.ones(self.pre.num) * -1e7 - self.plasticity.reset() + def reset_state(self, batch_size=None): + self.g = variable(bm.zeros, batch_size, self.pre.num) + self.x = variable(bm.zeros, batch_size, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num) + self.stp.reset_state(batch_size) + self.output.reset_state(batch_size) def dg(self, g, t, x): return self.alpha1 * x * (1 - g) - self.beta1 * g @@ -591,37 +562,41 @@ def dg(self, g, t, x): def dx(self, x, t, T): return self.alpha2 * T * (1 - x) - self.beta2 * x - def update(self, t, dt): - # delays - pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + def update(self, tdi, pre_spike=None): + t, dt = tdi['t'], tdi['dt'] - self.plasticity.update(t, dt, pre_spike, self.post.spike) + # pre-synaptic spikes + if pre_spike is None: + pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + if self.stop_spike_gradient: + pre_spike = pre_spike.value if isinstance(pre_spike, bm.JaxArray) else pre_spike + pre_spike = stop_gradient(pre_spike) + + # update sub-components + self.output.update(tdi) + self.stp.update(tdi, pre_spike) # update synapse variables self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) + if self.trainable: + self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value) T = ((t - self.spike_arrival_time) < self.T_dur) * self.T_0 - self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt=dt) + self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt) # post-synaptic value - syn_value = self.plasticity.filter(self.g.value) + syn_value = self.stp(self.g) if isinstance(self.conn, All2All): - if self.weight_type == 'homo': - post_g = bm.sum(syn_value) - if not self.conn.include_self: - post_g = post_g - syn_value - post_g = post_g * self.g_max - else: - post_g = syn_value @ self.g_max + post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_g = self.g_max * syn_value + post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: - if self.conn_type == 'sparse': - post_g = bm.pre2post_sum(syn_value, self.post.num, self.post_ids, self.pre_ids) + if self.comp_method == 'sparse': + f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) + if self.trainable: f = vmap(f) + post_vs = f(syn_value) else: - if self.weight_type == 'homo': - post_g = (self.g_max * syn_value) @ self.conn_mat - else: - post_g = syn_value @ self.g_max + post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self.output(post_vs) # output - self.post.input += self.output.filter(post_g) + self.post.input += post_vs diff --git a/brainpy/dyn/synapses/compat.py b/brainpy/dyn/synapses/compat.py index b76a6ee71..2aa5f5fed 100644 --- a/brainpy/dyn/synapses/compat.py +++ b/brainpy/dyn/synapses/compat.py @@ -36,7 +36,7 @@ def __init__( conn_type: str = 'sparse', weights: Union[float, Tensor, Initializer, Callable] = 1., delay_step: Union[float, Tensor, Initializer, Callable] = None, - post_key: str = 'V', + post_input_key: str = 'V', post_has_ref: bool = False, name: str = None, ): @@ -46,11 +46,11 @@ def __init__( conn=conn, output=CUBA(), name=name, - conn_type=conn_type, + comp_method=conn_type, g_max=weights, delay_step=delay_step, - post_key=post_key, - post_has_ref=post_has_ref) + post_input_key=post_input_key, + post_ref_key='refractory' if post_has_ref else None) class ExpCUBA(Exponential): @@ -77,7 +77,7 @@ def __init__( post=post, conn=conn, name=name, - conn_type=conn_type, + comp_method=conn_type, g_max=g_max, delay_step=delay_step, tau=tau, @@ -113,7 +113,7 @@ def __init__( super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn, - conn_type=conn_type, + comp_method=conn_type, g_max=g_max, delay_step=delay_step, tau=tau, @@ -146,7 +146,7 @@ def __init__( super(DualExpCUBA, self).__init__(pre=pre, post=post, conn=conn, - conn_type=conn_type, + comp_method=conn_type, g_max=g_max, tau_decay=tau_decay, tau_rise=tau_rise, @@ -182,7 +182,7 @@ def __init__( super(DualExpCOBA, self).__init__(pre=pre, post=post, conn=conn, - conn_type=conn_type, + comp_method=conn_type, g_max=g_max, tau_decay=tau_decay, tau_rise=tau_rise, diff --git a/brainpy/dyn/rates/couplings.py b/brainpy/dyn/synapses/couplings.py similarity index 96% rename from brainpy/dyn/rates/couplings.py rename to brainpy/dyn/synapses/couplings.py index 11e2ac5d9..d0a83f50e 100644 --- a/brainpy/dyn/rates/couplings.py +++ b/brainpy/dyn/synapses/couplings.py @@ -180,19 +180,19 @@ def __init__( self.coupling_var1 = coupling_var1 self.coupling_var2 = coupling_var2 - def update(self, t, dt): + def update(self, tdi): # delays if self.delay_steps is None: diffusive = bm.expand_dims(self.coupling_var1, axis=1) - self.coupling_var2 diffusive = (self.conn_mat * diffusive).sum(axis=0) elif self.delay_type == 'array': - delay_var: bm.LengthDelay = self.global_delay_vars[f'delay_{id(self.delay_var)}'] + delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var2.size))) # (post.num,) delays = f(bm.arange(self.coupling_var1.size).value) # (pre.num, post.num) diffusive = delays - self.coupling_var2 # (pre.num, post.num) diffusive = (self.conn_mat * diffusive).sum(axis=0) elif self.delay_type == 'int': - delay_var: bm.LengthDelay = self.global_delay_vars[f'delay_{id(self.delay_var)}'] + delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] delayed_var = delay_var(self.delay_steps) diffusive = bm.expand_dims(delayed_var, axis=1) - self.coupling_var2 diffusive = (self.conn_mat * diffusive).sum(axis=0) @@ -261,12 +261,12 @@ def update(self, t, dt): if self.delay_steps is None: additive = self.coupling_var @ self.conn_mat elif self.delay_type == 'array': - delay_var: bm.LengthDelay = self.global_delay_vars[f'delay_{id(self.delay_var)}'] + delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var.size))) # (pre.num,) delays = f(bm.arange(self.coupling_var.size).value) # (post.num, pre.num) additive = (self.conn_mat * delays.T).sum(axis=0) elif self.delay_type == 'int': - delay_var: bm.LengthDelay = self.global_delay_vars[f'delay_{id(self.delay_var)}'] + delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] delayed_var = delay_var(self.delay_steps) additive = (self.conn_mat * delayed_var).sum(axis=0) else: diff --git a/brainpy/dyn/synapses/gap_junction.py b/brainpy/dyn/synapses/gap_junction.py index b30d9d4c4..1875a47d8 100644 --- a/brainpy/dyn/synapses/gap_junction.py +++ b/brainpy/dyn/synapses/gap_junction.py @@ -4,8 +4,8 @@ import brainpy.math as bm from brainpy.connect import TwoEndConnector -from brainpy.dyn.base import NeuGroup, SynapseOutput, SynapsePlasticity, TwoEndConn -from brainpy.initialize import Initializer, init_param +from brainpy.dyn.base import NeuGroup, SynOutput, SynSTP, TwoEndConn +from brainpy.initialize import Initializer, parameter from brainpy.types import Tensor from ..synouts import CUBA @@ -20,48 +20,44 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - conn_type: str = 'dense', - output: SynapseOutput = None, - plasticity: Optional[SynapsePlasticity] = None, + comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 1., name: str = None, ): super(GapJunction, self).__init__(pre=pre, post=post, conn=conn, - output=CUBA() if output is None else output, - plasticity=plasticity, name=name) # checking self.check_pre_attrs('V', 'spike') self.check_post_attrs('V', 'input', 'spike') + # assert isinstance(self.output, _NullSynOut) + # assert isinstance(self.stp, _NullSynSTP) + # connections - self.conn_type = conn_type - if conn_type == 'dense': + self.comp_method = comp_method + if comp_method == 'dense': self.conn_mat = self.conn.require('conn_mat') - self.weights = init_param(g_max, (pre.num, post.num), allow_none=False) - elif conn_type == 'sparse': + self.weights = parameter(g_max, (pre.num, post.num), allow_none=False) + elif comp_method == 'sparse': self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') - self.weights = init_param(g_max, self.pre_ids.shape, allow_none=False) + self.weights = parameter(g_max, self.pre_ids.shape, allow_none=False) else: raise ValueError - def update(self, t, dt): - self.output.update(t, dt) - self.plasticity.update(t, dt, self.pre.spike, self.post.spike) - if self.conn_type == 'dense': + def update(self, tdi): + if self.comp_method == 'dense': # pre -> post diff = (self.pre.V.reshape((-1, 1)) - self.post.V) * self.conn_mat * self.weights - self.post.input += self.output.filter(bm.einsum('ij->j', diff)) + self.post.input += bm.einsum('ij->j', diff) # post -> pre - self.pre.input += self.output.filter(bm.einsum('ij->i', -diff)) + self.pre.input += bm.einsum('ij->i', -diff) else: diff = (self.pre.V[self.pre_ids] - self.post.V[self.post_ids]) * self.weights - self.post.input += self.output.filter(bm.syn2post_sum(diff, self.post_ids, self.post.num)) - self.pre.input += self.output.filter(bm.syn2post_sum(-diff, self.pre_ids, self.pre.num)) + self.post.input += bm.syn2post_sum(diff, self.post_ids, self.post.num) + self.pre.input += bm.syn2post_sum(-diff, self.pre_ids, self.pre.num) - def reset(self): - self.output.reset() - self.plasticity.reset() + def reset_state(self, batch_size=None): + pass diff --git a/brainpy/dyn/synapses/learning_rules.py b/brainpy/dyn/synapses/learning_rules.py index 8e10328e4..cb49ecc40 100644 --- a/brainpy/dyn/synapses/learning_rules.py +++ b/brainpy/dyn/synapses/learning_rules.py @@ -5,8 +5,7 @@ import brainpy.math as bm from brainpy.connect import TwoEndConnector from brainpy.dyn.base import NeuGroup, TwoEndConn -from brainpy.initialize import Initializer -from brainpy.dyn.utils import init_delay +from brainpy.initialize import Initializer, delay as init_delay from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor, Parameter @@ -189,7 +188,6 @@ def __init__( ): super(STP, self).__init__(pre=pre, post=post, conn=conn, name=name) self.check_post_attrs('input') - self.check_pre_attrs('spike') # parameters self.tau_d = tau_d diff --git a/brainpy/dyn/synapses/others.py b/brainpy/dyn/synapses/others.py deleted file mode 100644 index 0911a5652..000000000 --- a/brainpy/dyn/synapses/others.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union, Dict, Callable, Optional - -import brainpy.math as bm -from brainpy.connect import TwoEndConnector, All2All, One2One -from brainpy.dyn.base import NeuGroup, SynapseOutput, SynapsePlasticity, TwoEndConn -from brainpy.initialize import Initializer, init_param -from brainpy.integrators import odeint, JointEq -from brainpy.types import Tensor -from ..synouts import CUBA, MgBlock - -__all__ = [ - 'WeightedSum', -] - - -class WeightedSum(TwoEndConn): - def __init__( - self, - pre: NeuGroup, - post: NeuGroup, - conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: Optional[SynapseOutput] = None, - plasticity: Optional[SynapsePlasticity] = None, - conn_type: str = 'sparse', - weights: Union[float, Tensor, Initializer, Callable] = 1., - delay_step: Union[float, Tensor, Initializer, Callable] = None, - post_key: str = 'V', - post_has_ref: bool = False, - name: str = None, - ): - super(WeightedSum, self).__init__(pre, post, conn, name=name) - - diff --git a/brainpy/dyn/synouts/conductances.py b/brainpy/dyn/synouts/conductances.py index 3c06a133e..8e429ecb7 100644 --- a/brainpy/dyn/synouts/conductances.py +++ b/brainpy/dyn/synouts/conductances.py @@ -2,8 +2,8 @@ from typing import Union, Callable -from brainpy.dyn.base import SynapseOutput -from brainpy.initialize import init_param, Initializer +from brainpy.dyn.base import SynOutput +from brainpy.initialize import parameter, Initializer from brainpy.types import Tensor __all__ = [ @@ -12,7 +12,7 @@ ] -class CUBA(SynapseOutput): +class CUBA(SynOutput): r"""Current-based synaptic output. Given the conductance, this model outputs the post-synaptic current with a identity function: @@ -35,11 +35,14 @@ class CUBA(SynapseOutput): def __init__(self, name: str = None): super(CUBA, self).__init__(name=name) + def update(self, tdi): + pass + def filter(self, g): return g -class COBA(SynapseOutput): +class COBA(SynOutput): r"""Conductance-based synaptic output. Given the synaptic conductance, the model output the post-synaptic current with @@ -70,7 +73,10 @@ def __init__( def register_master(self, master): super(COBA, self).register_master(master) - self.E = init_param(self._E, self.master.post.num, allow_none=False) + self.E = parameter(self._E, self.master.post.num, allow_none=False) def filter(self, g): return g * (self.E - self.master.post.V) + + def update(self, tdi): + pass diff --git a/brainpy/dyn/synouts/ions.py b/brainpy/dyn/synouts/ions.py index 8b749d4d3..2a0ad5d80 100644 --- a/brainpy/dyn/synouts/ions.py +++ b/brainpy/dyn/synouts/ions.py @@ -3,8 +3,8 @@ from typing import Union, Callable import brainpy.math as bm -from brainpy.dyn.base import SynapseOutput -from brainpy.initialize import init_param, Initializer +from brainpy.dyn.base import SynOutput +from brainpy.initialize import parameter, Initializer from brainpy.types import Tensor @@ -13,7 +13,7 @@ ] -class MgBlock(SynapseOutput): +class MgBlock(SynOutput): r"""Synaptic output based on Magnesium blocking. Given the synaptic conductance, the model output the post-synaptic current with @@ -60,12 +60,15 @@ def __init__( def register_master(self, master): super(MgBlock, self).register_master(master) - self.E = init_param(self.E, self.master.post.num, allow_none=False) - self.cc_Mg = init_param(self.cc_Mg, self.master.post.num, allow_none=False) - self.alpha = init_param(self.alpha, self.master.post.num, allow_none=False) - self.beta = init_param(self.beta, self.master.post.num, allow_none=False) + self.E = parameter(self.E, self.master.post.num, allow_none=False) + self.cc_Mg = parameter(self.cc_Mg, self.master.post.num, allow_none=False) + self.alpha = parameter(self.alpha, self.master.post.num, allow_none=False) + self.beta = parameter(self.beta, self.master.post.num, allow_none=False) def filter(self, g): V = self.master.post.V.value return g * (self.E - V) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * V)) + def update(self, tdi): + pass + diff --git a/brainpy/dyn/synplast/long_term_plasticity.py b/brainpy/dyn/synplast/long_term_plasticity.py new file mode 100644 index 000000000..40a96afc6 --- /dev/null +++ b/brainpy/dyn/synplast/long_term_plasticity.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/brainpy/dyn/synplast/short_term_plasticity.py b/brainpy/dyn/synplast/short_term_plasticity.py index 4a3bc7bda..5bd68f2d6 100644 --- a/brainpy/dyn/synplast/short_term_plasticity.py +++ b/brainpy/dyn/synplast/short_term_plasticity.py @@ -3,10 +3,11 @@ from typing import Union import brainpy.math as bm -from brainpy.dyn.base import SynapsePlasticity +from brainpy.dyn.base import SynSTP from brainpy.integrators import odeint, JointEq from brainpy.tools.checking import check_float from brainpy.types import Tensor +from brainpy.initialize import variable __all__ = [ 'STD', @@ -14,7 +15,7 @@ ] -class STD(SynapsePlasticity): +class STD(SynSTP): r"""Synaptic output with short-term depression. This model filters the synaptic current by the following equation: @@ -52,9 +53,10 @@ def __init__( self, tau: float = 200., U: float = 0.07, - method: str = 'exp_auto' + method: str = 'exp_auto', + name: str = None ): - super(STD, self).__init__() + super(STD, self).__init__(name=name) # parameters check_float(tau, 'tau', min_bound=0, ) @@ -70,20 +72,22 @@ def register_master(self, master): super(STD, self).register_master(master) # variables - self.x = bm.Variable(bm.ones(self.master.pre.num)) + self.x = variable(bm.ones, self.master.trainable, self.master.pre.num) - def reset(self): - self.x[:] = 1. + def reset_state(self, batch_size=None): + self.x.value = variable(bm.ones, batch_size, self.master.pre.num) - def update(self, t, dt, pre_spike=None, post_spike=None): - x = self.integral(self.x.value, t, dt) + def update(self, tdi, pre_spike): + x = self.integral(self.x.value, tdi['t'], tdi['dt']) self.x.value = bm.where(pre_spike, x - self.U * self.x, x) def filter(self, g): + if bm.shape(g) != self.x.shape: + raise ValueError('Shape does not match.') return g * self.x -class STP(SynapsePlasticity): +class STP(SynSTP): r"""Synaptic output with short-term plasticity. This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. @@ -152,8 +156,12 @@ def register_master(self, master): super(STP, self).register_master(master) # variables - self.x = bm.Variable(bm.ones(self.master.pre.num)) - self.u = bm.Variable(bm.ones(self.master.pre.num) * self.U) + self.x = variable(bm.ones, self.master.trainable, self.master.pre.num) + self.u = variable(lambda s: bm.ones(s) * self.U, self.master.trainable, self.master.pre.num) + + def reset_state(self, batch_size=None): + self.x.value = variable(bm.ones, batch_size, self.master.pre.num) + self.u.value = variable(lambda s: bm.ones(s) * self.U, batch_size, self.master.pre.num) @property def derivative(self): @@ -161,16 +169,14 @@ def derivative(self): dx = lambda x, t: (1 - x) / self.tau_d return JointEq([du, dx]) - def reset(self): - self.x[:] = 1. - self.u[:] = self.U - - def update(self, t, dt, pre_spike=None, post_spike=None): - u, x = self.integral(self.u.value, self.x.value, t, dt) + def update(self, tdi, pre_spike): + u, x = self.integral(self.u.value, self.x.value, tdi['t'], tdi['dt']) u = bm.where(pre_spike, u + self.U * (1 - self.u), u) x = bm.where(pre_spike, x - u * self.x, x) self.x.value = x self.u.value = u def filter(self, g): + if bm.shape(g) != self.x.shape: + raise ValueError('Shape does not match.') return g * self.x * self.u diff --git a/brainpy/dyn/tests/test_dyn_runner.py b/brainpy/dyn/tests/test_dyn_runner.py index 911e619e8..a191ad6ad 100644 --- a/brainpy/dyn/tests/test_dyn_runner.py +++ b/brainpy/dyn/tests/test_dyn_runner.py @@ -13,7 +13,7 @@ def __init__(self): super(ExampleDS, self).__init__() self.i = bm.Variable(bm.zeros(1)) - def update(self, t, dt): + def update(self, tdi): self.i += 1 ds = ExampleDS() @@ -26,8 +26,8 @@ def __init__(self): super(ExampleDS, self).__init__() self.i = bm.Variable(bm.zeros(1)) - def update(self, t, dt): - self.i += 1 * dt + def update(self, tdi): + self.i += 1 * tdi.dt runner = bp.dyn.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) runner.run(100.) diff --git a/brainpy/dyn/training.py b/brainpy/dyn/training.py new file mode 100644 index 000000000..f1f1f4043 --- /dev/null +++ b/brainpy/dyn/training.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- + +import inspect +from typing import Union, Callable, Optional, Dict, Any + +from brainpy.dyn.base import DynamicalSystem +from brainpy.errors import NoImplementationError +from brainpy.algorithms import OfflineAlgorithm, OnlineAlgorithm +from brainpy.types import Tensor + +__all__ = [ + 'TrainingSystem', 'Sequential', +] + + +def not_customized(fun: Callable) -> Callable: + """Marks the given module method is not implemented. + + Methods wrapped in @not_customized can define submodules directly within the method. + + For instance:: + + @not_customized + init_fb(self): + ... + + @not_customized + def feedback(self): + ... + """ + fun.not_implemented = True + return fun + + +class TrainingSystem(DynamicalSystem): + """Base class for training system in BrainPy. + """ + + '''Online fitting method.''' + online_fit_by: Optional[OnlineAlgorithm] + + '''Offline fitting method.''' + offline_fit_by: Optional[OfflineAlgorithm] + + def __init__(self, name: str = None, trainable: bool = False): + super(TrainingSystem, self).__init__(name=name, trainable=trainable) + + self.online_fit_by = None + self.offline_fit_by = None + self.fit_record = dict() + + def reset(self, batch_size=1): + for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): + node.reset(batch_size=batch_size) + + def reset_state(self, batch_size=1): + for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): + node.reset_state(batch_size=batch_size) + + def __repr__(self): + return f"{type(self).__name__}(name={self.name}, trainable={self.trainable})" + + def __call__(self, *args, **kwargs) -> Tensor: + """The main computation function of a Node. + + Returns + ------- + Tensor + A output tensor value, or a dict of output tensors. + """ + return self.update(*args, **kwargs) + + @not_customized + def update(self, sha: dict, x) -> Tensor: + """Update function of a training system. + + Parameters + ---------- + sha: dict + The shared arguments (ShA) across multiple layers. + x: Any + The input information. + + Returns + ------- + y: Tensor + The output tensor. + """ + raise NotImplementedError('Subclass should implement "update()" function ' + 'when "update()" function is not customized.') + + @not_customized + def online_init(self): + raise NoImplementationError('Subclass must implement online_init() function when using OnlineTrainer.') + + @not_customized + def offline_init(self): + raise NoImplementationError('Subclass must implement offline_init() function when using OfflineTrainer.') + + @not_customized + def online_fit(self, + target: Tensor, + fit_record: Dict[str, Tensor]): + raise NoImplementationError('Subclass must implement online_fit() function when using OnlineTrainer.') + + @not_customized + def offline_fit(self, + target: Tensor, + fit_record: Dict[str, Tensor]): + raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.') + + +class Sequential(TrainingSystem): + def __init__(self, *modules, name: str = None, **kw_modules): + super(Sequential, self).__init__(name=name, trainable=False) + + # add sub-components + for module in modules: + if isinstance(module, TrainingSystem): + self.implicit_nodes[module.name] = module + elif isinstance(module, (list, tuple)): + for m in module: + if not isinstance(m, TrainingSystem): + raise ValueError(f'Should be instance of {TrainingSystem.__name__}. ' + f'But we got {type(m)}') + self.implicit_nodes[m.name] = module + elif isinstance(module, dict): + for k, v in module.items(): + if not isinstance(v, TrainingSystem): + raise ValueError(f'Should be instance of {TrainingSystem.__name__}. ' + f'But we got {type(v)}') + self.implicit_nodes[k] = v + else: + raise ValueError(f'Cannot parse sub-systems. They should be {TrainingSystem.__name__} ' + f'or a list/tuple/dict of {TrainingSystem.__name__}.') + for k, v in kw_modules.items(): + if not isinstance(v, TrainingSystem): + raise ValueError(f'Should be instance of {TrainingSystem.__name__}. ' + f'But we got {type(v)}') + self.implicit_nodes[k] = v + + def __getattr__(self, item): + """Wrap the dot access ('self.'). """ + child_ds = super(Sequential, self).__getattribute__('implicit_nodes') + if item in child_ds: + return child_ds[item] + else: + return super(Sequential, self).__getattribute__(item) + + def __getitem__(self, key: Union[int, slice]): + if isinstance(key, str): + if key not in self.implicit_nodes: + raise KeyError(f'Does not find a component named {key} in\n {str(self)}') + return self.implicit_nodes[key] + elif isinstance(key, slice): + keys = tuple(self.implicit_nodes.keys())[key] + components = tuple(self.implicit_nodes.values())[key] + return Sequential(dict(zip(keys, components))) + elif isinstance(key, int): + return self.implicit_nodes.values()[key] + elif isinstance(key, (tuple, list)): + all_keys = tuple(self.implicit_nodes.keys()) + all_vals = tuple(self.implicit_nodes.values()) + keys, vals = [], [] + for i in key: + if isinstance(i, int): + raise KeyError(f'We excepted a tuple/list of int, but we got {type(i)}') + keys.append(all_keys[i]) + vals.append(all_vals[i]) + return Sequential(dict(zip(keys, vals))) + else: + raise KeyError(f'Unknown type of key: {type(key)}') + + def __repr__(self): + def f(x): + if not isinstance(x, TrainingSystem) and callable(x): + signature = inspect.signature(x) + args = [f'{k}={v.default}' for k, v in signature.parameters.items() + if v.default is not inspect.Parameter.empty] + args = ', '.join(args) + while not hasattr(x, '__name__'): + if not hasattr(x, 'func'): + break + x = x.func # Handle functools.partial + if not hasattr(x, '__name__') and hasattr(x, '__class__'): + return x.__class__.__name__ + if args: + return f'{x.__name__}(*, {args})' + return x.__name__ + else: + x = repr(x).split('\n') + x = [x[0]] + [' ' + y for y in x[1:]] + return '\n'.join(x) + + entries = '\n'.join(f' [{i}] {f(x)}' for i, x in enumerate(self)) + return f'{self.__class__.__name__}(\n{entries}\n)' + + def update(self, sha: dict, x: Any) -> Tensor: + """Update function of a training system. + + Parameters + ---------- + sha: dict + The shared arguments (ShA) across multiple layers. + x: Any + The input information. + + Returns + ------- + y: Tensor + The output tensor. + """ + for node in self.implicit_nodes.values(): + x = node(sha, x) + return x + diff --git a/brainpy/dyn/utils.py b/brainpy/dyn/utils.py deleted file mode 100644 index 98a4dc32f..000000000 --- a/brainpy/dyn/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Union, Callable, Optional, Dict - -import jax.numpy as jnp -import numpy as np -from jax.tree_util import tree_flatten - -from brainpy import math as bm -from brainpy.initialize import init_param, Initializer -from brainpy.types import Shape -from brainpy.tools.checking import check_dict_data - -__all__ = [ - 'init_noise', - 'init_noise', -] - - -def serialize_kwargs(shared_kwargs: Optional[Dict]): - """Serialize kwargs.""" - shared_kwargs = dict() if shared_kwargs is None else shared_kwargs - check_dict_data(shared_kwargs, - key_type=str, - val_type=(bool, float, int, complex), - name='shared_kwargs') - shared_kwargs = {key: shared_kwargs[key] for key in sorted(shared_kwargs.keys())} - return str(shared_kwargs) - -def check_data_batch_size(data, num_batch=None, batch_idx=0): - leaves, tree = tree_flatten(data, is_leaf=lambda x: isinstance(x, bm.JaxArray)) - batches = [leaf.shape[batch_idx] for leaf in leaves] - if len(set(batches)) != 1: - raise ValueError('Batch sizes are not consistent among the given data. ' - f'Got {set(batches)}. We expect only one batch size.') - batch_size = batches[0] - if (num_batch is not None) and batch_size != num_batch: - raise ValueError(f'Batch size is not consistent with the expected {batch_size} != {num_batch}') - return batch_size - - -def init_noise( - noise: Optional[Union[int, bm.ndarray, jnp.ndarray, Initializer, Callable]], - size: Shape, - num_vars: int = 1, - noise_idx: int = 0, -) -> Optional[Callable]: - if callable(noise): - return noise - elif noise is None: - return None - else: - noise = init_param(noise, size, allow_none=False) - if num_vars > 1: - noises = [None] * num_vars - noises[noise_idx] = noise - noise = tuple(noises) - return lambda *args, **kwargs: noise - - -def init_delay( - delay_step: Union[int, bm.ndarray, jnp.ndarray, Callable, Initializer], - delay_target: Union[bm.ndarray, jnp.ndarray], - delay_data: Union[bm.ndarray, jnp.ndarray] = None -): - """Initialize delay variable. - - Parameters - ---------- - delay_step: int, ndarray, JaxArray - The number of delay steps. It can an integer of an array of integers. - delay_target: ndarray, JaxArray - The target variable to delay. - delay_data: optional, ndarray, JaxArray - The initial delay data. - - Returns - ------- - info: tuple - The triple of delay type, delay steps, and delay variable. - """ - # check delay type - if delay_step is None: - delay_type = 'none' - elif isinstance(delay_step, int): - delay_type = 'homo' - elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)): - delay_type = 'heter' - delay_step = bm.asarray(delay_step) - elif callable(delay_step): - delay_step = init_param(delay_step, delay_target.shape, allow_none=False) - delay_type = 'heter' - else: - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' - f'integer, array of integers, callable function, brainpy.init.Initializer.') - if delay_type == 'heter': - if delay_step.dtype not in [bm.int32, bm.int64]: - raise ValueError('Only support delay steps of int32, int64. If your ' - 'provide delay time length, please divide the "dt" ' - 'then provide us the number of delay steps.') - if delay_target.shape[0] != delay_step.shape[0]: - raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}') - - # init delay data - if delay_type == 'homo': - delays = bm.LengthDelay(delay_target, delay_step, initial_delay_data=delay_data) - elif delay_type == 'heter': - if delay_step.size != delay_target.size: - raise ValueError('Heterogeneous delay must have a length ' - f'of the delay target {delay_target.shape}, ' - f'while we got {delay_step.shape}') - delays = bm.LengthDelay(delay_target, int(delay_step.max())) - else: - delays = None - - return delay_type, delay_step, delays diff --git a/brainpy/initialize/generic.py b/brainpy/initialize/generic.py index 69758089d..6a2d8b7fd 100644 --- a/brainpy/initialize/generic.py +++ b/brainpy/initialize/generic.py @@ -1,22 +1,34 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable +from typing import Union, Callable, Optional import jax.numpy as jnp -import numpy as onp +import numpy as np import brainpy.math as bm from brainpy.tools.others import to_size -from brainpy.types import Shape +from brainpy.types import Shape, Tensor from .base import Initializer + __all__ = [ + 'parameter', + 'variable', + 'noise', + 'delay', + + # deprecated 'init_param', ] -def init_param( - param: Union[Callable, Initializer, bm.ndarray, jnp.ndarray, float, int, bool], +def parameter( + param: Union[Callable, + Initializer, + bm.ndarray, + np.ndarray, + jnp.ndarray, + float, int, bool], size: Shape, allow_none: bool = True, ): @@ -24,7 +36,7 @@ def init_param( Parameters ---------- - param: callable, Initializer, bm.ndarray, jnp.ndarray, float, int, bool + param: callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool The initialization of the parameter. - If it is None, the created parameter will be None. - If it is a callable function :math:`f`, the ``f(size)`` will be returned. @@ -40,18 +52,18 @@ def init_param( param: JaxArray, float, None The initialized parameter. """ - size = to_size(size) if param is None: if allow_none: return None else: raise ValueError(f'Expect a parameter with type of float, JaxArray, Initializer, or ' f'Callable function, but we got None. ') - elif isinstance(param, (float, int, bool)): + size = to_size(size) + if isinstance(param, (float, int, bool)): return param elif callable(param): param = bm.asarray(param(size)) - elif isinstance(param, (onp.ndarray, jnp.ndarray)): + elif isinstance(param, (np.ndarray, jnp.ndarray)): param = bm.asarray(param) elif isinstance(param, bm.Variable): param = param @@ -63,3 +75,113 @@ def init_param( raise ValueError(f'The shape of the parameters should be (), (1,) ' f'or {size}, but we got {param.shape}') return param + + +def init_param( + param: Union[Callable, Initializer, bm.ndarray, jnp.ndarray, float, int, bool], + size: Shape, + allow_none: bool = True, +): + return parameter(param, size, allow_none) + + +def variable( + data: Union[Callable, Tensor], + batch_size: Optional[Union[int, bool]] = None, + var_shape: Shape = None +): + var_shape = to_size(var_shape) + if callable(data): + if var_shape is None: + raise ValueError('"varshape" cannot be None when data is a callable function.') + if batch_size in (None, False): + return bm.Variable(data(var_shape)) + else: + return bm.Variable(data((int(batch_size),) + var_shape), batch_axis=0) + else: + if var_shape is not None: + if bm.shape(data) != var_shape: + raise ValueError(f'The shape of "data" {bm.shape(data)} does not match with "var_shape" {var_shape}') + if batch_size in (None, False): + return bm.Variable(data) + else: + data = bm.expand_dims(data, axis=0) + return bm.Variable(bm.repeat(data, int(batch_size), axis=0), batch_axis=0) + + +def noise( + noises: Optional[Union[int, bm.ndarray, jnp.ndarray, Initializer, Callable]], + size: Shape, + num_vars: int = 1, + noise_idx: int = 0, +) -> Optional[Callable]: + if callable(noises): + return noises + elif noises is None: + return None + else: + noises = parameter(noises, size, allow_none=False) + if num_vars > 1: + noises_ = [None] * num_vars + noises_[noise_idx] = noises + noises = tuple(noises_) + return lambda *args, **kwargs: noises + + +def delay( + delay_step: Union[int, bm.ndarray, jnp.ndarray, Callable, Initializer], + delay_target: Union[bm.ndarray, jnp.ndarray], + delay_data: Union[bm.ndarray, jnp.ndarray] = None +): + """Initialize delay variable. + + Parameters + ---------- + delay_step: int, ndarray, JaxArray + The number of delay steps. It can an integer of an array of integers. + delay_target: ndarray, JaxArray + The target variable to delay. + delay_data: optional, ndarray, JaxArray + The initial delay data. + + Returns + ------- + info: tuple + The triple of delay type, delay steps, and delay variable. + """ + # check delay type + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)): + delay_type = 'heter' + delay_step = bm.asarray(delay_step) + elif callable(delay_step): + delay_step = parameter(delay_step, delay_target.shape, allow_none=False) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [bm.int32, bm.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if delay_target.shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}') + + # init delay data + if delay_type == 'homo': + delays = bm.LengthDelay(delay_target, delay_step, initial_delay_data=delay_data) + elif delay_type == 'heter': + if delay_step.size != delay_target.size: + raise ValueError('Heterogeneous delay must have a length ' + f'of the delay target {delay_target.shape}, ' + f'while we got {delay_step.shape}') + delays = bm.LengthDelay(delay_target, int(delay_step.max())) + else: + delays = None + + return delay_type, delay_step, delays + diff --git a/brainpy/initialize/random_inits.py b/brainpy/initialize/random_inits.py index 1d2f55a9f..379ba451e 100644 --- a/brainpy/initialize/random_inits.py +++ b/brainpy/initialize/random_inits.py @@ -46,7 +46,7 @@ def __init__(self, mean=0., scale=1., seed=None): def __call__(self, shape, dtype=None): shape = [tools.size2num(d) for d in shape] - weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale) + weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale) return bm.asarray(weights, dtype=dtype) def __repr__(self): @@ -125,62 +125,109 @@ def __repr__(self): class KaimingUniform(VarianceScaling): - def __init__(self, scale=2.0, mode="fan_in", - distribution="uniform", - in_axis=-2, out_axis=-1, - seed=None): - super(KaimingUniform, self).__init__(scale, mode, distribution, - in_axis=in_axis, out_axis=out_axis, + def __init__( + self, + scale=2.0, mode="fan_in", + distribution="uniform", + in_axis=-2, + out_axis=-1, + seed=None + ): + super(KaimingUniform, self).__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, seed=seed) class KaimingNormal(VarianceScaling): - def __init__(self, scale=2.0, mode="fan_in", - distribution="truncated_normal", - in_axis=-2, out_axis=-1, - seed=None): - super(KaimingNormal, self).__init__(scale, mode, distribution, - in_axis=in_axis, out_axis=out_axis, + def __init__( + self, + scale=2.0, + mode="fan_in", + distribution="truncated_normal", + in_axis=-2, + out_axis=-1, + seed=None + ): + super(KaimingNormal, self).__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, seed=seed) class XavierUniform(VarianceScaling): - def __init__(self, scale=1.0, mode="fan_avg", - distribution="uniform", - in_axis=-2, out_axis=-1, - seed=None): - super(XavierUniform, self).__init__(scale, mode, distribution, - in_axis=in_axis, out_axis=out_axis, + def __init__( + self, + scale=1.0, + mode="fan_avg", + distribution="uniform", + in_axis=-2, + out_axis=-1, + seed=None + ): + super(XavierUniform, self).__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, seed=seed) class XavierNormal(VarianceScaling): - def __init__(self, scale=1.0, mode="fan_avg", - distribution="truncated_normal", - in_axis=-2, out_axis=-1, - seed=None): - super(XavierNormal, self).__init__(scale, mode, distribution, - in_axis=in_axis, out_axis=out_axis, + def __init__( + self, + scale=1.0, + mode="fan_avg", + distribution="truncated_normal", + in_axis=-2, + out_axis=-1, + seed=None + ): + super(XavierNormal, self).__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, seed=seed) class LecunUniform(VarianceScaling): - def __init__(self, scale=1.0, mode="fan_in", - distribution="uniform", - in_axis=-2, out_axis=-1, - seed=None): - super(LecunUniform, self).__init__(scale, mode, distribution, - in_axis=in_axis, out_axis=out_axis, + def __init__( + self, + scale=1.0, + mode="fan_in", + distribution="uniform", + in_axis=-2, + out_axis=-1, + seed=None + ): + super(LecunUniform, self).__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, seed=seed) class LecunNormal(VarianceScaling): - def __init__(self, scale=1.0, mode="fan_in", - distribution="truncated_normal", - in_axis=-2, out_axis=-1, - seed=None): - super(LecunNormal, self).__init__(scale, mode, distribution, - in_axis=in_axis, out_axis=out_axis, + def __init__( + self, + scale=1.0, + mode="fan_in", + distribution="truncated_normal", + in_axis=-2, + out_axis=-1, + seed=None + ): + super(LecunNormal, self).__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, seed=seed) diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py index cd616fadd..4e32a61d5 100644 --- a/brainpy/integrators/fde/Caputo.py +++ b/brainpy/integrators/fde/Caputo.py @@ -5,7 +5,7 @@ """ -from typing import Union, Dict +from typing import Union, Dict, Sequence, Callable import jax.numpy as jnp @@ -17,6 +17,7 @@ from brainpy.tools.errors import check_error_in_jit from .base import FDEIntegrator from .generic import register_fde_integrator, get_supported_methods +from brainpy.types import Tensor __all__ = [ 'CaputoEuler', @@ -113,12 +114,12 @@ class CaputoEuler(FDEIntegrator): def __init__( self, - f, - alpha, - num_memory, - inits, - dt=None, - name=None, + f: Callable, + alpha: Union[float, Sequence[float], Tensor], + num_memory: int, + inits: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], + dt: float = None, + name: str = None, state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, ): super(CaputoEuler, self).__init__(f=f, @@ -306,12 +307,12 @@ class CaputoL1Schema(FDEIntegrator): def __init__( self, - f, - alpha, - num_memory, - inits, - dt=None, - name=None, + f: Callable, + alpha: Union[float, Sequence[float], Tensor], + num_memory: int, + inits: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], + dt: float = None, + name: str = None, state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, ): super(CaputoL1Schema, self).__init__(f=f, diff --git a/brainpy/integrators/ode/tests/test_ode_method_exp_euler.py b/brainpy/integrators/ode/tests/test_ode_method_exp_euler.py index d10aba348..542625171 100644 --- a/brainpy/integrators/ode/tests/test_ode_method_exp_euler.py +++ b/brainpy/integrators/ode/tests/test_ode_method_exp_euler.py @@ -94,7 +94,8 @@ def dV(self, V, t, h, n, Iext): return dVdt - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi.t, tdi.dt V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt) self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V diff --git a/brainpy/integrators/tests/test_integ_runner.py b/brainpy/integrators/tests/test_integ_runner.py index dca80a7ae..e1a8bc4e7 100644 --- a/brainpy/integrators/tests/test_integ_runner.py +++ b/brainpy/integrators/tests/test_integ_runner.py @@ -23,15 +23,15 @@ def lorenz(x, y, z, t): dz = x * y - beta * z return dx, dy, dz - runner = bp.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.]) + runner = bp.integrators.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.]) runner.run(100.) fig = plt.figure() fig.add_subplot(111, projection='3d') plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0], ) plt.show() - runner = bp.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], - inits=[1., (1., 0.), (1., 0.)]) + runner = bp.integrators.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], + inits=[1., (1., 0.), (1., 0.)]) runner.run(100.) for i in range(2): fig = plt.figure() @@ -45,7 +45,7 @@ def test_ode2(self): dw = lambda w, t, V: (V + a - b * w) / tau fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) - runner = bp.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.], args=dict(Iext=1.5)) + runner = bp.integrators.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.], args=dict(Iext=1.5)) runner.run(100.) bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=True) @@ -57,9 +57,9 @@ def test_ode3(self): fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 500, 200], return_length=True) - runner = bp.IntegratorRunner(fhn, - monitors=['V', 'w'], inits=[1., 1.], - dyn_args=dict(Iext=Iext)) + runner = bp.integrators.IntegratorRunner(fhn, + monitors=['V', 'w'], inits=[1., 1.], + dyn_args=dict(Iext=Iext)) runner.run(duration) bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=True) diff --git a/brainpy/integrators/utils.py b/brainpy/integrators/utils.py index d7341434b..c0959889c 100644 --- a/brainpy/integrators/utils.py +++ b/brainpy/integrators/utils.py @@ -127,7 +127,7 @@ def check_inits(inits, variables): raise ValueError(f'"{key}" is not defined in variables: {variables}') val = inits[key] if isinstance(val, (float, int)): - inits[key] = bm.asarray([val], dtype=bm.get_dfloat()) + inits[key] = bm.asarray([val], dtype=bm.dftype()) return inits diff --git a/brainpy/losses/comparison.py b/brainpy/losses/comparison.py index 12d6e4254..0626e7102 100644 --- a/brainpy/losses/comparison.py +++ b/brainpy/losses/comparison.py @@ -7,11 +7,16 @@ # - https://github.com/deepmind/optax/blob/master/optax/_src/loss.py # - https://github.com/google/jaxopt/blob/main/jaxopt/_src/loss.py + +from typing import Tuple + import jax.numpy as jnp -import jax.scipy +from jax.scipy.special import logsumexp from jax.tree_util import tree_map +from jax.lax import scan import brainpy.math as bm +from brainpy.types import Tensor from .utils import _return, _multi_return, _is_leaf __all__ = [ @@ -26,10 +31,11 @@ 'mean_squared_log_error', 'binary_logistic_loss', 'multiclass_logistic_loss', - 'smooth_labels', 'sigmoid_binary_cross_entropy', 'softmax_cross_entropy', 'log_cosh_loss', + 'ctc_loss_with_forward_probs', + 'ctc_loss', ] @@ -66,11 +72,11 @@ def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'): Parameters ---------- - predicts : jmath.JaxArray + predicts : Tensor :math:`(N, C)` where `C = number of classes`, or :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - targets : jmath.JaxArray + targets : JaxArray :math:`(N, C)` or :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` @@ -90,26 +96,18 @@ def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'): :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` in the case of K-dimensional loss. """ - targets = bm.as_device_array(targets) - predicts = bm.as_device_array(predicts) - - # loss - if bm.ndim(targets) + 1 == bm.ndim(predicts): - # targets_old = targets.reshape((-1,)) - # length = targets_old.shape[0] - # rows = jn.arange(length) - # targets = ops.zeros((length, logits.shape[-1])) - # targets[rows, targets_old] = 1. - # targets = targets.reshape(logits.shape).value - targets = bm.activations.one_hot(targets, predicts.shape[-1]) - loss = jax.scipy.special.logsumexp(predicts, axis=-1) - (predicts * targets).sum(axis=-1) - # weighted loss if weight: - loss *= weight[targets] raise NotImplementedError - return _return(outputs=loss, reduction=reduction) + def _cel(_pred, _tar): + if bm.ndim(_tar) + 1 == bm.ndim(_pred): + _tar = bm.activations.one_hot(_tar, _pred.shape[-1]) + loss = logsumexp(bm.as_device_array(_pred), axis=-1) - (_pred * _tar).sum(axis=-1) + return _return(outputs=loss, reduction=reduction) + + r = tree_map(_cel, predicts, targets, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def cross_entropy_sparse(predicts, targets): @@ -122,14 +120,16 @@ def cross_entropy_sparse(predicts, targets): Returns: (batch, ...) tensor of the cross-entropy for each entry. """ - predicts = bm.as_device_array(predicts) - targets = bm.as_device_array(targets) - if isinstance(targets, int): - labeled_logits = predicts[..., targets] - else: - labeled_logits = jnp.take_along_axis(predicts, targets, -1).squeeze(-1) - loss = jax.scipy.special.logsumexp(predicts, axis=-1) - labeled_logits - return loss + + def crs(_prd, _tar): + if isinstance(_tar, int): + logits = _prd[..., _tar] + else: + logits = bm.take_along_axis(_prd, _tar, -1).squeeze(-1) + return logsumexp(bm.as_device_array(_prd), axis=-1) - logits + + r = tree_map(crs, predicts, targets, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def cross_entropy_sigmoid(predicts, targets): @@ -142,9 +142,9 @@ def cross_entropy_sigmoid(predicts, targets): Returns: (batch, ...) tensor of the cross-entropies for each entry. """ - predicts = bm.as_device_array(predicts) - targets = bm.as_device_array(targets) - return jnp.maximum(predicts, 0) - predicts * targets + jnp.log(1 + jnp.exp(-jnp.abs(predicts))) + r = tree_map(lambda pred, tar: bm.maximum(pred, 0) - pred * tar + bm.log(1 + bm.exp(-bm.abs(pred))), + predicts, targets, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def l1_loos(logits, targets, reduction='sum'): @@ -178,9 +178,9 @@ def l1_loos(logits, targets, reduction='sum'): Parameters ---------- - logits : jmath.JaxArray + logits : JaxArray :math:`(N, *)` where :math:`*` means, any number of additional dimensions. - targets : jmath.JaxArray + targets : JaxArray :math:`(N, *)`, same shape as the input. reduction : str Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. @@ -194,9 +194,14 @@ def l1_loos(logits, targets, reduction='sum'): output : scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. """ - diff = (logits - targets).reshape((logits.shape[0], -1)) - norm = jnp.linalg.norm(bm.as_device_array(diff), ord=1, axis=1, keepdims=False) - return _return(outputs=norm, reduction=reduction) + + def loss(pred, tar): + diff = (pred - tar).reshape((pred.shape[0], -1)) + norm = jnp.linalg.norm(bm.as_device_array(diff), ord=1, axis=1, keepdims=False) + return _return(outputs=norm, reduction=reduction) + + r = tree_map(loss, logits, targets, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def l2_loss(predicts, targets): @@ -222,7 +227,9 @@ def l2_loss(predicts, targets): ---------- .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. """ - return bm.as_device_array(0.5 * (predicts - targets) ** 2) + r = tree_map(lambda pred, tar: 0.5 * (pred - tar) ** 2, predicts, targets, + is_leaf=lambda a: isinstance(a, bm.JaxArray)) + return _multi_return(r) def mean_absolute_error(x, y, axis=None): @@ -246,7 +253,7 @@ def mean_squared_error(predicts, targets, axis=None): Args: predicts: a tensor of shape (d0, .. dN-1). targets: a tensor of shape (d0, .. dN-1). - keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. + axis: a sequence of the dimensions to keep, use `None` to return a scalar value. Returns: tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. @@ -296,6 +303,7 @@ def huber_loss(predicts, targets, delta: float = 1.0): ---------- .. [1] https://en.wikipedia.org/wiki/Huber_loss """ + def _loss(y_predict, y_target): # 0.5 * err^2 if |err| <= d # 0.5 * d^2 + d * (|err| - d) if |err| > d @@ -304,7 +312,8 @@ def _loss(y_predict, y_target): delta * (diff - .5 * delta), 0.5 * diff ** 2) - return tree_map(_loss, targets, predicts, is_leaf=_is_leaf) + r = tree_map(_loss, targets, predicts, is_leaf=_is_leaf) + return _multi_return(r) def binary_logistic_loss(predicts: float, targets: int, ) -> float: @@ -319,7 +328,9 @@ def binary_logistic_loss(predicts: float, targets: int, ) -> float: # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), # where xlogx(proba) = proba * log(proba). - return bm.as_device_array(bm.activations.softplus(predicts) - targets * predicts) + r = tree_map(lambda a, b: bm.activations.softplus(a) - b * a, + predicts, targets, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float: @@ -331,30 +342,14 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float: Returns: loss value """ - logits = bm.as_device_array(logits) - n_classes = logits.shape[0] - one_hot = bm.one_hot(label, n_classes) - # Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex. - # logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba)) - return jax.scipy.special.logsumexp(logits) - bm.dot(logits, one_hot) - - -def smooth_labels(labels, alpha: float) -> jnp.ndarray: - r"""Apply label smoothing. - Label smoothing is often used in combination with a cross-entropy loss. - Smoothed labels favour small logit gaps, and it has been shown that this can - provide better model calibration by preventing overconfident predictions. - References: - [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) - Args: - labels: one hot labels to be smoothed. - alpha: the smoothing factor, the greedy category with be assigned - probability `(1-alpha) + alpha / num_categories` - Returns: - a smoothed version of the one hot input labels. - """ - num_categories = labels.shape[-1] - return (1.0 - alpha) * labels + alpha / num_categories + + def loss(pred, tar): + pred = bm.as_device_array(pred) + one_hot = bm.one_hot(tar, pred.shape[0]) + return logsumexp(pred) - bm.dot(pred, one_hot) + + r = tree_map(loss, logits, label, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def sigmoid_binary_cross_entropy(logits, labels): @@ -371,10 +366,15 @@ def sigmoid_binary_cross_entropy(logits, labels): Returns: a sigmoid cross entropy loss. """ - log_p = bm.log_sigmoid(logits) - # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable - log_not_p = bm.log_sigmoid(-logits) - return -labels * log_p - (1. - labels) * log_not_p + + def loss(pred, tar): + log_p = bm.log_sigmoid(pred) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable + log_not_p = bm.log_sigmoid(-pred) + return -tar * log_p - (1. - tar) * log_not_p + + r = tree_map(loss, logits, labels, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def softmax_cross_entropy(logits, labels): @@ -392,9 +392,9 @@ def softmax_cross_entropy(logits, labels): Returns: the cross entropy loss. """ - logits = bm.as_device_array(logits) - labels = bm.as_device_array(labels) - return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) + r = tree_map(lambda pred, tar: -bm.sum(tar * bm.log_softmax(pred, axis=-1), axis=-1), + logits, labels, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) def log_cosh_loss(predicts, targets): @@ -411,5 +411,170 @@ def log_cosh_loss(predicts, targets): Returns: the log-cosh loss. """ - errors = bm.as_device_array(predicts - targets) - return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) + + def loss(pred, tar): + errors = bm.as_device_array(pred - tar) + return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) + + r = tree_map(loss, predicts, targets, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) + + +def ctc_loss_with_forward_probs( + logits: Tensor, + logit_paddings: Tensor, + labels: Tensor, + label_paddings: Tensor, + blank_id: int = 0, + log_epsilon: float = -1e5 +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Computes CTC loss and CTC forward-probabilities. + The CTC loss is a loss function based on log-likelihoods of the model that + introduces a special blank symbol :math:`\phi` to represent variable-length + output sequences. + Forward probabilities returned by this function, as auxiliary results, are + grouped into two part: blank alpha-probability and non-blank alpha + probability. Those are defined as follows: + .. math:: + \alpha_{\mathrm{BLANK}}(t, n) = + \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ + \alpha_{\mathrm{LABEL}}(t, n) = + \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). + Here, :math:`\pi` denotes the alignment sequence in the reference + [Graves et al, 2006] that is blank-inserted representations of ``labels``. + The return values are the logarithms of the above probabilities. + References: + [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) + Args: + logits: (B, T, K)-array containing logits of each class where B denotes + the batch size, T denotes the max time frames in ``logits``, and K + denotes the number of classes including a class for blanks. + logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` + denotes that ``logits[b, t, :]`` are padded values. + labels: (B, N)-array containing reference integer labels where N denotes + the max time frames in the label sequence. + label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` + denotes that ``labels[b, n]`` is a padded label. In the current + implementation, ``labels`` must be right-padded, i.e. each row + ``labelpaddings[b, :]`` must be repetition of zeroes, followed by + repetition of ones. + blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + probabilities of blank symbols. + log_epsilon: Numerically-stable approximation of log(+0). + Returns: + A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, + ``loss_value`` is a (B,)-array containing the loss values for each sequence + in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are + (T, B, N+1)-arrays where the (t, b, n)-th element denotes + \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th + sequence in the batch. + """ + assert logits.ndim == 3 + assert labels.ndim == 2 + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_of_labels, maxlabellen = labels.shape + assert (batchsize == batchsize_of_labels) + assert (labels.shape == label_paddings.shape) + assert (logits.shape[:2] == logit_paddings.shape) + + logits = logits.value if isinstance(logits, bm.JaxArray) else logits + logit_paddings = logit_paddings.value if isinstance(logit_paddings, bm.JaxArray) else logit_paddings + labels = labels.value if isinstance(labels, bm.JaxArray) else labels + label_paddings = label_paddings.value if isinstance(label_paddings, bm.JaxArray) else label_paddings + + logprobs = bm.log_softmax(logits).value + labellens = maxlabellen - jnp.sum(label_paddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id:blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = bm.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum('btk,bnk->btn', logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones( + (batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return jnp.concatenate( + [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, + prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logit_paddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = bm.one_hot(labellens, num_classes=maxlabellen + 1).value # [B, N+1] + per_seq_loss = -jnp.einsum('bn,bn->b', logalpha_phi_last, one_hot) + + return per_seq_loss, logalpha_phi, logalpha_emit + + +def ctc_loss(logits: Tensor, + logit_paddings: Tensor, + labels: Tensor, + label_paddings: Tensor, + blank_id: int = 0, + log_epsilon: float = -1e5) -> Tensor: + """Computes CTC loss. + See docstring for ``ctc_loss_with_forward_probs`` for details. + Args: + logits: (B, T, K)-array containing logits of each class where B denotes + the batch size, T denotes the max time frames in ``logits``, and K + denotes the number of classes including a class for blanks. + logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` + denotes that ``logits[b, t, :]`` are padded values. + labels: (B, N)-array containing reference integer labels where N denotes + the max time frames in the label sequence. + label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` + denotes that ``labels[b, n]`` is a padded label. In the current + implementation, ``labels`` must be right-padded, i.e. each row + ``labelpaddings[b, :]`` must be repetition of zeroes, followed by + repetition of ones. + blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + probabilities of blank symbols. + log_epsilon: Numerically-stable approximation of log(+0). + Returns: + (B,)-array containing loss values for each sequence in the batch. + """ + per_seq_loss, _, _ = ctc_loss_with_forward_probs( + logits, logit_paddings, labels, label_paddings, + blank_id=blank_id, log_epsilon=log_epsilon) + return per_seq_loss diff --git a/brainpy/losses/regularization.py b/brainpy/losses/regularization.py index 51811063c..9fc7f664e 100644 --- a/brainpy/losses/regularization.py +++ b/brainpy/losses/regularization.py @@ -11,6 +11,7 @@ 'mean_absolute', 'mean_square', 'log_cosh', + 'smooth_labels', ] @@ -57,3 +58,23 @@ def log_cosh(errors): r = tree_map(lambda a: bm.logaddexp(a, -a) - bm.log(2.0).astype(a.dtype), errors, is_leaf=_is_leaf) return _multi_return(r) + + +def smooth_labels(labels, alpha: float) -> jnp.ndarray: + r"""Apply label smoothing. + Label smoothing is often used in combination with a cross-entropy loss. + Smoothed labels favour small logit gaps, and it has been shown that this can + provide better model calibration by preventing overconfident predictions. + References: + [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) + Args: + labels: one hot labels to be smoothed. + alpha: the smoothing factor, the greedy category with be assigned + probability `(1-alpha) + alpha / num_categories` + Returns: + a smoothed version of the one hot input labels. + """ + r = tree_map(lambda tar: (1.0 - alpha) * tar + alpha / tar.shape[-1], + labels, is_leaf=lambda x: isinstance(x, bm.JaxArray)) + return _multi_return(r) + diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 12d693184..36f8c33ce 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -38,7 +38,6 @@ # functions from .activations import * from . import activations -from .compat import * # high-level numpy operations from .numpy_ops import * @@ -59,40 +58,3 @@ from . import setting from .setting import * from .function import * - - -def get_dint(): - """Get default int type.""" - return int_ - - -def get_dfloat(): - """Get default float type.""" - return float_ - - -def get_dcomplex(): - """Get default complex type.""" - return complex_ - - -def set_dint(int_type): - """Set default int type.""" - global int_ - assert isinstance(int_type, type) - int_ = int_type - - -def set_dfloat(float_type): - """Set default float type.""" - global float_ - assert isinstance(float_type, type) - float_ = float_type - - -def set_dcomplex(complex_type): - """Set default complex type.""" - global complex_ - assert isinstance(complex_type, type) - complex_ = complex_type - diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py index e99c82aa9..fc0793a74 100644 --- a/brainpy/math/activations.py +++ b/brainpy/math/activations.py @@ -355,7 +355,7 @@ def one_hot(x, num_classes, *, dtype=None, axis=-1): num_classes = jax.core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = jax.dtypes.canonicalize_dtype(jnp.float64 if dtype is None else dtype) - x = jnp.asarray(x) + x = jnp.asarray(x.value if isinstance(x, JaxArray) else x) try: output_pos_axis = _canonicalize_axis(axis, x.ndim + 1) except TypeError: diff --git a/brainpy/math/autograd.py b/brainpy/math/autograd.py index 718157591..02c028da7 100644 --- a/brainpy/math/autograd.py +++ b/brainpy/math/autograd.py @@ -240,7 +240,7 @@ def grad(func, grad_vars=None, dyn_vars=None, argnums=None, holomorphic=False, Parameters ---------- - func : function, Base + func : callable, function, Base Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified by ``argnums`` must be of diff --git a/brainpy/math/compat/__init__.py b/brainpy/math/compat/__init__.py deleted file mode 100644 index a547f80eb..000000000 --- a/brainpy/math/compat/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- - -__all__ = [ - 'optimizers', 'losses', - 'FixedLenDelay', -] - -from . import optimizers, losses -from .delayvars import * - diff --git a/brainpy/math/compat/delayvars.py b/brainpy/math/compat/delayvars.py deleted file mode 100644 index 1207ff757..000000000 --- a/brainpy/math/compat/delayvars.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings -from typing import Union, Callable - -import jax.numpy as jnp - -from brainpy.math.delayvars import TimeDelay - - -__all__ = [ - 'FixedLenDelay' -] - - -def FixedLenDelay(shape, - delay_len: Union[float, int], - before_t0: Union[Callable, jnp.ndarray, float, int] = None, - t0: Union[float, int] = 0., - dt: Union[float, int] = None, - name: str = None, - interp_method='linear_interp', ): - """Delay variable which has a fixed delay length. - - .. deprecated:: 2.1.2 - Please use "brainpy.math.TimeDelay" instead. - - See Also - -------- - TimeDelay - - """ - warnings.warn('Please use "brainpy.math.TimeDelay" instead. ' - '"brainpy.math.FixedLenDelay" is deprecated since version 2.1.2. ', - DeprecationWarning) - return TimeDelay(delay_target=jnp.zeros(shape), - delay_len=delay_len, - before_t0=before_t0, - t0=t0, - dt=dt, - name=name, - interp_method=interp_method) - diff --git a/brainpy/math/compat/losses.py b/brainpy/math/compat/losses.py deleted file mode 100644 index f2de660be..000000000 --- a/brainpy/math/compat/losses.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -from brainpy import losses - -__all__ = [ - 'cross_entropy_loss', - 'l1_loos', - 'l2_loss', - 'l2_norm', - 'huber_loss', - 'mean_absolute_error', - 'mean_squared_error', - 'mean_squared_log_error', -] - - -def cross_entropy_loss(*args, **kwargs): - """Cross entropy loss. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.cross_entropy_loss" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.cross_entropy_loss(*args, **kwargs) - - -def l1_loos(*args, **kwargs): - """L1 loss. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.l1_loss" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.l1_loos(*args, **kwargs) - - -def l2_loss(*args, **kwargs): - """L2 loss. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.l2_loss" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.l2_loss(*args, **kwargs) - - -def l2_norm(*args, **kwargs): - """L2 normal. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.l2_norm" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.l2_norm(*args, **kwargs) - - -def huber_loss(*args, **kwargs): - """Huber loss. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.huber_loss" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.huber_loss(*args, **kwargs) - - -def mean_absolute_error(*args, **kwargs): - """mean absolute error loss. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.mean_absolute_error" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.mean_absolute_error(*args, **kwargs) - - -def mean_squared_error(*args, **kwargs): - """Mean squared error loss. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.mean_squared_error" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.mean_squared_error(*args, **kwargs) - - -def mean_squared_log_error(*args, **kwargs): - """Mean squared log error loss. - - .. deprecated:: 2.1.0 - Please use "brainpy.losses.mean_squared_log_error" instead. - """ - warnings.warn('Please use "brainpy.losses.XXX" instead. ' - '"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', - DeprecationWarning) - return losses.mean_squared_log_error(*args, **kwargs) diff --git a/brainpy/math/compat/optimizers.py b/brainpy/math/compat/optimizers.py deleted file mode 100644 index d12d29fe8..000000000 --- a/brainpy/math/compat/optimizers.py +++ /dev/null @@ -1,177 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -from brainpy import optimizers - -__all__ = [ - 'SGD', - 'Momentum', - 'MomentumNesterov', - 'Adagrad', - 'Adadelta', - 'RMSProp', - 'Adam', - - 'Constant', - 'ExponentialDecay', - 'InverseTimeDecay', - 'PolynomialDecay', - 'PiecewiseConstant', -] - - -def SGD(*args, **kwargs): - """SGD optimizer. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.SGD" instead. - """ - warnings.warn('Please use "brainpy.optim.SGD" instead. ' - '"brainpy.math.optimizers.SGD" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.SGD(*args, **kwargs) - - -def Momentum(*args, **kwargs): - """Momentum optimizer. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.Momentum" instead. - """ - warnings.warn('Please use "brainpy.optim.Momentum" instead. ' - '"brainpy.math.optimizers.Momentum" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.Momentum(*args, **kwargs) - - -def MomentumNesterov(*args, **kwargs): - """MomentumNesterov optimizer. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.MomentumNesterov" instead. - """ - warnings.warn('Please use "brainpy.optim.MomentumNesterov" instead. ' - '"brainpy.math.optimizers.MomentumNesterov" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.MomentumNesterov(*args, **kwargs) - - -def Adagrad(*args, **kwargs): - """Adagrad optimizer. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.Adagrad" instead. - """ - warnings.warn('Please use "brainpy.optim.Adagrad" instead. ' - '"brainpy.math.optimizers.Adagrad" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.Adagrad(*args, **kwargs) - - -def Adadelta(*args, **kwargs): - """Adadelta optimizer. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.Adadelta" instead. - """ - warnings.warn('Please use "brainpy.optim.Adadelta" instead. ' - '"brainpy.math.optimizers.Adadelta" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.Adadelta(*args, **kwargs) - - -def RMSProp(*args, **kwargs): - """RMSProp optimizer. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.RMSProp" instead. - """ - warnings.warn('Please use "brainpy.optim.RMSProp" instead. ' - '"brainpy.math.optimizers.RMSProp" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.RMSProp(*args, **kwargs) - - -def Adam(*args, **kwargs): - """Adam optimizer. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.Adam" instead. - """ - warnings.warn('Please use "brainpy.optim.Adam" instead. ' - '"brainpy.math.optimizers.Adam" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.Adam(*args, **kwargs) - - -def Constant(*args, **kwargs): - """Constant scheduler. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.Constant" instead. - """ - warnings.warn('Please use "brainpy.optim.Constant" instead. ' - '"brainpy.math.optimizers.Constant" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.Constant(*args, **kwargs) - - -def ExponentialDecay(*args, **kwargs): - """ExponentialDecay scheduler. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.ExponentialDecay" instead. - """ - warnings.warn('Please use "brainpy.optim.ExponentialDecay" instead. ' - '"brainpy.math.optimizers.ExponentialDecay" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.ExponentialDecay(*args, **kwargs) - - -def InverseTimeDecay(*args, **kwargs): - """InverseTimeDecay scheduler. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.InverseTimeDecay" instead. - """ - warnings.warn('Please use "brainpy.optim.InverseTimeDecay" instead. ' - '"brainpy.math.optimizers.InverseTimeDecay" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.InverseTimeDecay(*args, **kwargs) - - -def PolynomialDecay(*args, **kwargs): - """PolynomialDecay scheduler. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.PolynomialDecay" instead. - """ - warnings.warn('Please use "brainpy.optim.PolynomialDecay" instead. ' - '"brainpy.math.optimizers.PolynomialDecay" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.PolynomialDecay(*args, **kwargs) - - -def PiecewiseConstant(*args, **kwargs): - """PiecewiseConstant scheduler. - - .. deprecated:: 2.1.0 - Please use "brainpy.optim.PiecewiseConstant" instead. - """ - warnings.warn('Please use "brainpy.optim.PiecewiseConstant" instead. ' - '"brainpy.math.optimizers.PiecewiseConstant" is ' - 'deprecated since version 2.0.3. ', - DeprecationWarning) - return optimizers.PiecewiseConstant(*args, **kwargs) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index 29939c594..3f5ef4120 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -122,7 +122,6 @@ def __init__( # shape if not isinstance(delay_target, (jnp.ndarray, JaxArray)): raise ValueError(f'Must be an instance of JaxArray or jax.numpy.ndarray. But we got {type(delay_target)}') - self.shape = delay_target.shape # delay_len self.t0 = t0 @@ -143,11 +142,15 @@ def __init__( self.current_time = Variable(jnp.asarray([t0])) # delay data - self.data = Variable(jnp.zeros((self.num_delay_step,) + self.shape, dtype=delay_target.dtype)) + batch_axis = None + if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): + batch_axis = delay_target.batch_axis + 1 + self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), + batch_axis=batch_axis) if before_t0 is None: self._before_type = _DATA_BEFORE elif callable(before_t0): - self._before_t0 = lambda t: bm.asarray(bm.broadcast_to(before_t0(t), self.shape), + self._before_t0 = lambda t: bm.asarray(bm.broadcast_to(before_t0(t), delay_target.shape), dtype=delay_target.dtype).value self._before_type = _FUNC_BEFORE elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): @@ -160,7 +163,7 @@ def __init__( # interpolation function self._interp_fun = jnp.interp - for dim in range(1, len(self.shape) + 1, 1): + for dim in range(1, delay_target.ndim + 1, 1): self._interp_fun = vmap(self._interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) def reset(self, @@ -183,8 +186,7 @@ def reset(self, """ self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 - self.data.value = jnp.zeros((self.num_delay_step,) + self.shape, - dtype=delay_target.dtype) + self.data.value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) self.data[-1] = delay_target self.idx = Variable(jnp.asarray([0])) self.current_time = Variable(jnp.asarray([t0])) @@ -290,7 +292,6 @@ def __init__( # attributes and variables self.num_delay_step: int = None - self.shape: Tuple[int] = None self.idx: Variable = None self.data: Variable = None @@ -306,7 +307,6 @@ def reset( if not isinstance(delay_target, (ndarray, jnp.ndarray)): raise ValueError(f'Must be an instance of brainpy.math.ndarray ' f'or jax.numpy.ndarray. But we got {type(delay_target)}') - self.shape = delay_target.shape # delay_len check_integer(delay_len, 'delay_len', allow_none=True, min_bound=0) @@ -324,16 +324,20 @@ def reset( # delay data if self.data is None: - self.data = Variable(jnp.zeros((self.num_delay_step,) + self.shape, dtype=delay_target.dtype)) + batch_axis = None + if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): + batch_axis = delay_target.batch_axis + 1 + self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), + batch_axis=batch_axis) else: - self.data._value = jnp.zeros((self.num_delay_step,) + self.shape, dtype=delay_target.dtype) + self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) self.data[-1] = delay_target if initial_delay_data is None: pass elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): self.data[:-1] = initial_delay_data elif callable(initial_delay_data): - self.data[:-1] = initial_delay_data((delay_len,) + self.shape, dtype=delay_target.dtype) + self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape, dtype=delay_target.dtype) else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') @@ -354,8 +358,6 @@ def __call__(self, delay_len, *indices): return self.data[indices] def update(self, value: Union[float, JaxArray, jnp.DeviceArray]): - if jnp.shape(value) != self.shape: - raise ValueError(f'value shape should be {self.shape}, but we got {jnp.shape(value)}') self.data[self.idx[0]] = value self.idx.value = (self.idx + 1) % self.num_delay_step diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 99fd98422..2c68e5556 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -50,10 +50,14 @@ class JaxArray(object): """ __slots__ = ("_value", "_outside_global_jit") - def __init__(self, value): + def __init__(self, value, dtype=None): # array value - if isinstance(value, (list, tuple)): value = jnp.asarray(value) - if isinstance(value, JaxArray): value = value._value + if isinstance(value, (list, tuple)): + value = jnp.asarray(value) + if isinstance(value, JaxArray): + value = value._value + if dtype is not None: + value = jnp.asarray(value, dtype=dtype) self._value = value # jit mode self._outside_global_jit = False if _global_jit_mode else True @@ -545,7 +549,7 @@ def diagonal(self, offset=0, axis1=0, axis2=1): def dot(self, b): """Dot product of two arrays.""" - return JaxArray(self.value.dot(b)) + return JaxArray(self.value.dot(b.value if isinstance(b, JaxArray) else b)) def fill(self, value): """Fill the array with a scalar value.""" @@ -902,17 +906,58 @@ def __jax_array__(self): class Variable(JaxArray): """The pointer to specify the dynamical variable. """ - __slots__ = ('_value',) + __slots__ = ('_value', '_batch_axis') + + def __init__(self, value, dtype=None, batch_axis: int = None): + super(Variable, self).__init__(value, dtype=dtype) + + # check batch axis + if isinstance(value, Variable): + if value.batch_axis is not None and batch_axis is not None: + if batch_axis != value.batch_axis: + raise ValueError(f'"batch_axis" is not consistent. Got batch_axis in the given value ' + f'is {value.batch_axis}, but the specified batch_axis is {batch_axis}') + batch_axis = value.batch_axis + + # assign batch axis + self._batch_axis = batch_axis + if batch_axis is not None: + if batch_axis >= self.ndim: + raise MathError(f'This variables has {self.ndim} dimension, ' + f'but the batch axis is set to be {batch_axis}.') + + @property + def batch_axis(self): + return self._batch_axis - def __init__(self, value): - super(Variable, self).__init__(value) + @batch_axis.setter + def batch_axis(self, val): + raise ValueError(f'Cannot set "batch_axis" after creating a {self.__class__.__name__} instance.') + + @property + def batch_size(self): + return self.ndim[self._batch_axis] + + @batch_size.setter + def batch_size(self, val): + raise ValueError(f'Cannot set "batch_size" manually.') def update(self, value): """Update the value of this JaxArray. """ - if value.shape != self._value.shape: - raise MathError(f"The shape of the original data is {self._value.shape}, " - f"while we got {value.shape}.") + if self._batch_axis is None: + ext_shape = value.shape + int_shape = self._value.shape + else: + ext_shape = value.shape[:self._batch_axis] + value.shape[self._batch_axis + 1:] + int_shape = self._value.shape[:self._batch_axis] + self._value.shape[self._batch_axis + 1:] + if ext_shape != int_shape: + error = f"The shape of the original data is {self._value.shape}, while we got {value.shape}" + if self._batch_axis is None: + error += '. Do you forget to set "batch_axis" when initialize this variable?' + else: + error += f' with batch_axis={self._batch_axis}.' + raise MathError(error) if value.dtype != self._value.dtype: raise MathError(f"The dtype of the original data is {self._value.dtype}, " f"while we got {value.dtype}.") @@ -1007,23 +1052,508 @@ def sort(self, axis=-1, kind=None, order=None): """Sort an array in-place.""" self._value = self.value.sort(axis=axis, kind=kind, order=order) + # ---------- # + # operations # + # ---------- # + + def __bool__(self) -> bool: + return self._value.__bool__() + + def __len__(self) -> int: + return len(self._value) + + def __neg__(self): + return self._value.__neg__() + + def __pos__(self): + return self._value.__pos__() + + def __abs__(self): + return self._value.__abs__() + + def __invert__(self): + return self._value.__invert__() + + def __eq__(self, oc): + return self._value == (oc._value if isinstance(oc, JaxArray) else oc) + + def __ne__(self, oc): + return self._value != (oc._value if isinstance(oc, JaxArray) else oc) + + def __lt__(self, oc): + return self._value < (oc._value if isinstance(oc, JaxArray) else oc) + + def __le__(self, oc): + return self._value <= (oc._value if isinstance(oc, JaxArray) else oc) + + def __gt__(self, oc): + return self._value > (oc._value if isinstance(oc, JaxArray) else oc) + + def __ge__(self, oc): + return self._value >= (oc._value if isinstance(oc, JaxArray) else oc) + + def __add__(self, oc): + return self._value + (oc._value if isinstance(oc, JaxArray) else oc) + + def __radd__(self, oc): + return self._value + (oc._value if isinstance(oc, JaxArray) else oc) + + def __sub__(self, oc): + return self._value - (oc._value if isinstance(oc, JaxArray) else oc) + + def __rsub__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) - self._value + + def __mul__(self, oc): + return self._value * (oc._value if isinstance(oc, JaxArray) else oc) + + def __rmul__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) * self._value + + def __rdiv__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) / self._value + + def __truediv__(self, oc): + return self._value / (oc._value if isinstance(oc, JaxArray) else oc) + + def __rtruediv__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) / self._value + + def __floordiv__(self, oc): + return self._value // (oc._value if isinstance(oc, JaxArray) else oc) + + def __rfloordiv__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) // self._value + + def __divmod__(self, oc): + return self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rdivmod__(self, oc): + return self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc) + + def __mod__(self, oc): + return self._value % (oc._value if isinstance(oc, JaxArray) else oc) + + def __rmod__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) % self._value + + def __pow__(self, oc): + return self._value ** (oc._value if isinstance(oc, JaxArray) else oc) + + def __rpow__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) ** self._value + + def __matmul__(self, oc): + return self._value @ (oc._value if isinstance(oc, JaxArray) else oc) + + def __rmatmul__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) @ self._value + + def __and__(self, oc): + return self._value & (oc._value if isinstance(oc, JaxArray) else oc) + + def __rand__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) & self._value + + def __or__(self, oc): + return self._value | (oc._value if isinstance(oc, JaxArray) else oc) + + def __ror__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) | self._value + + def __xor__(self, oc): + return self._value ^ (oc._value if isinstance(oc, JaxArray) else oc) + + def __rxor__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) ^ self._value + + def __lshift__(self, oc): + return self._value << (oc._value if isinstance(oc, JaxArray) else oc) + + def __rlshift__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) << self._value + + def __rshift__(self, oc): + return self._value >> (oc._value if isinstance(oc, JaxArray) else oc) + + def __rrshift__(self, oc): + return (oc._value if isinstance(oc, JaxArray) else oc) >> self._value + + def __round__(self, ndigits=None): + return self._value.__round__(ndigits) + + # ----------------------- # + # NumPy methods # + # ----------------------- # + + def all(self, axis=None, keepdims=False): + """Returns True if all elements evaluate to True.""" + return self.value.all(axis=axis, keepdims=keepdims) + + def any(self, axis=None, keepdims=False): + """Returns True if any of the elements of a evaluate to True.""" + return self.value.any(axis=axis, keepdims=keepdims) + + def argmax(self, axis=None): + """Return indices of the maximum values along the given axis.""" + return self.value.argmax(axis=axis) + + def argmin(self, axis=None): + """Return indices of the minimum values along the given axis.""" + return self.value.argmin(axis=axis) + + def argpartition(self, kth, axis=-1, kind='introselect', order=None): + """Returns the indices that would partition this array.""" + return self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order) + + def argsort(self, axis=-1, kind=None, order=None): + """Returns the indices that would sort this array.""" + return self.value.argsort(axis=axis, kind=kind, order=order) + + def astype(self, dtype): + """Copy of the array, cast to a specified type. + + Parameters + ---------- + dtype: str, dtype + Typecode or data-type to which the array is cast. + """ + return self.value.astype(dtype=dtype) + + def byteswap(self, inplace=False): + """Swap the bytes of the array elements + + Toggle between low-endian and big-endian data representation by + returning a byteswapped array, optionally swapped in-place. + Arrays of byte-strings are not swapped. The real and imaginary + parts of a complex number are swapped individually.""" + return self.value.byteswap(inplace=inplace) + + def choose(self, choices, mode='raise'): + """Use an index array to construct a new array from a set of choices.""" + choices = choices.value if isinstance(choices, JaxArray) else choices + return self.value.choose(choices=choices, mode=mode) + + def clip(self, min=None, max=None): + """Return an array whose values are limited to [min, max]. One of max or min must be given.""" + return self.value.clip(min=min, max=max) + + def compress(self, condition, axis=None): + """Return selected slices of this array along given axis.""" + condition = condition.value if isinstance(condition, JaxArray) else condition + return self.value.compress(condition=condition, axis=axis) + + def conj(self): + """Complex-conjugate all elements.""" + return self.value.conj() + + def conjugate(self): + """Return the complex conjugate, element-wise.""" + return self.value.conjugate() + + def copy(self): + """Return a copy of the array.""" + return self.value.copy() + + def cumprod(self, axis=None, dtype=None): + """Return the cumulative product of the elements along the given axis.""" + return self.value.cumprod(axis=axis, dtype=dtype) + + def cumsum(self, axis=None, dtype=None): + """Return the cumulative sum of the elements along the given axis.""" + return self.value.cumsum(axis=axis, dtype=dtype) + + def diagonal(self, offset=0, axis1=0, axis2=1): + """Return specified diagonals.""" + return self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2) + + def dot(self, b): + """Dot product of two arrays.""" + return self.value.dot(b.value if isinstance(b, JaxArray) else b) + + def flatten(self, order='C'): + return self.value.flatten(order=order) + + def item(self, *args): + """Copy an element of an array to a standard Python scalar and return it.""" + return self.value.item(*args) + + def max(self, axis=None, keepdims=False, *args, **kwargs): + """Return the maximum along a given axis.""" + return self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs) + + def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs): + """Returns the average of the array elements along given axis.""" + return self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs) + + def min(self, axis=None, keepdims=False, *args, **kwargs): + """Return the minimum along a given axis.""" + return self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs) + + def nonzero(self): + """Return the indices of the elements that are non-zero.""" + return self.value.nonzero() + + def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True): + """Return the product of the array elements over the given axis.""" + return self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + + def ptp(self, axis=None, keepdims=False): + """Peak to peak (maximum - minimum) value along a given axis.""" + return self.value.ptp(axis=axis, keepdims=keepdims) + + def ravel(self, order=None): + """Return a flattened array.""" + return self.value.ravel(order=order) + + def repeat(self, repeats, axis=None): + """Repeat elements of an array.""" + return self.value.repeat(repeats=repeats, axis=axis) + + def reshape(self, *shape, order='C'): + """Returns an array containing the same data with a new shape.""" + return self.value.reshape(*shape, order=order) + + def round(self, decimals=0): + """Return ``a`` with each element rounded to the given number of decimals.""" + return self.value.round(decimals=decimals) + + def searchsorted(self, v, side='left', sorter=None): + """Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted array `a` such that, if the + corresponding elements in `v` were inserted before the indices, the + order of `a` would be preserved. + + Assuming that `a` is sorted: + + ====== ============================ + `side` returned index `i` satisfies + ====== ============================ + left ``a[i-1] < v <= a[i]`` + right ``a[i-1] <= v < a[i]`` + ====== ============================ + + Parameters + ---------- + v : array_like + Values to insert into `a`. + side : {'left', 'right'}, optional + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `a`). + sorter : 1-D array_like, optional + Optional array of integer indices that sort array a into ascending + order. They are typically the result of argsort. + + Returns + ------- + indices : array of ints + Array of insertion points with the same shape as `v`. + """ + v = v.value if isinstance(v, JaxArray) else v + return self.value.searchsorted(v=v, side=side, sorter=sorter) + + def squeeze(self, axis=None): + """Remove axes of length one from ``a``.""" + return self.value.squeeze(axis=axis) + + def std(self, axis=None, dtype=None, ddof=0, keepdims=False): + """Compute the standard deviation along the specified axis. + + Returns the standard deviation, a measure of the spread of a distribution, + of the array elements. The standard deviation is computed for the + flattened array by default, otherwise over the specified axis. + + Parameters + ---------- + axis : None or int or tuple of ints, optional + Axis or axes along which the standard deviation is computed. The + default is to compute the standard deviation of the flattened array. + If this is a tuple of ints, a standard deviation is performed over + multiple axes, instead of a single axis or all the axes as before. + dtype : dtype, optional + Type to use in computing the standard deviation. For arrays of + integer type the default is float64, for arrays of float types it is + the same as the array type. + ddof : int, optional + Means Delta Degrees of Freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + By default `ddof` is zero. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `std` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + + Returns + ------- + standard_deviation : ndarray, see dtype parameter above. + If `out` is None, return a new array containing the standard deviation, + otherwise return a reference to the output array. + """ + return self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) + + def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True): + """Return the sum of the array elements over the given axis.""" + return self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + + def swapaxes(self, axis1, axis2): + """Return a view of the array with `axis1` and `axis2` interchanged.""" + return self.value.swapaxes(axis1, axis2) + + def split(self, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays as views into ``ary``. + + Parameters + ---------- + indices_or_sections : int, 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays as views into `ary`. + """ + return [JaxArray(a) for a in self.value.split(indices_or_sections, axis=axis)] + + def take(self, indices, axis=None, mode=None): + """Return an array formed from the elements of a at the given indices.""" + indices = indices.value if isinstance(indices, JaxArray) else indices + return self.value.take(indices=indices, axis=axis, mode=mode) + + def tobytes(self, order='C'): + """Construct Python bytes containing the raw data bytes in the array. + + Constructs Python bytes showing a copy of the raw contents of data memory. + The bytes object is produced in C-order by default. This behavior is + controlled by the ``order`` parameter.""" + return self.value.tobytes(order=order) + + def tolist(self): + """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. + + Return a copy of the array data as a (nested) Python list. + Data items are converted to the nearest compatible builtin Python type, via + the `~numpy.ndarray.item` function. + + If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will + not be a list at all, but a simple Python scalar. + """ + return self.value.tolist() + + def trace(self, offset=0, axis1=0, axis2=1, dtype=None): + """Return the sum along diagonals of the array.""" + return self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) + + def transpose(self, *axes): + """Returns a view of the array with axes transposed. + + For a 1-D array this has no effect, as a transposed vector is simply the + same vector. To convert a 1-D array into a 2D column vector, an additional + dimension must be added. `np.atleast2d(a).T` achieves this, as does + `a[:, np.newaxis]`. + For a 2-D array, this is a standard matrix transpose. + For an n-D array, if axes are given, their order indicates how the + axes are permuted (see Examples). If axes are not provided and + ``a.shape = (i[0], i[1], ... i[n-2], i[n-1])``, then + ``a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0])``. + + Parameters + ---------- + axes : None, tuple of ints, or `n` ints + + * None or no argument: reverses the order of the axes. + + * tuple of ints: `i` in the `j`-th place in the tuple means `a`'s + `i`-th axis becomes `a.transpose()`'s `j`-th axis. + + * `n` ints: same as an n-tuple of the same ints (this form is + intended simply as a "convenience" alternative to the tuple form) + + Returns + ------- + out : ndarray + View of `a`, with axes suitably permuted. + """ + return self.value.transpose(*axes) + + def tile(self, reps): + """Construct an array by repeating A the number of times given by reps. + + If `reps` has length ``d``, the result will have dimension of + ``max(d, A.ndim)``. + + If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new + axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, + or shape (1, 1, 3) for 3-D replication. If this is not the desired + behavior, promote `A` to d-dimensions manually before calling this + function. + + If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it. + Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as + (1, 1, 2, 2). + + Note : Although tile may be used for broadcasting, it is strongly + recommended to use numpy's broadcasting operations and functions. + + Parameters + ---------- + reps : array_like + The number of repetitions of `A` along each axis. + + Returns + ------- + c : ndarray + The tiled output array. + """ + return self.value.tile(reps.value if isinstance(reps, JaxArray) else reps) + + def var(self, axis=None, dtype=None, ddof=0, keepdims=False): + """Returns the variance of the array elements, along given axis.""" + return self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) + + def view(self, dtype=None, *args, **kwargs): + """New view of array with the same data.""" + return self.value.view(dtype=dtype, *args, **kwargs) + class TrainVar(Variable): """The pointer to specify the trainable variable. """ - __slots__ = ('_value',) + __slots__ = ('_value', '_batch_axis') - def __init__(self, value): - super(TrainVar, self).__init__(value) + def __init__(self, value, dtype=None, batch_axis: int = None): + super(TrainVar, self).__init__(value, dtype=dtype, batch_axis=batch_axis) class Parameter(Variable): """The pointer to specify the parameter. """ - __slots__ = ('_value',) + __slots__ = ('_value', '_batch_axis') - def __init__(self, value): - super(Parameter, self).__init__(value) + def __init__(self, value, dtype=None, batch_axis: int = None): + super(Parameter, self).__init__(value, dtype=dtype, batch_axis=batch_axis) register_pytree_node(JaxArray, diff --git a/brainpy/math/numpy_ops.py b/brainpy/math/numpy_ops.py index ab4fbd938..0cb773d3b 100644 --- a/brainpy/math/numpy_ops.py +++ b/brainpy/math/numpy_ops.py @@ -406,7 +406,7 @@ def lexsort(keys, axis=-1): return JaxArray(jnp.lexsort(keys, axis)) -load = wraps(jnp.histogram_bin_edges)(jnp.load) +load = wraps(jnp.load)(jnp.load) @wraps(np.save) @@ -1788,7 +1788,7 @@ def identity(n, dtype=None): @wraps(jnp.array) -def array(a, dtype=None, copy=True, order="K", ndmin=0): +def array(a, dtype=None, copy=True, order="K", ndmin=0) -> JaxArray: a = _remove_jaxarray(a) try: res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) @@ -1815,7 +1815,7 @@ def asarray(a, dtype=None, order=None): Returns ------- - out : ndarray + out : JaxArray Array interpretation of `a`. No copy is performed if the input is already an ndarray with matching dtype. """ diff --git a/brainpy/math/operators.py b/brainpy/math/operators.py deleted file mode 100644 index cda999979..000000000 --- a/brainpy/math/operators.py +++ /dev/null @@ -1,874 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union, Sequence, Callable, Optional, Dict - -import jax.numpy as jnp -from jax import jit, vmap, lax -from jax import ops as jops -from jax.abstract_arrays import ShapedArray - -from brainpy.errors import PackageMissingError, MathError -from brainpy.math import setting -from brainpy.math.jaxarray import JaxArray -from brainpy.math.numpy_ops import as_device_array, _remove_jaxarray - -try: - import brainpylib -except ModuleNotFoundError: - brainpylib = None - -__all__ = [ - # pre-to-post - 'pre2post_sum', - 'pre2post_prod', - 'pre2post_max', - 'pre2post_min', - 'pre2post_mean', - - # pre-to-syn - 'pre2syn', - - # syn-to-post - 'syn2post_sum', 'syn2post', - 'syn2post_prod', - 'syn2post_max', - 'syn2post_min', - 'syn2post_mean', - 'syn2post_softmax', - - # pre-to-post event operator - 'pre2post_event_sum', - 'pre2post_event_prod', - - # others - 'sparse_matmul', - 'segment_sum', - 'segment_prod', - 'segment_max', - 'segment_min', - - # numba operators - 'register_op' -] - -_BRAINPYLIB_MINIMAL_VERSION = '0.0.5' - -_pre2post = vmap(lambda pre_ids, pre_vs: pre_vs[pre_ids].sum(), in_axes=(0, None)) -_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None)) -_jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3)) -_jit_seg_prod = jit(jops.segment_prod, static_argnums=(2, 3)) -_jit_seg_max = jit(jops.segment_max, static_argnums=(2, 3)) -_jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3)) - - -def _check_brainpylib(ops_name): - if brainpylib is not None: - if brainpylib.__version__ < _BRAINPYLIB_MINIMAL_VERSION: - raise PackageMissingError( - f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n' - f'Please install it through:\n\n' - f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION} -U' - ) - else: - raise PackageMissingError( - f'"brainpylib" must be installed when the user ' - f'wants to use "{ops_name}" operator. \n' - f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n' - f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}' - ) - - -def register_op( - op_name: str, - cpu_func: Callable, - gpu_func: Callable = None, - out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]] = None, - apply_cpu_func_to_gpu: bool = False -): - """ - Converting the numba-jitted function in a Jax/XLA compatible primitive. - - Parameters - ---------- - op_name: str - Name of the operators. - cpu_func: Callble - A callable numba-jitted function or pure function (can be lambda function) running on CPU. - gpu_func: Callable, default = None - A callable cuda-jitted kernel running on GPU. - out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None - Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or - a sequence of `ShapedArray`. If it is a function, it takes as input the argument - shapes and dtypes and should return correct output shapes of `ShapedArray`. - apply_cpu_func_to_gpu: bool, default = False - True when gpu_func is implemented on CPU and other logics(data transfer) is implemented on GPU. - - Returns - ------- - A jitable JAX function. - """ - _check_brainpylib(register_op.__name__) - f = brainpylib.register_op(op_name, cpu_func, gpu_func, out_shapes, apply_cpu_func_to_gpu) - - def fixed_op(*inputs): - inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs]) - return f(*inputs) - - return fixed_op - - -def pre2post_event_sum(events, pre2post, post_num, values=1.): - """The pre-to-post synaptic computation with event-driven summation. - - When ``values`` is a scalar, this function is equivalent to - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] += values - - When ``values`` is a vector (with the length of ``len(post_ids)``), - this function is equivalent to - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] += values[j] - - - Parameters - ---------- - events: JaxArray, jax.numpy.ndarray, Variable - The events, must be bool. - pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray - A tuple contains the connection information of pre-to-post. - post_num: int - The number of post-synaptic group. - values: float, JaxArray, jax.numpy.ndarray - The value to make summation. - - Returns - ------- - out: JaxArray, jax.numpy.ndarray - A tensor with the shape of ``post_num``. - """ - _check_brainpylib(pre2post_event_sum.__name__) - indices, idnptr = pre2post - events = as_device_array(events) - indices = as_device_array(indices) - idnptr = as_device_array(idnptr) - values = as_device_array(values) - return brainpylib.event_sum(events, (indices, idnptr), post_num, values) - - -def pre2post_event_prod(events, pre2post, post_num, values=1.): - """The pre-to-post synaptic computation with event-driven production. - - When ``values`` is a scalar, this function is equivalent to - - .. highlight:: python - .. code-block:: python - - post_val = np.ones(post_num) - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] *= values - - When ``values`` is a vector (with the length of ``len(post_ids)``), - this function is equivalent to - - .. highlight:: python - .. code-block:: python - - post_val = np.ones(post_num) - - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] *= values[j] - - - Parameters - ---------- - events: JaxArray, jax.numpy.ndarray, Variable - The events, must be bool. - pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray - A tuple contains the connection information of pre-to-post. - post_num: int - The number of post-synaptic group. - values: float, JaxArray, jax.numpy.ndarray - The value to make summation. - - Returns - ------- - out: JaxArray, jax.numpy.ndarray - A tensor with the shape of ``post_num``. - """ - _check_brainpylib(pre2post_event_prod.__name__) - indices, idnptr = pre2post - events = as_device_array(events) - indices = as_device_array(indices) - idnptr = as_device_array(idnptr) - values = as_device_array(values) - return brainpylib.event_prod(events, (indices, idnptr), post_num, values) - - -def _raise_pre_ids_is_none(pre_ids): - if pre_ids is None: - raise MathError(f'pre2post synaptic computation needs "pre_ids" ' - f'when providing heterogeneous "pre_values" ' - f'(brainpy.math.ndim(pre_values) != 0).') - - -def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic summation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] += pre_values[pre_ids[i]] - - Parameters - ---------- - pre_values: float, jax.numpy.ndarray, JaxArray, Variable - The pre-synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - pre_ids: optional, jax.numpy.ndarray, JaxArray - The connected pre-synaptic neuron ids. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].add(pre_values) - - -def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic production. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] *= pre_values[pre_ids[i]] - - Parameters - ---------- - pre_values: float, jax.numpy.ndarray, JaxArray, Variable - The pre-synaptic values. - pre_ids: jax.numpy.ndarray, JaxArray - The connected pre-synaptic neuron ids. - post_ids: jax.numpy.ndarray, JaxArray - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].multiply(pre_values) - - -def pre2post_min(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic minimization. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]]) - - Parameters - ---------- - pre_values: float, jax.numpy.ndarray, JaxArray - The pre-synaptic values. - pre_ids: jax.numpy.ndarray, JaxArray - The connected pre-synaptic neuron ids. - post_ids: jax.numpy.ndarray, JaxArray - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].min(pre_values) - - -def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic maximization. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]]) - - Parameters - ---------- - pre_values: float, jax.numpy.ndarray, JaxArray, Variable - The pre-synaptic values. - pre_ids: jax.numpy.ndarray, JaxArray - The connected pre-synaptic neuron ids. - post_ids: jax.numpy.ndarray, JaxArray - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].max(pre_values) - - -def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic mean computation. - - Parameters - ---------- - pre_values: float, jax.numpy.ndarray, JaxArray, Variable - The pre-synaptic values. - pre_ids: jax.numpy.ndarray, JaxArray - The connected pre-synaptic neuron ids. - post_ids: jax.numpy.ndarray, JaxArray - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) - if jnp.ndim(pre_values) == 0: - # return out.at[post_ids].set(pre_values) - return out.at[jnp.unique(post_ids)].set(pre_values) - else: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) - pre_values = pre2syn(pre_values, pre_ids) - return syn2post_mean(pre_values, post_ids, post_num) - - -def pre2syn(pre_values, pre_ids): - """The pre-to-syn computation. - - Change the pre-synaptic data to the data with the dimension of synapses. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - syn_val = np.zeros(len(pre_ids)) - for syn_i, pre_i in enumerate(pre_ids): - syn_val[i] = pre_values[pre_i] - - Parameters - ---------- - pre_values: float, jax.numpy.ndarray, JaxArray, Variable - The pre-synaptic value. - pre_ids: jax.numpy.ndarray, JaxArray - The pre-synaptic neuron index. - - Returns - ------- - syn_val: jax.numpy.ndarray, JaxArray - The synaptic value. - """ - pre_values = as_device_array(pre_values) - pre_ids = as_device_array(pre_ids) - if jnp.ndim(pre_values) == 0: - return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values - else: - return _pre2syn(pre_ids, pre_values) - - -def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post summation computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] += syn_values[syn_i] - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. - post_num: int - The number of the post-synaptic neurons. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - - -syn2post = syn2post_sum - - -def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post product computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] *= syn_values[syn_i] - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) - - -def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post maximum computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) - - -def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post minimization computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) - - -def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post mean computation. - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) - return jnp.nan_to_num(nominator / denominator) - - -def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post softmax computation. - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) - syn_values = syn_values - syn_maxs[post_ids] - syn_values = jnp.exp(syn_values) - normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - softmax = syn_values / normalizers[post_ids] - return jnp.nan_to_num(softmax) - - -def _matmul_with_left_sparse( - sparse: Dict, - dense: Union[JaxArray, jnp.ndarray] -): - r"""Matrix multiplication with sparse matrix on the left. - - .. math:: - - Y = M_{\mathrm{sparse}} @ M_{\mathrm{dense}} - - Parameters - ---------- - sparse: dict - The sparse matrix with shape of :math:`(N, M)`. - dense: JaxArray, jnp.ndarray - The dense matrix with the shape of :math:`(M, K)`. - - Returns - ------- - matrix - A tensor the the shape of :math:`(N, K)`. - """ - assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' - values = sparse['data'] - rows, cols = sparse['index'] - shape = sparse['shape'] - if len(shape) != 2: - raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') - values = _remove_jaxarray(values) - rows = _remove_jaxarray(rows) - cols = _remove_jaxarray(cols) - dense = _remove_jaxarray(dense) - B = dense.take(cols, axis=0) - if B.ndim == 2: - prod = B * jnp.reshape(values, (-1, 1)) - else: - prod = B * values - return jops.segment_sum(prod, rows, shape[0]) - - -def _matmul_with_right_sparse( - dense: Union[JaxArray, jnp.ndarray], - sparse: Dict -): - r"""Matrix multiplication with sparse matrix on the left. - - .. math:: - - Y = M_{\mathrm{dense}} @ M_{\mathrm{sparse}} - - Parameters - ---------- - dense: JaxArray, jnp.ndarray - The dense matrix with the shape of :math:`(N, M)`. - sparse: dict - The sparse matrix with shape of :math:`(M, K)`. - - Returns - ------- - matrix - A tensor the the shape of :math:`(N, K)`. - """ - assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' - values = sparse['data'] - rows, cols = sparse['index'] - shape = sparse['shape'] - if len(shape) != 2: - raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') - values = _remove_jaxarray(values) - rows = _remove_jaxarray(rows) - cols = _remove_jaxarray(cols) - dense = _remove_jaxarray(dense) - if dense.ndim == 2: - A = dense[:, rows] - prod = (A * values).T - res = jops.segment_sum(prod, cols, shape[1]).T - else: - prod = dense[rows] * values - res = jops.segment_sum(prod, cols, shape[1]) - return res - - -def sparse_matmul(A, B): - r"""Sparse matrix multiplication. - - .. math:: - - y = A @ B - - where :math:`A` or :math:`B` is a sparse matrix. - :math:`A` and :math:`B` cannot be both sparse. - - Examples - -------- - - >>> import brainpy.math as bm - - 1. when the left matrix :math:`A` is a sparse matrix with the shape of :math:`(N, M)`, - - >>> # A is a sparse matrix (3, 4): - >>> # [[0, 2, 0, 4], - >>> # [1, 0, 0, 0], - >>> # [0, 3, 0, 2]] - >>> values = bm.asarray([2, 4, 1, 3, 2]) - >>> rows = bm.asarray([0, 0, 1, 2, 2]) - >>> cols = bm.asarray([1, 3, 0, 1, 3]) - >>> sparse = {'data': values, 'index': (rows, cols), 'shape': (3, 4)} - >>> B = bm.arange(4) - >>> bm.sparse_matmul(sparse, B) - JaxArray([14, 0, 9], dtype=int32) - >>> B = bm.random.rand(4, 3) - >>> bm.sparse_matmul(sparse, B) - JaxArray([[3.8331761 , 1.3708692 , 4.510223 ], - [0.9960836 , 0.37550318, 0.7370341 ], - [2.3700516 , 0.7574289 , 4.1124535 ]], dtype=float32) - - 2. when the right matrix :math:`B` is a sparse matrix with the shape of :math:`(M, K)`, - - >>> A = bm.arange(3) - >>> bm.sparse_matmul(A, sparse) - JaxArray([1, 6, 0, 4], dtype=int32) - >>> A = bm.random.rand(2, 3) - >>> bm.sparse_matmul(A, sparse) - JaxArray([[0.438388 , 1.4346815 , 0. , 2.361964 ], - [0.9171978 , 1.1214957 , 0. , 0.90534496]], dtype=float32) - - Parameters - ---------- - A: tensor, sequence - The dense or sparse matrix with the shape of :math:`(N, M)`. - B: tensor, sequence - The dense or sparse matrix with the shape of :math:`(M, K)`. - - Returns - ------- - results: JaxArray, jnp.ndarray - The tensor with the shape of :math:`(N, K)`. - """ - if isinstance(A, dict): - if not isinstance(B, (JaxArray, jnp.ndarray)): - raise ValueError('A and B cannot be both sparse. \n' - f'A:\n{A}\n' - f'B:\n{B}') - return _matmul_with_left_sparse(A, B) - else: - if not isinstance(B, dict): - raise ValueError('A and B cannot be both dense. \n' - f'A:\n{A}\n' - f'B:\n{B}') - return _matmul_with_right_sparse(A, B) - - -def segment_sum(data: Union[JaxArray, jnp.ndarray], - segment_ids: Union[JaxArray, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: - return JaxArray(jops.segment_sum(data.value if isinstance(data, JaxArray) else data, - segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, mode)) - - -def segment_prod(data: Union[JaxArray, jnp.ndarray], - segment_ids: Union[JaxArray, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: - return JaxArray(jops.segment_prod(data.value if isinstance(data, JaxArray) else data, - segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, mode)) - - -def segment_max(data: Union[JaxArray, jnp.ndarray], - segment_ids: Union[JaxArray, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: - return JaxArray(jops.segment_max(data.value if isinstance(data, JaxArray) else data, - segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, mode)) - - -def segment_min(data: Union[JaxArray, jnp.ndarray], - segment_ids: Union[JaxArray, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: - return JaxArray(jops.segment_min(data.value if isinstance(data, JaxArray) else data, - segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, mode)) diff --git a/brainpy/math/operators/__init__.py b/brainpy/math/operators/__init__.py new file mode 100644 index 000000000..517a0bc95 --- /dev/null +++ b/brainpy/math/operators/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + + +from . import multiplication +from . import op_register +from . import pre2syn as pre2syn_module +from . import pre2post as pre2post_module +from . import syn2post as syn2post_module +from . import wrap_jax +from . import spikegrad + +__all__ = multiplication.__all__ + op_register.__all__ +__all__ += pre2syn_module.__all__ + pre2post_module.__all__ + syn2post_module.__all__ +__all__ += wrap_jax.__all__ + spikegrad.__all__ + + +from .multiplication import * +from .op_register import * +from .pre2syn import * +from .pre2post import * +from .syn2post import * +from .wrap_jax import * +from .spikegrad import * diff --git a/brainpy/math/operators/multiplication.py b/brainpy/math/operators/multiplication.py new file mode 100644 index 000000000..af8dc9cf0 --- /dev/null +++ b/brainpy/math/operators/multiplication.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- + + +from typing import Union, Dict + +import jax.numpy as jnp +from jax import ops as jops + +from brainpy.math.jaxarray import JaxArray +from brainpy.math.numpy_ops import _remove_jaxarray + +__all__ = [ + 'sparse_matmul' +] + + +def _matmul_with_left_sparse( + sparse: Dict, + dense: Union[JaxArray, jnp.ndarray] +): + r"""Matrix multiplication with sparse matrix on the left. + + .. math:: + + Y = M_{\mathrm{sparse}} @ M_{\mathrm{dense}} + + Parameters + ---------- + sparse: dict + The sparse matrix with shape of :math:`(N, M)`. + dense: JaxArray, jnp.ndarray + The dense matrix with the shape of :math:`(M, K)`. + + Returns + ------- + matrix + A tensor the the shape of :math:`(N, K)`. + """ + assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' + values = sparse['data'] + rows, cols = sparse['index'] + shape = sparse['shape'] + if len(shape) != 2: + raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') + values = _remove_jaxarray(values) + rows = _remove_jaxarray(rows) + cols = _remove_jaxarray(cols) + dense = _remove_jaxarray(dense) + B = dense.take(cols, axis=0) + if B.ndim == 2: + prod = B * jnp.reshape(values, (-1, 1)) + else: + prod = B * values + return jops.segment_sum(prod, rows, shape[0]) + + +def _matmul_with_right_sparse( + dense: Union[JaxArray, jnp.ndarray], + sparse: Dict +): + r"""Matrix multiplication with sparse matrix on the left. + + .. math:: + + Y = M_{\mathrm{dense}} @ M_{\mathrm{sparse}} + + Parameters + ---------- + dense: JaxArray, jnp.ndarray + The dense matrix with the shape of :math:`(N, M)`. + sparse: dict + The sparse matrix with shape of :math:`(M, K)`. + + Returns + ------- + matrix + A tensor the the shape of :math:`(N, K)`. + """ + assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' + values = sparse['data'] + rows, cols = sparse['index'] + shape = sparse['shape'] + if len(shape) != 2: + raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') + values = _remove_jaxarray(values) + rows = _remove_jaxarray(rows) + cols = _remove_jaxarray(cols) + dense = _remove_jaxarray(dense) + if dense.ndim == 2: + A = dense[:, rows] + prod = (A * values).T + res = jops.segment_sum(prod, cols, shape[1]).T + else: + prod = dense[rows] * values + res = jops.segment_sum(prod, cols, shape[1]) + return res + + +def sparse_matmul(A, B): + r"""Sparse matrix multiplication. + + .. math:: + + y = A @ B + + where :math:`A` or :math:`B` is a sparse matrix. + :math:`A` and :math:`B` cannot be both sparse. + + Examples + -------- + + >>> import brainpy.math as bm + + 1. when the left matrix :math:`A` is a sparse matrix with the shape of :math:`(N, M)`, + + >>> # A is a sparse matrix (3, 4): + >>> # [[0, 2, 0, 4], + >>> # [1, 0, 0, 0], + >>> # [0, 3, 0, 2]] + >>> values = bm.asarray([2, 4, 1, 3, 2]) + >>> rows = bm.asarray([0, 0, 1, 2, 2]) + >>> cols = bm.asarray([1, 3, 0, 1, 3]) + >>> sparse = {'data': values, 'index': (rows, cols), 'shape': (3, 4)} + >>> B = bm.arange(4) + >>> bm.sparse_matmul(sparse, B) + JaxArray([14, 0, 9], dtype=int32) + >>> B = bm.random.rand(4, 3) + >>> bm.sparse_matmul(sparse, B) + JaxArray([[3.8331761 , 1.3708692 , 4.510223 ], + [0.9960836 , 0.37550318, 0.7370341 ], + [2.3700516 , 0.7574289 , 4.1124535 ]], dtype=float32) + + 2. when the right matrix :math:`B` is a sparse matrix with the shape of :math:`(M, K)`, + + >>> A = bm.arange(3) + >>> bm.sparse_matmul(A, sparse) + JaxArray([1, 6, 0, 4], dtype=int32) + >>> A = bm.random.rand(2, 3) + >>> bm.sparse_matmul(A, sparse) + JaxArray([[0.438388 , 1.4346815 , 0. , 2.361964 ], + [0.9171978 , 1.1214957 , 0. , 0.90534496]], dtype=float32) + + Parameters + ---------- + A: tensor, sequence + The dense or sparse matrix with the shape of :math:`(N, M)`. + B: tensor, sequence + The dense or sparse matrix with the shape of :math:`(M, K)`. + + Returns + ------- + results: JaxArray, jnp.ndarray + The tensor with the shape of :math:`(N, K)`. + """ + if isinstance(A, dict): + if not isinstance(B, (JaxArray, jnp.ndarray)): + raise ValueError('A and B cannot be both sparse. \n' + f'A:\n{A}\n' + f'B:\n{B}') + return _matmul_with_left_sparse(A, B) + else: + if not isinstance(B, dict): + raise ValueError('A and B cannot be both dense. \n' + f'A:\n{A}\n' + f'B:\n{B}') + return _matmul_with_right_sparse(A, B) diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py new file mode 100644 index 000000000..3132004cf --- /dev/null +++ b/brainpy/math/operators/op_register.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +from typing import Union, Sequence, Callable + +from jax.abstract_arrays import ShapedArray + +from brainpy.math.jaxarray import JaxArray +from .utils import _check_brainpylib + +try: + import brainpylib +except ModuleNotFoundError: + brainpylib = None + +__all__ = [ + 'register_op' +] + + +def register_op( + op_name: str, + cpu_func: Callable, + gpu_func: Callable = None, + out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]] = None, + apply_cpu_func_to_gpu: bool = False +): + """ + Converting the numba-jitted function in a Jax/XLA compatible primitive. + + Parameters + ---------- + op_name: str + Name of the operators. + cpu_func: Callble + A callable numba-jitted function or pure function (can be lambda function) running on CPU. + gpu_func: Callable, default = None + A callable cuda-jitted kernel running on GPU. + out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None + Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or + a sequence of `ShapedArray`. If it is a function, it takes as input the argument + shapes and dtypes and should return correct output shapes of `ShapedArray`. + apply_cpu_func_to_gpu: bool, default = False + True when gpu_func is implemented on CPU and other logics(data transfer) is implemented on GPU. + + Returns + ------- + A jitable JAX function. + """ + _check_brainpylib(register_op.__name__) + f = brainpylib.register_op(op_name, cpu_func, gpu_func, out_shapes, apply_cpu_func_to_gpu) + + def fixed_op(*inputs): + inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs]) + return f(*inputs) + + return fixed_op diff --git a/brainpy/math/operators/pre2post.py b/brainpy/math/operators/pre2post.py new file mode 100644 index 000000000..66dc8f06d --- /dev/null +++ b/brainpy/math/operators/pre2post.py @@ -0,0 +1,492 @@ +# -*- coding: utf-8 -*- + +import jax.numpy as jnp +from typing import Union, Tuple +from jax import vmap, jit +from jax.lax import cond, scan, fori_loop +from functools import partial + +from brainpy.errors import MathError +from brainpy.math.numpy_ops import as_device_array +from brainpy.math.jaxarray import JaxArray +from .utils import _check_brainpylib +from .pre2syn import pre2syn +from .syn2post import syn2post_mean +from brainpy.types import Tensor + +try: + import brainpylib +except ModuleNotFoundError: + brainpylib = None + +__all__ = [ + # pre-to-post + 'pre2post_sum', + 'pre2post_prod', + 'pre2post_max', + 'pre2post_min', + 'pre2post_mean', + + # pre-to-post event operator + 'pre2post_event_sum', + 'pre2post_event_prod', + +] + + +def _raise_pre_ids_is_none(pre_ids): + if pre_ids is None: + raise MathError(f'pre2post synaptic computation needs "pre_ids" ' + f'when providing heterogeneous "pre_values" ' + f'(brainpy.math.ndim(pre_values) != 0).') + + +def pre2post_event_sum(events: Tensor, + pre2post: Tuple[Tensor, Tensor], + post_num: int, + values: Union[float, Tensor] = 1.): + """The pre-to-post synaptic computation with event-driven summation. + + When ``values`` is a scalar, this function is equivalent to + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[i]] += values + + When ``values`` is a vector (with the length of ``len(post_ids)``), + this function is equivalent to + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[i]] += values[j] + + + Parameters + ---------- + events: Tensor + The events, must be bool. + pre2post: tuple of Tensor, tuple of Tensor + A tuple contains the connection information of pre-to-post. + post_num: int + The number of post-synaptic group. + values: float, Tensor + The value to make summation. + + Returns + ------- + out: JaxArray, jax.numpy.ndarray + A tensor with the shape of ``post_num``. + """ + _check_brainpylib(pre2post_event_sum.__name__) + indices, idnptr = pre2post + events = as_device_array(events) + indices = as_device_array(indices) + idnptr = as_device_array(idnptr) + values = as_device_array(values) + return brainpylib.event_sum(events, (indices, idnptr), post_num, values) + + +def pre2post_event_sum2(events: Tensor, + pre2post: Tuple[Tensor, Tensor], + post_num: int, + values: Union[float, Tensor] = 1.): + """The pre-to-post synaptic computation with event-driven summation. + + When ``values`` is a scalar, this function is equivalent to + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[i]] += values + + When ``values`` is a vector (with the length of ``len(post_ids)``), + this function is equivalent to + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[i]] += values[j] + + + Parameters + ---------- + events: Tensor + The events, must be bool. + pre2post: tuple of Tensor, tuple of Tensor + A tuple contains the connection information of pre-to-post. + post_num: int + The number of post-synaptic group. + values: float, Tensor + The value to make summation. + + Returns + ------- + out: JaxArray, jax.numpy.ndarray + A tensor with the shape of ``post_num``. + """ + _check_brainpylib(pre2post_event_sum.__name__) + indices, idnptr = pre2post + events = as_device_array(events) + indices = as_device_array(indices) + idnptr = as_device_array(idnptr) + values = as_device_array(values) + return brainpylib.event_sum2(events, (indices, idnptr), post_num, values) + + +def pre2post_event_prod(events, pre2post, post_num, values=1.): + """The pre-to-post synaptic computation with event-driven production. + + When ``values`` is a scalar, this function is equivalent to + + .. highlight:: python + .. code-block:: python + + post_val = np.ones(post_num) + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[i]] *= values + + When ``values`` is a vector (with the length of ``len(post_ids)``), + this function is equivalent to + + .. highlight:: python + .. code-block:: python + + post_val = np.ones(post_num) + + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[i]] *= values[j] + + + Parameters + ---------- + events: JaxArray, jax.numpy.ndarray, Variable + The events, must be bool. + pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray + A tuple contains the connection information of pre-to-post. + post_num: int + The number of post-synaptic group. + values: float, JaxArray, jax.numpy.ndarray + The value to make summation. + + Returns + ------- + out: JaxArray, jax.numpy.ndarray + A tensor with the shape of ``post_num``. + """ + _check_brainpylib(pre2post_event_prod.__name__) + indices, idnptr = pre2post + events = as_device_array(events) + indices = as_device_array(indices) + idnptr = as_device_array(idnptr) + values = as_device_array(values) + return brainpylib.event_prod(events, (indices, idnptr), post_num, values) + + +def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): + """The pre-to-post synaptic summation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] += pre_values[pre_ids[i]] + + Parameters + ---------- + pre_values: float, jax.numpy.ndarray, JaxArray, Variable + The pre-synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + pre_ids: optional, jax.numpy.ndarray, JaxArray + The connected pre-synaptic neuron ids. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_device_array(pre_values) + post_ids = as_device_array(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_device_array(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].add(pre_values) + + +def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None): + """The pre-to-post synaptic production. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] *= pre_values[pre_ids[i]] + + Parameters + ---------- + pre_values: float, jax.numpy.ndarray, JaxArray, Variable + The pre-synaptic values. + pre_ids: jax.numpy.ndarray, JaxArray + The connected pre-synaptic neuron ids. + post_ids: jax.numpy.ndarray, JaxArray + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_device_array(pre_values) + post_ids = as_device_array(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_device_array(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].multiply(pre_values) + + +def pre2post_min(pre_values, post_num, post_ids, pre_ids=None): + """The pre-to-post synaptic minimization. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]]) + + Parameters + ---------- + pre_values: float, jax.numpy.ndarray, JaxArray + The pre-synaptic values. + pre_ids: jax.numpy.ndarray, JaxArray + The connected pre-synaptic neuron ids. + post_ids: jax.numpy.ndarray, JaxArray + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_device_array(pre_values) + post_ids = as_device_array(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_device_array(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].min(pre_values) + + +def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): + """The pre-to-post synaptic maximization. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]]) + + Parameters + ---------- + pre_values: float, jax.numpy.ndarray, JaxArray, Variable + The pre-synaptic values. + pre_ids: jax.numpy.ndarray, JaxArray + The connected pre-synaptic neuron ids. + post_ids: jax.numpy.ndarray, JaxArray + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_device_array(pre_values) + post_ids = as_device_array(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_device_array(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].max(pre_values) + + +def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): + """The pre-to-post synaptic mean computation. + + Parameters + ---------- + pre_values: float, jax.numpy.ndarray, JaxArray, Variable + The pre-synaptic values. + pre_ids: jax.numpy.ndarray, JaxArray + The connected pre-synaptic neuron ids. + post_ids: jax.numpy.ndarray, JaxArray + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_device_array(pre_values) + post_ids = as_device_array(post_ids) + if jnp.ndim(pre_values) == 0: + # return out.at[post_ids].set(pre_values) + return out.at[jnp.unique(post_ids)].set(pre_values) + else: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_device_array(pre_ids) + pre_values = pre2syn(pre_values, pre_ids) + return syn2post_mean(pre_values, post_ids, post_num) + + +def pre2post_matmul(event, conn): + event = event.value if isinstance(event, JaxArray) else event + Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] + Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] + if jnp.ndim(event) != 1: + raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') + if jnp.ndim(Cl) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') + if jnp.ndim(Cr) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') + + f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum(), in_axes=(0, None)) + ii = jnp.arange(Cl.shape[0]) + f1 = vmap(lambda j: f0(ii, j).sum(), in_axes=(None, 0)) + return f1(jnp.arange(Cr.shape[1])) + + +def pre2post_matmul2(event, conn): + event = event.value if isinstance(event, JaxArray) else event + Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] + Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] + if jnp.ndim(event) != 1: + raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') + if jnp.ndim(Cl) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') + if jnp.ndim(Cr) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') + f1 = vmap(lambda j: (event * (Cl * Cr[:, j]).sum(1)).sum()) + return f1(jnp.arange(Cr.shape[1])) + + +def pre2post_matmul_mask(event, conn, mask): + event = event.value if isinstance(event, JaxArray) else event + Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] + Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] + Ml = mask[0].value if isinstance(mask[0], JaxArray) else mask[0] + Mr = mask[1].value if isinstance(mask[1], JaxArray) else mask[1] + if jnp.ndim(event) != 1: + raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') + if jnp.ndim(Cl) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') + if jnp.ndim(Cr) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') + if jnp.ndim(Mr) != 2: + raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Mr)}') + if jnp.ndim(Ml) != 2: + raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Ml)}') + + f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum() * (Ml[i] * Mr[:, j]).sum(), in_axes=(0, None)) + f1 = jit(vmap(lambda ii, j: f0(ii, j).sum(), in_axes=(None, 0))) + return f1(jnp.arange(Cl.shape[0]), jnp.arange(Cr.shape[1])) + + +def pre2post_matmul_mask2(event, conn, mask): + event = event.value if isinstance(event, JaxArray) else event + Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] + Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] + Ml = mask[0].value if isinstance(mask[0], JaxArray) else mask[0] + Mr = mask[1].value if isinstance(mask[1], JaxArray) else mask[1] + if jnp.ndim(event) != 1: + raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') + if jnp.ndim(Cl) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') + if jnp.ndim(Cr) != 2: + raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') + if jnp.ndim(Mr) != 2: + raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Mr)}') + if jnp.ndim(Ml) != 2: + raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Ml)}') + + # f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum() * (Ml[i] * Mr[:, j]).sum(), in_axes=(0, None)) + @partial(vmap, in_axes=(0, None)) + def f0(i, j): + return cond(event[i] > 0., + lambda _: event[i] * jnp.sum(Cl[i] * Cr[:, j]) * jnp.sum(Ml[i] * Mr[:, j]), + lambda _: 0., + None) + # fori_loop(0, + # Cr.shape[1], + # lambda x: f0(x[0], x[1]), + # ) + + ii = jnp.arange(Cl.shape[0]) + jj = jnp.arange(Cr.shape[1]) + + def body(_, j): + r = f0(ii, j).sum() + return 0, r + + _, out = scan(body, 0, jj) + + # f1 = jit(vmap(lambda ii, j: f0(ii, j).sum(), in_axes=(None, 0))) + return out + diff --git a/brainpy/math/operators/pre2syn.py b/brainpy/math/operators/pre2syn.py new file mode 100644 index 000000000..b60551d5b --- /dev/null +++ b/brainpy/math/operators/pre2syn.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +import jax.numpy as jnp +from jax import vmap + +from brainpy.math.numpy_ops import as_device_array + +__all__ = [ + 'pre2syn' +] + + +_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None)) + + +def pre2syn(pre_values, pre_ids): + """The pre-to-syn computation. + + Change the pre-synaptic data to the data with the dimension of synapses. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + syn_val = np.zeros(len(pre_ids)) + for syn_i, pre_i in enumerate(pre_ids): + syn_val[i] = pre_values[pre_i] + + Parameters + ---------- + pre_values: float, jax.numpy.ndarray, JaxArray, Variable + The pre-synaptic value. + pre_ids: jax.numpy.ndarray, JaxArray + The pre-synaptic neuron index. + + Returns + ------- + syn_val: jax.numpy.ndarray, JaxArray + The synaptic value. + """ + pre_values = as_device_array(pre_values) + pre_ids = as_device_array(pre_ids) + if jnp.ndim(pre_values) == 0: + return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values + else: + return _pre2syn(pre_ids, pre_values) diff --git a/brainpy/math/operators/spikegrad.py b/brainpy/math/operators/spikegrad.py new file mode 100644 index 000000000..923473920 --- /dev/null +++ b/brainpy/math/operators/spikegrad.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- + + +from jax import custom_gradient, custom_jvp + +from brainpy.math import numpy_ops as bm +from brainpy.math.jaxarray import JaxArray +from brainpy.types import Tensor + +from brainpy.math.setting import dftype + +__all__ = [ + 'spike_with_sigmoid_grad', + 'spike2_with_sigmoid_grad', + 'spike_with_relu_grad', + 'spike2_with_relu_grad', + 'step_pwl' +] + + +def _consistent_type(target, compare): + return target.value if not isinstance(compare, JaxArray) else target + + +@custom_gradient +def spike_with_sigmoid_grad(x: Tensor, scale: float = None): + """Spike function with the sigmoid surrogate gradient. + + Parameters + ---------- + x: Tensor + The input data. + scale: float + The scaling factor. + """ + z = bm.asarray(x >= 0, dtype=dftype()) + + def grad(dE_dz): + _scale = scale + if scale is None: + _scale = 100. + dE_dx = dE_dz / (_scale * bm.abs(x) + 1.0) ** 2 + if scale is None: + return (_consistent_type(dE_dx, x),) + else: + dscale = bm.zeros_like(_scale) + return (_consistent_type(dE_dx, x), + _consistent_type(dscale, scale)) + + return z, grad + + +@custom_gradient +def spike2_with_sigmoid_grad(x_new: Tensor, x_old: Tensor, scale: float = None): + """Spike function with the sigmoid surrogate gradient. + + Parameters + ---------- + x_new: Tensor + The input data. + x_old: Tensor + The input data. + scale: optional, float + The scaling factor. + """ + x_new_comp = x_new >= 0 + x_old_comp = x_old < 0 + z = bm.asarray(bm.logical_and(x_new_comp, x_old_comp), dtype=dftype()) + + def grad(dE_dz): + _scale = scale + if scale is None: + _scale = 100. + dx_new = (dE_dz / (_scale * bm.abs(x_new) + 1.0) ** 2) * bm.asarray(x_old_comp, dtype=dftype()) + dx_old = -(dE_dz / (_scale * bm.abs(x_old) + 1.0) ** 2) * bm.asarray(x_new_comp, dtype=dftype()) + if scale is None: + return (_consistent_type(dx_new, x_new), + _consistent_type(dx_old, x_old)) + else: + dscale = bm.zeros_like(_scale) + return (_consistent_type(dx_new, x_new), + _consistent_type(dx_old, x_old), + _consistent_type(dscale, scale)) + + return z, grad + + +@custom_gradient +def spike_with_relu_grad(x: Tensor, scale: float = None): + """Spike function with the relu surrogate gradient. + + Parameters + ---------- + x: Tensor + The input data. + scale: float + The scaling factor. + """ + z = bm.asarray(x >= 0., dtype=dftype()) + + def grad(dE_dz): + _scale = scale + if scale is None: _scale = 0.3 + dE_dx = dE_dz * bm.maximum(1 - bm.abs(x), 0) * _scale + if scale is None: + return (_consistent_type(dE_dx, x),) + else: + dscale = bm.zeros_like(_scale) + return (_consistent_type(dE_dx, x), + _consistent_type(dscale, _scale)) + + return z, grad + + +@custom_gradient +def spike2_with_relu_grad(x_new: Tensor, x_old: Tensor, scale: float = 10.): + """Spike function with the relu surrogate gradient. + + Parameters + ---------- + x_new: Tensor + The input data. + x_old: Tensor + The input data. + scale: float + The scaling factor. + """ + x_new_comp = x_new >= 0 + x_old_comp = x_old < 0 + z = bm.asarray(bm.logical_and(x_new_comp, x_old_comp), dtype=dftype()) + + def grad(dE_dz): + _scale = scale + if scale is None: + _scale = 0.3 + dx_new = (dE_dz * bm.maximum(1 - bm.abs(x_new), 0) * _scale) * bm.asarray(x_old_comp, dtype=dftype()) + dx_old = -(dE_dz * bm.maximum(1 - bm.abs(x_old), 0) * _scale) * bm.asarray(x_new_comp, dtype=dftype()) + if scale is None: + return (_consistent_type(dx_new, x_new), + _consistent_type(dx_old, x_old)) + else: + dscale = bm.zeros_like(_scale) + return (_consistent_type(dx_new, x_new), + _consistent_type(dx_old, x_old), + _consistent_type(dscale, scale)) + + return z, grad + + +@custom_jvp +def step_pwl(x, threshold, window=0.5, max_spikes_per_dt: int = bm.inf): + """ + Heaviside step function with piece-wise linear derivative to use as spike-generation surrogate + + Args: + x (float): Input value + threshold (float): Firing threshold + window (float): Learning window around threshold. Default: 0.5 + max_spikes_per_dt (int): Maximum number of spikes that may be produced each dt. Default: ``np.inf``, do not clamp spikes + + Returns: + float: Number of output events for each input value + """ + spikes = (x >= threshold) * bm.floor(x / threshold) + return bm.clip(spikes, 0.0, max_spikes_per_dt) + + +@step_pwl.defjvp +def step_pwl_jvp(primals, tangents): + x, threshold, window, max_spikes_per_dt = primals + x_dot, threshold_dot, window_dot, max_spikes_per_dt_dot = tangents + primal_out = step_pwl(*primals) + tangent_out = (x >= (threshold - window)) * (x_dot / threshold - threshold_dot * x / (threshold ** 2)) + return primal_out, tangent_out diff --git a/brainpy/math/operators/syn2post.py b/brainpy/math/operators/syn2post.py new file mode 100644 index 000000000..d022c14a1 --- /dev/null +++ b/brainpy/math/operators/syn2post.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- + +import jax.numpy as jnp +from jax import jit, vmap +from jax import ops as jops + +from brainpy.math.numpy_ops import as_device_array + + +_jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3)) +_jit_seg_prod = jit(jops.segment_prod, static_argnums=(2, 3)) +_jit_seg_max = jit(jops.segment_max, static_argnums=(2, 3)) +_jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3)) + + +__all__ = [ + 'syn2post_sum', 'syn2post', + 'syn2post_prod', + 'syn2post_max', + 'syn2post_min', + 'syn2post_mean', + 'syn2post_softmax', + +] + + +def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post summation computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] += syn_values[syn_i] + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. + post_num: int + The number of the post-synaptic neurons. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + + +syn2post = syn2post_sum + + +def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post product computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] *= syn_values[syn_i] + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) + + +def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post maximum computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) + + +def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post minimization computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) + + +def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post mean computation. + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) + return jnp.nan_to_num(nominator / denominator) + + +def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post softmax computation. + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) + syn_values = syn_values - syn_maxs[post_ids] + syn_values = jnp.exp(syn_values) + normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + softmax = syn_values / normalizers[post_ids] + return jnp.nan_to_num(softmax) + diff --git a/brainpy/math/operators/tests/test_differential_spike.py b/brainpy/math/operators/tests/test_differential_spike.py new file mode 100644 index 000000000..a4c6bd737 --- /dev/null +++ b/brainpy/math/operators/tests/test_differential_spike.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + + +import brainpy.math as bm + +from functools import partial + +import unittest + + +def test_sp_sigmoid_grad(): + f_grad = bm.vector_grad(lambda a: bm.spike_with_sigmoid_grad(a, 1.)) + x = bm.random.random(10) - 0.5 + print(f_grad(x)) + + +class TestSpike2SigmoidGrad(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestSpike2SigmoidGrad, self).__init__(*args, **kwargs) + + @partial(bm.vector_grad, return_value=True) + def f4(a, b): + return b + bm.spike_with_sigmoid_grad(a + 0.1, 100.) * bm.spike_with_sigmoid_grad(-a, 100.) + + @partial(bm.vector_grad, return_value=True) + def f5(a, b): + return b + bm.spike2_with_sigmoid_grad(a + 0.1, a, 100.) + + self.f4 = f4 + self.f5 = f5 + + def test_sp_sigmoid_grad2(self): + a = bm.ones(10) * 2 + b = bm.ones(10) + grad1, val1 = self.f4(a, b) + grad2, val2 = self.f5(a, b) + self.assertTrue(bm.array_equal(grad1, grad2)) + self.assertTrue(bm.array_equal(val1, val2)) + + def test_sp_sigmoid_grad1(self): + a = bm.zeros(10) + b = bm.ones(10) + grad1, val1 = self.f4(a, b) + grad2, val2 = self.f5(a, b) + print(grad2) + print(grad1) + + self.assertTrue(~bm.array_equal(grad1, grad2)) + self.assertTrue(~bm.array_equal(val1, val2)) + + def test_sp_sigmoid_grad3(self): + a = bm.ones(10) * -2 + b = bm.ones(10) + grad1, val1 = self.f4(a, b) + grad2, val2 = self.f5(a, b) + self.assertTrue(bm.array_equal(grad1, grad2)) + self.assertTrue(bm.array_equal(val1, val2)) + + + + + diff --git a/brainpy/math/operators/tests/test_op_register.py b/brainpy/math/operators/tests/test_op_register.py new file mode 100644 index 000000000..95089c1ea --- /dev/null +++ b/brainpy/math/operators/tests/test_op_register.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- + +import unittest +import brainpy as bp +import brainpy.math as bm +import matplotlib.pyplot as plt + +bm.set_platform('cpu') + + +def abs_eval(events, indices, indptr, post_val, values): + return post_val + + +def event_sum_op(outs, ins): + events, indices, indptr, post, values = ins + v = values[()] + outs.fill(0) + for i in range(len(events)): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + index = indices[j] + outs[index] += v + + +event_sum = bm.register_op(op_name='event_sum', cpu_func=event_sum_op, out_shapes=abs_eval) +event_sum = bm.jit(event_sum) + + +class ExponentialSyn(bp.dyn.TwoEndConn): + def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., + method='exp_auto'): + super(ExponentialSyn, self).__init__(pre=pre, post=post, conn=conn) + self.check_pre_attrs('spike') + self.check_post_attrs('input', 'V') + + # parameters + self.E = E + self.tau = tau + self.delay = delay + self.g_max = g_max + self.pre2post = self.conn.require('pre2post') + + # variables + self.g = bm.Variable(bm.zeros(self.post.num)) + + # function + self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) + + def update(self, tdi): + self.g.value = self.integral(self.g, tdi['t'], dt=tdi['dt']) + self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max) + self.post.input += self.g * (self.E - self.post.V) + + +class ExponentialSyn2(bp.dyn.TwoEndConn): + def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., + method='exp_auto'): + super(ExponentialSyn2, self).__init__(pre=pre, post=post, conn=conn) + self.check_pre_attrs('spike') + self.check_post_attrs('input', 'V') + + # parameters + self.E = E + self.tau = tau + self.delay = delay + self.g_max = g_max + self.pre2post = self.conn.require('pre2post') + + # variables + self.g = bm.Variable(bm.zeros(self.post.num)) + + # function + self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) + + def update(self, tdi): + self.g.value = self.integral(self.g, tdi['t'], tdi['dt']) + # Customized operator + # ------------------------------------------------------------------------------------------------------------ + post_val = bm.zeros(self.post.num) + self.g += event_sum(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max) + # ------------------------------------------------------------------------------------------------------------ + self.post.input += self.g * (self.E - self.post.V) + + +class EINet(bp.dyn.Network): + def __init__(self, syn_class, scale=1.0, method='exp_auto', ): + super(EINet, self).__init__() + + # network size + num_exc = int(3200 * scale) + num_inh = int(800 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) + self.E = bp.neurons.LIF(num_exc, **pars, method=method) + self.I = bp.neurons.LIF(num_inh, **pars, method=method) + self.E.V[:] = bp.math.random.randn(num_exc) * 2 - 55. + self.I.V[:] = bp.math.random.randn(num_inh) * 2 - 55. + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = syn_class(self.E, self.E, bp.conn.FixedProb(0.02), E=0., g_max=we, tau=5., method=method) + self.E2I = syn_class(self.E, self.I, bp.conn.FixedProb(0.02), E=0., g_max=we, tau=5., method=method) + self.I2E = syn_class(self.I, self.E, bp.conn.FixedProb(0.02), E=-80., g_max=wi, tau=10., method=method) + self.I2I = syn_class(self.I, self.I, bp.conn.FixedProb(0.02), E=-80., g_max=wi, tau=10., method=method) + + + +class TestOpRegister(unittest.TestCase): + def test_op(self): + + fig, gs = bp.visualize.get_figure(1, 2, 4, 5) + + net = EINet(ExponentialSyn, scale=1., method='euler') + runner = bp.dyn.DSRunner( + net, + inputs=[(net.E.input, 20.), (net.I.input, 20.)], + monitors={'E.spike': net.E.spike}, + ) + t, _ = runner.run(100., eval_time=True) + print(t) + ax = fig.add_subplot(gs[0, 0]) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax) + + net2 = EINet(ExponentialSyn2, scale=1., method='euler') + runner2 = bp.dyn.DSRunner( + net2, + inputs=[(net2.E.input, 20.), (net2.I.input, 20.)], + monitors={'E.spike': net2.E.spike}, + ) + t, _ = runner2.run(100., eval_time=True) + print(t) + ax = fig.add_subplot(gs[0, 1]) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax, show=True) + plt.close() diff --git a/brainpy/math/tests/test_oprators.py b/brainpy/math/operators/tests/test_oprators.py similarity index 100% rename from brainpy/math/tests/test_oprators.py rename to brainpy/math/operators/tests/test_oprators.py diff --git a/brainpy/math/operators/utils.py b/brainpy/math/operators/utils.py new file mode 100644 index 000000000..548620f81 --- /dev/null +++ b/brainpy/math/operators/utils.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +from brainpy.errors import PackageMissingError + +try: + import brainpylib +except ModuleNotFoundError: + brainpylib = None + + +_BRAINPYLIB_MINIMAL_VERSION = '0.0.5' + + +def _check_brainpylib(ops_name): + if brainpylib is not None: + if brainpylib.__version__ < _BRAINPYLIB_MINIMAL_VERSION: + raise PackageMissingError( + f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n' + f'Please install it through:\n\n' + f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION} -U' + ) + else: + raise PackageMissingError( + f'"brainpylib" must be installed when the user ' + f'wants to use "{ops_name}" operator. \n' + f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n' + f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}' + ) \ No newline at end of file diff --git a/brainpy/math/operators/wrap_jax.py b/brainpy/math/operators/wrap_jax.py new file mode 100644 index 000000000..432bcc8cd --- /dev/null +++ b/brainpy/math/operators/wrap_jax.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + + +from typing import Union, Optional + +import jax.numpy as jnp +from jax import lax +from jax import ops as jops + +from brainpy.math.jaxarray import JaxArray + +__all__ = [ + 'segment_sum', + 'segment_prod', + 'segment_max', + 'segment_min', +] + + +def segment_sum(data: Union[JaxArray, jnp.ndarray], + segment_ids: Union[JaxArray, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + return JaxArray(jops.segment_sum(data.value if isinstance(data, JaxArray) else data, + segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, mode)) + + +def segment_prod(data: Union[JaxArray, jnp.ndarray], + segment_ids: Union[JaxArray, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + return JaxArray(jops.segment_prod(data.value if isinstance(data, JaxArray) else data, + segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, mode)) + + +def segment_max(data: Union[JaxArray, jnp.ndarray], + segment_ids: Union[JaxArray, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + return JaxArray(jops.segment_max(data.value if isinstance(data, JaxArray) else data, + segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, mode)) + + +def segment_min(data: Union[JaxArray, jnp.ndarray], + segment_ids: Union[JaxArray, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + return JaxArray(jops.segment_min(data.value if isinstance(data, JaxArray) else data, + segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, mode)) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 7e5f3f78e..178605560 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -446,10 +446,11 @@ def split_keys(self, n): # random functions # # ---------------- # - def rand(self, *dn): - return JaxArray(jr.uniform(self.split_key(), shape=dn, minval=0., maxval=1.)) + def rand(self, *dn, key=None): + key = self.split_key() if key is None else key + return JaxArray(jr.uniform(key, shape=dn, minval=0., maxval=1.)) - def randint(self, low, high=None, size=None, dtype=jnp.int_): + def randint(self, low, high=None, size=None, dtype=jnp.int_, key=None): low = _remove_jax_array(low) high = _remove_jax_array(high) if high is None: @@ -460,11 +461,12 @@ def randint(self, low, high=None, size=None, dtype=jnp.int_): if size is None: size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) - return JaxArray(jr.randint(self.split_key(), + key = self.split_key() if key is None else key + return JaxArray(jr.randint(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)) - def random_integers(self, low, high=None, size=None): + def random_integers(self, low, high=None, size=None, key=None): low = _remove_jax_array(low) high = _remove_jax_array(high) low = _check_py_seq(low) @@ -475,161 +477,182 @@ def random_integers(self, low, high=None, size=None): high += 1 if size is None: size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) - return JaxArray(jr.randint(self.split_key(), + key = self.split_key() if key is None else key + return JaxArray(jr.randint(key, shape=_size2shape(size), minval=low, maxval=high)) - def randn(self, *dn): - return JaxArray(jr.normal(self.split_key(), shape=dn)) + def randn(self, *dn, key=None): + key = self.split_key() if key is None else key + return JaxArray(jr.normal(key, shape=dn)) - def random(self, size=None): - return JaxArray(jr.uniform(self.split_key(), shape=_size2shape(size), minval=0., maxval=1.)) + def random(self, size=None, key=None): + key = self.split_key() if key is None else key + return JaxArray(jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.)) - def random_sample(self, size=None): - return self.random(size=size) + def random_sample(self, size=None, key=None): + return self.random(size=size, key=key) - def ranf(self, size=None): - return self.random(size=size) + def ranf(self, size=None, key=None): + return self.random(size=size, key=key) - def sample(self, size=None): - return self.random(size=size) + def sample(self, size=None, key=None): + return self.random(size=size, key=key) - def choice(self, a, size=None, replace=True, p=None): + def choice(self, a, size=None, replace=True, p=None, key=None): a = _remove_jax_array(a) p = _remove_jax_array(p) a = _check_py_seq(a) p = _check_py_seq(p) - return JaxArray(jr.choice(self.split_key(), a=a, shape=_size2shape(size), + key = self.split_key() if key is None else key + return JaxArray(jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)) - def permutation(self, x): + def permutation(self, x, key=None): x = x.value if isinstance(x, JaxArray) else x x = _check_py_seq(x) - return JaxArray(jr.permutation(self.split_key(), x)) + key = self.split_key() if key is None else key + return JaxArray(jr.permutation(key, x)) - def shuffle(self, x, axis=0): + def shuffle(self, x, axis=0, key=None): assert isinstance(x, JaxArray), f'Must be a JaxArray, but got {type(x)}' - x.value = jr.permutation(self.split_key(), x.value, axis=axis) + key = self.split_key() if key is None else key + x.value = jr.permutation(key, x.value, axis=axis) - def beta(self, a, b, size=None): + def beta(self, a, b, size=None, key=None): a = a.value if isinstance(a, JaxArray) else a b = b.value if isinstance(b, JaxArray) else b a = _check_py_seq(a) b = _check_py_seq(b) if size is None: size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b)) - return JaxArray(jr.beta(self.split_key(), a=a, b=b, shape=_size2shape(size))) + key = self.split_key() if key is None else key + return JaxArray(jr.beta(key, a=a, b=b, shape=_size2shape(size))) - def exponential(self, scale=None, size=None): + def exponential(self, scale=None, size=None, key=None): scale = _remove_jax_array(scale) scale = _check_py_seq(scale) if size is None: size = jnp.shape(scale) - r = jr.exponential(self.split_key(), shape=_size2shape(size)) + key = self.split_key() if key is None else key + r = jr.exponential(key, shape=_size2shape(size)) if scale is None: return JaxArray(r) else: return JaxArray(r / scale) - def gamma(self, shape, scale=None, size=None): + def gamma(self, shape, scale=None, size=None, key=None): shape = _remove_jax_array(shape) scale = _remove_jax_array(scale) shape = _check_py_seq(shape) scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale)) - r = jr.gamma(self.split_key(), a=shape, shape=_size2shape(size)) + key = self.split_key() if key is None else key + r = jr.gamma(key, a=shape, shape=_size2shape(size)) if scale is None: return JaxArray(r) else: return JaxArray(r * scale) - def gumbel(self, loc=None, scale=None, size=None): + def gumbel(self, loc=None, scale=None, size=None, key=None): loc = _remove_jax_array(loc) scale = _remove_jax_array(scale) loc = _check_py_seq(loc) scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - return _loc_scale(loc, scale, jr.gumbel(self.split_key(), shape=_size2shape(size))) + key = self.split_key() if key is None else key + return _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size))) - def laplace(self, loc=None, scale=None, size=None): + def laplace(self, loc=None, scale=None, size=None, key=None): loc = _remove_jax_array(loc) scale = _remove_jax_array(scale) loc = _check_py_seq(loc) scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - return _loc_scale(loc, scale, jr.laplace(self.split_key(), shape=_size2shape(size))) + key = self.split_key() if key is None else key + return _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size))) - def logistic(self, loc=None, scale=None, size=None): + def logistic(self, loc=None, scale=None, size=None, key=None): loc = _remove_jax_array(loc) scale = _remove_jax_array(scale) loc = _check_py_seq(loc) scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - return _loc_scale(loc, scale, jr.logistic(self.split_key(), shape=_size2shape(size))) + key = self.split_key() if key is None else key + return _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size))) - def normal(self, loc=None, scale=None, size=None): + def normal(self, loc=None, scale=None, size=None, key=None): loc = _remove_jax_array(loc) scale = _remove_jax_array(scale) loc = _check_py_seq(loc) scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc)) - return _loc_scale(loc, scale, jr.normal(self.split_key(), shape=_size2shape(size))) + key = self.split_key() if key is None else key + return _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size))) - def pareto(self, a, size=None): + def pareto(self, a, size=None, key=None): a = _remove_jax_array(a) a = _check_py_seq(a) if size is None: size = jnp.shape(a) - return JaxArray(jr.pareto(self.split_key(), b=a, shape=_size2shape(size))) + key = self.split_key() if key is None else key + return JaxArray(jr.pareto(key, b=a, shape=_size2shape(size))) - def poisson(self, lam=1.0, size=None): + def poisson(self, lam=1.0, size=None, key=None): lam = _check_py_seq(_remove_jax_array(lam)) if size is None: size = jnp.shape(lam) - return JaxArray(jr.poisson(self.split_key(), lam=lam, shape=_size2shape(size))) + key = self.split_key() if key is None else key + return JaxArray(jr.poisson(key, lam=lam, shape=_size2shape(size))) - def standard_cauchy(self, size=None): - return JaxArray(jr.cauchy(self.split_key(), shape=_size2shape(size))) + def standard_cauchy(self, size=None, key=None): + key = self.split_key() if key is None else key + return JaxArray(jr.cauchy(key, shape=_size2shape(size))) - def standard_exponential(self, size=None): - return JaxArray(jr.exponential(self.split_key(), shape=_size2shape(size))) + def standard_exponential(self, size=None, key=None): + key = self.split_key() if key is None else key + return JaxArray(jr.exponential(key, shape=_size2shape(size))) - def standard_gamma(self, shape, size=None): + def standard_gamma(self, shape, size=None, key=None): shape = _remove_jax_array(shape) shape = _check_py_seq(shape) if size is None: size = jnp.shape(shape) - return JaxArray(jr.gamma(self.split_key(), a=shape, shape=_size2shape(size))) + key = self.split_key() if key is None else key + return JaxArray(jr.gamma(key, a=shape, shape=_size2shape(size))) - def standard_normal(self, size=None): - return JaxArray(jr.normal(self.split_key(), shape=_size2shape(size))) + def standard_normal(self, size=None, key=None): + key = self.split_key() if key is None else key + return JaxArray(jr.normal(key, shape=_size2shape(size))) - def standard_t(self, df, size=None): + def standard_t(self, df, size=None, key=None): df = _remove_jax_array(df) df = _check_py_seq(df) if size is None: size = jnp.shape(size) - return JaxArray(jr.t(self.split_key(), df=df, shape=_size2shape(size))) + key = self.split_key() if key is None else key + return JaxArray(jr.t(key, df=df, shape=_size2shape(size))) - def uniform(self, low=0.0, high=1.0, size=None): + def uniform(self, low=0.0, high=1.0, size=None, key=None): low = _remove_jax_array(low) high = _remove_jax_array(high) low = _check_py_seq(low) high = _check_py_seq(high) if size is None: size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) - return JaxArray(jr.uniform(self.split_key(), + key = self.split_key() if key is None else key + return JaxArray(jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high)) - def truncated_normal(self, lower, upper, size, scale=None): + def truncated_normal(self, lower, upper, size, scale=None, key=None): lower = _remove_jax_array(lower) lower = _check_py_seq(lower) upper = _remove_jax_array(upper) @@ -640,7 +663,8 @@ def truncated_normal(self, lower, upper, size, scale=None): size = lax.broadcast_shapes(jnp.shape(lower), jnp.shape(upper), jnp.shape(scale)) - rands = jr.truncated_normal(self.split_key(), + key = self.split_key() if key is None else key + rands = jr.truncated_normal(key, lower=lower, upper=upper, shape=_size2shape(size)) @@ -652,62 +676,69 @@ def truncated_normal(self, lower, upper, size, scale=None): def _check_p(self, p): raise ValueError(f'Parameter p should be within [0, 1], but we got {p}') - def bernoulli(self, p, size=None): + def bernoulli(self, p, size=None, key=None): p = _check_py_seq(_remove_jax_array(p)) check_error_in_jit(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) if size is None: size = jnp.shape(p) - return JaxArray(jr.bernoulli(self.split_key(), p=p, shape=_size2shape(size))) + key = self.split_key() if key is None else key + return JaxArray(jr.bernoulli(key, p=p, shape=_size2shape(size))) - def lognormal(self, mean=None, sigma=None, size=None): + def lognormal(self, mean=None, sigma=None, size=None, key=None): mean = _check_py_seq(_remove_jax_array(mean)) sigma = _check_py_seq(_remove_jax_array(sigma)) if size is None: size = jnp.broadcast_shapes(jnp.shape(mean), jnp.shape(sigma)) - samples = jr.normal(self.split_key(), shape=_size2shape(size)) + key = self.split_key() if key is None else key + samples = jr.normal(key, shape=_size2shape(size)) samples = _loc_scale(mean, sigma, samples) samples = jnp.exp(samples.value) return JaxArray(samples) - def binomial(self, n, p, size=None): + def binomial(self, n, p, size=None, key=None): n = _check_py_seq(n.value if isinstance(n, JaxArray) else n) p = _check_py_seq(p.value if isinstance(p, JaxArray) else p) check_error_in_jit(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) if size is None: size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)) - return JaxArray(_binomial(self.split_key(), p, n, shape=_size2shape(size))) + key = self.split_key() if key is None else key + return JaxArray(_binomial(key, p, n, shape=_size2shape(size))) - def chisquare(self, df, size=None): + def chisquare(self, df, size=None, key=None): df = _check_py_seq(_remove_jax_array(df)) + key = self.split_key() if key is None else key if size is None: if jnp.ndim(df) == 0: - dist = jr.normal(self.split_key(), (df,)) ** 2 + dist = jr.normal(key, (df,)) ** 2 dist = dist.sum() else: raise NotImplementedError('Do not support non-scale "df" when "size" is None') else: - dist = jr.normal(self.split_key(), (df,) + _size2shape(size)) ** 2 + dist = jr.normal(key, (df,) + _size2shape(size)) ** 2 dist = dist.sum(axis=0) return JaxArray(dist) - def dirichlet(self, alpha, size=None): + def dirichlet(self, alpha, size=None, key=None): + key = self.split_key() if key is None else key alpha = _check_py_seq(_remove_jax_array(alpha)) - return JaxArray(jr.dirichlet(self.split_key(), alpha=alpha, shape=_size2shape(size))) + return JaxArray(jr.dirichlet(key, alpha=alpha, shape=_size2shape(size))) - def geometric(self, p, size=None): + def geometric(self, p, size=None, key=None): p = _remove_jax_array(p) p = _check_py_seq(p) if size is None: size = jnp.shape(p) - u = jr.uniform(self.split_key(), size) + key = self.split_key() if key is None else key + u = jr.uniform(key, size) r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p)) return JaxArray(r) def _check_p2(self, p): raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}') - def multinomial(self, n, pvals, size=None): + def multinomial(self, n, pvals, size=None, key=None): + key = self.split_key() if key is None else key n = _check_py_seq(_remove_jax_array(n)) pvals = _check_py_seq(_remove_jax_array(pvals)) check_error_in_jit(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals) @@ -716,13 +747,14 @@ def multinomial(self, n, pvals, size=None): size = _size2shape(size) n_max = int(np.max(jax.device_get(n))) batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n)) - return JaxArray(_multinomial(self.split_key(), pvals, n, n_max, batch_shape + size)) + return JaxArray(_multinomial(key, pvals, n, n_max, batch_shape + size)) - def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky'): + def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky', key=None): if method not in {'svd', 'eigh', 'cholesky'}: raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}") mean = _check_py_seq(_remove_jax_array(mean)) cov = _check_py_seq(_remove_jax_array(cov)) + key = self.split_key() if key is None else key if not jnp.ndim(mean) >= 1: raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}") @@ -746,33 +778,37 @@ def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky'): factor = v * jnp.sqrt(w[..., None, :]) else: # 'cholesky' factor = jnp.linalg.cholesky(cov) - normal_samples = jr.normal(self.split_key(), size + mean.shape[-1:]) + normal_samples = jr.normal(key, size + mean.shape[-1:]) r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) return JaxArray(r) - def rayleigh(self, scale=1.0, size=None): + def rayleigh(self, scale=1.0, size=None, key=None): scale = _check_py_seq(_remove_jax_array(scale)) if size is None: size = jnp.shape(scale) - x = jnp.sqrt(-2. * jnp.log(jr.uniform(self.split_key(), shape=_size2shape(size), minval=0, maxval=1))) + key = self.split_key() if key is None else key + x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), minval=0, maxval=1))) return JaxArray(x * scale) - def triangular(self, size=None): - bernoulli_samples = jr.bernoulli(self.split_key(), p=0.5, shape=_size2shape(size)) + def triangular(self, size=None, key=None): + key = self.split_key() if key is None else key + bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size)) return JaxArray(2 * bernoulli_samples - 1) - def vonmises(self, mu, kappa, size=None): + def vonmises(self, mu, kappa, size=None, key=None): + key = self.split_key() if key is None else key mu = _check_py_seq(_remove_jax_array(mu)) kappa = _check_py_seq(_remove_jax_array(kappa)) if size is None: size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa)) size = _size2shape(size) - samples = _von_mises_centered(self.split_key(), kappa, size) + samples = _von_mises_centered(key, kappa, size) samples = samples + mu samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi return JaxArray(samples) - def weibull(self, a, size=None): + def weibull(self, a, size=None, key=None): + key = self.split_key() if key is None else key a = _check_py_seq(_remove_jax_array(a)) if size is None: size = jnp.shape(a) @@ -780,11 +816,11 @@ def weibull(self, a, size=None): if jnp.size(a) > 1: raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') size = _size2shape(size) - random_uniform = jr.uniform(key=self.split_key(), shape=size, minval=0, maxval=1) + random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1) r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) return JaxArray(r) - def weibull_min(self, a, scale=None, size=None): + def weibull_min(self, a, scale=None, size=None, key=None): """Sample from a Weibull minimum distribution. Parameters @@ -801,6 +837,7 @@ def weibull_min(self, a, scale=None, size=None): out: array_like The sampling results. """ + key = self.split_key() if key is None else key a = _check_py_seq(_remove_jax_array(a)) scale = _check_py_seq(_remove_jax_array(scale)) if size is None: @@ -809,35 +846,40 @@ def weibull_min(self, a, scale=None, size=None): if jnp.size(a) > 1: raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') size = _size2shape(size) - random_uniform = jr.uniform(key=self.split_key(), shape=size, minval=0, maxval=1) + random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1) r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) if scale is not None: r /= scale return JaxArray(r) - def maxwell(self, size=None): + def maxwell(self, size=None, key=None): + key = self.split_key() if key is None else key shape = core.canonicalize_shape(_size2shape(size)) + (3,) - norm_rvs = jr.normal(key=self.split_key(), shape=shape) + norm_rvs = jr.normal(key=key, shape=shape) return JaxArray(jnp.linalg.norm(norm_rvs, axis=-1)) - def negative_binomial(self, n, p, size=None): + def negative_binomial(self, n, p, size=None, key=None): n = _check_py_seq(_remove_jax_array(n)) p = _check_py_seq(_remove_jax_array(p)) if size is None: size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)) size = _size2shape(size) logits = jnp.log(p) - jnp.log1p(-p) - rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size) - return JaxArray(self.poisson(lam=rate)) + if key is None: + keys = self.split_keys(2) + else: + keys = jr.split(key, 2) + rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0]) + return JaxArray(self.poisson(lam=rate, key=keys[1])) - def wald(self, mean, scale, size=None): + def wald(self, mean, scale, size=None, key=None): mean = _check_py_seq(_remove_jax_array(mean)) scale = _check_py_seq(_remove_jax_array(scale)) if size is None: size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale)) size = _size2shape(size) sampled_chi2 = jnp.square(self.randn(*size).value) - sampled_uniform = self.uniform(size=size).value + sampled_uniform = self.uniform(size=size, key=key).value # Wikipedia defines an intermediate x with the formula # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2) # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration. @@ -869,43 +911,66 @@ def wald(self, mean, scale, size=None): jnp.square(mean) / sampled) return JaxArray(res) - def t(self, df, size=None): + def t(self, df, size=None, key=None): df = _check_py_seq(_remove_jax_array(df)) if size is None: size = np.shape(df) else: size = _size2shape(size) _check_shape("t", size, np.shape(df)) - keys = self.split_keys(2) + if key is None: + keys = self.split_keys(2) + else: + keys = jr.split(key, 2) n = jr.normal(keys[0], size) two = _const(n, 2) half_df = lax.div(df, two) g = jr.gamma(keys[1], half_df, size) return JaxArray(n * jnp.sqrt(half_df / g)) - def orthogonal(self, n: int, size=None): + def orthogonal(self, n: int, size=None, key=None): + key = self.split_key() if key is None else key size = _size2shape(size) _check_shape("orthogonal", size) n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") - z = jr.normal(self.split_key(), size + (n, n)) + z = jr.normal(key, size + (n, n)) q, r = jnp.linalg.qr(z) d = jnp.diagonal(r, 0, -2, -1) return JaxArray(q * jnp.expand_dims(d / abs(d), -2)) - def noncentral_chisquare(self, df, nonc, size=None): + def noncentral_chisquare(self, df, nonc, size=None, key=None): df = _check_py_seq(_remove_jax_array(df)) nonc = _check_py_seq(_remove_jax_array(nonc)) if size is None: size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc)) size = _size2shape(size) - i = jr.poisson(self.split_key(), 0.5 * nonc, shape=size) - n = jr.normal(self.split_key(), shape=size) + jnp.sqrt(nonc) + if key is None: + keys = self.split_keys(3) + else: + keys = jr.split(key, 3) + i = jr.poisson(keys[0], 0.5 * nonc, shape=size) + n = jr.normal(keys[1], shape=size) + jnp.sqrt(nonc) cond = jnp.greater(df, 1.0) df2 = jnp.where(cond, df - 1.0, df + 2.0 * i) - chi2 = 2.0 * jr.gamma(self.split_key(), 0.5 * df2, shape=size) + chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size) return JaxArray(jnp.where(cond, chi2 + n * n, chi2)) - def zipf(self, a, size=None): + def loggamma(self, a, size=None, key=None): + key = self.split_key() if key is None else key + a = _check_py_seq(_remove_jax_array(a)) + if size is None: + size = jnp.shape(a) + return JaxArray(jr.loggamma(key, a, shape=_size2shape(size))) + + def categorical(self, logits, axis: int = -1, size=None, key=None): + key = self.split_key() if key is None else key + logits = _check_py_seq(_remove_jax_array(logits)) + if size is None: + size = list(jnp.shape(logits)) + size.pop(axis) + return JaxArray(jr.categorical(key, logits, axis=axis, shape=_size2shape(size))) + + def zipf(self, a, size=None, key=None): a = _check_py_seq(_remove_jax_array(a)) if size is None: size = jnp.shape(a) @@ -913,7 +978,7 @@ def zipf(self, a, size=None): a, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))) - def power(self, a, size=None): + def power(self, a, size=None, key=None): a = _check_py_seq(_remove_jax_array(a)) if size is None: size = jnp.shape(a) @@ -921,7 +986,7 @@ def power(self, a, size=None): return JaxArray(call(lambda a: np.random.power(a=a, size=size), a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))) - def f(self, dfnum, dfden, size=None): + def f(self, dfnum, dfden, size=None, key=None): dfnum = _remove_jax_array(dfnum) dfden = _remove_jax_array(dfden) dfnum = _check_py_seq(dfnum) @@ -936,7 +1001,7 @@ def f(self, dfnum, dfden, size=None): d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))) - def hypergeometric(self, ngood, nbad, nsample, size=None): + def hypergeometric(self, ngood, nbad, nsample, size=None, key=None): ngood = _check_py_seq(_remove_jax_array(ngood)) nbad = _check_py_seq(_remove_jax_array(nbad)) nsample = _check_py_seq(_remove_jax_array(nsample)) @@ -953,7 +1018,7 @@ def hypergeometric(self, ngood, nbad, nsample, size=None): size=size), d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))) - def logseries(self, p, size=None): + def logseries(self, p, size=None, key=None): p = _check_py_seq(_remove_jax_array(p)) if size is None: size = jnp.shape(p) @@ -961,7 +1026,7 @@ def logseries(self, p, size=None): return JaxArray(call(lambda p: np.random.logseries(p=p, size=size), p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))) - def noncentral_f(self, dfnum, dfden, nonc, size=None): + def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None): dfnum = _check_py_seq(_remove_jax_array(dfnum)) dfden = _check_py_seq(_remove_jax_array(dfden)) nonc = _check_py_seq(_remove_jax_array(nonc)) @@ -977,19 +1042,6 @@ def noncentral_f(self, dfnum, dfden, nonc, size=None): size=size), d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))) - def loggamma(self, a, size=None): - a = _check_py_seq(_remove_jax_array(a)) - if size is None: - size = jnp.shape(a) - return JaxArray(jr.loggamma(self.split_key(), a, shape=_size2shape(size))) - - def categorical(self, logits, axis: int = -1, size=None): - logits = _check_py_seq(_remove_jax_array(logits)) - if size is None: - size = list(jnp.shape(logits)) - size.pop(axis) - return JaxArray(jr.categorical(self.split_key(), logits, axis=axis, shape=_size2shape(size))) - # alias Generator = RandomState @@ -1016,138 +1068,138 @@ def seed(seed=None): @wraps(np.random.rand) -def rand(*dn): - return DEFAULT.rand(*dn) +def rand(*dn, key=None): + return DEFAULT.rand(*dn, key=key) @wraps(np.random.randint) -def randint(low, high=None, size=None, dtype=jnp.int_): - return DEFAULT.randint(low, high=high, size=size, dtype=dtype) +def randint(low, high=None, size=None, dtype=jnp.int_, key=None): + return DEFAULT.randint(low, high=high, size=size, dtype=dtype, key=key) @wraps(np.random.random_integers) -def random_integers(low, high=None, size=None): - return DEFAULT.random_integers(low, high=high, size=size) +def random_integers(low, high=None, size=None, key=None): + return DEFAULT.random_integers(low, high=high, size=size, key=key) @wraps(np.random.randn) -def randn(*dn): - return DEFAULT.randn(*dn) +def randn(*dn, key=None): + return DEFAULT.randn(*dn, key=key) @wraps(np.random.random) -def random(size=None): - return DEFAULT.random(size) +def random(size=None, key=None): + return DEFAULT.random(size, key=key) @wraps(np.random.random_sample) -def random_sample(size=None): - return DEFAULT.random_sample(size) +def random_sample(size=None, key=None): + return DEFAULT.random_sample(size, key=key) @wraps(np.random.ranf) -def ranf(size=None): - return DEFAULT.ranf(size) +def ranf(size=None, key=None): + return DEFAULT.ranf(size, key=key) @wraps(np.random.sample) -def sample(size=None): - return DEFAULT.sample(size) +def sample(size=None, key=None): + return DEFAULT.sample(size, key=key) @wraps(np.random.choice) -def choice(a, size=None, replace=True, p=None): +def choice(a, size=None, replace=True, p=None, key=None): a = _remove_jax_array(a) - return DEFAULT.choice(a=a, size=size, replace=replace, p=p) + return DEFAULT.choice(a=a, size=size, replace=replace, p=p, key=key) @wraps(np.random.permutation) -def permutation(x): - return DEFAULT.permutation(x) +def permutation(x, key=None): + return DEFAULT.permutation(x, key=key) @wraps(np.random.shuffle) -def shuffle(x, axis=0): - DEFAULT.shuffle(x, axis) +def shuffle(x, axis=0, key=None): + DEFAULT.shuffle(x, axis, key=key) @wraps(np.random.beta) -def beta(a, b, size=None): - return DEFAULT.beta(a, b, size=size) +def beta(a, b, size=None, key=None): + return DEFAULT.beta(a, b, size=size, key=key) @wraps(np.random.exponential) -def exponential(scale=None, size=None): - return DEFAULT.exponential(scale, size) +def exponential(scale=None, size=None, key=None): + return DEFAULT.exponential(scale, size, key=key) @wraps(np.random.gamma) -def gamma(shape, scale=None, size=None): - return DEFAULT.gamma(shape, scale, size=size) +def gamma(shape, scale=None, size=None, key=None): + return DEFAULT.gamma(shape, scale, size=size, key=key) @wraps(np.random.gumbel) -def gumbel(loc=None, scale=None, size=None): - return DEFAULT.gumbel(loc, scale, size=size) +def gumbel(loc=None, scale=None, size=None, key=None): + return DEFAULT.gumbel(loc, scale, size=size, key=key) @wraps(np.random.laplace) -def laplace(loc=None, scale=None, size=None): - return DEFAULT.laplace(loc, scale, size) +def laplace(loc=None, scale=None, size=None, key=None): + return DEFAULT.laplace(loc, scale, size, key=key) @wraps(np.random.logistic) -def logistic(loc=None, scale=None, size=None): - return DEFAULT.logistic(loc, scale, size) +def logistic(loc=None, scale=None, size=None, key=None): + return DEFAULT.logistic(loc, scale, size, key=key) @wraps(np.random.normal) -def normal(loc=None, scale=None, size=None): - return DEFAULT.normal(loc, scale, size) +def normal(loc=None, scale=None, size=None, key=None): + return DEFAULT.normal(loc, scale, size, key=key) @wraps(np.random.pareto) -def pareto(a, size=None): - return DEFAULT.pareto(a, size) +def pareto(a, size=None, key=None): + return DEFAULT.pareto(a, size, key=key) @wraps(np.random.poisson) -def poisson(lam=1.0, size=None): - return DEFAULT.poisson(lam, size) +def poisson(lam=1.0, size=None, key=None): + return DEFAULT.poisson(lam, size, key=key) @wraps(np.random.standard_cauchy) -def standard_cauchy(size=None): - return DEFAULT.standard_cauchy(size) +def standard_cauchy(size=None, key=None): + return DEFAULT.standard_cauchy(size, key=key) @wraps(np.random.standard_exponential) -def standard_exponential(size=None): - return DEFAULT.standard_exponential(size) +def standard_exponential(size=None, key=None): + return DEFAULT.standard_exponential(size, key=key) @wraps(np.random.standard_gamma) -def standard_gamma(shape, size=None): - return DEFAULT.standard_gamma(shape, size) +def standard_gamma(shape, size=None, key=None): + return DEFAULT.standard_gamma(shape, size, key=key) @wraps(np.random.standard_normal) -def standard_normal(size=None): - return DEFAULT.standard_normal(size) +def standard_normal(size=None, key=None): + return DEFAULT.standard_normal(size, key=key) @wraps(np.random.standard_t) -def standard_t(df, size=None): - return DEFAULT.standard_t(df, size) +def standard_t(df, size=None, key=None): + return DEFAULT.standard_t(df, size, key=key) @wraps(np.random.uniform) -def uniform(low=0.0, high=1.0, size=None): - return DEFAULT.uniform(low, high, size) +def uniform(low=0.0, high=1.0, size=None, key=None): + return DEFAULT.uniform(low, high, size, key=key) @wraps(jr.truncated_normal) -def truncated_normal(lower, upper, size=None, scale=None): +def truncated_normal(lower, upper, size=None, scale=None, key=None): """Sample truncated standard normal random values with given shape and dtype. Parameters @@ -1174,11 +1226,11 @@ def truncated_normal(lower, upper, size=None, scale=None): ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. Returns values in the open interval ``(lower, upper)``. """ - return DEFAULT.truncated_normal(lower, upper, size, scale) + return DEFAULT.truncated_normal(lower, upper, size, scale, key=key) @wraps(jr.bernoulli) -def bernoulli(p=0.5, size=None): +def bernoulli(p=0.5, size=None, key=None): """Sample Bernoulli random values with given shape and mean. Parameters @@ -1198,120 +1250,120 @@ def bernoulli(p=0.5, size=None): A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ - return DEFAULT.bernoulli(p, size) + return DEFAULT.bernoulli(p, size, key=key) @wraps(np.random.lognormal) -def lognormal(mean=None, sigma=None, size=None): - return DEFAULT.lognormal(mean, sigma, size) +def lognormal(mean=None, sigma=None, size=None, key=None): + return DEFAULT.lognormal(mean, sigma, size, key=key) @wraps(np.random.binomial) -def binomial(n, p, size=None): - return DEFAULT.binomial(n, p, size) +def binomial(n, p, size=None, key=None): + return DEFAULT.binomial(n, p, size, key=key) @wraps(np.random.chisquare) -def chisquare(df, size=None): - return DEFAULT.chisquare(df, size) +def chisquare(df, size=None, key=None): + return DEFAULT.chisquare(df, size, key=key) @wraps(np.random.dirichlet) -def dirichlet(alpha, size=None): - return DEFAULT.dirichlet(alpha, size) +def dirichlet(alpha, size=None, key=None): + return DEFAULT.dirichlet(alpha, size, key=key) @wraps(np.random.geometric) -def geometric(p, size=None): - return DEFAULT.geometric(p, size) +def geometric(p, size=None, key=None): + return DEFAULT.geometric(p, size, key=key) @wraps(np.random.f) -def f(dfnum, dfden, size=None): - return DEFAULT.f(dfnum, dfden, size) +def f(dfnum, dfden, size=None, key=None): + return DEFAULT.f(dfnum, dfden, size, key=key) @wraps(np.random.hypergeometric) -def hypergeometric(ngood, nbad, nsample, size=None): - return DEFAULT.hypergeometric(ngood, nbad, nsample, size) +def hypergeometric(ngood, nbad, nsample, size=None, key=None): + return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key) @wraps(np.random.logseries) -def logseries(p, size=None): - return DEFAULT.logseries(p, size) +def logseries(p, size=None, key=None): + return DEFAULT.logseries(p, size, key=key) @wraps(np.random.multinomial) -def multinomial(n, pvals, size=None): - return DEFAULT.multinomial(n, pvals, size) +def multinomial(n, pvals, size=None, key=None): + return DEFAULT.multinomial(n, pvals, size, key=key) @wraps(np.random.multivariate_normal) -def multivariate_normal(mean, cov, size=None, method: str = 'cholesky'): - return DEFAULT.multivariate_normal(mean, cov, size, method) +def multivariate_normal(mean, cov, size=None, method: str = 'cholesky', key=None): + return DEFAULT.multivariate_normal(mean, cov, size, method, key=key) @wraps(np.random.negative_binomial) -def negative_binomial(n, p, size=None): - return DEFAULT.negative_binomial(n, p, size) +def negative_binomial(n, p, size=None, key=None): + return DEFAULT.negative_binomial(n, p, size, key=key) @wraps(np.random.noncentral_chisquare) -def noncentral_chisquare(df, nonc, size=None): - return DEFAULT.noncentral_chisquare(df, nonc, size) +def noncentral_chisquare(df, nonc, size=None, key=None): + return DEFAULT.noncentral_chisquare(df, nonc, size, key=key) @wraps(np.random.noncentral_f) -def noncentral_f(dfnum, dfden, nonc, size=None): - return DEFAULT.noncentral_f(dfnum, dfden, nonc, size) +def noncentral_f(dfnum, dfden, nonc, size=None, key=None): + return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key) @wraps(np.random.power) -def power(a, size=None): - return DEFAULT.power(a, size) +def power(a, size=None, key=None): + return DEFAULT.power(a, size, key=key) @wraps(np.random.rayleigh) -def rayleigh(scale=1.0, size=None): - return DEFAULT.rayleigh(scale, size) +def rayleigh(scale=1.0, size=None, key=None): + return DEFAULT.rayleigh(scale, size, key=key) @wraps(np.random.triangular) -def triangular(size=None): - return DEFAULT.triangular(size) +def triangular(size=None, key=None): + return DEFAULT.triangular(size, key=key) @wraps(np.random.vonmises) -def vonmises(mu, kappa, size=None): - return DEFAULT.vonmises(mu, kappa, size) +def vonmises(mu, kappa, size=None, key=None): + return DEFAULT.vonmises(mu, kappa, size, key=key) @wraps(np.random.wald) -def wald(mean, scale, size=None): - return DEFAULT.wald(mean, scale, size) +def wald(mean, scale, size=None, key=None): + return DEFAULT.wald(mean, scale, size, key=key) @wraps(np.random.weibull) -def weibull(a, size=None): - return DEFAULT.weibull(a, size) +def weibull(a, size=None, key=None): + return DEFAULT.weibull(a, size, key=key) @wraps(jr.weibull_min) -def weibull_min(a, scale=None, size=None): - return DEFAULT.weibull_min(a, scale, size) +def weibull_min(a, scale=None, size=None, key=None): + return DEFAULT.weibull_min(a, scale, size, key=key) @wraps(np.random.zipf) -def zipf(a, size=None): - return DEFAULT.zipf(a, size) +def zipf(a, size=None, key=None): + return DEFAULT.zipf(a, size, key=key) @wraps(jr.maxwell) -def maxwell(size=None): - return DEFAULT.maxwell(size) +def maxwell(size=None, key=None): + return DEFAULT.maxwell(size, key=key) -def t(df, size=None): +def t(df, size=None, key=None): """Sample Student’s t random values. Parameters @@ -1327,10 +1379,10 @@ def t(df, size=None): out: array_like The sampled value. """ - return DEFAULT.t(df, size) + return DEFAULT.t(df, size, key=key) -def orthogonal(n: int, size=None): +def orthogonal(n: int, size=None, key=None): """Sample uniformly from the orthogonal group `O(n)`. Parameters @@ -1345,10 +1397,10 @@ def orthogonal(n: int, size=None): out: JaxArray The sampled results. """ - return DEFAULT.orthogonal(n, size) + return DEFAULT.orthogonal(n, size, key=key) -def loggamma(a, size=None): +def loggamma(a, size=None, key=None): """Sample log-gamma random values. Parameters @@ -1368,5 +1420,5 @@ def loggamma(a, size=None): @wraps(jr.categorical) -def categorical(logits, axis: int = -1, size=None): - return DEFAULT.categorical(logits, axis, size) +def categorical(logits, axis: int = -1, size=None, key=None): + return DEFAULT.categorical(logits, axis, size, key=key) diff --git a/brainpy/math/setting.py b/brainpy/math/setting.py index 7d72ecee3..a28c0f21a 100644 --- a/brainpy/math/setting.py +++ b/brainpy/math/setting.py @@ -3,9 +3,8 @@ import os import re -from jax import dtypes -import jax.config -import jax.numpy as jnp +from jax import dtypes, config, numpy as jnp +from jax.lib import xla_bridge __all__ = [ 'enable_x64', @@ -13,15 +12,20 @@ 'set_platform', 'set_host_device_count', - # data types + # device memory + 'clear_live_buffers', + 'disable_gpu_memory_preallocation', + 'enable_gpu_memory_preallocation', + + # default data types 'bool_', 'int_', 'float_', 'complex_', - 'get_dint', - 'get_dfloat', + 'ditype', + 'dftype', - # change default data types + # default numerical integration step 'set_dt', 'get_dt', ] @@ -35,12 +39,14 @@ complex_ = jnp.complex_ -def get_dint(): - return jnp.int64 if jax.config.read('jax_enable_x64') else jnp.int32 +def ditype(): + """Default int type.""" + return jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -def get_dfloat(): - return jnp.float64 if jax.config.read('jax_enable_x64') else jnp.float32 +def dftype(): + """Default float type.""" + return jnp.float64 if config.read('jax_enable_x64') else jnp.float32 # numerical precision @@ -79,11 +85,11 @@ def get_dt(): def enable_x64(mode=True): assert mode in [True, False] - jax.config.update("jax_enable_x64", mode) + config.update("jax_enable_x64", mode) def disable_x64(): - jax.config.update("jax_enable_x64", False) + config.update("jax_enable_x64", False) def set_platform(platform): @@ -92,7 +98,7 @@ def set_platform(platform): effect at the beginning of your program. """ assert platform in ['cpu', 'gpu', 'tpu'] - jax.config.update("jax_platform_name", platform) + config.update("jax_platform_name", platform) def set_host_device_count(n): @@ -117,3 +123,28 @@ def set_host_device_count(n): xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split() os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags) + + +def clear_live_buffers(): + """Clear all on-device buffers. + + .. warning:: + + This operation may cause errors when you use a deleted buffer. + Therefore, regenerate data always. + """ + for buf in xla_bridge.get_backend().live_buffers(): + buf.delete() + + +def disable_gpu_memory_preallocation(): + """Disable pre-allocating the GPU memory.""" + os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' + + +def enable_gpu_memory_preallocation(): + """Disable pre-allocating the GPU memory.""" + os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' + os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR') + diff --git a/brainpy/optimizers/optimizer.py b/brainpy/optimizers/optimizer.py index 6b15c282d..1e085e00e 100644 --- a/brainpy/optimizers/optimizer.py +++ b/brainpy/optimizers/optimizer.py @@ -51,6 +51,9 @@ def check_grads(self, grads): def __repr__(self): return f"{self.__class__.__name__}(lr={self.lr})" + def update(self, grads: dict): + raise NotImplementedError + class SGD(Optimizer): r"""Stochastic gradient descent optimizer. diff --git a/brainpy/running/__init__.py b/brainpy/running/__init__.py index b6d7d1e23..e6441aaea 100644 --- a/brainpy/running/__init__.py +++ b/brainpy/running/__init__.py @@ -5,5 +5,5 @@ This module provides APIs for brain simulations. """ -from .parallel import * +from .multiprocess import * from .runner import * diff --git a/brainpy/running/parallel.py b/brainpy/running/multiprocess.py similarity index 100% rename from brainpy/running/parallel.py rename to brainpy/running/multiprocess.py diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index f2e3d75f6..319061d18 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import gc import types from typing import Callable, Dict, Sequence, Union @@ -206,5 +206,12 @@ def _find_monitor_targets(self, _monitors): monitors[key] = (getattr(master, splits[-1]), index) return monitors - def build_monitors(self, return_without_idx, return_with_idx) -> Callable: + def build_monitors(self, return_without_idx, return_with_idx, shared_args) -> Callable: raise NotImplementedError + + def __del__(self): + for key in tuple(self.mon.keys()): + del self.mon[key] + for key in tuple(self.__dict__.keys()): + del self.__dict__[key] + gc.collect() diff --git a/brainpy/tools/checking.py b/brainpy/tools/checking.py index af75ba3a5..a29001c2e 100644 --- a/brainpy/tools/checking.py +++ b/brainpy/tools/checking.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Union, Sequence, Dict, Callable, Tuple, Type +from typing import Union, Sequence, Dict, Callable, Tuple, Type, Optional import jax.numpy as jnp import numpy as np @@ -16,12 +16,15 @@ 'check_shape_except_batch', 'check_shape', 'check_dict_data', + 'check_callable', 'check_initializer', 'check_connector', 'check_float', 'check_integer', 'check_string', 'check_sequence', + + 'serialize_kwargs', ] @@ -167,9 +170,23 @@ def check_dict_data(a_dict: Dict, f'while we got ({type(key)}, {type(value)})') +def check_callable(fun: Callable, + name: str = None, + allow_none: bool = False): + name = '' if name is None else name + if fun is None: + if allow_none: + return None + else: + raise ValueError(f'{name} must be a callable function, but we got None.') + if not callable(fun): + raise ValueError(f'{name} should be a callable function. While we got {type(fun)}') + return fun + + def check_initializer(initializer: Union[Callable, init.Initializer, Tensor], name: str = None, - allow_none=False): + allow_none: bool = False): """Check the initializer. """ import brainpy.math as bm @@ -181,11 +198,11 @@ def check_initializer(initializer: Union[Callable, init.Initializer, Tensor], else: raise ValueError(f'{name} must be an initializer, but we got None.') if isinstance(initializer, init.Initializer): - return + return initializer elif isinstance(initializer, (bm.ndarray, jnp.ndarray)): - return + return initializer elif callable(initializer): - return + return initializer else: raise ValueError(f'{name} should be an instance of brainpy.init.Initializer, ' f'tensor or callable function. While we got {type(initializer)}') @@ -233,8 +250,14 @@ def check_sequence(value: Sequence, f'but we got {type(elem_type)}: {v}') -def check_float(value: float, name=None, min_bound=None, max_bound=None, - allow_none=False, allow_int=True): +def check_float( + value: float, + name: str = None, + min_bound: float = None, + max_bound: float = None, + allow_none: bool = False, + allow_int: bool = True +) -> float: """Check float type. Parameters @@ -253,7 +276,7 @@ def check_float(value: float, name=None, min_bound=None, max_bound=None, if name is None: name = '' if value is None: if allow_none: - return + return None else: raise ValueError(f'{name} must be a float, but got None') if allow_int: @@ -270,6 +293,7 @@ def check_float(value: float, name=None, min_bound=None, max_bound=None, if value > max_bound: raise ValueError(f"{name} must be a float smaller than {max_bound}, " f"while we got {value}") + return value def check_integer(value: int, name=None, min_bound=None, max_bound=None, allow_none=False): @@ -321,3 +345,14 @@ def check_string(value: str, name: str = None, candidates: Sequence[str] = None, if value not in candidates: raise ValueError(f'{name} must be a str in {candidates}, ' f'but we got {value}') + + +def serialize_kwargs(shared_kwargs: Optional[Dict]): + """Serialize kwargs.""" + shared_kwargs = dict() if shared_kwargs is None else shared_kwargs + check_dict_data(shared_kwargs, + key_type=str, + val_type=(bool, float, int, complex, str), + name='shared_kwargs') + shared_kwargs = {key: shared_kwargs[key] for key in sorted(shared_kwargs.keys())} + return str(shared_kwargs) diff --git a/brainpy/tools/others/dicts.py b/brainpy/tools/others/dicts.py index 9f877e3b4..7cacf87da 100644 --- a/brainpy/tools/others/dicts.py +++ b/brainpy/tools/others/dicts.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -import copy +from jax.tree_util import register_pytree_node +from jax.util import safe_zip __all__ = [ 'DotDict', @@ -25,115 +26,12 @@ class DotDict(dict): """ def __init__(self, *args, **kwargs): - object.__setattr__(self, '__parent', kwargs.pop('__parent', None)) - object.__setattr__(self, '__key', kwargs.pop('__key', None)) - for arg in args: - if not arg: - continue - elif isinstance(arg, dict): - for key, val in arg.items(): - self[key] = self._hook(val) - elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): - self[arg[0]] = self._hook(arg[1]) - else: - for key, val in iter(arg): - self[key] = self._hook(val) + super().__init__(*args, **kwargs) + self.__dict__ = self - for key, val in kwargs.items(): - self[key] = self._hook(val) - def __setattr__(self, name, value): - if hasattr(self.__class__, name): - raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.") - else: - self[name] = value - - def __setitem__(self, name, value): - super(DotDict, self).__setitem__(name, value) - try: - p = object.__getattribute__(self, '__parent') - key = object.__getattribute__(self, '__key') - except AttributeError: - p = None - key = None - if p is not None: - p[key] = self - object.__delattr__(self, '__parent') - object.__delattr__(self, '__key') - - def __add__(self, other): - if not self.keys(): - return other - else: - self_type = type(self).__name__ - other_type = type(other).__name__ - msg = "Unsupported operand type(s) for +: '{}' and '{}'" - raise TypeError(msg.format(self_type, other_type)) - - @classmethod - def _hook(cls, item): - if isinstance(item, dict): - return cls(item) - elif isinstance(item, (list, tuple)): - return type(item)(cls._hook(elem) for elem in item) - return item - - def __getattr__(self, item): - return self.__getitem__(item) - - def __delattr__(self, name): - del self[name] - - def copy(self): - return copy.copy(self) - - def deepcopy(self): - return copy.deepcopy(self) - - def __deepcopy__(self, memo): - other = self.__class__() - memo[id(self)] = other - for key, value in self.items(): - other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) - return other - - def to_dict(self): - base = {} - for key, value in self.items(): - if isinstance(value, type(self)): - base[key] = value.to_dict() - elif isinstance(value, (list, tuple)): - base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item - for item in value) - else: - base[key] = value - return base - - def update(self, *args, **kwargs): - other = {} - if args: - if len(args) > 1: - raise TypeError() - other.update(args[0]) - other.update(kwargs) - for k, v in other.items(): - if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)): - self[k] = v - else: - self[k].update(v) - - def __getnewargs__(self): - return tuple(self.items()) - - def __getstate__(self): - return self - - def __setstate__(self, state): - self.update(state) - - def setdefault(self, key, default=None): - if key in self: - return self[key] - else: - self[key] = default - return default +register_pytree_node( + DotDict, + lambda x: (tuple(x.values()), tuple(x.keys())), + lambda keys, values: DotDict(safe_zip(keys, values)) +) diff --git a/brainpy/train/__init__.py b/brainpy/train/__init__.py index eb9029012..88bbf73ce 100644 --- a/brainpy/train/__init__.py +++ b/brainpy/train/__init__.py @@ -1,7 +1,29 @@ # -*- coding: utf-8 -*- + +""" +This module provides various running and training algorithms +for various neural networks. + +The supported training algorithms include + +- offline training methods, like ridge regression, linear regression, etc. +- online training methods, like recursive least squares (RLS, or Force Learning), + least mean squares (LMS), etc. +- back-propagation learning method +- and others + +The supported neural networks include + +- reservoir computing networks, +- artificial recurrent neural networks, +- spiking neural networks, +- and others. +""" + + from .base import * -from .layers import * -from .runners import * -from .algorithms import * +from .back_propagation import * +from .online_trainer import * +from .offline_trainer import * diff --git a/brainpy/train/algorithms/offline.py b/brainpy/train/algorithms/offline.py deleted file mode 100644 index 5ad296432..000000000 --- a/brainpy/train/algorithms/offline.py +++ /dev/null @@ -1,197 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy.math as bm -from brainpy.base import Base - -__all__ = [ - # base class for offline training algorithm - 'OfflineAlgorithm', - - # training methods - 'RidgeRegression', - 'LinearRegression', - - # general supports - 'get_supported_offline_methods', - 'register_offline_method', -] - -name2func = dict() - - -class OfflineAlgorithm(Base): - """Base class for offline training algorithm.""" - - def __init__(self, name=None): - super(OfflineAlgorithm, self).__init__(name=name) - - def __call__(self, targets, inputs, outputs): - """The training procedure. - - Parameters - ---------- - inputs: JaxArray, jax.numpy.ndarray, numpy.ndarray - The 3d input data with the shape of `(num_batch, num_time, num_input)`, - or, the 2d input data with the shape of `(num_time, num_input)`. - - targets: JaxArray, jax.numpy.ndarray, numpy.ndarray - The 3d target data with the shape of `(num_batch, num_time, num_output)`, - or the 2d target data with the shape of `(num_time, num_output)`. - - outputs: JaxArray, jax.numpy.ndarray, numpy.ndarray - The 3d output data with the shape of `(num_batch, num_time, num_output)`, - or the 2d output data with the shape of `(num_time, num_output)`. - - Returns - ------- - weight: JaxArray - The weights after fit. - """ - raise NotImplementedError('Must implement the __call__ function by the subclass itself.') - - def __repr__(self): - return self.__class__.__name__ - - def initialize(self, identifier, *args, **kwargs): - raise NotImplementedError('Must implement the initialize() ' - 'function by the subclass itself.') - - -class RidgeRegression(OfflineAlgorithm): - """Training algorithm of ridge regression. - - Parameters - ---------- - beta: float - The regularization coefficient. - """ - - def __init__(self, beta=1e-7, name=None): - super(RidgeRegression, self).__init__(name=name) - self.beta = beta - - def __call__(self, targets, inputs, outputs=None): - # checking - inputs = bm.asarray(inputs).reshape((-1, inputs.shape[2])) - targets = bm.asarray(targets).reshape((-1, targets.shape[2])) - # solving - temp = inputs.T @ inputs - if self.beta > 0.: - temp += self.beta * bm.eye(inputs.shape[-1]) - weights = bm.linalg.pinv(temp) @ (inputs.T @ targets) - return weights - - def __repr__(self): - return f'{self.__class__.__name__}(beta={self.beta})' - - def initialize(self, identifier, *args, **kwargs): - pass - - -name2func['ridge'] = RidgeRegression - - -class LinearRegression(OfflineAlgorithm): - """Training algorithm of least-square regression.""" - - def __init__(self, name=None): - super(LinearRegression, self).__init__(name=name) - - def __call__(self, targets, inputs, outputs=None): - inputs = bm.asarray(inputs).reshape((-1, inputs.shape[2])) - targets = bm.asarray(targets).reshape((-1, targets.shape[2])) - weights = bm.linalg.lstsq(inputs, targets) - return weights[0] - - def initialize(self, identifier, *args, **kwargs): - pass - - -name2func['linear'] = LinearRegression -name2func['lstsq'] = LinearRegression - - -class LassoRegression(OfflineAlgorithm): - """Lasso regression method for offline training. - - Parameters - ---------- - alpha: float - Constant that multiplies the L1 term. Defaults to 1.0. - `alpha = 0` is equivalent to an ordinary least square. - max_iter: int - The maximum number of iterations. - """ - - def __init__(self, alpha=1.0, max_iter=1000, name=None): - super(LassoRegression, self).__init__(name=name) - self.alpha = alpha - self.max_iter = max_iter - - def __call__(self, *args, **kwargs): - pass - - def initialize(self, identifier, *args, **kwargs): - pass - - -# name2func['lasso'] = LassoRegression - - -def elastic_net_regression(x, y, train_pars): - pass - - -# name2func['elastic_net'] = elastic_net_regression - - -def logistic_regression(x, y, train_pars): - pass - - -# name2func['logistic'] = logistic_regression - - -def polynomial_regression(x, y, train_pars): - pass - - -# name2func['polynomial'] = polynomial_regression - - -def stepwise_regression(x, y, train_pars): - pass - - -# name2func['stepwise'] = stepwise_regression - - -def get_supported_offline_methods(): - """Get all supported offline training methods.""" - return tuple(name2func.keys()) - - -def register_offline_method(name, method): - """Register a new offline learning method. - - Parameters - ---------- - name: str - The method name. - method: callable - The function method. - """ - if name in name2func: - raise ValueError(f'"{name}" has been registered in offline training methods.') - if not callable(method): - raise ValueError(f'"method" must be an instance of callable ' - f'function, but we got {type(method)}') - name2func[name] = method - - -def get(name): - """Get the training function according to the training method name.""" - if name not in name2func: - raise ValueError(f'All offline methods are: {get_supported_offline_methods()}.\n' - f'But we got {name}.') - return name2func[name] diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py new file mode 100644 index 000000000..497c3c1fc --- /dev/null +++ b/brainpy/train/back_propagation.py @@ -0,0 +1,599 @@ +# -*- coding: utf-8 -*- + +import time +from typing import Union, Dict, Callable, Sequence + +import numpy as np +from jax import numpy as jnp +from jax.tree_util import tree_map, tree_flatten + +import brainpy.losses as losses +import brainpy.math as bm +import brainpy.optimizers as optim +from brainpy.dyn.base import DynamicalSystem +from brainpy.errors import UnsupportedError +from brainpy.tools.checking import serialize_kwargs +from brainpy.tools.others import DotDict +from brainpy.types import Tensor, Output +from . import constants as c +from .base import DSTrainer + +__all__ = [ + 'BPTT', + 'BPFF', + 'OnlineBPTT', +] + + +def _is_jax_array(s): + return isinstance(s, bm.JaxArray) + + +class BPTrainer(DSTrainer): + """Trainer implementing back-propagation algorithm. + + Parameters + ---------- + target: DynamicalSystem, TrainingSystem + The target model to train. + loss_fun: str, callable + The loss function. If it is a string, it should be the + function chosen from ``brainpy.losses`` module. Otherwise, + a callable function which receives argument of `(predicts, targets)` + should be provided. + optimizer: optim.Optimizer + The optimizer used for training. + shuffle_data: bool + seed: int + numpy_mon_after_run: bool + """ + + def __init__( + self, + target: DynamicalSystem, + loss_fun: Union[str, Callable], # loss function + optimizer: optim.Optimizer = None, # optimizer + loss_has_aux: bool = False, + shuffle_data: bool = True, # shuffle data + seed: int = None, # random seed for data shuffling + numpy_mon_after_run: bool = False, + **kwargs, + ): + super(BPTrainer, self).__init__(target=target, + numpy_mon_after_run=numpy_mon_after_run, + **kwargs) + + self.shuffle_data = shuffle_data + self.rng = bm.random.RandomState(seed=seed) + + # jit settings + self.jit[c.PREDICT_PHASE] = self.jit.get(c.PREDICT_PHASE, True) + self.jit[c.LOSS_PHASE] = self.jit.get(c.LOSS_PHASE, True) + self.jit[c.FIT_PHASE] = self.jit.get(c.FIT_PHASE, True) + + # optimizer + if optimizer is None: + lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) + optimizer = optim.Adam(lr=lr) + self.optimizer: optim.Optimizer = optimizer + self.optimizer.register_vars(self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique()) + + # loss + self.loss_has_aux = loss_has_aux + if isinstance(loss_fun, str): + loss_fun = getattr(losses, loss_fun) + elif callable(loss_fun): + loss_fun = loss_fun + else: + raise UnsupportedError(f'Do not support {type(loss_fun)} to specify the loss function. ' + f'We only support str and callable function.') + self._loss_func = loss_fun + self._train_losses = None + self._train_loss_aux = None + self._test_losses = None + self._f_shuffle = None + + # functions + self._f_loss_compiled = dict() + self._f_train_compiled = dict() + self._f_grad_compiled = dict() + + def __repr__(self): + name = self.__class__.__name__ + prefix = ' ' * len(name) + return (f'{name}(target={self.target}, \n\t' + f'{prefix}jit={self.jit}, \n\t' + f'{prefix}loss={self._loss_func}, \n\t' + f'{prefix}optimizer={self.optimizer})') + + @property + def train_losses(self): + """Training loss.""" + return self._train_losses + + @property + def train_loss_aux(self): + return self._train_loss_aux + + def predict( + self, + inputs: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], + reset_state: bool = True, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Predict a series of input data with the given target model. + + This function use the JIT compilation to accelerate the model simulation. + Moreover, it can automatically monitor the node variables, states, inputs, + feedbacks and its output, if users want. + + Parameters + ---------- + inputs: Tensor, sequence, dict + The feedforward input data. It must be a 3-dimensional data + which has the shape of `(num_sample, num_time, num_feature)`. + shared_args: dict + Shared keyword arguments for the given target model. + reset_state: bool + Whether reset the model states. Default True. + eval_time: bool + Whether evaluate the running time or not. Default False. + """ + return super(BPTrainer, self).predict(inputs=inputs, + reset_state=reset_state, + shared_args=shared_args, + eval_time=eval_time) + + def fit( + self, + train_data: Union[Callable, Sequence], + # test_data: Union[Callable, Sequence] = None, + batch_size: int = None, + num_epoch: int = 100, + num_report: int = 100, + reset_state: bool = True, + shared_args: Dict = None, + ): + """ + Fit the target model according to the given training and testing data. + + Parameters + ---------- + train_data: callable, sequence of data + It can be a callable function, or a tuple/list representing `(X, Y)` data. + - Callable. This function should return a pair of `(X, Y)` data + - Sequence. It should be a pair of `(X, Y)` train set. + - ``X``: should be a tensor or a dict of tensors with the shape of + `(num_sample, num_time, num_feature)`, where `num_sample` is + the number of samples, `num_time` is the number of the time step, + and `num_feature` is the number of features. + - ``Y``: Target values. A tensor or a dict of tensors. + - If the shape of each tensor is `(num_sample, num_feature)`, + then we will only fit the model with the only last output. + - If the shape of each tensor is `(num_sample, num_time, num_feature)`, + then the fitting happens on the whole data series. + test_data: callable, sequence of data + Same as the ``train_data``. It can be a callable function, + or a tuple/list representing `(X, Y)` data. + batch_size: int + The batch size. Default 32. This setting is used when users provide + the ``train_data`` and ``test_data`` as a pair of `(X, Y)` data, rather + than a function. + num_epoch: int + The number of training epoch. Default 100. + num_report: int + The number of step to report the progress. Default 100 training steps. + reset_state: bool + Whether reset the initial states of the target model. + shared_args: dict + The shared keyword arguments for the target models. + + """ + true_progress_bar = self.progress_bar + self.progress_bar = False + + # training the model + all_train_losses = [] + all_train_loss_aux = None + # all_test_losses = [] + + train_i = 0 + t0 = time.time() + for _ in range(num_epoch): + # training set + train_data_ = self._get_batchable_data(train_data, batch_size, self.shuffle_data) + for x, y in train_data_: + if reset_state: + self.target.reset_state(self._get_batch_size(x)) + self.reset_state() + + # training + res = self.f_train(shared_args)(x, y) + + # loss + loss = res[0] + all_train_losses.append(loss) + if self.loss_has_aux: + if all_train_loss_aux is None: + all_train_loss_aux = {k: [] for k in res[1].keys()} + for k, v in res[1].items(): + all_train_loss_aux[k].append(v) + + # report + train_i += 1 + if train_i % num_report == 0: + t1 = time.time() + msg = (f'Train {train_i} steps, use {t1 - t0:.4f} s, ' + f'train loss {round(float(loss), 5)}') + if self.loss_has_aux: + if isinstance(res[1], dict): + msg += ', {}'.format(", ".join([f"{k} {v}" for k, v in res[1].items()])) + print(msg) + t0 = t1 + + # # testing set + # if test_data is not None: + # test_data_ = self._get_batchable_data(test_data, batch_size, False) + # for x, y in test_data_: + # if reset_state: + # self.target.reset_state(self._get_batch_size(x)) + # self.reset_state() + # loss = self.f_loss(shared_args)(x, y) + # all_test_losses.append(loss) + + # finally + self._train_losses = bm.asarray(all_train_losses) + self._train_loss_aux = {k: bm.asarray(v) for k, v in all_train_loss_aux.items()} + # self._test_losses = bm.asarray(all_test_losses) + self.progress_bar = true_progress_bar + + def _get_batchable_data(self, data, num_batch, shuffle=False): + if callable(data): + data = self._get_data_by_callable(data, num_batch) + elif isinstance(data, (tuple, list)): + if len(data) != 2: + raise ValueError(f"Must be (X, Y) pair, but got a sequence with " + f"length {len(data)}") + data = self._get_data_by_tensor(data, num_batch=num_batch, shuffle=shuffle) + else: + raise ValueError(f'Train data does not support {type(data)}. ') + return data + + def _get_batch_size(self, xs, batch_axis=0): + if isinstance(xs, (bm.JaxArray, jnp.ndarray)): + return xs.shape[batch_axis] + else: + num_batch_sizes = [leaf.shape[batch_axis] for leaf in tree_flatten(xs, is_leaf=_is_jax_array)[0]] + if len(set(num_batch_sizes)) != 1: + raise ValueError(f'Number of batch size is different across tensors in ' + f'the provided "xs". We got {set(num_batch_sizes)}.') + return num_batch_sizes[0] + + def _get_data_by_callable(self, dataset, num_batch): + raise NotImplementedError + + def _get_data_by_tensor(self, dataset, num_batch=None, shuffle=False): + raise NotImplementedError + + def f_train(self, shared_args=None) -> Callable: + raise NotImplementedError + + def f_loss(self, shared_args=None) -> Callable: + raise NotImplementedError + + +class BPTT(BPTrainer): + """ + The trainer implementing back propagation through time (BPTT) + algorithm for recurrent neural networks. + """ + + def f_loss(self, shared_args=None, jit=True) -> Callable: + """Get loss function.""" + if shared_args is None: shared_args = dict() + + shared_args2 = {k: v for k, v in shared_args.items()} + shared_args2['_local_jit_'] = jit + shared_args_str = serialize_kwargs(shared_args2) + if shared_args_str not in self._f_loss_compiled: + + def loss_fun(inputs, targets): + times, indices, inputs, _, _, _, _ = self._format_xs( + None, inputs, inputs_are_batching=True, move_axis=True) + inputs = (times, indices, inputs) + outputs, mon = self._predict(xs=inputs, shared_args=shared_args) + outputs = bm.moveaxis(outputs, 0, 1) + predicts = (outputs, mon) if len(mon) > 0 else outputs + return self._loss_func(predicts, targets) + + self._f_loss_compiled[shared_args_str] = loss_fun + if self.jit[c.LOSS_PHASE] and jit: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str], + dyn_vars=dyn_vars) + return self._f_loss_compiled[shared_args_str] + + def f_grad(self, shared_args=None) -> Callable: + """Get gradient function.""" + shared_args_str = serialize_kwargs(shared_args) + if shared_args_str not in self._f_grad_compiled: + _f_loss_internal = self.f_loss(shared_args, jit=False) + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + tran_vars = dyn_vars.subset(bm.TrainVar) + grad_f = bm.grad(_f_loss_internal, + dyn_vars=dyn_vars.unique(), + grad_vars=tran_vars.unique(), + return_value=True, + has_aux=self.loss_has_aux) + self._f_grad_compiled[shared_args_str] = grad_f + return self._f_grad_compiled[shared_args_str] + + def f_train(self, shared_args=None) -> Callable: + """Get training function.""" + if shared_args is None: shared_args = dict() + if not isinstance(shared_args, dict): + raise ValueError(f'Only supports dict for "shared_args". ' + f'But got {type(shared_args)}: {shared_args}') + + shared_args_str = serialize_kwargs(shared_args) + if shared_args_str not in self._f_train_compiled: + + def train_func(inputs, targets): + res = self.f_grad(shared_args)(inputs, targets) + self.optimizer.update(res[0]) + return res[1:] + + if self.jit[c.FIT_PHASE]: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + dyn_vars.update(self.optimizer.vars()) + self._f_train_compiled[shared_args_str] = bm.jit(train_func, dyn_vars=dyn_vars.unique()) + else: + self._f_train_compiled[shared_args_str] = train_func + return self._f_train_compiled[shared_args_str] + + def _get_data_by_callable(self, dataset: Callable, num_batch=None): + for xs, ys in dataset(): + yield xs, ys + + def _get_data_by_tensor(self, dataset, num_batch=None, shuffle=False): + if num_batch is None: + raise ValueError('Must provide "num_batch" when dataset is not a callable function.') + assert isinstance(dataset, (tuple, list)) and len(dataset) == 2 + xs, ys = dataset + num_sample = self._get_batch_size(xs) + if shuffle: + xs, ys = self._shuffle(xs, ys) + for data_idx in range(0, num_sample, num_batch): + if (data_idx + num_batch) > num_sample: + inputs = tree_map(lambda v: v[data_idx:], xs, is_leaf=_is_jax_array) + targets = tree_map(lambda v: v[data_idx:], ys, is_leaf=_is_jax_array) + else: + inputs = tree_map(lambda v: v[data_idx: data_idx + num_batch], xs, is_leaf=_is_jax_array) + targets = tree_map(lambda v: v[data_idx: data_idx + num_batch], ys, is_leaf=_is_jax_array) + yield inputs, targets + + def _shuffle(self, xs, ys): + key = self.rng.split_key() + + if self._f_shuffle is None: + def shuffle(xs, ys, key): + xs = tree_map(lambda x: self.rng.permutation(x, key=key), xs) + ys = tree_map(lambda y: self.rng.permutation(y, key=key), ys) + return xs, ys + + self._f_shuffle = bm.jit(shuffle) + return self._f_shuffle(xs, ys, key) + + +class BPFF(BPTT): + """ + The trainer implementing back propagation algorithm + for feedforward neural networks. + + """ + + def predict( + self, + inputs: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], + reset_state: bool = True, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Predict a series of input data with the given target model. + + This function use the JIT compilation to accelerate the model simulation. + Moreover, it can automatically monitor the node variables, states, inputs, + feedbacks and its output. + + Parameters + ---------- + inputs: Tensor, dict + The feedforward input data. It must be a 3-dimensional data + which has the shape of `(num_sample, num_time, num_feature)`. + reset_state: bool + Whether reset the model states. + shared_args: optional, dict + The shared arguments across different layers. + eval_time: bool + Evaluate the time used for running. + + Returns + ------- + output: Tensor, dict + The model output. + """ + # format input data + num_batch = self._get_batch_size(inputs) + # reset the model states + if reset_state: + self.target.reset_state(num_batch) + self.reset_state() + # init monitor + for key in self.mon.var_names: + self.mon[key] = [] # reshape the monitor items + # prediction + outputs, hists = self._predict(xs=inputs, shared_args=shared_args) + # post-running for monitors + for key in hists.keys(): + self.mon[key] = bm.asarray(hists[key]) + if self.numpy_mon_after_run: + self.mon.ts = np.asarray(self.mon.ts) + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) + return outputs + + def f_loss(self, shared_args=None, jit=True) -> Callable: + """Get loss function.""" + if shared_args is None: shared_args = dict() + + shared_args2 = {k: v for k, v in shared_args.items()} + shared_args2['_local_jit_'] = jit + shared_args_str = serialize_kwargs(shared_args2) + if shared_args_str not in self._f_loss_compiled: + + def loss_fun(inputs, targets): + outputs, mon = self.f_predict(shared_args)(inputs) + outs = (outputs, mon) if len(mon) > 0 else outputs + loss = self._loss_func(outs, targets) + return loss + + if self.jit[c.LOSS_PHASE] and jit: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str], + dyn_vars=dyn_vars) + else: + self._f_loss_compiled[shared_args_str] = loss_fun + return self._f_loss_compiled[shared_args_str] + + def f_predict(self, shared_args: Dict = None, jit: bool = True): + if shared_args is None: shared_args = DotDict() + if not isinstance(shared_args, dict): + raise ValueError(f'"shared_args" must be a dict, ' + f'but got {type(shared_args)}') + + shared_args2 = {k: v for k, v in shared_args.items()} + shared_args2['_local_jit_'] = jit + shared_args_str = serialize_kwargs(shared_args) + if shared_args_str not in self._f_predict_compiled: + + monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args) + + def run_func(xs): + outs = self.target(shared_args, xs) + hist = monitor_func(shared_args) + return outs, hist + + if self.jit[c.PREDICT_PHASE] and jit: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + self._f_predict_compiled[shared_args_str] = bm.jit(run_func, dyn_vars=dyn_vars.unique()) + else: + self._f_predict_compiled[shared_args_str] = run_func + return self._f_predict_compiled[shared_args_str] + + +class OnlineBPTT(BPTT): + + def f_loss(self, shared_args=None, jit=True) -> Callable: + """Get loss function.""" + if shared_args is None: shared_args = dict() + + shared_args2 = {k: v for k, v in shared_args.items()} + shared_args2['_local_jit_'] = jit + shared_args_str = serialize_kwargs(shared_args2) + if shared_args_str not in self._f_loss_compiled: + + def loss_fun(t, i, input_, target_): + outputs, mon = self.f_predict_one_step(shared_args)(t, i, input_) + predicts = (outputs, mon) if len(mon) > 0 else outputs + return self._loss_func(predicts, target_) + + if self.jit[c.LOSS_PHASE] and jit: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str], + dyn_vars=dyn_vars) + else: + self._f_loss_compiled[shared_args_str] = loss_fun + return self._f_loss_compiled[shared_args_str] + + def f_train(self, shared_args=None) -> Callable: + """Get training function.""" + if shared_args is None: shared_args = dict() + if not isinstance(shared_args, dict): + raise ValueError(f'Only supports dict for "shared_args". ' + f'But got {type(shared_args)}: {shared_args}') + shared_args_str = serialize_kwargs(shared_args) + if shared_args_str not in self._f_train_compiled: + + def train_step(x): + # t, i, input_, target_ = x + res = self.f_grad(shared_args)(*x) + self.optimizer.update(res[0]) + return res[1:] + + if self.jit[c.FIT_PHASE]: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + f = bm.make_loop(train_step, dyn_vars=dyn_vars.unique(), has_return=True) + run_func = lambda all_inputs: f(all_inputs)[1] + + else: + def run_func(xs): + times, indices, inputs, targets = xs + losses = [] + for i in range(times.shape[0]): + # data at time i + x = tree_map(lambda x: x[i], inputs, is_leaf=_is_jax_array) + y = tree_map(lambda x: x[i], targets, is_leaf=_is_jax_array) + # step at the i + loss = train_step((times[i], indices[i], x, y)) + # append output and monitor + losses.append(loss) + return bm.asarray(losses) + + def train_fun(inputs, targets): + times, indices, inputs, num_step, _, duration, _ = self._format_xs( + None, inputs, inputs_are_batching=True, move_axis=True) + targets = tree_map(lambda x: bm.moveaxis(x, 0, 1), targets, is_leaf=_is_jax_array) + ls = run_func([times, indices, inputs, targets]) + self.i0 += num_step + self.t0 += duration + return ls + + self._f_train_compiled[shared_args_str] = train_fun + return self._f_train_compiled[shared_args_str] + + def f_predict_one_step(self, shared_args: Dict = None, jit: bool = False): + if shared_args is None: shared_args = DotDict() + if not isinstance(shared_args, dict): + raise ValueError(f'"shared_args" must be a dict, ' + f'but got {type(shared_args)}') + + shared_args2 = {k: v for k, v in shared_args.items()} + shared_args2['_local_jit_'] = jit + shared_args2['_one_step_'] = True + shared_args_str = serialize_kwargs(shared_args) + if shared_args_str not in self._f_predict_compiled: + + monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args) + + def run_func(t, i, x): + shared = DotDict(t=t, i=i, dt=self.dt) + shared.update(shared_args) + outs = self.target(shared, x) + hist = monitor_func(shared) + return outs, hist + + if self.jit[c.FIT_PHASE] and jit: + dyn_vars = self.target.vars() + dyn_vars.update(self.dyn_vars) + self._f_predict_compiled[shared_args_str] = bm.jit(run_func, dyn_vars=dyn_vars.unique()) + else: + self._f_predict_compiled[shared_args_str] = run_func + return self._f_predict_compiled[shared_args_str] diff --git a/brainpy/train/base.py b/brainpy/train/base.py index 70eddff79..b364ff088 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -1,201 +1,132 @@ # -*- coding: utf-8 -*- -import inspect -from typing import Union, Callable, Optional, Dict +from typing import Dict, Sequence, Any, Union, Tuple +import jax.numpy as jnp + +import brainpy.math as bm +from brainpy.dyn.runners import DSRunner from brainpy.dyn.base import DynamicalSystem -from brainpy.train.algorithms import OfflineAlgorithm, OnlineAlgorithm -from brainpy.types import Tensor +from brainpy.tools.checking import check_dict_data +from brainpy.dyn.training import TrainingSystem + +from brainpy.types import Tensor, Output +from . import constants as c __all__ = [ - 'TrainingSystem', 'Sequential', + 'DSTrainer', 'DSRunner', ] -def not_implemented(fun: Callable) -> Callable: - """Marks the given module method is not implemented. - - Methods wrapped in @not_implemented can define submodules directly within the method. - - For instance:: - - @not_implemented - init_fb(self): - ... - - @not_implemented - def feedback(self): - ... - """ - fun.not_implemented = True - return fun - - -class TrainingSystem(DynamicalSystem): - """Base class for training system in BrainPy. - - """ - - '''Online fitting method.''' - online_fit_by: Optional[OnlineAlgorithm] - - '''Offline fitting method.''' - offline_fit_by: Optional[OfflineAlgorithm] - - def __init__(self, name: str = None, trainable: bool = False): - super(TrainingSystem, self).__init__(name=name) - self._trainable = trainable - self.online_fit_by = None - self.offline_fit_by = None - self.fit_record = dict() - - @property - def trainable(self): - return self._trainable - - @trainable.setter - def trainable(self, value): - self._trainable = value - - def __repr__(self): - return f"{type(self).__name__}(name={self.name}, trainable={self.trainable})" - - def __call__(self, *args, **kwargs) -> Tensor: - """The main computation function of a Node. +class DSTrainer(DSRunner): + """Structural Trainer for Dynamical Systems.""" + + target: Union[DynamicalSystem, TrainingSystem] + train_nodes: Sequence[DynamicalSystem] # need to be initialized by subclass + + def __init__( + self, + target: Union[DynamicalSystem, TrainingSystem], + **kwargs + ): + if not isinstance(target, (DynamicalSystem, TrainingSystem)): + raise TypeError(f'"target" must be an instance of ' + f'{DynamicalSystem.__name__} or {TrainingSystem.__name__}, ' + f'but we got {type(target)}: {target}') + super(DSTrainer, self).__init__(target=target, **kwargs) + + # jit + self.jit[c.PREDICT_PHASE] = self.jit.get(c.PREDICT_PHASE, True) + self.jit[c.FIT_PHASE] = self.jit.get(c.FIT_PHASE, True) + + def predict( + self, + inputs: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], + reset_state: bool = False, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Prediction function. + + What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that + the `inputs_are_batching` is default `True`. + + Parameters + ---------- + inputs: Tensor, sequence of Tensor, dict of Tensor + The input values. + reset_state: bool + Reset the target state before running. + shared_args: dict + The shared arguments across nodes. + eval_time: bool + Whether we evaluate the running time or not? Returns ------- - Tensor - A output tensor value, or a dict of output tensors. + output: Tensor, sequence of Tensor, dict of Tensor + The running output. """ - return self.forward(*args, **kwargs) - - @not_implemented - def update(self, t, dt, x, shared_args=None) -> Tensor: - return self.forward(x, shared_args) - - def forward(self, x, shared_args=None) -> Tensor: - raise NotImplementedError('Subclass should implement "forward()" function ' - 'when "update()" function is not customized.') - - def reset(self, batch_size=1): - for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): - node.reset(batch_size=batch_size) - - def reset_state(self, batch_size=1): - for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): - node.reset_state(batch_size=batch_size) - - @not_implemented - def online_init(self): - raise NotImplementedError('Subclass must implement online_init() function when using ' - 'OnlineTrainer.') - - @not_implemented - def offline_init(self): - raise NotImplementedError('Subclass must implement offline_init() function when using ' - 'OfflineTrainer.') - - @not_implemented - def online_fit(self, - target: Tensor, - fit_record: Dict[str, Tensor], - shared_args: Dict = None): - raise NotImplementedError('Subclass must implement online_fit() function when using ' - 'OnlineTrainer.') - - @not_implemented - def offline_fit(self, - target: Tensor, - fit_record: Dict[str, Tensor], - shared_args: Dict = None): - raise NotImplementedError('Subclass must implement offline_fit() function when using ' - 'OfflineTrainer.') - - -class Sequential(TrainingSystem): - def __init__(self, *modules, name: str = None, **kw_modules): - super(Sequential, self).__init__(name=name, trainable=False) - - # add sub-components - for module in modules: - if isinstance(module, TrainingSystem): - self.implicit_nodes[module.name] = module - elif isinstance(module, (list, tuple)): - for m in module: - if not isinstance(m, TrainingSystem): - raise ValueError(f'Should be instance of {TrainingSystem.__name__}. ' - f'But we got {type(m)}') - self.implicit_nodes[m.name] = module - elif isinstance(module, dict): - for k, v in module.items(): - if not isinstance(v, TrainingSystem): - raise ValueError(f'Should be instance of {TrainingSystem.__name__}. ' - f'But we got {type(v)}') - self.implicit_nodes[k] = v + return super(DSTrainer, self).predict(duration=None, + inputs=inputs, + inputs_are_batching=True, + reset_state=reset_state, + shared_args=shared_args, + eval_time=eval_time) + + def fit( + self, + train_data: Any, + reset_state: bool = False, + shared_args: Dict = None + ) -> Output: # need to be implemented by subclass + raise NotImplementedError('Must implement the fit function. ') + + def _get_trainable_nodes(self) -> Tuple[TrainingSystem, ...]: + # check trainable nodes + nodes = self.target.nodes(level=-1, include_self=True).subset(TrainingSystem).unique() + return tuple([node for node in nodes.values() if node.trainable]) + + def _check_ys(self, ys, num_batch, num_step, move_axis=False): + if isinstance(ys, (bm.ndarray, jnp.ndarray)): + if len(self.train_nodes) == 1: + ys = {self.train_nodes[0].name: ys} else: - raise ValueError(f'Cannot parse sub-systems. They should be {TrainingSystem.__name__} ' - f'or a list/tuple/dict of {TrainingSystem.__name__}.') - for k, v in kw_modules.items(): - if not isinstance(v, TrainingSystem): - raise ValueError(f'Should be instance of {TrainingSystem.__name__}. ' - f'But we got {type(v)}') - self.implicit_nodes[k] = v - - def __getattr__(self, item): - """Wrap the dot access ('self.'). """ - child_ds = super(Sequential, self).__getattribute__('implicit_nodes') - if item in child_ds: - return child_ds[item] - else: - return super(Sequential, self).__getattribute__(item) - - def __getitem__(self, key: Union[int, slice]): - if isinstance(key, str): - if key not in self.implicit_nodes: - raise KeyError(f'Does not find a component named {key} in\n {str(self)}') - return self.implicit_nodes[key] - elif isinstance(key, slice): - keys = tuple(self.implicit_nodes.keys())[key] - components = tuple(self.implicit_nodes.values())[key] - return Sequential(dict(zip(keys, components))) - elif isinstance(key, int): - return self.implicit_nodes.values()[key] - elif isinstance(key, (tuple, list)): - for i in key: - if isinstance(i, int): - raise KeyError(f'We excepted a tuple/list of int, but we got {type(i)}') - keys = tuple(self.implicit_nodes.keys())[key] - components = tuple(self.implicit_nodes.values())[key] - return Sequential(dict(zip(keys, components))) - else: - raise KeyError(f'Unknown type of key: {type(key)}') - - def __repr__(self): - def f(x): - if not isinstance(x, TrainingSystem) and callable(x): - signature = inspect.signature(x) - args = [f'{k}={v.default}' for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty] - args = ', '.join(args) - while not hasattr(x, '__name__'): - if not hasattr(x, 'func'): - break - x = x.func # Handle functools.partial - if not hasattr(x, '__name__') and hasattr(x, '__class__'): - return x.__class__.__name__ - if args: - return f'{x.__name__}(*, {args})' - return x.__name__ + raise ValueError(f'The network\n {self.target} \nhas {len(self.train_nodes)} ' + f'training nodes, while we only got one target data.') + check_dict_data(ys, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) + + # check data path + abs_node_names = [node.name for node in self.train_nodes] + formatted_ys = {} + ys_not_included = {} + for k, v in ys.items(): + if k in abs_node_names: + formatted_ys[k] = v else: - x = repr(x).split('\n') - x = [x[0]] + [' ' + y for y in x[1:]] - return '\n'.join(x) - - entries = '\n'.join(f' [{i}] {f(x)}' for i, x in enumerate(self)) - return f'{self.__class__.__name__}(\n{entries}\n)' - - def forward(self, x, shared_args=None) -> Tensor: - for node in self.implicit_nodes.values(): - x = node.forward(x, shared_args=shared_args) - return x + ys_not_included[k] = v + if len(ys_not_included): + rel_nodes = self.target.nodes('relative', level=-1, include_self=True).subset(DynamicalSystem).unique() + for k, v in ys_not_included.items(): + if k in rel_nodes: + formatted_ys[rel_nodes[k].name] = v + else: + raise ValueError(f'Unknown target "{k}" for fitting.') + + # check data shape + for key, val in formatted_ys.items(): + if val.ndim < 3: + raise ValueError("Targets must be a tensor with shape of " + "(num_sample, num_time, feature_dim, ...), " + f"but we got {val.shape}") + if val.shape[0] != num_batch: + raise ValueError(f'Batch size of the target {key} does not match ' + f'with the input data {val.shape[0]} != {num_batch}') + if val.shape[1] != num_step: + raise ValueError(f'The time step of the target {key} does not match ' + f'with the input data {val.shape[1]} != {num_step})') + + if move_axis: + # change shape to (num_time, num_sample, num_feature) + formatted_ys = {k: bm.moveaxis(v, 0, 1) for k, v in formatted_ys.items()} + return formatted_ys diff --git a/brainpy/train/constants.py b/brainpy/train/constants.py new file mode 100644 index 000000000..6c26c36ad --- /dev/null +++ b/brainpy/train/constants.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +TRAIN_PHASE = 'fit' +FIT_PHASE = 'fit' +PREDICT_PHASE = 'predict' +RUN_PHASE = 'predict' +LOSS_PHASE = 'loss' + + diff --git a/brainpy/train/runners/offline_trainer.py b/brainpy/train/offline_trainer.py similarity index 72% rename from brainpy/train/runners/offline_trainer.py rename to brainpy/train/offline_trainer.py index d46e67941..97ea8e1ad 100644 --- a/brainpy/train/runners/offline_trainer.py +++ b/brainpy/train/offline_trainer.py @@ -7,13 +7,13 @@ from jax.experimental.host_callback import id_tap import brainpy.math as bm +from brainpy.algorithms.offline import get, RidgeRegression, OfflineAlgorithm from brainpy.base import Base +from brainpy.dyn.base import DynamicalSystem from brainpy.errors import NoImplementationError -from brainpy.train.algorithms.offline import get, RidgeRegression, OfflineAlgorithm -from brainpy.train.base import TrainingSystem -from brainpy.train.utils import serialize_kwargs -from brainpy.types import Tensor -from .base_runner import DSTrainer +from brainpy.tools.checking import serialize_kwargs +from brainpy.types import Tensor, Output +from .base import DSTrainer __all__ = [ 'OfflineTrainer', @@ -26,7 +26,7 @@ class OfflineTrainer(DSTrainer): Parameters ---------- - target: Node + target: DynamicalSystem The target model to train. fit_method: OfflineAlgorithm, Callable, dict, str The fitting method applied to the target model. @@ -47,7 +47,7 @@ class OfflineTrainer(DSTrainer): def __init__( self, - target: TrainingSystem, + target: DynamicalSystem, fit_method: Union[OfflineAlgorithm, Callable, Dict, str] = None, **kwargs ): @@ -94,12 +94,48 @@ def __repr__(self): return (f'{name}(target={self.target}, \n\t' f'{prefix}fit_method={self.fit_method})') + def predict( + self, + inputs: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], + reset_state: bool = False, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Prediction function. + + What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that + the `inputs_are_batching` is default `True`. + + Parameters + ---------- + inputs: Tensor, sequence of Tensor, dict of Tensor + The input values. + reset_state: bool + Reset the target state before running. + shared_args: dict + The shared arguments across nodes. + eval_time: bool + Whether we evaluate the running time or not? + + Returns + ------- + output: Tensor, sequence of Tensor, dict of Tensor + The running output. + """ + outs = super(OfflineTrainer, self).predict(inputs=inputs, + reset_state=reset_state, + shared_args=shared_args, + eval_time=eval_time) + for node in self.train_nodes: + node.fit_record.clear() + return outs + def fit( self, train_data: Sequence, reset_state: bool = False, - shared_kwargs: Dict = None, - ): + shared_args: Dict = None, + ) -> Output: """Fit the target model according to the given training and testing data. Parameters @@ -117,9 +153,12 @@ def fit( then the fitting happens on the whole data series. reset_state: bool Whether reset the initial states of the target model. - shared_kwargs: dict + shared_args: dict The shared keyword arguments for the target models. """ + if shared_args is None: shared_args = dict() + shared_args['fit'] = shared_args.get('fit', True) + # checking training and testing data if not isinstance(train_data, (list, tuple)): raise ValueError(f"{self.__class__.__name__} only support " @@ -132,7 +171,7 @@ def fit( xs, ys = train_data # prediction, get all needed data - _ = self.predict(xs=xs, reset_state=reset_state) + outs = self.predict(inputs=xs, reset_state=reset_state, shared_args=shared_args) # get all input data xs, num_step, num_batch = self._check_xs(xs, move_axis=False) @@ -150,7 +189,7 @@ def fit( for node in self.train_nodes: key = f'{node.name}-fit_record' monitor_data[key] = self.mon.get(key) - self.f_train(shared_kwargs)(monitor_data, ys) + self.f_train(shared_args)(monitor_data, ys) del monitor_data # close the progress bar @@ -160,26 +199,29 @@ def fit( # final things for node in self.train_nodes: self.mon.pop(f'{node.name}-fit_record') + node.fit_record.clear() # clear fit records if self.true_numpy_mon_after_run: for key in self.mon.keys(): if key != 'var_names': self.mon[key] = np.asarray(self.mon[key]) - def f_train(self, shared_kwargs: Dict = None) -> Callable: + return outs + + def f_train(self, shared_args: Dict = None) -> Callable: """Get training function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) + shared_kwargs_str = serialize_kwargs(shared_args) if shared_kwargs_str not in self._f_train: - self._f_train[shared_kwargs_str] = self._make_fit_func(shared_kwargs) + self._f_train[shared_kwargs_str] = self._make_fit_func(shared_args) return self._f_train[shared_kwargs_str] - def _make_fit_func(self, shared_kwargs): - shared_kwargs = dict() if shared_kwargs is None else shared_kwargs + def _make_fit_func(self, shared_args): + shared_args = dict() if shared_args is None else shared_args def train_func(monitor_data: Dict[str, Tensor], target_data: Dict[str, Tensor]): for node in self.train_nodes: fit_record = monitor_data[f'{node.name}-fit_record'] targets = target_data[node.name] - node.offline_fit(targets, fit_record, shared_kwargs) + node.offline_fit(targets, fit_record) if self.progress_bar: id_tap(lambda *args: self._pbar.update(), ()) @@ -189,30 +231,35 @@ def train_func(monitor_data: Dict[str, Tensor], target_data: Dict[str, Tensor]): train_func = bm.jit(train_func, dyn_vars=dyn_vars.unique()) return train_func - def build_monitors(self, return_without_idx, return_with_idx, flatten=False): - def func(_t, _dt): + def build_monitors(self, return_without_idx, return_with_idx, shared_args: dict): + if shared_args.get('fit', False): + def func(tdi): + res = {k: v.value for k, v in return_without_idx.items()} + res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) + res.update({k: f(tdi) for k, f in self.fun_monitors.items()}) + res.update({f'{node.name}-fit_record': node.fit_record for node in self.train_nodes}) + return res + else: + def func(tdi): res = {k: v.value for k, v in return_without_idx.items()} res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) - res.update({k: f(_t, _dt) for k, f in self.fun_monitors.items()}) - res.update({f'{node.name}-fit_record': {k: node.fit_record.pop(k) - for k in node.fit_record.keys()} - for node in self.train_nodes}) + res.update({k: f(tdi) for k, f in self.fun_monitors.items()}) return res return func def _check_interface(self): for node in self.train_nodes: - if hasattr(node.offline_fit, 'not_implemented'): - if node.offline_fit.not_implemented: + if hasattr(node.offline_fit, 'not_customized'): + if node.offline_fit.not_customized: raise NoImplementationError( f'The node \n\n{node}\n\n' f'is set to be trainable with {self.__class__.__name__} method. ' f'However, it does not implement the required training ' f'interface "offline_fit()" function. ' ) - if hasattr(node.offline_init, 'not_implemented'): - if node.offline_init.not_implemented: + if hasattr(node.offline_init, 'not_customized'): + if node.offline_init.not_customized: raise NoImplementationError( f'The node \n\n{node}\n\n' f'is set to be trainable with {self.__class__.__name__} method. ' @@ -227,7 +274,7 @@ class RidgeTrainer(OfflineTrainer): Parameters ---------- - target: TrainingSystem + target: TrainingSystem, DynamicalSystem The target model. beta: float The regularization coefficient. @@ -235,7 +282,7 @@ class RidgeTrainer(OfflineTrainer): Other common parameters for :py:class:`brainpy.nn.RNNTrainer``. """ - def __init__(self, target, beta=1e-7, **kwargs): + def __init__(self, target, alpha=1e-7, **kwargs): super(RidgeTrainer, self).__init__(target=target, - fit_method=dict(name='ridge', beta=beta), + fit_method=dict(name='ridge', alpha=alpha), **kwargs) diff --git a/brainpy/train/runners/online_trainer.py b/brainpy/train/online_trainer.py similarity index 63% rename from brainpy/train/runners/online_trainer.py rename to brainpy/train/online_trainer.py index 9f044c195..730fed1be 100644 --- a/brainpy/train/runners/online_trainer.py +++ b/brainpy/train/online_trainer.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Dict, Sequence, Union, Callable +from typing import Dict, Sequence, Union, Callable, Tuple import numpy as np import tqdm.auto @@ -8,13 +8,15 @@ from jax.tree_util import tree_map import brainpy.math as bm +from brainpy.algorithms.online import get, OnlineAlgorithm, RLS from brainpy.base import Base +from brainpy.dyn.base import DynamicalSystem +from brainpy.dyn.training import TrainingSystem from brainpy.errors import NoImplementationError -from brainpy.train.algorithms.online import get, OnlineAlgorithm, RLS -from brainpy.train.base import TrainingSystem -from brainpy.train.utils import (serialize_kwargs, check_data_batch_size) -from brainpy.types import Tensor -from .base_runner import DSTrainer +from brainpy.tools.checking import serialize_kwargs +from brainpy.tools.others.dicts import DotDict +from brainpy.types import Tensor, Output +from .base import DSTrainer __all__ = [ 'OnlineTrainer', @@ -27,7 +29,7 @@ class OnlineTrainer(DSTrainer): Parameters ---------- - target: Node + target: DynamicalSystem, TrainingSystem The target model to train. fit_method: OnlineAlgorithm, Callable, dict, str The fitting method applied to the target model. @@ -47,7 +49,7 @@ class OnlineTrainer(DSTrainer): def __init__( self, - target: TrainingSystem, + target: DynamicalSystem, fit_method: Union[OnlineAlgorithm, Callable, Dict, str] = None, **kwargs ): @@ -78,6 +80,7 @@ def __init__( # initialize the fitting method for node in self.train_nodes: + assert isinstance(node, TrainingSystem) node.online_init() # update dynamical variables @@ -94,12 +97,51 @@ def __repr__(self): f'{prefix}jit={self.jit}, \n\t' f'{prefix}fit_method={self.fit_method})') + def predict( + self, + inputs: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], + reset_state: bool = False, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Prediction function. + + What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that + the `inputs_are_batching` is default `True`. + + Parameters + ---------- + inputs: Tensor, sequence of Tensor, dict of Tensor + The input values. + reset_state: bool + Reset the target state before running. + shared_args: dict + The shared arguments across nodes. + eval_time: bool + Whether we evaluate the running time or not? + + Returns + ------- + output: Tensor, sequence of Tensor, dict of Tensor + The running output. + """ + outs = super(OnlineTrainer, self).predict(inputs=inputs, + reset_state=reset_state, + shared_args=shared_args, + eval_time=eval_time) + for node in self.train_nodes: + node.fit_record.clear() + return outs + def fit( self, train_data: Sequence, reset_state: bool = False, shared_args: Dict = None, - ): + ) -> Output: + if shared_args is None: shared_args = dict() + shared_args['fit'] = shared_args.get('fit', True) + # checking training and testing data if not isinstance(train_data, (list, tuple)): raise ValueError(f"{self.__class__.__name__} only support " @@ -112,7 +154,8 @@ def fit( xs, ys = train_data # format input data - xs, num_step, num_batch = self._check_xs(xs, move_axis=True) + times, indices, xs, num_step, num_batch, duration, _ = self._format_xs( + None, inputs=xs, inputs_are_batching=True) # format target data ys = self._check_ys(ys, num_batch=num_batch, num_step=num_step, move_axis=True) @@ -120,6 +163,7 @@ def fit( # reset the model states if reset_state: self.target.reset_state(num_batch) + self.reset_state() # init monitor for key in self.mon.var_names: @@ -131,33 +175,35 @@ def fit( self._pbar.set_description(f"Train {num_step} steps: ", refresh=True) # prediction - hists = self._fit(xs=xs, ys=ys, shared_args=shared_args) + outs, hists = self._fit(xs=(times, indices, xs), ys=ys, shared_args=shared_args) # close the progress bar if self.progress_bar: self._pbar.close() # post-running for monitors + hists['ts'] = times + self.dt + if self.numpy_mon_after_run: + hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.JaxArray)) for key in hists.keys(): self.mon[key] = hists[key] - if self.numpy_mon_after_run: - self.mon.ts = np.asarray(self.mon.ts) - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) + self.i0 += times.shape[0] + self.t0 += duration + return outs def _fit( self, - xs: Dict[str, Tensor], - ys: Dict[str, Tensor], + xs: Tuple, + ys: Union[Tensor, Sequence[Tensor], Dict[str, Tensor]], shared_args: Dict = None, ): """Predict the output according to the inputs. Parameters ---------- - xs: dict + xs: tuple Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - ys: dict + ys: Tensor, sequence of Tensor, dict of Tensor Each tensor should have the shape of `(num_time, num_batch, num_feature)`. shared_args: optional, dict The shared keyword arguments. @@ -167,40 +213,45 @@ def _fit( outputs, hists A tuple of pair of (outputs, hists). """ - _predict_func = self._get_fit_func(shared_args) - hists = _predict_func([xs, ys]) + _fit_func = self._get_fit_func(shared_args) + hists = _fit_func(xs + (ys, )) hists = tree_map(lambda x: bm.moveaxis(x, 0, 1), hists, is_leaf=lambda x: isinstance(x, bm.JaxArray)) return hists - def _get_fit_func(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - shared_kwargs_str = serialize_kwargs(shared_kwargs) + def _get_fit_func(self, shared_args: Dict = None): + if shared_args is None: shared_args = dict() + shared_kwargs_str = serialize_kwargs(shared_args) if shared_kwargs_str not in self._f_train: - self._f_train[shared_kwargs_str] = self._make_fit_func(shared_kwargs) + self._f_train[shared_kwargs_str] = self._make_fit_func(shared_args) return self._f_train[shared_kwargs_str] def _make_fit_func(self, shared_args: Dict): if not isinstance(shared_args, dict): raise ValueError(f'"shared_kwargs" must be a dict, but got {type(shared_args)}') + monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args) + def _step_func(all_inputs): - xs, ys = all_inputs - t = 0. - self._input_step(t, self.dt) - if xs is None: - args = (t, self.dt) - else: - args = (t, self.dt, xs) - kwargs = dict() - if len(shared_args): - kwargs['shared_args'] = shared_args - out = self.target.update(*args, **kwargs) - monitors = self._monitor_step(t, self.dt) + t, i, x, ys = all_inputs + shared = DotDict(t=t, dt=self.dt, i=i) + + # input step + self._input_step(shared) + + # update step + shared.update(shared_args) + args = (shared, ) if x is None else (shared, x) + out = self.target(*args) + + # monitor step + monitors = monitor_func(shared) for node in self.train_nodes: fit_record = monitors.pop(f'{node.name}-fit_record') target = ys[node.name] - node.online_fit(target, fit_record, shared_args) + node.online_fit(target, fit_record) + + # finally if self.progress_bar: id_tap(lambda *arg: self._pbar.update(), ()) return out, monitors @@ -213,15 +264,13 @@ def _step_func(all_inputs): else: def run_func(all_inputs): - xs, ys = all_inputs + times, indices, xs, ys = all_inputs outputs = [] - monitors = {key: [] for key in - set(self.mon.item_contents.keys()) | - set(self.fun_monitors.keys())} - for i in range(check_data_batch_size(xs)): + monitors = {key: [] for key in (set(self.mon.var_names) | set(self.fun_monitors.keys()))} + for i in range(times.shape[0]): x = tree_map(lambda x: x[i], xs) y = tree_map(lambda x: x[i], ys) - output, mon = _step_func((x, y)) + output, mon = _step_func((times[i], indices[i], x, y)) outputs.append(output) for key, value in mon.items(): monitors[key].append(value) @@ -236,16 +285,16 @@ def run_func(all_inputs): def _check_interface(self): for node in self.train_nodes: - if hasattr(node.online_fit, 'not_implemented'): - if node.online_fit.not_implemented: + if hasattr(node.online_fit, 'not_customized'): + if node.online_fit.not_customized: raise NoImplementationError( f'The node \n\n{node}\n\n' f'is set to be trainable with {self.__class__.__name__} method. ' f'However, it does not implement the required training ' f'interface "online_fit()" function. ' ) - if hasattr(node.online_init, 'not_implemented'): - if node.online_init.not_implemented: + if hasattr(node.online_init, 'not_customized'): + if node.online_init.not_customized: raise NoImplementationError( f'The node \n\n{node}\n\n' f'is set to be trainable with {self.__class__.__name__} method. ' @@ -253,15 +302,20 @@ def _check_interface(self): f'interface "online_init()" function. ' ) - def build_monitors(self, return_without_idx, return_with_idx, flatten=False): - def func(t, dt): - res = {k: v.value for k, v in return_without_idx.items()} - res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) - res.update({k: f(t, dt) for k, f in self.fun_monitors.items()}) - res.update({f'{node.name}-fit_record': {k: node.fit_record.pop(k) - for k in node.fit_record.keys()} - for node in self.train_nodes}) - return res + def build_monitors(self, return_without_idx, return_with_idx, shared_args: dict): + if shared_args.get('fit', False): + def func(tdi): + res = {k: v.value for k, v in return_without_idx.items()} + res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) + res.update({k: f(tdi) for k, f in self.fun_monitors.items()}) + res.update({f'{node.name}-fit_record': node.fit_record for node in self.train_nodes}) + return res + else: + def func(tdi): + res = {k: v.value for k, v in return_without_idx.items()} + res.update({k: v[idx] for k, (v, idx) in return_with_idx.items()}) + res.update({k: f(tdi) for k, f in self.fun_monitors.items()}) + return res return func diff --git a/brainpy/train/runners/__init__.py b/brainpy/train/runners/__init__.py deleted file mode 100644 index 5472736aa..000000000 --- a/brainpy/train/runners/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- - - -""" -This module provides various running and training algorithms -for various neural networks. - -The supported training algorithms include - -- offline training methods, like ridge regression, linear regression, etc. -- online training methods, like recursive least squares (RLS, or Force Learning), - least mean squares (LMS), etc. -- back-propagation learning method -- and others - -The supported neural networks include - -- reservoir computing networks, -- artificial recurrent neural networks, -- spiking neural networks, -- and others. -""" - - -from .base_runner import * -from .online_trainer import * -from .offline_trainer import * -from .back_propagation import * - diff --git a/brainpy/train/runners/back_propagation.py b/brainpy/train/runners/back_propagation.py deleted file mode 100644 index 85c1263ce..000000000 --- a/brainpy/train/runners/back_propagation.py +++ /dev/null @@ -1,485 +0,0 @@ -# -*- coding: utf-8 -*- - -import time -from typing import Union, Dict, Callable, Sequence - -import jax.numpy as jnp -import numpy as np -from jax import jit, random as jr -from jax.tree_util import tree_map - -import brainpy.losses as losses -import brainpy.math as bm -import brainpy.optimizers as optim -from brainpy.errors import UnsupportedError -from brainpy.tools.checking import check_float -from brainpy.train.base import TrainingSystem -from brainpy.train.utils import check_data_batch_size, serialize_kwargs -from brainpy.types import Tensor -from .base_runner import DSTrainer - -__all__ = [ - 'BPTT', - 'BPFF', -] - - -class BPTT(DSTrainer): - """ - The trainer implementing back propagation through time (BPTT) - algorithm for recurrent neural networks. - - """ - - def __init__( - self, - target: TrainingSystem, - - # arguments for BPTT trainer - loss: Union[str, Callable], # loss function - optimizer: optim.Optimizer = None, # optimizer - max_grad_norm: float = None, - shuffle_data: bool = True, - - # common arguments for RNNTrainer - **kwargs - ): - super(BPTT, self).__init__(target=target, **kwargs) - - # jit settings - self.jit['predict'] = self.jit.get('predict', True) - self.jit['loss'] = self.jit.get('loss', True) - self.jit['fit'] = self.jit.get('fit', True) - - # optimizer - if optimizer is None: - lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) - optimizer = optim.Adam(lr=lr) - self.optimizer = optimizer - - # loss - if isinstance(loss, str): - loss = getattr(losses, loss) - elif callable(loss): - loss = loss - else: - raise UnsupportedError(f'Do not support {type(loss)} to specify the loss function. ' - f'We only support str and callable function.') - self.loss_fun = loss - self._train_losses = None - self._test_losses = None - self._f_shuffle = None - - # functions - self._f_loss = dict() - self._f_train = dict() - self._f_grad = dict() - - # training parameters - self.max_grad_norm = max_grad_norm # gradient clipping - self.shuffle_data = shuffle_data - - # initialize the optimizer - self.optimizer.register_vars(self.target.vars().subset(bm.TrainVar).unique()) - - - def __repr__(self): - name = self.__class__.__name__ - prefix = ' ' * len(name) - return (f'{name}(target={self.target}, \n\t' - f'{prefix}jit={self.jit}, \n\t' - f'{prefix}loss={self.loss_fun}, \n\t' - f'{prefix}optimizer={self.optimizer})') - - def predict( - self, - xs: Union[Tensor, Dict[str, Tensor]], - reset_state: bool = True, - shared_args: Dict = None, - **kwargs - ): - """Predict a series of input data with the given target model. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - feedbacks and its output, if users want. - - Parameters - ---------- - xs: Tensor, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. - shared_args: dict - Shared keyword arguments for the given target model. - reset_state: bool - Whether reset the model states. Default True. - - Returns - ------- - output: Tensor, dict - The model output. - """ - # check forced states/feedbacks - return super(BPTT, self).predict(xs=xs, reset_state=reset_state, shared_args=shared_args, **kwargs) - - def fit( - self, - train_data: Union[Callable, Sequence], - test_data: Union[Callable, Sequence] = None, - num_batch: int = None, - num_epoch: int = 100, - num_report: int = 100, - reset_state: bool = True, - shared_args: Dict = None, - ): - """ - Fit the target model according to the given training and testing data. - - Parameters - ---------- - train_data: callable, sequence of data - It can be a callable function, or a tuple/list representing `(X, Y)` data. - - Callable. This function should return a pair of `(X, Y)` data - - Sequence. It should be a pair of `(X, Y)` train set. - - ``X``: should be a tensor or a dict of tensors with the shape of - `(num_sample, num_time, num_feature)`, where `num_sample` is - the number of samples, `num_time` is the number of the time step, - and `num_feature` is the number of features. - - ``Y``: Target values. A tensor or a dict of tensors. - - If the shape of each tensor is `(num_sample, num_feature)`, - then we will only fit the model with the only last output. - - If the shape of each tensor is `(num_sample, num_time, num_feature)`, - then the fitting happens on the whole data series. - test_data: callable, sequence of data - Same as the ``train_data``. It can be a callable function, - or a tuple/list representing `(X, Y)` data. - num_batch: int - The batch size. Default 32. This setting is used when users provide - the ``train_data`` and ``test_data`` as a pair of `(X, Y)` data, rather - than a function. - num_epoch: int - The number of training epoch. Default 100. - num_report: int - The number of step to report the progress. Default 100 training steps. - reset_state: bool - Whether reset the initial states of the target model. - shared_args: dict - The shared keyword arguments for the target models. - - """ - true_progress_bar = self.progress_bar - self.progress_bar = False - # training the model - all_train_losses = [] - all_test_losses = [] - train_i = 0 - t0 = time.time() - for _ in range(num_epoch): - train_data_ = self._get_train_data(train_data, num_batch) - - # training set - for x, y in train_data_: - if reset_state: - self.target.reset_state(check_data_batch_size(x)) - loss = self.f_train(shared_args)(x, y) - all_train_losses.append(loss) - train_i += 1 - if train_i % num_report == 0: - t1 = time.time() - print(f'Train {train_i} steps, use {t1 - t0:.4f} s, train loss {round(float(loss), 5)}') - t0 = t1 - - # testing set - test_data_ = self._get_test_data(test_data, num_batch) - if test_data_ is not None: - for x, y in test_data_: - if reset_state: - self.target.reset_state(check_data_batch_size(x)) - loss = self.f_loss(shared_args)(x, y) - all_test_losses.append(loss) - - self._train_losses = bm.asarray(all_train_losses) - self._test_losses = bm.asarray(all_test_losses) - self.progress_bar = true_progress_bar - - def f_grad(self, shared_kwargs=None) -> Callable: - """Get gradient function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_grad: - self._f_grad[shared_kwargs_str] = self._make_f_grad(shared_kwargs) - return self._f_grad[shared_kwargs_str] - - def f_loss(self, shared_kwargs=None) -> Callable: - """Get loss function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_loss: - self._f_loss[shared_kwargs_str] = self._make_f_loss(shared_kwargs) - if self.jit['loss']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - self._f_loss[shared_kwargs_str] = bm.jit(self._f_loss[shared_kwargs_str], - dyn_vars=dyn_vars) - return self._f_loss[shared_kwargs_str] - - def f_train(self, shared_kwargs=None) -> Callable: - """Get training function.""" - shared_kwargs_str = serialize_kwargs(shared_kwargs) - if shared_kwargs_str not in self._f_train: - self._f_train[shared_kwargs_str] = self._make_f_train(shared_kwargs) - return self._f_train[shared_kwargs_str] - - @property - def train_losses(self): - """Training loss.""" - return self._train_losses - - def _make_f_loss(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - if not isinstance(shared_kwargs, dict): - raise ValueError(f'Only supports dict for "shared_kwargs". ' - f'But got {type(shared_kwargs)}: {shared_kwargs}') - - def loss_fun(inputs, targets): - inputs, num_step, num_batch = self._check_xs(inputs, move_axis=True) - times = jnp.linspace(0., self.dt * (num_step - 1), num_step) - inputs = (times, inputs) - outputs, _ = self._predict(xs=inputs, shared_args=shared_kwargs) - loss = self.loss_fun(bm.moveaxis(outputs, 0, 1), targets) - return loss - - return loss_fun - - def _make_f_grad(self, shared_kwargs: Dict = None): - _f_loss_internal = self._make_f_loss(shared_kwargs) - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - tran_vars = dyn_vars.subset(bm.TrainVar) - return bm.grad(_f_loss_internal, - dyn_vars=dyn_vars.unique(), - grad_vars=tran_vars.unique(), - return_value=True) - - def _make_f_train(self, shared_kwargs: Dict = None): - if shared_kwargs is None: - shared_kwargs = dict() - elif not isinstance(shared_kwargs, dict): - raise ValueError(f'Only supports dict for "shared_kwargs". ' - f'But got {type(shared_kwargs)}: {shared_kwargs}') - - def train_func(inputs, targets): - grads, loss = self.f_grad(shared_kwargs)(inputs, targets) - if self.max_grad_norm is not None: - check_float(self.max_grad_norm, 'max_grad_norm', min_bound=0.) - grads = bm.clip_by_norm(grads, self.max_grad_norm) - self.optimizer.update(grads) - return loss - - if self.jit['fit']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - dyn_vars.update(self.optimizer.vars()) - train_func = bm.jit(train_func, dyn_vars=dyn_vars.unique()) - return train_func - - def _get_train_data(self, train_data, num_batch): - # training dataset - if callable(train_data): - train_data = self._get_data_by_method1(train_data, num_batch) - elif isinstance(train_data, (tuple, list)): - if len(train_data) != 2: - raise ValueError(f"Must be (X, Y) pair, but got a sequence with " - f"length {len(train_data)}") - train_data = self._get_data_by_method2(train_data, - num_batch=num_batch, - shuffle=self.shuffle_data) - else: - raise ValueError(f'Train data does not support {type(train_data)}. ') - return train_data - - def _get_test_data(self, test_data, num_batch): - # testing dataset - if test_data is None: - test_data = None - elif callable(test_data): - test_data = self._get_data_by_method1(test_data, num_batch) - elif isinstance(test_data, (tuple, list)): - assert len(test_data) == 2, f"Must be (X, Y) pair, but got a sequence with length {len(test_data)}" - test_data = self._get_data_by_method2(test_data, - num_batch=num_batch, - shuffle=False) - else: - raise ValueError(f'Test data does not support {type(test_data)}. ') - return test_data - - def _get_data_by_method1(self, dataset, num_batch): - for xs, ys in dataset(): - yield xs, ys - - def _get_data_by_method2(self, dataset, num_batch=None, shuffle=False): - if num_batch is None: - raise ValueError('Must provide "num_batch" when dataset is not a callable function.') - assert isinstance(dataset, (tuple, list)) and len(dataset) == 2 - xs, ys = dataset - num_sample = self._get_xs_batch_size(xs) - if shuffle: - xs, ys = self._shuffle(xs, ys) - for data_idx in range(0, num_sample, num_batch): - if (data_idx + num_batch) > num_sample: - inputs = {k: v[data_idx:] for k, v in xs.items()} - targets = {k: v[data_idx:] for k, v in ys.items()} - else: - inputs = {k: v[data_idx: data_idx + num_batch] for k, v in xs.items()} - targets = {k: v[data_idx: data_idx + num_batch] for k, v in ys.items()} - yield inputs, targets - - def _shuffle(self, xs, ys): - key = jr.PRNGKey(seed=np.random.randint(0, 100000)) - if self._f_shuffle is None: - def shuffle(xs, ys, key): - xs = tree_map(lambda x: jr.permutation(key, x, axis=0), xs) - ys = tree_map(lambda y: jr.permutation(key, y, axis=0), ys) - return xs, ys - - self._f_shuffle = jit(shuffle) - return self._f_shuffle(xs, ys, key) - - def _get_xs_batch_size(self, xs): - num_batch_sizes = [] - for key, val in xs.items(): - num_batch_sizes.append(val.shape[0]) - if len(set(num_batch_sizes)) != 1: - raise ValueError(f'Number of batch size is different across tensors in ' - f'the provided "xs". We got {set(num_batch_sizes)}.') - return num_batch_sizes[0] - - -class BPFF(BPTT): - """ - The trainer implementing back propagation algorithm - for feedforward neural networks. - - """ - - def __init__( - self, - target: TrainingSystem, - **kwargs - ): - super(BPFF, self).__init__(target=target, **kwargs) - - def predict( - self, - xs: Union[Tensor, Dict[str, Tensor]], - reset_state: bool = True, - shared_args: Dict = None, - **kwargs - ): - """Predict a series of input data with the given target model. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - feedbacks and its output. - - Parameters - ---------- - xs: Tensor, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. - reset_state: bool - Whether reset the model states. - shared_args: optional, dict - The shared arguments across different layers. - - Returns - ------- - output: Tensor, dict - The model output. - """ - # format input data - num_batch = self._get_xs_batch_size(xs) - # reset the model states - if reset_state: - self.target.reset_state(num_batch) - # init monitor - for key in self.mon.var_names: - self.mon[key] = [] # reshape the monitor items - # prediction - outputs, hists = self._predict(xs=xs, shared_args=shared_args) - # post-running for monitors - for key in hists.keys(): - self.mon[key] = bm.asarray(hists[key]) - if self.numpy_mon_after_run: - self.mon.ts = np.asarray(self.mon.ts) - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) - return outputs - - def _predict( - self, - xs: Dict[str, Tensor], - shared_args: Dict = None, - forced_states: Dict[str, Tensor] = None, - forced_feedbacks: Dict[str, Tensor] = None, - ): - """Predict the output according to the inputs. - - Parameters - ---------- - xs: dict - Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - forced_states: dict - The forced state values. - forced_feedbacks: dict - The forced feedback output values. - shared_args: optional, dict - The shared keyword arguments. - - Returns - ------- - outputs, hists - A tuple of pair of (outputs, hists). - """ - return self._get_predict_func(shared_args)(xs) - - def _make_f_loss(self, shared_kwargs: Dict = None): - if shared_kwargs is None: shared_kwargs = dict() - if not isinstance(shared_kwargs, dict): - raise ValueError(f'Only supports dict for "shared_kwargs". ' - f'But got {type(shared_kwargs)}: {shared_kwargs}') - - def loss_fun(inputs, targets): - outputs, _ = self._predict(xs=inputs, shared_args=shared_kwargs) - loss = self.loss_fun(outputs, targets) - return loss - - return loss_fun - - def _get_predict_func(self, shared_args: Dict = None): - if shared_args is None: shared_args = dict() - shared_kwargs_str = serialize_kwargs(shared_args) - if shared_kwargs_str not in self._predict_func: - self._predict_func[shared_kwargs_str] = self._make_predict_func(shared_args) - return self._predict_func[shared_kwargs_str] - - def _make_predict_func(self, shared_args: Dict): - if not isinstance(shared_args, dict): - raise ValueError(f'"shared_kwargs" must be a dict, ' - f'but got {type(shared_args)}') - - def run_func(xs): - return self.target(xs, shared_args) - - if self.jit['predict']: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - run_func = bm.jit(run_func, dyn_vars=dyn_vars.unique()) - return run_func - - def _get_xs_batch_size(self, xs): - num_batch_sizes = [] - for key, val in xs.items(): - num_batch_sizes.append(val.shape[0]) - if len(set(num_batch_sizes)) != 1: - raise ValueError(f'Number of batch size is different across tensors in ' - f'the provided "xs". We got {set(num_batch_sizes)}.') - return num_batch_sizes[0] diff --git a/brainpy/train/runners/base_runner.py b/brainpy/train/runners/base_runner.py deleted file mode 100644 index 180862b11..000000000 --- a/brainpy/train/runners/base_runner.py +++ /dev/null @@ -1,91 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Sequence, Any - -import jax.numpy as jnp - -import brainpy.math as bm -from brainpy.dyn.runners import DSRunner -from brainpy.tools.checking import check_dict_data -from brainpy.train.base import TrainingSystem - -__all__ = [ - 'DSTrainer', 'DSRunner', -] - - -class DSTrainer(DSRunner): - """Structural Trainer for Models with Recurrent Dynamics.""" - - target: TrainingSystem - train_nodes: Sequence[TrainingSystem] # need to be initialized by subclass - - def __init__( - self, - target: TrainingSystem, - **kwargs - ): - if not isinstance(target, TrainingSystem): - raise TypeError(f'"target" must be an instance of {TrainingSystem.__name__}, ' - f'but we got {type(target)}: {target}') - super(DSTrainer, self).__init__(target=target, **kwargs) - - # jit - self.jit['predict'] = self.jit.get('predict', True) - self.jit['fit'] = self.jit.get('fit', True) - - def fit( - self, - train_data: Any, - reset: bool = False, - shared_kwargs: Dict = None - ): # need to be implemented by subclass - raise NotImplementedError('Must implement the fit function. ') - - def _get_trainable_nodes(self): - # check trainable nodes - nodes = self.target.nodes(level=-1, include_self=True).subset(TrainingSystem).unique() - return tuple([node for node in nodes.values() if node.trainable]) - - def _check_ys(self, ys, num_batch, num_step, move_axis=False): - if isinstance(ys, (bm.ndarray, jnp.ndarray)): - if len(self.train_nodes) == 1: - ys = {self.train_nodes[0].name: ys} - else: - raise ValueError(f'The network\n {self.target} \nhas {len(self.train_nodes)} ' - f'training nodes, while we only got one target data.') - check_dict_data(ys, key_type=str, val_type=(bm.ndarray, jnp.ndarray)) - - # check data path - abs_node_names = [node.name for node in self.train_nodes] - formatted_ys = {} - ys_not_included = {} - for k, v in ys.items(): - if k in abs_node_names: - formatted_ys[k] = v - else: - ys_not_included[k] = v - if len(ys_not_included): - rel_nodes = self.target.nodes('relative', level=-1, include_self=True).subset(TrainingSystem).unique() - for k, v in ys_not_included.items(): - if k in rel_nodes: - formatted_ys[rel_nodes[k].name] = v - else: - raise ValueError(f'Unknown target "{k}" for fitting.') - - # check data shape - for key, val in formatted_ys.items(): - if val.ndim < 3: - raise ValueError("Targets must be a tensor with shape of " - "(num_sample, num_time, feature_dim, ...), " - f"but we got {val.shape}") - if val.shape[0] != num_batch: - raise ValueError(f'Batch size of the target {key} does not match ' - f'with the input data {val.shape[0]} != {num_batch}') - if val.shape[1] != num_step: - raise ValueError(f'The time step of the target {key} does not match ' - f'with the input data {val.shape[1]} != {num_step})') - if move_axis: - # change shape to (num_time, num_sample, num_feature) - formatted_ys = {k: bm.moveaxis(v, 0, 1) for k, v in formatted_ys.items()} - return formatted_ys diff --git a/brainpy/train/utils.py b/brainpy/train/utils.py deleted file mode 100644 index 6f8b5a732..000000000 --- a/brainpy/train/utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Dict, Optional -from brainpy import math as bm - -from brainpy.tools.checking import check_dict_data -from jax.tree_util import tree_flatten - -__all__ = [ - 'serialize_kwargs', - # 'check_rnn_data_time_step', - # 'check_data_batch_size', -] - - -def serialize_kwargs(shared_kwargs: Optional[Dict]): - """Serialize kwargs.""" - shared_kwargs = dict() if shared_kwargs is None else shared_kwargs - check_dict_data(shared_kwargs, - key_type=str, - val_type=(bool, float, int, complex), - name='shared_kwargs') - shared_kwargs = {key: shared_kwargs[key] for key in sorted(shared_kwargs.keys())} - return str(shared_kwargs) - - -def check_rnn_data_time_step(data: Dict, num_step=None): - if len(data) == 1: - time_step = list(data.values())[0].shape[1] - else: - steps = [] - for key, val in data.items(): - steps.append(val.shape[1]) - if len(set(steps)) != 1: - raise ValueError('Time steps are not consistent among the given data. ' - f'Got {set(steps)}. We expect only one time step.') - time_step = steps[0] - if (num_step is not None) and time_step != num_step: - raise ValueError(f'Time step is not consistent with the expected {time_step} != {num_step}') - return time_step - - -def check_data_batch_size(data, num_batch=None, batch_idx=0): - leaves, tree = tree_flatten(data, is_leaf=lambda x: isinstance(x, bm.JaxArray)) - batches = [leaf.shape[batch_idx] for leaf in leaves] - if len(set(batches)) != 1: - raise ValueError('Batch sizes are not consistent among the given data. ' - f'Got {set(batches)}. We expect only one batch size.') - batch_size = batches[0] - if (num_batch is not None) and batch_size != num_batch: - raise ValueError(f'Batch size is not consistent with the expected {batch_size} != {num_batch}') - return batch_size diff --git a/brainpy/types.py b/brainpy/types.py index 794c91f85..01141d681 100644 --- a/brainpy/types.py +++ b/brainpy/types.py @@ -2,16 +2,23 @@ from typing import TypeVar, Tuple +import numpy as np import jax.numpy as jnp -import brainpy.math as bm __all__ = [ - 'Tensor', - 'Parameter', + 'Tensor', 'Parameter', + 'Shape', + + 'Output', 'Monitor' ] -Tensor = TypeVar('Tensor', bm.JaxArray, jnp.ndarray) -Parameter = TypeVar('Parameter', float, int, jnp.ndarray, bm.JaxArray, bm.Variable) +Parameter = TypeVar('Parameter', float, int, jnp.ndarray, 'JaxArray', 'Variable') # noqa +Tensor = TypeVar('Tensor', 'JaxArray', 'Variable', 'TrainVar', jnp.ndarray, np.ndarray) # noqa + Shape = TypeVar('Shape', int, Tuple[int, ...]) + +Output = TypeVar('Output') +Monitor = TypeVar('Monitor') + diff --git a/docs/conf.py b/docs/conf.py index 89960806e..5bebd4968 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,7 +36,7 @@ auto_generater.generate_datasets_docs() auto_generater.generate_tools_docs() auto_generater.generate_compact_docs() -auto_generater.generate_math_compact_docs() +# auto_generater.generate_math_compact_docs() import shutil diff --git a/docs/quickstart/simulation.ipynb b/docs/quickstart/simulation.ipynb index 127f6c590..957ed814b 100644 --- a/docs/quickstart/simulation.ipynb +++ b/docs/quickstart/simulation.ipynb @@ -3,7 +3,11 @@ { "cell_type": "markdown", "id": "2e1966cc", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "# Simulating a Spiking Neural Network" ] @@ -11,7 +15,11 @@ { "cell_type": "markdown", "id": "724ccd02", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "@[Xiaoyu Chen](mailto:c-xy17@tsinghua.org.cn)" ] @@ -19,7 +27,11 @@ { "cell_type": "markdown", "id": "66f9a769", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "One of the most important approaches of studying brain dynamics is building a dynamic model and doing simulation. Generally, there are two ways to construct a dynamic model. The first one is called spiking models, which attempt to finely simulate the activity of each neuron in the target population. They are named spiking models because the simulation process records the precise timing of spiking of every neuron. The second is called rate models, which regard a population of neurons with similar properties as a single firing unit and examine the firing rate of this population. In this section, we will illustrate how to build and simulate a spiking neural network, e.i. SNN.\n", "\n", @@ -30,7 +42,11 @@ "cell_type": "code", "execution_count": 9, "id": "c4fbe84d", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "import brainpy as bp\n", @@ -42,7 +58,11 @@ { "cell_type": "markdown", "id": "dd03123d", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "## Simulating an E-I balanced network" ] @@ -50,7 +70,11 @@ { "cell_type": "markdown", "id": "5e88fc7f", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "### Building an E-I balanced network" ] @@ -58,7 +82,11 @@ { "cell_type": "markdown", "id": "63354c42", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Firstly, let's try to build an E-I balanced network. It was proposed to interpret the irregular firing of neurons in the cortical area \\[1\\]. Since the structure of an E-I balanced network is relatively simple, it is a good practice that helps users to learn the basic paradigm of brain dynamic simulation in BrainPy. The structure of a E-I balanced network is as follows:\n", "\n", @@ -69,7 +97,11 @@ { "cell_type": "markdown", "id": "62d35f9b", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "A E-I balanced network is composed of two neuron groups and the synaptic connections between them. Specifically, they include:\n", "1. a group of excitatory neurons, $\\mathrm{E}$,\n", @@ -81,7 +113,11 @@ { "cell_type": "markdown", "id": "c367fbf1", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "To construct the network, we need to define these components one by one. BrainPy provides plenty of handy built-in models for brain dynamic simulation. They are contained in ``brainpy.dyn``. Let's choose the simplest yet the most canonical neuron model, the Leaky Integrate-and-Fire (LIF) model, to build the excitatory and inhibitory neuron groups:" ] @@ -90,7 +126,11 @@ "cell_type": "code", "execution_count": 3, "id": "69556409", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "E = bp.dyn.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., method='exp_auto')\n", @@ -103,7 +143,11 @@ { "cell_type": "markdown", "id": "931a0a84", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "When defining the LIF neuron group, the parameters can be tuned according to users' need. The first parameter denotes the number of neurons. Here the ratio of excitatory and inhibitory neurons is set 4:1. ``V_rest`` denotes the resting potential, ``V_th`` denotes the firing threshold, ``V_reset`` denotes the reset value after firing, ``tau`` is the time constant, and ``tau_ref`` is the duration of the refractory period. ``method`` refers to the numerical integration method to be used in simulation. " ] @@ -111,7 +155,11 @@ { "cell_type": "markdown", "id": "abe09b1b", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Then the synaptic connections between these two groups can be defined as follows:" ] @@ -120,7 +168,11 @@ "cell_type": "code", "execution_count": 4, "id": "8be1733f", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "E2E = bp.dyn.ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6, tau=5., method='exp_auto')\n", @@ -132,7 +184,11 @@ { "cell_type": "markdown", "id": "13b3c3a9", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Here we use the Expnential synapse model (``ExpCOBA``) to simulate synaptic connections. Among the parameters of the model, the first two denotes the pre- and post-synaptic neuron groups, respectively. The third one refers to the connection types. In this example, we use ``bp.conn.FixedProb``, which connects the presynaptic neurons to postsynaptic neurons with a given probability (detailed information is available in [Synaptic Connection](../tutorial_toolbox/synaptic_connections.ipynb)). The following three parameters describes the dynamic properties of the synapse, and the last one is the numerical integration method as that in the LIF model." ] @@ -140,7 +196,11 @@ { "cell_type": "markdown", "id": "572fa775", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "After defining all the components, they can be combined to form a network:" ] @@ -149,7 +209,11 @@ "cell_type": "code", "execution_count": 5, "id": "f8a6c731", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "net = bp.dyn.Network(E2E, E2I, I2E, I2I, E=E, I=I)" @@ -158,7 +222,11 @@ { "cell_type": "markdown", "id": "0412deb5", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "In the definition, neurons and synapses are given to the network. The excitatory and inhibitory neuron groups (`E` and `I`) are passed with a name, for they will be specifically operated in the simulation (here they will be given with input currents).\n", "\n", @@ -168,7 +236,11 @@ { "cell_type": "markdown", "id": "e3bcad34", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "### Running a simulation" ] @@ -176,7 +248,11 @@ { "cell_type": "markdown", "id": "43ec39f4", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "After build a SNN, we can use it for dynamic simulation. To run a simulation, we need first wrap the network model into a **runner**. Currently BrainPy provides ``DSRunner`` and ``ReportRunner`` in ``brainpy.dyn``, which will be expanded in the [Runners](../tutorial_simulation/runner.ipynb) tutorial. Here we use ``DSRunner`` as an example:" ] @@ -185,7 +261,11 @@ "cell_type": "code", "execution_count": 6, "id": "8e16cd97", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "runner = bp.dyn.DSRunner(net,\n", @@ -197,7 +277,11 @@ { "cell_type": "markdown", "id": "11473917", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "To make dynamic simulation more applicable and powerful, users can [**monitor**](../tutorial_toolbox/monitors.ipynb) variable trajectories and give [**inputs**](../tutorial_toolbox/inputs.ipynb) to target neuron groups. Here we monitor the ``spike`` variable in the ``E`` and ``I`` LIF model, which refers to the spking status of the neuron group, and give a constant input to both neuron groups. The time interval of numerical integration ``dt`` (with the default value of 0.1) can also be specified.\n", "\n", @@ -208,7 +292,11 @@ "cell_type": "code", "execution_count": 7, "id": "a2a602d2", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { @@ -242,7 +330,11 @@ { "cell_type": "markdown", "id": "8452dec3", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "where the calling function receives the simulation time (usually in milliseconds) as the input and returns the time (seconds) spent on simulation. BrainPy achieves an extraordinary simulation speed with the assistance of just-in-time (JIT) compilation. Please refer to [Just-In-Time Compilation](../tutorial_math/compilation.ipynb) for more details.\n", "\n", @@ -253,7 +345,11 @@ "cell_type": "code", "execution_count": 8, "id": "f3aab08c", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { @@ -282,7 +378,11 @@ { "cell_type": "markdown", "id": "3f78546b", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "In the code above, ``brianpy.visualize`` contains some useful functions to visualize simulation results based on the ``matplotlib`` package. Since the simulation results are stored as NumPy arrays, users can directly use ``matplotlib`` for visualization." ] @@ -290,7 +390,11 @@ { "cell_type": "markdown", "id": "8ce65bd2", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "## Simulating a decision-making network" ] @@ -298,7 +402,11 @@ { "cell_type": "markdown", "id": "d403c2f5", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "### Building a decision-making network" ] @@ -306,7 +414,11 @@ { "cell_type": "markdown", "id": "9d9bf6b8", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "After learning how to build a E-I balanced network, we can try to handle a more complex model. In 2002, Wang proposed a decision-making model that could choose between two conflict inputs by accumulating evidence over time \\[2\\]. \n", "\n", @@ -319,7 +431,11 @@ { "cell_type": "markdown", "id": "81c432b0", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "To construct a decision-making network, we should build all neuron groups:\n", "\n", @@ -339,7 +455,11 @@ { "cell_type": "markdown", "id": "0a3345af", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Now let's build these neuron groups and connections.\n", "\n", @@ -350,7 +470,11 @@ "cell_type": "code", "execution_count": 11, "id": "217204d5", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "pre_stimulus_period = 100. # time before the external simuli are given\n", @@ -362,7 +486,11 @@ { "cell_type": "markdown", "id": "e559ece9", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "To build $\\mathrm{I_A}$ and $\\mathrm{I_B}$, we shall define a class of neuron groups that can generate stochastic Possion stimulu. Two define neuron groups, they should inherit `brainpy.dyn.NeuGroup`." ] @@ -371,7 +499,11 @@ "cell_type": "code", "execution_count": 12, "id": "b76c3965", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "class PoissonStim(bp.dyn.NeuGroup):\n", @@ -402,7 +534,11 @@ { "cell_type": "markdown", "id": "0dbe7213", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Because there are too many neuron groups and connections, it will be much clearer if we define a new network class inheriting `brainpy.dyn.Network` to accommodate all these neurons and synapses:" ] @@ -411,7 +547,11 @@ "cell_type": "code", "execution_count": 13, "id": "ca22fe03", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "class DecisionMaking(bp.dyn.Network):\n", @@ -518,7 +658,11 @@ { "cell_type": "markdown", "id": "833eb50a", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Though the code seems longer than the E-I balanced network, the basic building paradigm is the same: building neuron groups and the connections among them." ] @@ -526,7 +670,11 @@ { "cell_type": "markdown", "id": "54efdc44", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "### Running a simulation" ] @@ -534,7 +682,11 @@ { "cell_type": "markdown", "id": "60f10858", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "After building it, the simulation process will be much the same as running a E-I balanced network. First we should wrap the network into a runner:" ] @@ -543,7 +695,11 @@ "cell_type": "code", "execution_count": 14, "id": "47ebe27c", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "net = DecisionMaking(scale=1., coherence=25.6, mu0=40.)\n", @@ -553,7 +709,11 @@ { "cell_type": "markdown", "id": "8beac6d6", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Then we call the runner to run the simulation:" ] @@ -562,7 +722,11 @@ "cell_type": "code", "execution_count": 15, "id": "96e97756", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { @@ -596,7 +760,11 @@ { "cell_type": "markdown", "id": "0d27aac5", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Finally, we visualize the simulation result by using `matplotlib`:" ] @@ -606,7 +774,10 @@ "execution_count": 16, "id": "0d57a44d", "metadata": { - "scrolled": false + "scrolled": false, + "pycharm": { + "name": "#%%\n" + } }, "outputs": [ { @@ -669,7 +840,11 @@ { "cell_type": "markdown", "id": "5a8dd84e", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "For more information about brain dynamic simulation, please refer to [Dynamics Simulation](../tutorial_simulation/index.rst) in the BDP tutorial." ] @@ -677,7 +852,11 @@ { "cell_type": "markdown", "id": "42c6d43f", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "## References\n", "\n", @@ -690,7 +869,11 @@ "cell_type": "code", "execution_count": null, "id": "e645a3df", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [] } @@ -752,4 +935,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/docs/tutorial_training/node_customization.ipynb b/docs/tutorial_training/node_customization.ipynb index 9896feb2f..b79e9dd09 100644 --- a/docs/tutorial_training/node_customization.ipynb +++ b/docs/tutorial_training/node_customization.ipynb @@ -126,7 +126,7 @@ " \n", " # 2. Initialize the weight W\n", " weight_shape = (sum(free_sizes), self.num_unit)\n", - " self.W = bp.nn.init_param(self.W_initializer, weight_shape)\n", + " self.W = bp.nn.parameter(self.W_initializer, weight_shape)\n", " # If the user want to train this node, we need mark the \n", " # weight as a \"brainpy.math.TrainVar\"\n", " if self.trainable:\n", @@ -327,8 +327,8 @@ " def init_ff(self):\n", " unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)\n", " num_input = sum(free_sizes)\n", - " self.wi = bp.nn.init_param(self.wi_initializer, (num_input, self.num_unit))\n", - " self.wr = bp.nn.init_param(self.wr_initializer, (self.num_unit, self.num_unit))\n", + " self.wi = bp.nn.parameter(self.wi_initializer, (num_input, self.num_unit))\n", + " self.wr = bp.nn.parameter(self.wr_initializer, (self.num_unit, self.num_unit))\n", " if self.trainable:\n", " self.wi = bm.TrainVar(self.wi)\n", " self.wr = bm.TrainVar(self.wr)\n", @@ -407,7 +407,7 @@ " unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True)\n", " # 2. Initialize the feedforward weight Wff\n", " weight_shape = (sum(free_sizes), self.num_unit)\n", - " self.Wff = bp.nn.init_param(self.W_initializer, weight_shape)\n", + " self.Wff = bp.nn.parameter(self.W_initializer, weight_shape)\n", " if self.trainable:\n", " self.Wff = bm.TrainVar(self.Wff)\n", " # 3. Set the output shape \n", @@ -418,7 +418,7 @@ " unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True)\n", " # 2. Initialize the feedback weight Wfb\n", " weight_shape = (sum(free_sizes), self.num_unit)\n", - " self.Wfb = bp.nn.init_param(self.W_initializer, weight_shape)\n", + " self.Wfb = bp.nn.parameter(self.W_initializer, weight_shape)\n", " if self.trainable:\n", " self.Wfb = bm.TrainVar(self.Wfb)\n", " \n", diff --git a/examples/analysis/1d_qif.py b/examples/analysis/1d_qif.py index 2372afc25..e59d4b359 100644 --- a/examples/analysis/1d_qif.py +++ b/examples/analysis/1d_qif.py @@ -6,10 +6,8 @@ bp.math.enable_x64() # important! -@bp.odeint def qif(V, t, c=.07, R=1., tau=10., Iext=0., V_rest=-65., V_c=-50.0, ): - dVdt = (c * (V - V_rest) * (V - V_c) + R * Iext) / tau - return dVdt + return (c * (V - V_rest) * (V - V_c) + R * Iext) / tau pp = bp.analysis.PhasePlane1D( diff --git a/examples/analysis/1d_system.py b/examples/analysis/1d_system.py index ec25e8602..270181cf7 100644 --- a/examples/analysis/1d_system.py +++ b/examples/analysis/1d_system.py @@ -41,8 +41,7 @@ def cubic_system1(): def cubic_system_2(): @bp.odeint def int_x(x, t, Iext): - dx = x ** 3 - x + Iext - return dx + return x ** 3 - x + Iext analyzer = bp.analysis.PhasePlane1D(model=int_x, target_vars={'x': [-2, 2]}, diff --git a/examples/analysis/2d_fitzhugh_nagumo_model.py b/examples/analysis/2d_fitzhugh_nagumo_model.py index 12f692f08..ba30f1c9e 100644 --- a/examples/analysis/2d_fitzhugh_nagumo_model.py +++ b/examples/analysis/2d_fitzhugh_nagumo_model.py @@ -33,7 +33,8 @@ def dw(w, t, V, a=0.7, b=0.8): self.int_V = bp.odeint(dV, method=method) self.int_w = bp.odeint(dw, method=method) - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt) self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt) self.Iext[:] = 0. diff --git a/examples/analysis/2d_mean_field_QIF.py b/examples/analysis/2d_mean_field_QIF.py index be9cc70b9..c5f48f101 100644 --- a/examples/analysis/2d_mean_field_QIF.py +++ b/examples/analysis/2d_mean_field_QIF.py @@ -38,7 +38,8 @@ def dv(v, t, r, Iext=0., eta=-5.0): self.int_r = bp.odeint(dr, method=method) self.int_v = bp.odeint(dv, method=method) - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] self.r.value = self.int_r(self.r, t, self.v, self.delta, dt) self.v.value = self.int_v(self.v, t, self.r, self.Iext, self.eta, dt) self.Iext[:] = 0. diff --git a/examples/analysis/2d_wilson_cowan_model.py b/examples/analysis/2d_wilson_cowan_model.py index 2132acda9..6248f5940 100644 --- a/examples/analysis/2d_wilson_cowan_model.py +++ b/examples/analysis/2d_wilson_cowan_model.py @@ -47,7 +47,8 @@ def di(i, t, e): self.int_e = bp.odeint(de, method=method) self.int_i = bp.odeint(di, method=method) - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] self.e.value = self.int_e(self.e, t, self.i, self.Iext, dt) self.i.value = self.int_i(self.i, t, self.e, dt) self.Iext[:] = 0. diff --git a/examples/analysis/3d_hindmarsh_rose_model.py b/examples/analysis/3d_hindmarsh_rose_model.py index 6511fcff0..03d496305 100644 --- a/examples/analysis/3d_hindmarsh_rose_model.py +++ b/examples/analysis/3d_hindmarsh_rose_model.py @@ -8,52 +8,12 @@ bp.math.enable_x64() -class HindmarshRose(bp.dyn.DynamicalSystem): - def __init__(self, method='exp_auto'): - super(HindmarshRose, self).__init__() - - # parameters - self.a = 1. - self.b = 2.5 - self.c = 1. - self.d = 5. - self.s = 4. - self.x_r = -1.6 - self.r = 0.001 - - # variables - self.x = bp.math.Variable(bp.math.ones(1)) - self.y = bp.math.Variable(bp.math.ones(1)) - self.z = bp.math.Variable(bp.math.ones(1)) - self.I = bp.math.Variable(bp.math.zeros(1)) - - # integral functions - def dx(x, t, y, z, Isyn): - return y - self.a * x ** 3 + self.b * x * x - z + Isyn - - def dy(y, t, x): - return self.c - self.d * x * x - y - - def dz(z, t, x): - return self.r * (self.s * (x - self.x_r) - z) - - self.int_x = bp.odeint(f=dx, method=method) - self.int_y = bp.odeint(f=dy, method=method) - self.int_z = bp.odeint(f=dz, method=method) - - def update(self, t, dt): - self.x.value = self.int_x(self.x, t, self.y, self.z, self.I, dt) - self.y.value = self.int_y(self.y, t, self.x, dt) - self.z.value = self.int_z(self.z, t, self.x, dt) - self.I[:] = 0. - - def simulation(): - model = HindmarshRose() - # model.b = 2.5 + model = bp.dyn.neurons.HindmarshRose(1) runner = bp.dyn.DSRunner( - model, monitors=['x', 'y', 'z'], - inputs=['I', 1.5], + model, + monitors=['x', 'y', 'z'], + inputs=[model.input, 1.5], ) runner.run(2000.) bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x') @@ -63,45 +23,45 @@ def simulation(): def bifurcation_analysis(): - model = HindmarshRose() - + model = bp.dyn.neurons.HindmarshRose(1) analyzer = bp.analysis.FastSlow2D( - [model.int_x, model.int_y, model.int_z], - fast_vars={'x': [-3, 2], 'y': [-20., 3.]}, + model, + fast_vars={'V': [-3, 2], 'y': [-20., 3.]}, slow_vars={'z': [-0.5, 3.]}, - pars_update={'Isyn': 1.5}, + pars_update={'I_ext': 1.5}, resolutions={'z': 0.01}, # options={bp.analysis.C.y_by_x_in_fy: lambda x: model.c - model.d * x * x} ) analyzer.plot_bifurcation(num_rank=20) - analyzer.plot_trajectory({'x': [1.], 'y': [1.], 'z': [1.]}, + analyzer.plot_trajectory({'V': [1.], 'y': [1.], 'z': [1.]}, duration=1700, plot_durations=[360, 1680]) analyzer.show_figure() def phase_plane_analysis(): - model = HindmarshRose() - + model = bp.dyn.neurons.HindmarshRose(1) for z in np.arange(0., 2.5, 0.3): analyzer = bp.analysis.PhasePlane2D( - [model.int_x, model.int_y], - target_vars={'x': [-3, 2], 'y': [-20., 3.]}, - pars_update={'Isyn': 1.5, 'z': z}, - resolutions={'x': 0.01, 'y': 0.01}, + model, + target_vars={'V': [-3, 2], 'y': [-20., 3.]}, + pars_update={'I_ext': 1.5, 'z': z}, + resolutions={'V': 0.01, 'y': 0.01}, ) analyzer.plot_nullcline() analyzer.plot_vector_field() fps = analyzer.plot_fixed_point(with_return=True) - analyzer.plot_trajectory({'x': [fps[-1, 0] + 0.1], 'y': [fps[-1, 0] + 0.1]}, - duration=500, plot_durations=[400, 500]) + analyzer.plot_trajectory({'V': [fps[-1, 0] + 0.1], + 'y': [fps[-1, 0] + 0.1]}, + duration=500, + plot_durations=[400, 500]) plt.title(f'z={z:.2f}') + # plt.show() plt.savefig(f'data/z={z:.2f}.png') plt.close() - # analyzer.show_figure() if __name__ == '__main__': - # simulation() + simulation() bifurcation_analysis() - # phase_plane_analysis() + phase_plane_analysis() diff --git a/examples/analysis/3d_reduced_trn_model.py b/examples/analysis/3d_reduced_trn_model.py index 28b92ff8b..ce3d0e8c0 100644 --- a/examples/analysis/3d_reduced_trn_model.py +++ b/examples/analysis/3d_reduced_trn_model.py @@ -191,13 +191,14 @@ def derivative(self, V, y, z, t, Isyn): dzdt = self.fz(z, t, V) return dvdt, dydt, dzdt - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] if isinstance(self.int_V, bp.ode.ExponentialEuler): - V = self.int_V(self.V, t, self.y, self.z, self.input, dt=dt) - self.y.value = self.int_y(self.y, t, self.V, dt=dt) - self.z.value = self.int_z(self.z, t, self.V, dt=dt) + V = self.int_V(self.V, t, self.y, self.z, self.input, dt) + self.y.value = self.int_y(self.y, t, self.V, dt) + self.z.value = self.int_z(self.z, t, self.V, dt) else: - V, self.y.value, self.z.value = self.integral(self.V, self.y, self.z, t, self.input, dt=dt) + V, self.y.value, self.z.value = self.integral(self.V, self.y, self.z, t, self.input, dt) self.spike.value = bm.logical_and((self.V < self.Vth), (V >= self.Vth)) self.V.value = V self.input[:] = 0. diff --git a/examples/analysis/4d_HH_model.py b/examples/analysis/4d_HH_model.py index dcceee557..4042d325c 100644 --- a/examples/analysis/4d_HH_model.py +++ b/examples/analysis/4d_HH_model.py @@ -1,120 +1,44 @@ # -*- coding: utf-8 -*- -import matplotlib.pyplot as plt -import numpy as np - import brainpy as bp import brainpy.math as bm - -class HH(bp.dyn.NeuGroup): - def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, - V_th=20., C=1.0, name=None): - super(HH, self).__init__(size=size, name=name) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - - # variables - self.V = bm.Variable(bm.ones(self.num) * -65.) - self.m = bm.Variable(0.5 * bm.ones(self.num)) - self.h = bm.Variable(0.6 * bm.ones(self.num)) - self.n = bm.Variable(0.32 * bm.ones(self.num)) - self.spike = bm.Variable(bm.zeros(size, dtype=bool)) - self.input = bm.Variable(bm.zeros(size)) - - # integral functions - self.int_h = bp.ode.ExponentialEuler(self.dh) - self.int_n = bp.ode.ExponentialEuler(self.dn) - self.int_m = bp.ode.ExponentialEuler(self.dm) - self.int_V = bp.ode.ExponentialEuler(self.dV) - - def dh(self, h, t, V): - alpha = 0.07 * bm.exp(-(V + 65) / 20.) - beta = 1 / (1 + bm.exp(-(V + 35) / 10)) - dhdt = alpha * (1 - h) - beta * h - return dhdt - - def dn(self, n, t, V): - alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) - beta = 0.125 * bm.exp(-(V + 65) / 80) - dndt = alpha * (1 - n) - beta * n - return dndt - - def dm(self, m, t, V): - alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) - beta = 4.0 * bm.exp(-(V + 65) / 18) - dmdt = alpha * (1 - m) - beta * m - return dmdt - - def dV(self, V, t, m, h, n, Iext): - I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) - I_K = (self.gK * n ** 4.0) * (V - self.EK) - I_leak = self.gL * (V - self.EL) - dVdt = (- I_Na - I_K - I_leak + Iext) / self.C - return dVdt - - def step(self, h, Iext): - V, m, h, n = bm.split(h, 4) - dV = self.dV(V, 0., m, h, n, Iext) - dm = self.dm(m, 0., V) - dh = self.dh(h, 0., V) - dn = self.dn(n, 0., V) - return bm.concatenate([dV, dm, dh, dn]) - - def update(self, t, dt): - m = self.int_m(self.m, t, self.V, dt=dt) - h = self.int_h(self.h, t, self.V, dt=dt) - n = self.int_n(self.n, t, self.V, dt=dt) - V = self.int_V(self.V, t, self.m, self.h, self.n, self.input, dt=dt) - self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) - self.V.value = V - self.h.value = h - self.n.value = n - self.m.value = m - self.input[:] = 0. - - I = 5. -model = HH(1) +model = bp.dyn.neurons.HH(1) runner = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V']) runner.run(100) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) # analysis -finder = bp.analysis.SlowPointFinder(model, inputs=(model.input, I), excluded_vars=[model.input]) -V = bm.random.normal(0., 5., (1000, model.num)) - 50. -mhn = bm.random.random((1000, model.num * 3)) -finder.find_fps_with_opt_solver(candidates=bm.hstack([V, mhn])) +model = bp.dyn.neurons.HH(1, method='euler') +finder = bp.analysis.SlowPointFinder( + model, + inputs=(model.input, I), + included_vars={'V': model.V, + 'm': model.m, + 'h': model.h, + 'n': model.n}, + dt=1. +) +candidates = {'V': bm.random.normal(0., 5., (1000, model.num)) - 50., + 'm': bm.random.random((1000, model.num)), + 'h': bm.random.random((1000, model.num)), + 'n': bm.random.random((1000, model.num))} +finder.find_fps_with_opt_solver(candidates=candidates) finder.filter_loss(1e-7) -finder.keep_unique() +finder.keep_unique(tolerance=1e-1) print('fixed_points: ', finder.fixed_points) print('losses:', finder.losses) -if len(finder.fixed_points): - jac = finder.compute_jacobians(finder.fixed_points) - for i in range(len(finder.fixed_points)): - eigval, eigvec = np.linalg.eig(np.asarray(jac[i])) - plt.figure() - plt.scatter(np.real(eigval), np.imag(eigval)) - plt.plot([0, 0], [-1, 1], '--') - plt.xlabel('Real') - plt.ylabel('Imaginary') - plt.title(f'FP {i}') - plt.show() +if finder.num_fps > 0: + jac = finder.compute_jacobians(finder.fixed_points, plot=True) # verify -for i, fp in enumerate(finder.fixed_points): - model.V[:] = fp[0] - model.m[:] = fp[1] - model.h[:] = fp[2] - model.n[:] = fp[3] - runner = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V']) +for i in range(finder.num_fps): + model = bp.dyn.neurons.HH(1) + model.V[:] = finder._fixed_points['V'][i] + model.m[:] = finder._fixed_points['m'][i] + model.h[:] = finder._fixed_points['h'][i] + model.n[:] = finder._fixed_points['n'][i] + runner = bp.dyn.DSRunner(model, inputs=(model.input, I), monitors=['V']) runner.run(100) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', title=f'FP {i}', show=True) diff --git a/examples/analysis/highdim_CANN.py b/examples/analysis/highdim_CANN.py index 229e26694..1b09cc8b7 100644 --- a/examples/analysis/highdim_CANN.py +++ b/examples/analysis/highdim_CANN.py @@ -3,7 +3,9 @@ import matplotlib.pyplot as plt import numpy as np from sklearn.decomposition import PCA +import sys +# sys.path.append('/mnt/d/codes/Projects/brainpy-chaoming0625') import brainpy as bp import brainpy.math as bm @@ -64,53 +66,49 @@ def make_conn(self, x): def get_stimulus_by_pos(self, pos): return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a)) - def update(self, t, dt): - self.u[:] = self.integral(self.u, t, self.input) + def update(self, tdi): + t, dt = tdi.get('t'), tdi.get('dt') + self.u.value = self.integral(self.u, t, self.input, dt) self.input[:] = 0. - def cell(self, u): - return self.derivative(u, 0., 0.) - - -k = 0.1 -a = 0.5 -A = 10 -fps_output_fn = f'data/fps,k={k},a={a},A={A},f32,BFGS,randominit.npy' - - -def find_fixed_points(): - cann = CANN1D(num=512, k=k, A=A, a=a) - - candidates = cann.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.01).reshape((-1, 1))) - candidates += bm.random.normal(0., 0.01, candidates.shape) - - # candidates = bm.random.uniform(0, 20., (1000, cann.num)) - - finder = bp.analysis.SlowPointFinder(f_cell=cann, included_vars={'u': cann.u}) - # finder.find_fps_with_gd_method( - # candidates=candidates, - # tolerance=1e-6, - # optimizer = bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.1, , 0.999)), - # num_batch=200 - # ) - finder.find_fps_with_opt_solver({'u': candidates}) - finder.filter_loss(1e-5) - finder.keep_unique() - # finder.exclude_outliers() - - np.save(fps_output_fn, finder.fixed_points) - - print(finder.fixed_points) - print(finder.losses) - # print(finder.selected_ids) - - -def visualize_fixed_points(): - fixed_points = np.load(fps_output_fn) +def find_fixed_points(pars=None, verbose=False, opt_method='gd', cand_method='random', tolerance=1e-6): + if pars is None: pars = dict() + cann = CANN1D(num=512, **pars) + + if cand_method == 'random': + candidates = bm.random.uniform(0, 20., (1000, cann.num)) + elif cand_method == 'bump': + candidates = cann.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.01).reshape((-1, 1))) + candidates += bm.random.normal(0., 0.01, candidates.shape) + else: + raise ValueError + + finder = bp.analysis.SlowPointFinder(f_cell=cann, included_vars={'u': cann.u}, dt=1.) + if opt_method == 'gd': + finder.find_fps_with_gd_method( + candidates={'u': candidates}, tolerance=tolerance, num_batch=200, + optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.2, 1, 0.999)), + ) + elif opt_method == 'BFGS': + finder.find_fps_with_opt_solver({'u': candidates}) + else: + raise ValueError + finder.filter_loss(tolerance) + finder.keep_unique(5e-3) + + if verbose: + print(finder.fixed_points) + print(finder.losses) + print(finder.selected_ids) + + return finder.fixed_points, finder + + +def visualize_fixed_points(fixed_points): bp.visualize.animate_1D( - dynamical_vars={'ys': fixed_points, - 'xs': bm.linspace(-bm.pi, bm.pi, fixed_points.shape[1]), + dynamical_vars={'ys': fixed_points['u'], + 'xs': bm.linspace(-bm.pi, bm.pi, fixed_points['u'].shape[1]), 'legend': 'fixed point'}, frame_step=1, frame_delay=100, @@ -119,13 +117,12 @@ def visualize_fixed_points(): ) -def verify_fixed_points_through_simulation(num=3): - fixed_points = np.load(fps_output_fn) - - cann = CANN1D(num=512, k=k, a=a, A=A) +def verify_fixed_points_through_simulation(fixed_points, pars=None, num=3): + if pars is None: pars = dict() + cann = CANN1D(num=512, **pars) for i in range(num): - cann.u[:] = fixed_points[i] + cann.u[:] = fixed_points['u'][i] runner = bp.dyn.DSRunner(cann, monitors=['u'], dyn_vars=cann.vars()) @@ -135,30 +132,10 @@ def verify_fixed_points_through_simulation(num=3): plt.show() -def verify_fixed_point_stability(num=3): - fixed_points = np.load(fps_output_fn) - - cann = CANN1D(num=512, k=k, a=a, A=A) - finder = bp.analysis.SlowPointFinder(f_cell=cann.cell, - f_type=bp.analysis.CONTINUOUS) - J = finder.compute_jacobians(fixed_points[:num]) - - for i in range(num): - eigval, eigvec = np.linalg.eig(np.asarray(J[i])) - plt.figure() - plt.scatter(np.real(eigval), np.imag(eigval)) - plt.plot([0, 0], [-1, 1], '--') - plt.xlabel('Real') - plt.ylabel('Imaginary') - plt.show() - - -def pca_reduction(): - fixed_points = np.load(fps_output_fn) - +def pca_reduction(fixed_points): pca = PCA(2) - pca.fit(fixed_points) - fixedpoints_pc = pca.transform(fixed_points) + pca.fit(fixed_points['u']) + fixedpoints_pc = pca.transform(fixed_points['u']) plt.plot(fixedpoints_pc[:, 0], fixedpoints_pc[:, 1], 'x', label='fixed points') plt.xlabel('PC 1') @@ -168,8 +145,12 @@ def pca_reduction(): if __name__ == '__main__': - find_fixed_points() - visualize_fixed_points() - verify_fixed_points_through_simulation() - verify_fixed_point_stability(num=6) - pca_reduction() + params = dict(k=0.1, a=0.5, A=20) + fps, finder = find_fixed_points(params, cand_method='bump', tolerance=1e-7) + # fps, finder = find_fixed_points(params, cand_method='random', opt_method='gd', tolerance=1e-7) + # fps, finder = find_fixed_points(params, cand_method='random', opt_method='BFGS', tolerance=1e-5) + visualize_fixed_points(fps) + verify_fixed_points_through_simulation(fps, params) + finder.compute_jacobians(fps['u'][:6], plot=True) + pca_reduction(fps) + diff --git a/examples/analysis/highdim_RNN_Analysis.py b/examples/analysis/highdim_RNN_Analysis.py index 80fa55d65..d542c61d2 100644 --- a/examples/analysis/highdim_RNN_Analysis.py +++ b/examples/analysis/highdim_RNN_Analysis.py @@ -89,15 +89,15 @@ def __init__(self, num_input, num_hidden, num_output, num_batch, dt=None, seed=N self.rng = bm.random.RandomState(seed=seed) # input weight - self.w_ir = bm.TrainVar(bp.init.init_param(w_ir, (num_input, num_hidden))) + self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden))) # recurrent weight bound = 1 / num_hidden ** 0.5 - self.w_rr = bm.TrainVar(bp.init.init_param(w_rr, (num_hidden, num_hidden))) + self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden))) self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden)) # readout weight - self.w_ro = bm.TrainVar(bp.init.init_param(w_ro, (num_hidden, num_output))) + self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (num_hidden, num_output))) self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output)) # variables diff --git a/examples/analysis/highdim_gj_coupled_fhn.py b/examples/analysis/highdim_gj_coupled_fhn.py index 3b39228dc..bbeccdc59 100644 --- a/examples/analysis/highdim_gj_coupled_fhn.py +++ b/examples/analysis/highdim_gj_coupled_fhn.py @@ -39,7 +39,8 @@ def dw(self, w, t, V): dw = (V + self.a - self.b * w) / self.tau return dw - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi.get('t'), tdi.get('dt') self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt) self.w.value = self.int_w(self.w, t, self.V, dt) self.Iext[:] = 0. diff --git a/examples/simulation/Brette_2007_COBA.py b/examples/simulation/Brette_2007_COBA.py index 2065fdfcb..d47a13ca4 100644 --- a/examples/simulation/Brette_2007_COBA.py +++ b/examples/simulation/Brette_2007_COBA.py @@ -1,44 +1,48 @@ # -*- coding: utf-8 -*- import brainpy as bp +import brainpy.math as bm bp.math.set_platform('cpu') class EINet(bp.dyn.Network): def __init__(self, scale=1.0, method='exp_auto'): + super(EINet, self).__init__() + # network size num_exc = int(3200 * scale) num_inh = int(800 * scale) # neurons pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - E = bp.dyn.LIF(num_exc, **pars, method=method) - I = bp.dyn.LIF(num_inh, **pars, method=method) - E.V[:] = bp.math.random.randn(num_exc) * 2 - 55. - I.V[:] = bp.math.random.randn(num_inh) * 2 - 55. + self.E = bp.dyn.LIF(num_exc, **pars, method=method) + self.I = bp.dyn.LIF(num_inh, **pars, method=method) + self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. + self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - E2E = bp.dyn.ExpCOBA(E, E, bp.conn.FixedProb(0.02), - E=0., g_max=we, tau=5., method=method) - E2I = bp.dyn.ExpCOBA(E, I, bp.conn.FixedProb(0.02), - E=0., g_max=we, tau=5., method=method) - I2E = bp.dyn.ExpCOBA(I, E, bp.conn.FixedProb(0.02), - E=-80., g_max=wi, tau=10., method=method) - I2I = bp.dyn.ExpCOBA(I, I, bp.conn.FixedProb(0.02), - E=-80., g_max=wi, tau=10., method=method) - - super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) + prob = 0.1 + we = 0.6 / scale / (prob / 0.02)**2 # excitatory synaptic weight (voltage) + wi = 6.7 / scale / (prob / 0.02)**2 # inhibitory synaptic weight + self.E2E = bp.dyn.ExpCOBA(self.E, self.E, bp.conn.FixedProb(prob), + E=0., g_max=we, tau=5., method=method) + self.E2I = bp.dyn.ExpCOBA(self.E, self.I, bp.conn.FixedProb(prob), + E=0., g_max=we, tau=5., method=method) + self.I2E = bp.dyn.ExpCOBA(self.I, self.E, bp.conn.FixedProb(prob), + E=-80., g_max=wi, tau=10., method=method) + self.I2I = bp.dyn.ExpCOBA(self.I, self.I, bp.conn.FixedProb(prob), + E=-80., g_max=wi, tau=10., method=method) net = EINet(scale=1., method='exp_auto') # simulation -runner = bp.dyn.DSRunner(net, - monitors=['E.spike'], - inputs=[('E.input', 20.), ('I.input', 20.)]) -runner.run(100.) +runner = bp.dyn.DSRunner( + net, + monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)] +) +runner.run(1000.) # visualization bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) diff --git a/examples/simulation/Brette_2007_COBAHH.py b/examples/simulation/Brette_2007_COBAHH.py index ce4768dae..4e824c328 100644 --- a/examples/simulation/Brette_2007_COBAHH.py +++ b/examples/simulation/Brette_2007_COBAHH.py @@ -19,21 +19,22 @@ def __init__(self, scale=1.): super(EINet, self).__init__() self.E = HH(int(3200 * scale)) self.I = HH(int(800 * scale)) - self.E2E = synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob=0.02), + prob = 0.02 + self.E2E = synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob), g_max=0.03 / scale, tau=5, output=synouts.COBA(E=0.)) - self.E2I = synapses.Exponential(self.E, self.I, bp.conn.FixedProb(prob=0.02), + self.E2I = synapses.Exponential(self.E, self.I, bp.conn.FixedProb(prob), g_max=0.03 / scale, tau=5., output=synouts.COBA(E=0.)) - self.I2E = synapses.Exponential(self.I, self.E, bp.conn.FixedProb(prob=0.02), + self.I2E = synapses.Exponential(self.I, self.E, bp.conn.FixedProb(prob), g_max=0.335 / scale, tau=10., output=synouts.COBA(E=-80)) - self.I2I = synapses.Exponential(self.I, self.I, bp.conn.FixedProb(prob=0.02), + self.I2I = synapses.Exponential(self.I, self.I, bp.conn.FixedProb(prob), g_max=0.335 / scale, tau=10., output=synouts.COBA(E=-80.)) net = EINet(scale=1) -runner = bp.dyn.DSRunner(net, monitors=['E.spike']) +runner = bp.dyn.DSRunner(net, monitors={'E.spike': net.E.spike}) runner.run(100.) bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) diff --git a/examples/simulation/JR_1995_jansen_rit_model.py b/examples/simulation/JR_1995_jansen_rit_model.py index 9f677dd7e..a80660cc1 100644 --- a/examples/simulation/JR_1995_jansen_rit_model.py +++ b/examples/simulation/JR_1995_jansen_rit_model.py @@ -100,7 +100,8 @@ def dy4(self, y4, t, y0, y1, p): def dy5(self, y5, t, y0, y2): return (self.B * self.C4 * self.sigmoid(self.C3 * y0) - 2 * y5 - y2 / self.tau_i) / self.tau_i - def update(self, t, dt): + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] self.y0.value, self.y1.value, self.y2.value, self.y3.value, self.y4.value, self.y5.value = \ self.integral(self.y0, self.y1, self.y2, self.y3, self.y4, self.y5, t, p=self.p, dt=dt) diff --git a/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py b/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py index 0d93457a7..01cf44193 100644 --- a/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py +++ b/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py @@ -99,7 +99,7 @@ def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ): IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca) -class MgBlock(bp.dyn.SynapseOutput): +class MgBlock(bp.dyn.SynOutput): def __init__(self, E=0.): super(MgBlock, self).__init__() self.E = E @@ -111,9 +111,13 @@ def filter(self, g): class Thalamus(bp.dyn.Network): def __init__( - self, g_input: Dict[str, float], g_KL: Dict[str, float], - HTC_V_init=bp.init.OneInit(-65.), RTC_V_init=bp.init.OneInit(-65.), - IN_V_init=bp.init.OneInit(-70.), RE_V_init=bp.init.OneInit(-70.), + self, + g_input: Dict[str, float], + g_KL: Dict[str, float], + HTC_V_init=bp.init.OneInit(-65.), + RTC_V_init=bp.init.OneInit(-65.), + IN_V_init=bp.init.OneInit(-70.), + RE_V_init=bp.init.OneInit(-70.), ): super(Thalamus, self).__init__() @@ -144,19 +148,19 @@ def __init__( # HTC cells were connected with gap junctions self.gj_HTC = synapses.GapJunction(self.HTC, self.HTC, bp.conn.ProbDist(dist=2., prob=0.3, ), - conn_type='sparse', + comp_method='sparse', g_max=1e-2) # HTC provides feedforward excitation to INs self.HTC2IN_ampa = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), alpha=0.94, beta=0.18, g_max=6e-3) self.HTC2IN_nmda = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=MgBlock(), alpha=1., beta=0.0067, @@ -165,7 +169,7 @@ def __init__( # INs delivered feedforward inhibition to RTC cells self.IN2RTC = synapses.GABAa(self.IN, self.RTC, bp.conn.FixedProb(0.3), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=synouts.COBA(E=-80), alpha=10.5, beta=0.166, @@ -174,47 +178,47 @@ def __init__( # 20% RTC cells electrically connected with HTC cells self.gj_RTC2HTC = synapses.GapJunction(self.RTC, self.HTC, bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2), - conn_type='sparse', + comp_method='sparse', g_max=1 / 300) # Both HTC and RTC cells sent glutamatergic synapses to RE neurons, while # receiving GABAergic feedback inhibition from the RE population self.HTC2RE_ampa = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), alpha=0.94, beta=0.18, g_max=4e-3) self.RTC2RE_ampa = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), alpha=0.94, beta=0.18, g_max=4e-3) self.HTC2RE_nmda = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=MgBlock(), alpha=1., beta=0.0067, g_max=2e-3) self.RTC2RE_nmda = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=MgBlock(), alpha=1., beta=0.0067, g_max=2e-3) self.RE2HTC = synapses.GABAa(self.RE, self.HTC, bp.conn.FixedProb(0.2), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=synouts.COBA(E=-80), alpha=10.5, beta=0.166, g_max=3e-3) self.RE2RTC = synapses.GABAa(self.RE, self.RTC, bp.conn.FixedProb(0.2), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=synouts.COBA(E=-80), alpha=10.5, beta=0.166, @@ -223,11 +227,11 @@ def __init__( # RE neurons were connected with both gap junctions and GABAergic synapses self.gj_RE = synapses.GapJunction(self.RE, self.RE, bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2), - conn_type='sparse', + comp_method='sparse', g_max=1 / 300) self.RE2RE = synapses.GABAa(self.RE, self.RE, bp.conn.FixedProb(0.2), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=synouts.COBA(E=-70), alpha=10.5, beta=0.166, g_max=1e-3) @@ -236,7 +240,7 @@ def __init__( # probability (0.05) was used for the RE->IN synapses according to experimental data self.RE2IN = synapses.GABAa(self.RE, self.IN, bp.conn.FixedProb(0.05, pre_ratio=0.1), delay_step=int(2 / bm.get_dt()), - plasticity=synplast.STD(tau=700, U=0.07), + stp=synplast.STD(tau=700, U=0.07), output=synouts.COBA(E=-80), alpha=10.5, beta=0.166, g_max=1e-3, ) diff --git a/examples/simulation/Vreeswijk_1996_EI_net.py b/examples/simulation/Vreeswijk_1996_EI_net.py new file mode 100644 index 000000000..168eb1e6c --- /dev/null +++ b/examples/simulation/Vreeswijk_1996_EI_net.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +import brainpy as bp +import brainpy.math as bm + +bm.set_platform('cpu') + + +class EINet(bp.dyn.Network): + def __init__(self, num_exc, num_inh, prob, JE, JI): + # neurons + pars = dict(V_rest=-52., V_th=-50., V_reset=-60., tau=10., tau_ref=0.) + E = bp.neurons.LIF(num_exc, **pars) + I = bp.neurons.LIF(num_inh, **pars) + E.V[:] = bm.random.random(num_exc) * (E.V_th - E.V_rest) + E.V_rest + I.V[:] = bm.random.random(num_inh) * (E.V_th - E.V_rest) + E.V_rest + + # synapses + E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob), g_max=JE, tau=2., + output=bp.synouts.CUBA()) + E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob), g_max=JE, tau=2., + output=bp.synouts.CUBA()) + I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob), g_max=JI, tau=2., + output=bp.synouts.CUBA()) + I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob), g_max=JI, tau=2., + output=bp.synouts.CUBA()) + + super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) + + +num_exc = 500 +num_inh = 500 +prob = 0.5 + +Ib = 3. +JE = 1 / bp.math.sqrt(prob * num_exc) +JI = -1 / bp.math.sqrt(prob * num_inh) + +net = EINet(num_exc, num_inh, prob=prob, JE=JE, JI=JI) + +runner = bp.dyn.DSRunner(net, + monitors=['E.spike'], + inputs=[('E.input', Ib), ('I.input', Ib)]) +t = runner.run(1000.) + +import matplotlib.pyplot as plt + +fig, gs = bp.visualize.get_figure(4, 1, 2, 10) + +fig.add_subplot(gs[:3, 0]) +bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], xlim=(50, 950)) + +fig.add_subplot(gs[3, 0]) +rates = bp.measure.firing_rate(runner.mon['E.spike'], 5.) +plt.plot(runner.mon.ts, rates) +plt.xlim(50, 950) +plt.show() diff --git a/examples/simulation/Wang_2002_decision_making_spiking.py b/examples/simulation/Wang_2002_decision_making_spiking.py index 85dee408f..5b6436f69 100644 --- a/examples/simulation/Wang_2002_decision_making_spiking.py +++ b/examples/simulation/Wang_2002_decision_making_spiking.py @@ -1,51 +1,17 @@ # -*- coding: utf-8 -*- +import matplotlib.pyplot as plt + import brainpy as bp import brainpy.math as bm +from brainpy.dyn import synapses, synouts bm.set_platform('cpu') -import matplotlib.pyplot as plt - - -class LIF(bp.dyn.NeuGroup): - def __init__(self, size, V_L=-70., V_reset=-55., V_th=-50., - Cm=0.5, gL=0.025, t_refractory=2.): - super(LIF, self).__init__(size=size) - - # parameters - self.V_L = V_L - self.V_reset = V_reset - self.V_th = V_th - self.Cm = Cm - self.gL = gL - self.t_refractory = t_refractory - - # variables - self.V = bm.Variable(bm.ones(self.num) * V_L) - self.input = bm.Variable(bm.zeros(self.num)) - self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) - self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) - self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) - self.I = bm.Variable(bm.zeros(self.num)) - - # functions - self.integral = bp.odeint(lambda V, t: (- self.gL * (V - self.V_L) + self.input) / self.Cm) - - def update(self, t, dt): - ref = (t - self.t_last_spike) <= self.t_refractory - V = self.integral(self.V, t, dt) - V = bm.where(ref, self.V, V) - spike = (V >= self.V_th) - self.V.value = bm.where(spike, self.V_reset, V) - self.spike.value = spike - self.t_last_spike.value = bm.where(spike, t, self.t_last_spike) - self.refractory.value = bm.logical_or(spike, ref) - self.input[:] = 0. class PoissonStim(bp.dyn.NeuGroup): - def __init__(self, size, freq_mean, freq_var, t_interval, **kwargs): - super(PoissonStim, self).__init__(size=size, **kwargs) + def __init__(self, size, freq_mean, freq_var, t_interval, trainable=False): + super(PoissonStim, self).__init__(size=size, trainable=trainable) # parameters self.freq_mean = freq_mean @@ -54,22 +20,30 @@ def __init__(self, size, freq_mean, freq_var, t_interval, **kwargs): self.dt = bm.get_dt() / 1000. # variables - self.freq = bm.Variable(bm.zeros(1)) - self.freq_t_last_change = bm.Variable(bm.ones(1) * -1e7) - self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) + self.freq = bp.init.variable(bm.zeros, trainable, 1) + self.freq_t_last_change = bp.init.variable(lambda s: bm.ones(s) * -1e7, trainable, 1) + self.spike = bp.init.variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) self.rng = bm.random.RandomState() - def update(self, t, dt): + def reset_state(self, batch_size=None): + self.freq.value = bp.init.variable(bm.zeros, batch_size, 1) + self.freq_t_last_change.value = bp.init.variable(lambda s: bm.ones(s) * -1e7, batch_size, 1) + self.spike.value = bp.init.variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] in_interval = bm.logical_and(pre_stimulus_period < t, t < pre_stimulus_period + stimulus_period) - prev_freq = bm.where(in_interval, self.freq[0], 0.) - in_interval = bm.logical_and(in_interval, (t - self.freq_t_last_change[0]) >= self.t_interval) - self.freq[0] = bm.where(in_interval, self.rng.normal(self.freq_mean, self.freq_var), prev_freq) - self.freq_t_last_change[0] = bm.where(in_interval, t, self.freq_t_last_change[0]) - self.spike.value = self.rng.random(self.num) < self.freq[0] * self.dt + in_interval = bm.ones_like(self.freq, dtype=bool) * in_interval + prev_freq = bm.where(in_interval, self.freq, 0.) + in_interval = bm.logical_and(in_interval, (t - self.freq_t_last_change) >= self.t_interval) + self.freq.value = bm.where(in_interval, self.rng.normal(self.freq_mean, self.freq_var, self.freq.shape), prev_freq) + self.freq_t_last_change.value = bm.where(in_interval, t, self.freq_t_last_change) + shape = (self.spike.shape[:1] + self.varshape) if self.trainable else self.varshape + self.spike.value = self.rng.random(shape) < self.freq * self.dt class DecisionMaking(bp.dyn.Network): - def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15): + def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, batching=False): super(DecisionMaking, self).__init__() num_exc = int(1600 * scale) @@ -91,72 +65,136 @@ def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15): g_I2E_GABAa = 1.3 / scale # nS g_I2I_GABAa = 1.0 / scale # nS - ampa_par = dict(delay_step=int(0.5 / bm.get_dt()), E=0., tau=2.0) - gaba_par = dict(delay_step=int(0.5 / bm.get_dt()), E=-70., tau=5.0) - nmda_par = dict(delay_step=int(0.5 / bm.get_dt()), tau_decay=100, tau_rise=2., E=0., cc_Mg=1., a=0.5) + ampa_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=2.0) + gaba_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=5.0) + nmda_par = dict(delay_step=int(0.5 / bm.get_dt()), tau_decay=100, tau_rise=2., a=0.5) # E neurons/pyramid neurons - A = LIF(num_A, Cm=500., gL=25., t_refractory=2.) - B = LIF(num_B, Cm=500., gL=25., t_refractory=2.) - N = LIF(num_N, Cm=500., gL=25., t_refractory=2.) + A = bp.dyn.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, + tau_ref=2., V_initializer=bp.init.OneInit(-70.), trainable=batching) + B = bp.dyn.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, + tau_ref=2., V_initializer=bp.init.OneInit(-70.), trainable=batching) + N = bp.dyn.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, + tau_ref=2., V_initializer=bp.init.OneInit(-70.), trainable=batching) # I neurons/interneurons - I = LIF(num_inh, Cm=200., gL=20., t_refractory=1.) + I = bp.dyn.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05, + tau_ref=1., V_initializer=bp.init.OneInit(-70.), trainable=batching) # poisson stimulus - IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence) - IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence) + IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence, trainable=batching) + IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence, trainable=batching) # noise neurons - self.noise_B = bp.dyn.PoissonGroup(num_B, freqs=poisson_freq) - self.noise_A = bp.dyn.PoissonGroup(num_A, freqs=poisson_freq) - self.noise_N = bp.dyn.PoissonGroup(num_N, freqs=poisson_freq) - self.noise_I = bp.dyn.PoissonGroup(num_inh, freqs=poisson_freq) + self.noise_B = bp.dyn.PoissonGroup(num_B, freqs=poisson_freq, trainable=batching) + self.noise_A = bp.dyn.PoissonGroup(num_A, freqs=poisson_freq, trainable=batching) + self.noise_N = bp.dyn.PoissonGroup(num_N, freqs=poisson_freq, trainable=batching) + self.noise_I = bp.dyn.PoissonGroup(num_inh, freqs=poisson_freq, trainable=batching) # define external inputs - self.IA2A = bp.dyn.ExpCOBA(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par) - self.IB2B = bp.dyn.ExpCOBA(IB, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par) + self.IA2A = synapses.Exponential(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, + trainable=batching, output=synouts.COBA(E=0.), + **ampa_par) + self.IB2B = synapses.Exponential(IB, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, + trainable=batching, output=synouts.COBA(E=0.), + **ampa_par) # define E->E/I conn - self.N2B_AMPA = bp.dyn.ExpCOBA(N, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par) - self.N2A_AMPA = bp.dyn.ExpCOBA(N, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par) - self.N2N_AMPA = bp.dyn.ExpCOBA(N, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par) - self.N2I_AMPA = bp.dyn.ExpCOBA(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par) - self.N2B_NMDA = bp.dyn.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par) - self.N2A_NMDA = bp.dyn.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par) - self.N2N_NMDA = bp.dyn.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par) - self.N2I_NMDA = bp.dyn.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par) - - self.B2B_AMPA = bp.dyn.ExpCOBA(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, **ampa_par) - self.B2A_AMPA = bp.dyn.ExpCOBA(B, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par) - self.B2N_AMPA = bp.dyn.ExpCOBA(B, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par) - self.B2I_AMPA = bp.dyn.ExpCOBA(B, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par) - self.B2B_NMDA = bp.dyn.NMDA(B, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, **nmda_par) - self.B2A_NMDA = bp.dyn.NMDA(B, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par) - self.B2N_NMDA = bp.dyn.NMDA(B, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par) - self.B2I_NMDA = bp.dyn.NMDA(B, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par) - - self.A2B_AMPA = bp.dyn.ExpCOBA(A, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, **ampa_par) - self.A2A_AMPA = bp.dyn.ExpCOBA(A, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, **ampa_par) - self.A2N_AMPA = bp.dyn.ExpCOBA(A, N, bp.conn.All2All(), g_max=g_E2E_AMPA, **ampa_par) - self.A2I_AMPA = bp.dyn.ExpCOBA(A, I, bp.conn.All2All(), g_max=g_E2I_AMPA, **ampa_par) - self.A2B_NMDA = bp.dyn.NMDA(A, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, **nmda_par) - self.A2A_NMDA = bp.dyn.NMDA(A, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, **nmda_par) - self.A2N_NMDA = bp.dyn.NMDA(A, N, bp.conn.All2All(), g_max=g_E2E_NMDA, **nmda_par) - self.A2I_NMDA = bp.dyn.NMDA(A, I, bp.conn.All2All(), g_max=g_E2I_NMDA, **nmda_par) + self.N2B_AMPA = synapses.Exponential(N, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) + self.N2A_AMPA = synapses.Exponential(N, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) + self.N2N_AMPA = synapses.Exponential(N, N, bp.conn.All2All(), g_max=g_E2E_AMPA, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) + self.N2I_AMPA = synapses.Exponential(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) + self.N2B_NMDA = bp.dyn.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.N2A_NMDA = bp.dyn.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.N2N_NMDA = bp.dyn.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.N2I_NMDA = bp.dyn.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + + self.B2B_AMPA = synapses.Exponential(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.B2A_AMPA = synapses.Exponential(B, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.B2N_AMPA = synapses.Exponential(B, N, bp.conn.All2All(), g_max=g_E2E_AMPA, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.B2I_AMPA = synapses.Exponential(B, I, bp.conn.All2All(), g_max=g_E2I_AMPA, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.B2B_NMDA = synapses.NMDA(B, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.B2A_NMDA = synapses.NMDA(B, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.B2N_NMDA = synapses.NMDA(B, N, bp.conn.All2All(), g_max=g_E2E_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.B2I_NMDA = synapses.NMDA(B, I, bp.conn.All2All(), g_max=g_E2I_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + + self.A2B_AMPA = synapses.Exponential(A, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.A2A_AMPA = synapses.Exponential(A, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.A2N_AMPA = synapses.Exponential(A, N, bp.conn.All2All(), g_max=g_E2E_AMPA, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.A2I_AMPA = synapses.Exponential(A, I, bp.conn.All2All(), g_max=g_E2I_AMPA, + output=synouts.COBA(E=0.), + trainable=batching, **ampa_par) + self.A2B_NMDA = synapses.NMDA(A, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.A2A_NMDA = synapses.NMDA(A, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.A2N_NMDA = synapses.NMDA(A, N, bp.conn.All2All(), g_max=g_E2E_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) + self.A2I_NMDA = synapses.NMDA(A, I, bp.conn.All2All(), g_max=g_E2I_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), + trainable=batching, **nmda_par) # define I->E/I conn - self.I2B = bp.dyn.ExpCOBA(I, B, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par) - self.I2A = bp.dyn.ExpCOBA(I, A, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par) - self.I2N = bp.dyn.ExpCOBA(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa, **gaba_par) - self.I2I = bp.dyn.ExpCOBA(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa, **gaba_par) + self.I2B = synapses.Exponential(I, B, bp.conn.All2All(), g_max=g_I2E_GABAa, + output=synouts.COBA(E=-70.), + trainable=batching, **gaba_par) + self.I2A = synapses.Exponential(I, A, bp.conn.All2All(), g_max=g_I2E_GABAa, + output=synouts.COBA(E=-70.), + trainable=batching, **gaba_par) + self.I2N = synapses.Exponential(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa, + output=synouts.COBA(E=-70.), + trainable=batching, **gaba_par) + self.I2I = synapses.Exponential(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa, + output=synouts.COBA(E=-70.), + trainable=batching, **gaba_par) # define external projections - self.noise2B = bp.dyn.ExpCOBA(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par) - self.noise2A = bp.dyn.ExpCOBA(self.noise_A, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par) - self.noise2N = bp.dyn.ExpCOBA(self.noise_N, N, bp.conn.One2One(), g_max=g_ext2E_AMPA, **ampa_par) - self.noise2I = bp.dyn.ExpCOBA(self.noise_I, I, bp.conn.One2One(), g_max=g_ext2I_AMPA, **ampa_par) + self.noise2B = synapses.Exponential(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) + self.noise2A = synapses.Exponential(self.noise_A, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) + self.noise2N = synapses.Exponential(self.noise_N, N, bp.conn.One2One(), g_max=g_ext2E_AMPA, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) + self.noise2I = synapses.Exponential(self.noise_I, I, bp.conn.One2One(), g_max=g_ext2I_AMPA, + output=synouts.COBA(E=0.), trainable=batching, **ampa_par) # nodes self.B = B @@ -167,62 +205,110 @@ def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15): self.IB = IB -net = DecisionMaking(scale=1., coherence=0., mu0=50.) +def visualize_raster(ax, mon, t_start=0., title=None): + bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax, color='', label="Group A") + bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax, color='', label="Group B") + if title: + ax.set_title(title) + ax.set_ylabel("Neuron Index") + ax.set_xlim(t_start, total_period + 1) + ax.axvline(pre_stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') + ax.legend() + + +def visualize_results(axes, mon, t_start=0., title=None): + ax = axes[0] + bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax) + if title: + ax.set_title(title) + ax.set_ylabel("Group A") + ax.set_xlim(t_start, total_period + 1) + ax.axvline(pre_stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') + + ax = axes[1] + bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax) + ax.set_ylabel("Group B") + ax.set_xlim(t_start, total_period + 1) + ax.axvline(pre_stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') + + ax = axes[2] + rateA = bp.measure.firing_rate(mon['A.spike'], width=10.) + rateB = bp.measure.firing_rate(mon['B.spike'], width=10.) + ax.plot(mon['ts'], rateA, label="Group A") + ax.plot(mon['ts'], rateB, label="Group B") + ax.set_ylabel('Population activity [Hz]') + ax.set_xlim(t_start, total_period + 1) + ax.axvline(pre_stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') + ax.legend() + + ax = axes[3] + ax.plot(mon['ts'], mon['IA.freq'], label="group A") + ax.plot(mon['ts'], mon['IB.freq'], label="group B") + ax.set_ylabel("Input activity [Hz]") + ax.set_xlim(t_start, total_period + 1) + ax.axvline(pre_stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') + ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') + ax.legend() + ax.set_xlabel("Time [ms]") + -runner = bp.dyn.DSRunner( - net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'] -) pre_stimulus_period = 100. stimulus_period = 1000. delay_period = 500. total_period = pre_stimulus_period + stimulus_period + delay_period -t = runner(total_period) -print(f'Used time: {t} s') - -fig, gs = bp.visualize.get_figure(4, 1, 3, 10) - -t_start = 0. -fig.add_subplot(gs[0, 0]) -bp.visualize.raster_plot(runner.mon.ts, runner.mon['A.spike'], markersize=1) -plt.title("Spiking activity of group A") -plt.ylabel("Neuron Index") -plt.xlim(t_start, total_period + 1) -plt.axvline(pre_stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') - -fig.add_subplot(gs[1, 0]) -bp.visualize.raster_plot(runner.mon.ts, runner.mon['B.spike'], markersize=1) -plt.title("Spiking activity of group B") -plt.ylabel("Neuron Index") -plt.xlim(t_start, total_period + 1) -plt.axvline(pre_stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') - -fig.add_subplot(gs[2, 0]) -rateA = bp.measure.firing_rate(runner.mon['A.spike'], width=10.) -rateB = bp.measure.firing_rate(runner.mon['B.spike'], width=10.) -plt.plot(runner.mon.ts, rateA, label="Group A") -plt.plot(runner.mon.ts, rateB, label="Group B") -plt.ylabel('Firing rate [Hz]') -plt.title("Population activity") -plt.xlim(t_start, total_period + 1) -plt.axvline(pre_stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') -plt.legend() - -fig.add_subplot(gs[3, 0]) -plt.plot(runner.mon.ts, runner.mon['IA.freq'], label="group A") -plt.plot(runner.mon.ts, runner.mon['IB.freq'], label="group B") -plt.title("Input activity") -plt.ylabel("Firing rate [Hz]") -plt.xlim(t_start, total_period + 1) -plt.axvline(pre_stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') -plt.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') -plt.legend() - -plt.xlabel("Time [ms]") -plt.show() + + +def single_run(): + net = DecisionMaking(scale=1., coherence=-80., mu0=50.) + + runner = bp.dyn.DSRunner( + net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'] + ) + runner.run(total_period) + + fig, gs = bp.visualize.get_figure(4, 1, 3, 10) + axes = [fig.add_subplot(gs[i, 0]) for i in range(4)] + visualize_results(axes, mon=runner.mon) + plt.show() + + +def batching_run(): + num_row, num_col = 3, 4 + num_batch = 12 + coherence = bm.expand_dims(bm.linspace(-100, 100., num_batch), 1) + net = DecisionMaking(scale=1., coherence=coherence, mu0=20., batching=True) + net.reset_state(batch_size=num_batch) + + runner = bp.dyn.DSRunner( + net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'] + ) + runner.run(total_period) + + coherence = coherence.to_numpy() + fig, gs = bp.visualize.get_figure(num_row, num_col, 3, 4) + for i in range(num_row): + for j in range(num_col): + idx = i * num_col + j + if idx < num_batch: + mon = {'A.spike': runner.mon['A.spike'][:, idx], + 'B.spike': runner.mon['B.spike'][:, idx], + 'IA.freq': runner.mon['IA.freq'][:, idx], + 'IB.freq': runner.mon['IB.freq'][:, idx], + 'ts': runner.mon['ts']} + ax = fig.add_subplot(gs[i, j]) + visualize_raster(ax, mon=mon, title=f'coherence={coherence[idx, 0]}%') + plt.show() + + +if __name__ == '__main__': + # single_run() + batching_run() diff --git a/examples/simulation/hh_model.py b/examples/simulation/hh_model.py index e43b75079..1469df818 100644 --- a/examples/simulation/hh_model.py +++ b/examples/simulation/hh_model.py @@ -18,9 +18,11 @@ def __init__(self, size): I, length = bp.inputs.section_input(values=[0, 5, 0], durations=[100, 500, 100], return_length=True) -runner = bp.dyn.DSRunner(hh, - monitors=['V', 'INa.p', 'INa.q', 'IK.p'], - inputs=['input', I, 'iter']) +runner = bp.dyn.DSRunner( + hh, + monitors=['V', 'INa.p', 'INa.q', 'IK.p'], + inputs=[hh.input, I, 'iter'], +) runner.run(length) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) diff --git a/examples/simulation/multi_scale_COBAHH.py b/examples/simulation/multi_scale_COBAHH.py new file mode 100644 index 000000000..5168319dd --- /dev/null +++ b/examples/simulation/multi_scale_COBAHH.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- + +import numpy as np + +import brainpy as bp +import brainpy.math as bm +from brainpy.dyn.channels import INa_TM1991, IL +from brainpy.dyn.synapses import Exponential +from brainpy.dyn.synouts import COBA + + +class IK2(bp.dyn.channels.IK_p4_markov): + def __init__(self, size, E=-90., g_max=10., phi=1., V_sh=-50.): + super(IK2, self).__init__(size, g_max=g_max, phi=phi, E=E) + self.V_sh = V_sh + + def f_p_alpha(self, V): + tmp = V - self.V_sh - 15. + return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + + def f_p_beta(self, V): + return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) + + +class IK(bp.dyn.Channel): + def __init__(self, size, E=-90., g_max=10., phi=1., V_sh=-50.): + super(IK, self).__init__(size) + self.g_max, self.E, self.V_sh, self.phi = g_max, E, V_sh, phi + self.p = bm.Variable(bm.zeros(size)) + self.integral = bp.odeint(self.dp, method='exp_euler') + + def dp(self, p, t, V): + tmp = V - self.V_sh - 15. + alpha = 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + beta = 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) + return self.phi * (alpha * (1. - p) - beta * p) + + def update(self, tdi, V): + self.p.value = self.integral(self.p, tdi.t, V, dt=tdi.dt) + + def current(self, V): + return self.g_max * self.p ** 4 * (self.E - V) + + +class HH(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(HH, self).__init__(size, ) + self.INa = INa_TM1991(size, g_max=100., V_sh=-63.) + self.IK = IK(size, g_max=30., V_sh=-63.) + self.IL = IL(size, E=-60., g_max=0.05) + + +class Network(bp.dyn.Network): + def __init__(self, num_E, num_I, ): + super(Network, self).__init__() + self.E = HH(num_E) + self.I = HH(num_I) + self.E2E = Exponential(self.E, self.E, bp.conn.FixedProb(0.02), + g_max=0.03, tau=5, output=COBA(E=0.)) + self.E2I = Exponential(self.E, self.I, bp.conn.FixedProb(0.02), + g_max=0.03, tau=5., output=COBA(E=0.)) + self.I2E = Exponential(self.I, self.E, bp.conn.FixedProb(0.02), + g_max=0.335, tau=10., output=COBA(E=-80)) + self.I2I = Exponential(self.I, self.I, bp.conn.FixedProb(0.02), + g_max=0.335, tau=10., output=COBA(E=-80.)) + + +class Projection(bp.dyn.DynamicalSystem): + def __init__(self, pre, post, delay, conn, g_max=0.03, tau=5.): + super(Projection, self).__init__() + self.pre = pre + self.post = post + + g_max = conn * g_max + self.E2E = Exponential(pre.E, post.E, bp.conn.FixedProb(0.02), + delay_step=delay, g_max=g_max, tau=tau, + output=COBA(0.)) + self.E2I = Exponential(pre.E, post.I, bp.conn.FixedProb(0.02), + delay_step=delay, g_max=g_max, tau=tau, + output=COBA(0.)) + + def update(self, tdi): + self.E2E.update(tdi) + self.E2I.update(tdi) + + +class Circuit(bp.dyn.Network): + def __init__(self, conn, delay): + super(Circuit, self).__init__() + + num_area = conn.shape[0] + self.areas = [Network(3200, 800) for _ in range(num_area)] + self.projections = [] + for i in range(num_area): + for j in range(num_area): + if i != j: + proj = Projection(self.areas[j], self.areas[i], + delay=delay[i, j], conn=conn[i, j]) + self.projections.append(proj) + self.register_implicit_nodes(self.projections, self.areas) + + +bp.math.random.seed(1234) + +data = np.load('./data/visual_conn.npz') +conn_data = data['conn'] +delay_data = (data['delay'] / bm.get_dt()).astype(int) + +circuit = Circuit(conn_data, delay_data) +f1 = lambda tdi: bm.concatenate([area.E.spike for area in circuit.areas]) +f2 = lambda tdi: bm.concatenate([area.I.spike for area in circuit.areas]) +I, duration = bp.inputs.section_input([0, 0.8, 0.], [50., 50., 100.], return_length=True) +runner = bp.dyn.DSRunner( + circuit, + monitors={'K.p': circuit.areas[0].E.IK.p, + 'A0.V': (circuit.areas[0].E.V,), + 'A0.spike': circuit.areas[0].E.spike}, + fun_monitors={'exc.spike': f1, 'inh.spike': f2}, + # inputs=[circuit.areas[0].E.input, I, 'iter'] +) +runner.run(duration) + +fig, gs = bp.visualize.get_figure(2, 1, 4, 10) +fig.add_subplot(gs[0, 0]) +bp.visualize.raster_plot(runner.mon['ts'], runner.mon.get('exc.spike')) +fig.add_subplot(gs[1, 0]) +bp.visualize.raster_plot(runner.mon['ts'], runner.mon.get('inh.spike'), show=True) + +import seaborn as sns + +sns.set_theme(font_scale=1.5) + +fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) +fig.add_subplot(gs[0, 0]) +bp.visualize.line_plot(runner.mon['ts'], runner.mon['K.p'], show=True, plot_ids=(4, 5, 1)) + +fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) +fig.add_subplot(gs[0, 0]) +bp.visualize.line_plot(runner.mon['ts'], runner.mon['A0.V'], show=True, plot_ids=(4, 5, 1)) + +fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) +fig.add_subplot(gs[0, 0]) +bp.visualize.raster_plot(runner.mon['ts'], runner.mon['A0.spike'], show=True) diff --git a/examples/simulation/whole_brain_simulation_with_sl_oscillator.py b/examples/simulation/whole_brain_simulation_with_sl_oscillator.py index 2f1269e19..1fa946a49 100644 --- a/examples/simulation/whole_brain_simulation_with_sl_oscillator.py +++ b/examples/simulation/whole_brain_simulation_with_sl_oscillator.py @@ -36,9 +36,9 @@ def __init__(self): bm.fill_diagonal(conn_mat, 0) gc = 0.6 # global coupling strength - self.sl = rates.StuartLandauOscillator(80, x_ou_sigma=0.14, y_ou_sigma=0.14, name='sl') - self.coupling = rates.DiffusiveCoupling(self.sl.x, self.sl.x, self.sl.input, - conn_mat=conn_mat * gc) + self.sl = bp.rates.StuartLandauOscillator(80, x_ou_sigma=0.14, y_ou_sigma=0.14, name='sl') + self.coupling = bp.synapses.DiffusiveCoupling(self.sl.x, self.sl.x, self.sl.input, + conn_mat=conn_mat * gc) def simulation(): diff --git a/examples/training/Bellec_2020_eprop_evidence_accumulation.py b/examples/training/Bellec_2020_eprop_evidence_accumulation.py new file mode 100644 index 000000000..231726f19 --- /dev/null +++ b/examples/training/Bellec_2020_eprop_evidence_accumulation.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- + +""" +Implementation of the paper: + +- Bellec, G., Scherr, F., Subramoney, A., Hajek, E., Salaj, D., Legenstein, R., + & Maass, W. (2020). A solution to the learning dilemma for recurrent networks + of spiking neurons. Nature communications, 11(1), 1-15. + +""" +import matplotlib.pyplot as plt +import numpy as np +import brainpy as bp +import brainpy.math as bm +from jax.lax import stop_gradient + +bm.set_dt(1.) + + +class EligSNN(bp.dyn.Network): + def __init__(self, num_in, num_rec, num_out, neuron_model='lif'): + super(EligSNN, self).__init__() + + # parameters + self.num_in = num_in + self.num_rec = num_rec + self.num_out = num_out + + # neurons + self.i = bp.neurons.InputGroup(num_in, trainable=True) + self.o = bp.neurons.LeakyIntegrator(num_out, tau=20, trainable=True) + tau_a = 2e3 + tau_v = 2e1 + n_regular = 50 + n_adaptive = num_rec - n_regular + beta_a1 = bm.exp(- bm.get_dt() / tau_a) + beta_a2 = 1.7 * (1 - beta_a1) / (1 - bm.exp(-1 / tau_v)) + self.r = bp.neurons.ALIFBellec2020( + n_regular + n_adaptive, trainable=True, + V_rest=0., tau_ref=5., V_th=0.6, tau_a=tau_a, tau=tau_v, + beta=bm.concatenate([bm.ones(n_regular), bm.ones(n_adaptive) * beta_a2]), + ) + + # synapses + self.i2r = bp.layers.Dense(num_in, num_rec, W_initializer=bp.init.KaimingNormal()) + self.r2r = bp.layers.Dense(num_rec, num_rec, W_initializer=bp.init.KaimingNormal()) + self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), + output=bp.synouts.CUBA(), tau=10., + g_max=bp.init.KaimingNormal(), + trainable=True) + + def update(self, shared, x): + self.i2r(shared, x) + self.r(shared, x=self.r2r(shared, stop_gradient(self.r.spike.value))) + self.r2o(shared, ) + self.o(shared, ) + return self.o.V.value + + +@bp.tools.numba_jit +def generate_click_task_data(batch_size, seq_len, n_neuron, recall_duration, prob, f0=0.5, + n_cues=7, t_cue=100, t_interval=150, n_input_symbols=4): + n_channel = n_neuron // n_input_symbols + + # assign input spike probabilities + probs = np.where(np.random.random((batch_size, 1)) < 0.5, prob, 1 - prob) + + # for each example in batch, draw which cues are going to be active (left or right) + cue_assignments = np.asarray(np.random.random(n_cues) > probs, dtype=np.int_) + + # generate input nums - 0: left, 1: right, 2:recall, 3:background noise + input_nums = 3 * np.ones((batch_size, seq_len), dtype=np.int_) + input_nums[:, :n_cues] = cue_assignments + input_nums[:, -1] = 2 + + # generate input spikes + input_spike_prob = np.zeros((batch_size, seq_len, n_neuron)) + d_silence = t_interval - t_cue + for b in range(batch_size): + for k in range(n_cues): + # input channels only fire when they are selected (left or right) + c = cue_assignments[b, k] + # reverse order of cues + i_seq = d_silence + k * t_interval + i_neu = c * n_channel + input_spike_prob[b, i_seq:i_seq + t_cue, i_neu:i_neu + n_channel] = f0 + # recall cue + input_spike_prob[:, -recall_duration:, 2 * n_channel:3 * n_channel] = f0 + # background noise + input_spike_prob[:, :, 3 * n_channel:] = f0 / 4. + input_spikes = input_spike_prob > np.random.rand(*input_spike_prob.shape) + + # generate targets + target_mask = np.zeros((batch_size, seq_len), dtype=np.bool_) + target_mask[:, -1] = True + target_nums = (np.sum(cue_assignments, axis=1) > n_cues / 2).astype(np.int_) + return input_spikes, input_nums, target_nums, target_mask + + +def get_data(batch_size, n_in, t_interval, f0): + # used for obtaining a new randomly generated batch of examples + def generate_data(): + for _ in range(100): + seq_len = int(t_interval * 7 + 1200) + spk_data, _, target_data, _ = generate_click_task_data( + batch_size=batch_size, seq_len=seq_len, n_neuron=n_in, recall_duration=150, + prob=0.3, t_cue=100, n_cues=7, t_interval=t_interval, f0=f0, n_input_symbols=4) + yield spk_data, target_data + + return generate_data + + +# experiment parameters +reg_f = 1. # regularization coefficient for firing rate +reg_rate = 10 # target firing rate for regularization [Hz] +t_cue_spacing = 150 # distance between two consecutive cues in ms + +# frequency +input_f0 = 40. / 1000. # poisson firing rate of input neurons in khz +regularization_f0 = reg_rate / 1000. # mean target network firing frequency + +# model +net = EligSNN(num_in=40, num_rec=100, num_out=2, neuron_model='alif') + + +def loss_fun(predicts, targets): + predicts, mon = predicts + + # we only use network output at the end for classification + output_logits = predicts[:, -t_cue_spacing:] + + # Define the accuracy + y_predict = bm.argmax(bm.mean(output_logits, axis=1), axis=1) + accuracy = bm.equal(targets, y_predict).astype(bm.dftype()).mean() + + # loss function + tiled_targets = bm.tile(bm.expand_dims(targets, 1), (1, t_cue_spacing)) + loss_cls = bm.mean(bp.losses.cross_entropy_loss(output_logits, tiled_targets)) + + # Firing rate regularization: + # For historical reason we often use this regularization, + # but the other one is easier to implement in an "online" fashion by a single agent. + av = bm.mean(mon['r.spike'], axis=(0, 1)) / bm.get_dt() + loss_reg_f = bm.sum(bm.square(av - regularization_f0) * reg_f) + + # Aggregate the losses # + loss = loss_reg_f + loss_cls + + loss_res = {'loss': loss, 'loss reg': loss_reg_f, 'accuracy': accuracy} + return loss, loss_res + + +# Training +trainer = bp.train.BPTT(net, + loss_fun, + loss_has_aux=True, + optimizer=bp.optimizers.Adam(lr=1e-2), + monitors={'r.spike': net.r.spike}, ) +trainer.fit(get_data(64, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0), + num_epoch=2, num_report=10) + + +fig, gs = bp.visualize.get_figure(2, 2, 4, 5) + +fig.add_subplot(gs[0, 0]) +plt.plot(bm.as_numpy(trainer.train_losses)) +plt.ylabel('Overall Loss') +fig.add_subplot(gs[0, 1]) +plt.plot(bm.as_numpy(trainer.train_loss_aux['loss'])) +plt.ylabel('Accuracy Loss') +fig.add_subplot(gs[1, 0]) +plt.plot(bm.as_numpy(trainer.train_loss_aux['loss reg'])) +plt.ylabel('Regularization Loss') +fig.add_subplot(gs[1, 1]) +plt.plot(bm.as_numpy(trainer.train_loss_aux['accuracy'])) +plt.ylabel('Accuracy') +plt.show() diff --git a/examples/training/Gauthier_2021_ngrc_double_scroll.py b/examples/training/Gauthier_2021_ngrc_double_scroll.py index f82d1e8b3..5547ac254 100644 --- a/examples/training/Gauthier_2021_ngrc_double_scroll.py +++ b/examples/training/Gauthier_2021_ngrc_double_scroll.py @@ -14,6 +14,7 @@ import brainpy as bp import brainpy.math as bm + bm.enable_x64() @@ -102,14 +103,14 @@ def plot_double_scroll(ground_truth, predictions): # ----- # -class NGRC(bp.train.TrainingSystem): +class NGRC(bp.dyn.TrainingSystem): def __init__(self, num_in): super(NGRC, self).__init__() - self.r = bp.train.NVAR(num_in, delay=2, order=3) - self.di = bp.train.Dense(self.r.num_out, num_in, trainable=True) + self.r = bp.layers.NVAR(num_in, delay=2, order=3) + self.di = bp.layers.Dense(self.r.num_out, num_in, trainable=True) - def forward(self, x, shared_args=None): - di = self.di(self.r(x, shared_args), shared_args) + def update(self, shared, x): + di = self.di(shared, self.r(shared, x)) return x + di @@ -119,7 +120,7 @@ def forward(self, x, shared_args=None): # -------- # # warm-up -trainer = bp.train.RidgeTrainer(model, beta=1e-5, jit=True) +trainer = bp.train.RidgeTrainer(model, alpha=1e-5, jit=True) outputs = trainer.predict(X_warmup) print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) @@ -128,10 +129,11 @@ def forward(self, x, shared_args=None): plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b) # prediction +shared = dict() model_jit = bm.jit(model) -outputs = [model_jit(X_test[:, 0])] +outputs = [model_jit(shared, X_test[:, 0])] for i in range(1, X_test.shape[1]): - outputs.append(model_jit(outputs[i - 1])) + outputs.append(model_jit(shared, outputs[i - 1])) outputs = bm.asarray(outputs).squeeze() print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) -plot_double_scroll(Y_test.numpy().squeeze(), outputs.numpy()) +plot_double_scroll(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs)) diff --git a/examples/training/Gauthier_2021_ngrc_lorenz.py b/examples/training/Gauthier_2021_ngrc_lorenz.py index 2d89fa47a..dbaf38941 100644 --- a/examples/training/Gauthier_2021_ngrc_lorenz.py +++ b/examples/training/Gauthier_2021_ngrc_lorenz.py @@ -105,14 +105,14 @@ def plot_lorenz(ground_truth, predictions): # Model # # ----- # -class NGRC(bp.train.TrainingSystem): +class NGRC(bp.dyn.TrainingSystem): def __init__(self, num_in): super(NGRC, self).__init__() - self.r = bp.train.NVAR(num_in, delay=2, order=2, constant=True) - self.di = bp.train.Dense(self.r.num_out, num_in, b_initializer=None) + self.r = bp.layers.NVAR(num_in, delay=2, order=2, constant=True) + self.di = bp.layers.Dense(self.r.num_out, num_in, b_initializer=None) - def forward(self, x, shared_args=None): - dx = self.di(self.r(x, shared_args), shared_args) + def update(self, sha, x): + dx = self.di(sha, self.r(sha, x)) return x + dx @@ -129,14 +129,15 @@ def forward(self, x, shared_args=None): print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) # training -trainer.fit([X_train, {'di': dX_train}]) +trainer.fit([X_train, dX_train]) plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b) # prediction +shared = dict() model_jit = bm.jit(model) -outputs = [model_jit(X_test[:, 0])] +outputs = [model_jit(shared, X_test[:, 0])] for i in range(1, X_test.shape[1]): - outputs.append(model_jit(outputs[i - 1])) + outputs.append(model_jit(shared, outputs[i - 1])) outputs = bm.asarray(outputs) print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) -plot_lorenz(Y_test.numpy().squeeze(), outputs.numpy().squeeze()) +plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze()) diff --git a/examples/training/Gauthier_2021_ngrc_lorenz_inference.py b/examples/training/Gauthier_2021_ngrc_lorenz_inference.py index 174b8e767..93d1e35e3 100644 --- a/examples/training/Gauthier_2021_ngrc_lorenz_inference.py +++ b/examples/training/Gauthier_2021_ngrc_lorenz_inference.py @@ -142,14 +142,14 @@ def plot_lorenz(x, y, true_z, predict_z, linewidth=.8): # Model # # ----- # -class NGRC(bp.train.TrainingSystem): +class NGRC(bp.dyn.TrainingSystem): def __init__(self, num_in): super(NGRC, self).__init__() - self.r = bp.train.NVAR(num_in, delay=4, order=2, stride=5) - self.o = bp.train.Dense(self.r.num_out, 1, trainable=True) + self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5) + self.o = bp.layers.Dense(self.r.num_out, 1, trainable=True) - def forward(self, x, shared_args=None): - return self.o(self.r(x, shared_args), shared_args) + def update(self, sha, x): + return self.o(sha, self.r(sha, x)) model = NGRC(2) @@ -157,7 +157,8 @@ def forward(self, x, shared_args=None): # Training # # -------- # -trainer = bp.train.RidgeTrainer(model, beta=0.05) +trainer = bp.train.RidgeTrainer(model, alpha=0.05) +# trainer = bp.train.ForceTrainer(model, ) # warm-up outputs = trainer.predict(X_warmup) @@ -170,7 +171,7 @@ def forward(self, x, shared_args=None): outputs = trainer.predict(X_test, reset_state=True) print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) -plot_lorenz(x=lorenz_series['x'].flatten().to_numpy(), - y=lorenz_series['y'].flatten().to_numpy(), - true_z=lorenz_series['z'].flatten().to_numpy(), - predict_z=outputs.to_numpy().flatten()) +plot_lorenz(x=bm.as_numpy(lorenz_series['x']).flatten(), + y=bm.as_numpy(lorenz_series['y']).flatten(), + true_z=bm.as_numpy(lorenz_series['z']).flatten(), + predict_z=bm.as_numpy(outputs).flatten()) diff --git a/examples/training/Song_2016_EI_RNN.py b/examples/training/Song_2016_EI_RNN.py index f366fed4d..ff9f8ddab 100644 --- a/examples/training/Song_2016_EI_RNN.py +++ b/examples/training/Song_2016_EI_RNN.py @@ -125,20 +125,20 @@ def __init__(self, num_input, num_hidden, num_output, num_batch, # hidden mask mask = np.tile([1] * self.e_size + [-1] * self.i_size, (num_hidden, 1)) np.fill_diagonal(mask, 0) - self.mask = bm.asarray(mask, dtype=bm.get_dfloat()) + self.mask = bm.asarray(mask, dtype=bm.dftype()) # input weight - self.w_ir = bm.TrainVar(bp.init.init_param(w_ir, (num_input, num_hidden))) + self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden))) # recurrent weight bound = 1 / num_hidden ** 0.5 - self.w_rr = bm.TrainVar(bp.init.init_param(w_rr, (num_hidden, num_hidden))) + self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden))) self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size) self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden)) # readout weight bound = 1 / self.e_size ** 0.5 - self.w_ro = bm.TrainVar(bp.init.init_param(w_ro, (self.e_size, num_output))) + self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (self.e_size, num_output))) self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output)) # variables diff --git a/examples/training/SurrogateGrad_lif.py b/examples/training/SurrogateGrad_lif.py new file mode 100644 index 000000000..836aaeb85 --- /dev/null +++ b/examples/training/SurrogateGrad_lif.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + + +""" +Reproduce the results of the``spytorch`` tutorial 1: + +- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial1.ipynb + +""" + +import time + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec + +import brainpy as bp +import brainpy.math as bm + + +class SNN(bp.dyn.Network): + def __init__(self, num_in, num_rec, num_out): + super(SNN, self).__init__() + + # parameters + self.num_in = num_in + self.num_rec = num_rec + self.num_out = num_out + + # neuron groups + self.i = bp.neurons.InputGroup(num_in, trainable=True) + self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1., trainable=True) + self.o = bp.neurons.LeakyIntegrator(num_out, tau=5, trainable=True) + + # synapse: i->r + self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), + output=bp.synouts.CUBA(), tau=10., + g_max=bp.init.KaimingNormal(scale=10.), + trainable=True) + # synapse: r->o + self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), + output=bp.synouts.CUBA(), tau=10., + g_max=bp.init.KaimingNormal(scale=10.), + trainable=True) + + def update(self, tdi, spike): + self.i2r(tdi, spike) + self.r2o(tdi) + self.r(tdi) + self.o(tdi) + return self.o.V.value + + +def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): + gs = GridSpec(*dim) + mem = 1. * mem + if spk is not None: + mem[spk > 0.0] = spike_height + mem = bm.as_numpy(mem) + for i in range(np.prod(dim)): + if i == 0: + a0 = ax = plt.subplot(gs[i]) + else: + ax = plt.subplot(gs[i], sharey=a0) + ax.plot(mem[i]) + ax.axis("off") + plt.tight_layout() + plt.show() + + +def print_classification_accuracy(output, target): + """ Dirty little helper function to compute classification accuracy. """ + m = bm.max(output, axis=1) # max over time + am = bm.argmax(m, axis=1) # argmax over output units + acc = bm.mean(target == am) # compare to labels + print("Accuracy %.3f" % acc) + + +net = SNN(100, 4, 2) + +num_step = 2000 +num_sample = 256 +freq = 5 # Hz +mask = bm.random.rand(num_sample, num_step, net.num_in) +x_data = bm.zeros((num_sample, num_step, net.num_in)) +x_data[mask < freq * bm.get_dt() / 1000.] = 1.0 +y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.dftype()) +rng = bm.random.RandomState() + + +# Before training +runner = bp.train.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) +out = runner.run(inputs=x_data, inputs_are_batching=True, reset_state=True) +plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) +plot_voltage_traces(out) +print_classification_accuracy(out, y_data) + + +def loss(): + key = rng.split_key() + X = rng.permutation(x_data, key=key) + Y = rng.permutation(y_data, key=key) + looper = bp.train.DSRunner(net, numpy_mon_after_run=False, progress_bar=False) + predictions = looper.run(inputs=X, inputs_are_batching=True, reset_state=True) + predictions = bm.max(predictions, axis=1) + return bp.losses.cross_entropy_loss(predictions, Y) + + +f_grad = bm.grad(loss, + grad_vars=net.train_vars().unique(), + dyn_vars=net.vars().unique() + {'rng': rng}, + return_value=True) +f_opt = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique()) + + +def train(_): + grads, l = f_grad() + f_opt.update(grads) + return l + + +f_train = bm.make_loop(train, + dyn_vars=f_opt.vars() + net.vars() + {'rng': rng}, + has_return=True) + +# train the network +net.reset_state(num_sample) +train_losses = [] +for i in range(0, 1000, 100): + t0 = time.time() + _, ls = f_train(bm.arange(i, i + 100, 1)) + print(f'Train {i + 100} epoch, loss = {bm.mean(ls):.4f}, used time {time.time() - t0:.4f} s') + train_losses.append(ls) + + +# visualize the training losses +plt.plot(bm.as_numpy(bm.concatenate(train_losses))) +plt.xlabel("Epoch") +plt.ylabel("Training Loss") +plt.show() + + +# predict the output according to the input data +runner = bp.dyn.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) +out = runner.run(inputs=x_data, inputs_are_batching=True, reset_state=True) +plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) +plot_voltage_traces(out) +print_classification_accuracy(out, y_data) + diff --git a/examples/training/SurrogateGrad_lif_fashion_mnist.py b/examples/training/SurrogateGrad_lif_fashion_mnist.py new file mode 100644 index 000000000..dc6008e68 --- /dev/null +++ b/examples/training/SurrogateGrad_lif_fashion_mnist.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- + +""" +Reproduce the results of the``spytorch`` tutorial 2 & 3: + +- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb +- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial3.ipynb + +""" + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec + +import brainpy as bp +import brainpy.math as bm + + +class SNN(bp.dyn.Network): + """ + This class implements a spiking neural network model with three layers: + + i >> r >> o + + Each two layers are connected through the exponential synapse model. + """ + + def __init__(self, num_in, num_rec, num_out): + super(SNN, self).__init__() + + # parameters + self.num_in = num_in + self.num_rec = num_rec + self.num_out = num_out + + # neuron groups + self.i = bp.neurons.InputGroup(num_in, trainable=True) + self.r = bp.neurons.LIF(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1., trainable=True) + self.o = bp.neurons.LeakyIntegrator(num_out, tau=5, trainable=True) + + # synapse: i->r + self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), + output=bp.synouts.CUBA(), tau=10., + g_max=bp.init.KaimingNormal(scale=2.), + trainable=True) + # synapse: r->o + self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), + output=bp.synouts.CUBA(), tau=10., + g_max=bp.init.KaimingNormal(scale=2.), + trainable=True) + + def update(self, shared, spike): + self.i2r(shared, spike) + self.r2o(shared) + self.r(shared) + self.o(shared) + return self.o.V.value + + +def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): + gs = GridSpec(*dim) + mem = 1. * mem + if spk is not None: + mem[spk > 0.0] = spike_height + mem = bm.as_numpy(mem) + for i in range(np.prod(dim)): + if i == 0: + a0 = ax = plt.subplot(gs[i]) + else: + ax = plt.subplot(gs[i], sharey=a0) + ax.plot(mem[i]) + ax.axis("off") + plt.tight_layout() + plt.show() + + +def print_classification_accuracy(output, target): + """ Dirty little helper function to compute classification accuracy. """ + m = bm.max(output, axis=1) # max over time + am = bm.argmax(m, axis=1) # argmax over output units + acc = bm.mean(target == am) # compare to labels + print("Accuracy %.3f" % acc) + + +def current2firing_time(x, tau=20., thr=0.2, tmax=1.0, epsilon=1e-7): + """Computes first firing time latency for a current input x + assuming the charge time of a current based LIF neuron. + + Args: + x -- The "current" values + + Keyword args: + tau -- The membrane time constant of the LIF neuron to be charged + thr -- The firing threshold value + tmax -- The maximum time returned + epsilon -- A generic (small) epsilon > 0 + + Returns: + Time to first spike for each "current" x + """ + x = np.clip(x, thr + epsilon, 1e9) + T = tau * np.log(x / (x - thr)) + T = np.where(x < thr, tmax, T) + return T + + +def sparse_data_generator(X, y, batch_size, nb_steps, nb_units, shuffle=True): + """ This generator takes datasets in analog format and + generates spiking network input as sparse tensors. + + Args: + X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples + y: The labels + """ + + labels_ = np.array(y, dtype=bm.ditype()) + sample_index = np.arange(len(X)) + + # compute discrete firing times + tau_eff = 2. / bm.get_dt() + unit_numbers = np.arange(nb_units) + firing_times = np.array(current2firing_time(X, tau=tau_eff, tmax=nb_steps), dtype=bm.ditype()) + + if shuffle: + np.random.shuffle(sample_index) + + counter = 0 + number_of_batches = len(X) // batch_size + while counter < number_of_batches: + batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)] + all_batch, all_times, all_units = [], [], [] + for bc, idx in enumerate(batch_index): + c = firing_times[idx] < nb_steps + times, units = firing_times[idx][c], unit_numbers[c] + batch = bc * np.ones(len(times), dtype=bm.ditype()) + all_batch.append(batch) + all_times.append(times) + all_units.append(units) + all_batch = np.concatenate(all_batch).flatten() + all_times = np.concatenate(all_times).flatten() + all_units = np.concatenate(all_units).flatten() + x_batch = bm.zeros((batch_size, nb_steps, nb_units)) + x_batch[all_batch, all_times, all_units] = 1. + y_batch = bm.asarray(labels_[batch_index]) + yield x_batch, y_batch + counter += 1 + + +def train(model, x_data, y_data, lr=1e-3, nb_epochs=10, batch_size=128, nb_steps=128, nb_inputs=28 * 28): + def loss_fun(predicts, targets): + predicts, mon = predicts + # Here we set up our regularizer loss + # The strength paramters here are merely a guess and + # there should be ample room for improvement by + # tuning these paramters. + l1_loss = 1e-5 * bm.sum(mon['r.spike']) # L1 loss on total number of spikes + l2_loss = 1e-5 * bm.mean(bm.sum(bm.sum(mon['r.spike'], axis=0), axis=0) ** 2) # L2 loss on spikes per neuron + # predictions + predicts = bm.max(predicts, axis=1) + loss = bp.losses.cross_entropy_loss(predicts, targets) + return loss + l2_loss + l1_loss + + f_opt = bp.optim.Adam(lr=lr) + trainer = bp.train.BPTT(model, loss_fun, f_opt, + monitors={'r.spike': net.r.spike}, + dyn_vars={'rand': bm.random.DEFAULT}) + trainer.fit(lambda: sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs), + num_epoch=nb_epochs) + return trainer.train_losses + + +def compute_classification_accuracy(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): + """ Computes classification accuracy on supplied data in batches. """ + accs = [] + runner = bp.dyn.DSRunner(model, dyn_vars={'rand': bm.random.DEFAULT}, progress_bar=False) + for x_local, y_local in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False): + output = runner.predict(inputs=x_local, inputs_are_batching=True, reset_state=True) + m = bm.max(output, 1) # max over time + am = bm.argmax(m, 1) # argmax over output units + tmp = bm.mean(y_local == am) # compare to labels + accs.append(tmp) + return bm.mean(bm.asarray(accs)) + + +def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): + runner = bp.dyn.DSRunner(model, + dyn_vars={'rand': bm.random.DEFAULT}, + monitors={'r.spike': model.r.spike}, + progress_bar=False) + data = sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False) + x_local, y_local = next(data) + output = runner.predict(inputs=x_local, inputs_are_batching=True, reset_state=True) + return output, runner.mon.get('r.spike') + + +num_input = 28 * 28 +net = SNN(num_in=num_input, num_rec=100, num_out=10) + +# load the dataset +root = "some_path/" +train_dataset = bp.datasets.FashionMNIST(root, + train=True, + transform=None, + target_transform=None, + download=True) +test_dataset = bp.datasets.FashionMNIST(root, + train=False, + transform=None, + target_transform=None, + download=True) + +# Standardize data +x_train = np.array(train_dataset.data, dtype=bm.dftype()) +x_train = x_train.reshape(x_train.shape[0], -1) / 255 +y_train = np.array(train_dataset.targets, dtype=bm.ditype()) +x_test = np.array(test_dataset.data, dtype=bm.dftype()) +x_test = x_test.reshape(x_test.shape[0], -1) / 255 +y_test = np.array(test_dataset.targets, dtype=bm.ditype()) + +# training +train_losses = train(net, x_train, y_train, lr=1e-3, nb_epochs=30, batch_size=256, nb_steps=100, nb_inputs=28 * 28) + +plt.figure(figsize=(3.3, 2), dpi=150) +plt.plot(train_losses) +plt.xlabel("Epoch") +plt.ylabel("Loss") +plt.show() + +print("Training accuracy: %.3f" % (compute_classification_accuracy(net, x_train, y_train, batch_size=512))) +print("Test accuracy: %.3f" % (compute_classification_accuracy(net, x_test, y_test, batch_size=512))) + + +outs, spikes = get_mini_batch_results(net, x_train, y_train) +# Let's plot the hidden layer spiking activity for some input stimuli +fig = plt.figure(dpi=100) +plot_voltage_traces(outs) +plt.show() + +nb_plt = 4 +gs = GridSpec(1, nb_plt) +fig = plt.figure(figsize=(7, 3), dpi=150) +for i in range(nb_plt): + plt.subplot(gs[i]) + plt.imshow(bm.as_numpy(spikes[i]).T, cmap=plt.cm.gray_r, origin="lower") + if i == 0: + plt.xlabel("Time") + plt.ylabel("Units") +plt.tight_layout() +plt.show() + diff --git a/examples/training/echo_state_network.py b/examples/training/echo_state_network.py index 9f8275000..e05a2a172 100644 --- a/examples/training/echo_state_network.py +++ b/examples/training/echo_state_network.py @@ -4,30 +4,29 @@ import brainpy.math as bm -class ESN(bp.train.TrainingSystem): +class ESN(bp.dyn.TrainingSystem): def __init__(self, num_in, num_hidden, num_out): super(ESN, self).__init__() - self.r = bp.train.Reservoir(num_in, num_hidden, - Win_initializer=bp.init.Uniform(-0.1, 0.1), - Wrec_initializer=bp.init.Normal(scale=0.1), - ff_connectivity=0.02, - fb_connectivity=0.02, - rec_connectivity=0.02, - conn_type='dense') - self.o = bp.train.Dense(num_hidden, num_out, W_initializer=bp.init.Normal()) + self.r = bp.layers.Reservoir(num_in, num_hidden, + Win_initializer=bp.init.Uniform(-0.1, 0.1), + Wrec_initializer=bp.init.Normal(scale=0.1), + in_connectivity=0.02, + rec_connectivity=0.02, + conn_type='dense') + self.o = bp.layers.Dense(num_hidden, num_out, W_initializer=bp.init.Normal()) def forward(self, x, shared_args=None): return self.o(self.r(x, shared_args), shared_args) -class NGRC(bp.train.TrainingSystem): +class NGRC(bp.dyn.TrainingSystem): def __init__(self, num_in, num_out): super(NGRC, self).__init__() - self.r = bp.train.NVAR(num_in, delay=2, order=2) - self.o = bp.train.Dense(self.r.num_out, num_out, - W_initializer=bp.init.Normal(0.1), - trainable=True) + self.r = bp.layers.NVAR(num_in, delay=2, order=2) + self.o = bp.layers.Dense(self.r.num_out, num_out, + W_initializer=bp.init.Normal(0.1), + trainable=True) def forward(self, x, shared_args=None): return self.o(self.r(x, shared_args), shared_args) @@ -93,7 +92,7 @@ def ngrc(num_in=10, num_out=30): X = bm.random.random((1, 200, num_in)) # (num_batch, num_time, num_feature) Y = bm.random.random((1, 200, num_out)) - trainer = bp.train.RidgeTrainer(model, beta=1e-6) + trainer = bp.train.RidgeTrainer(model, alpha=1e-6) outputs = trainer.predict(X) print(outputs.shape) print(bp.losses.mean_absolute_error(outputs, Y)) diff --git a/examples/training/integrator_rnn.py b/examples/training/integrator_rnn.py index 1bcae3aa5..010be2bc9 100644 --- a/examples/training/integrator_rnn.py +++ b/examples/training/integrator_rnn.py @@ -31,14 +31,14 @@ def train_data(): yield build_inputs_and_targets(batch_size=num_batch) -class RNN(bp.train.TrainingSystem): +class RNN(bp.dyn.TrainingSystem): def __init__(self, num_in, num_hidden): super(RNN, self).__init__() - self.rnn = bp.train.VanillaRNN(num_in, num_hidden, train_state=True) - self.out = bp.train.Dense(num_hidden, 1) + self.rnn = bp.layers.VanillaRNN(num_in, num_hidden, train_state=True) + self.out = bp.layers.Dense(num_hidden, 1) - def forward(self, x, shared_args=None): - return self.out(self.rnn(x, shared_args), shared_args) + def update(self, sha, x): + return self.out(sha, self.rnn(sha, x)) model = RNN(1, 100) @@ -56,12 +56,9 @@ def loss(predictions, targets, l2_reg=2e-4): opt = bp.optim.Adam(lr=lr, eps=1e-1) # create a trainer -trainer = bp.train.BPTT(model, - loss=loss, - optimizer=opt, - max_grad_norm=5.0) +trainer = bp.train.BPTT(model, loss_fun=loss, optimizer=opt) trainer.fit(train_data, - num_batch=num_batch, + batch_size=num_batch, num_epoch=30, num_report=200) diff --git a/extensions/CMakeLists.txt b/extensions/CMakeLists.txt index e17a99e60..85a048270 100644 --- a/extensions/CMakeLists.txt +++ b/extensions/CMakeLists.txt @@ -5,7 +5,7 @@ message(STATUS "Using CMake version " ${CMAKE_VERSION}) find_package(CUDA REQUIRED) find_package(Python COMPONENTS Interpreter Development REQUIRED) set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT_DIR} "/usr/local/cuda") -set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "/home/brainpy/miniconda3/lib/python3.9/site-packages/pybind11/share/cmake/") +set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ) find_package(pybind11 REQUIRED) include_directories( @@ -23,12 +23,15 @@ pybind11_add_module( gpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/gpu_ops.cc ${CMAKE_CURRENT_LIST_DIR}/lib/event_sum_gpu.cu + ${CMAKE_CURRENT_LIST_DIR}/lib/atomic_prod_gpu.cu ${CMAKE_CURRENT_LIST_DIR}/lib/atomic_sum_gpu.cu) install(TARGETS gpu_ops DESTINATION brainpylib) pybind11_add_module( cpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/cpu_ops.cc + ${CMAKE_CURRENT_LIST_DIR}/lib/event_prod_cpu.cc ${CMAKE_CURRENT_LIST_DIR}/lib/event_sum_cpu.cc + ${CMAKE_CURRENT_LIST_DIR}/lib/atomic_prod_cpu.cc ${CMAKE_CURRENT_LIST_DIR}/lib/atomic_sum_cpu.cc ) install(TARGETS cpu_ops DESTINATION brainpylib) diff --git a/extensions/brainpylib/__init__.py b/extensions/brainpylib/__init__.py index f53a4916d..2328558f4 100644 --- a/extensions/brainpylib/__init__.py +++ b/extensions/brainpylib/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "0.0.5" +__version__ = "0.0.6" # IMPORTANT, must import first from . import register_custom_calls diff --git a/extensions/brainpylib/custom_op/cuda.py b/extensions/brainpylib/custom_op/cuda.py index 400e9e0b6..4b66349aa 100644 --- a/extensions/brainpylib/custom_op/cuda.py +++ b/extensions/brainpylib/custom_op/cuda.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + import ctypes import ctypes.util import sys diff --git a/extensions/brainpylib/custom_op/regis_op.py b/extensions/brainpylib/custom_op/regis_op.py index 5c5df2598..9e65fb556 100644 --- a/extensions/brainpylib/custom_op/regis_op.py +++ b/extensions/brainpylib/custom_op/regis_op.py @@ -4,7 +4,6 @@ from functools import partial from typing import Callable, Union, Sequence -import jax.numpy as jnp import numba import numpy as np from jax import core @@ -41,12 +40,14 @@ def register_op( Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or a sequence of `ShapedArray`. If it is a function, it takes as input the argument shapes and dtypes and should return correct output shapes of `ShapedArray`. - apply_cpu_func_to_gpu: bool, default = True + apply_cpu_func_to_gpu: bool, True when gpu_func is implemented on CPU and other logics(data transfer) is implemented on GPU. + Default is True. Returns ------- - A jitable JAX function. + op: callable + A jitable JAX function. """ if gpu_func is not None: raise RuntimeError('Currently cuda.jit function is not supported to convert into a Jax/XLA compatible primitive.' diff --git a/extensions/brainpylib/event_sum.py b/extensions/brainpylib/event_sum.py index 02f57f99f..fbbd985f9 100644 --- a/extensions/brainpylib/event_sum.py +++ b/extensions/brainpylib/event_sum.py @@ -2,6 +2,7 @@ __all__ = [ 'event_sum', + 'event_sum2', ] from functools import partial @@ -86,12 +87,19 @@ def _event_sum_translation(c, events, indices, indptr, values, out, *, platform= if platform == "cpu": v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter' return x_ops.CustomCallWithLayout( - c, platform.encode() + v_type + f_type + i_type, + c, + platform.encode() + v_type + f_type + i_type, operands=(x_ops.ConstantLiteral(c, pre_size), x_ops.ConstantLiteral(c, post_size), - events, indices, indptr, values), - operand_shapes_with_layout=(_pre_shape, _post_shape, c.get_shape(events), - c.get_shape(indices), c.get_shape(indptr), + events, + indices, + indptr, + values), + operand_shapes_with_layout=(_pre_shape, + _post_shape, + c.get_shape(events), + c.get_shape(indices), + c.get_shape(indptr), c.get_shape(values)), shape_with_layout=c.get_shape(out), ) @@ -101,10 +109,16 @@ def _event_sum_translation(c, events, indices, indptr, values, out, *, platform= v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter' opaque = gpu_ops.build_event_sum_descriptor(pre_size, post_size) return x_ops.CustomCallWithLayout( - c, platform.encode() + v_type + f_type + i_type, - operands=(events, indices, indptr, values), - operand_shapes_with_layout=(c.get_shape(events), c.get_shape(indices), - c.get_shape(indptr), c.get_shape(values)), + c, + platform.encode() + v_type + f_type + i_type, + operands=(events, + indices, + indptr, + values), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(indices), + c.get_shape(indptr), + c.get_shape(values)), shape_with_layout=c.get_shape(out), opaque=opaque, ) @@ -116,107 +130,120 @@ def _event_sum_translation(c, events, indices, indptr, values, out, *, platform= xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu") xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu") -# # --------------------------- -# # event sum kernel 2 -# # --------------------------- -# -# -# _event_sum2_prim = core.Primitive("event_sum2") -# -# -# def event_sum2(events, pre_ids, post_ids, post_num, values): -# # events -# if events.dtype != jnp.bool_: -# raise ValueError(f'"events" must be a vector of bool, while we got {events.dtype}') -# -# # connections -# if len(pre_ids) != len(post_ids): -# raise ValueError(f'The length of "pre_ids" must be equal to "post_ids", ' -# f'while we get: {len(pre_ids)} != {len(post_ids)}') -# if pre_ids.dtype != post_ids.dtype: -# raise ValueError(f'The dtype of "pre_ids" must be equal to that of "post_ids", ' -# f'while we got {(pre_ids.dtype, post_ids.dtype)}') -# if pre_ids.dtype not in [jnp.uint32, jnp.uint64]: -# raise ValueError(f'The dtype of "post_ids/pre_ids" must be uint32 or uint64, ' -# f'while we got {pre_ids.dtype}') -# -# # output value -# values = jnp.asarray([values]) -# if values.dtype not in [jnp.float32, jnp.float64]: -# raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.') -# if values.size not in [1, pre_ids.size]: -# raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre_ids) (a vector), ' -# f'while we got {values.size} != 1 != {pre_ids.size}') -# out = jnp.zeros(post_num, dtype=values.dtype) -# values = values.flatten() -# -# # bind operator -# return _event_sum2_prim.bind(events, pre_ids, post_ids, values, out) -# -# -# def _event_sum2_abstract(events, pre_ids, post_ids, value, out): -# return out -# -# -# _event_sum2_prim.def_abstract_eval(_event_sum2_abstract) -# _event_sum2_prim.def_impl(partial(xla.apply_primitive, _event_sum2_prim)) -# -# -# def _event_sum2_translation(c, events, pre_ids, post_ids, values, out, *, platform="cpu"): -# # The conn/post shape -# conn_size = np.array(c.get_shape(pre_ids).dimensions()[0], dtype=np.uint32) -# post_size = np.array(c.get_shape(out).dimensions()[0], dtype=np.uint32) -# _pre_shape = x_shape(np.dtype(np.uint32), (), ()) -# _post_shape = x_shape(np.dtype(np.uint32), (), ()) -# -# # The pre_ids shape -# pre_ids_shape = c.get_shape(pre_ids) -# Itype = pre_ids_shape.element_type() -# assert Itype in [np.uint32, np.uint64] -# -# # The value shape -# values_shape = c.get_shape(values) -# Ftype = values_shape.element_type() -# assert Ftype in [np.float32, np.float64] -# values_dim = values_shape.dimensions() -# -# # We dispatch a different call depending on the dtype -# f_type = b'_f32' if Ftype == np.float32 else b'_f64' -# i_type = b'_i32' if Itype == np.uint32 else b'_i64' -# -# # And then the following is what changes between the GPU and CPU -# if platform == "cpu": -# v_type = b'_event_sum2_homo' if values_dim[0] == 1 else b'_event_sum2_heter' -# return x_ops.CustomCallWithLayout( -# c, platform.encode() + v_type + f_type + i_type, -# operands=(x_ops.ConstantLiteral(c, conn_size), -# x_ops.ConstantLiteral(c, post_size), -# events, pre_ids, post_ids, values), -# operand_shapes_with_layout=(_pre_shape, _post_shape, c.get_shape(events), -# c.get_shape(pre_ids), c.get_shape(post_ids), -# c.get_shape(values)), -# shape_with_layout=c.get_shape(out), -# ) -# elif platform == 'gpu': -# if gpu_ops is None: -# raise ValueError('Cannot find compiled gpu wheels.') -# v_type = b'_event_sum2_homo' if values_dim[0] == 1 else b'_event_sum2_heter' -# opaque = gpu_ops.build_event_sum2_descriptor(conn_size, post_size) -# return x_ops.CustomCallWithLayout( -# c, platform.encode() + v_type + f_type + i_type, -# operands=(events, pre_ids, post_ids, values), -# operand_shapes_with_layout=(c.get_shape(events), c.get_shape(pre_ids), -# c.get_shape(post_ids), c.get_shape(values)), -# shape_with_layout=c.get_shape(out), -# opaque=opaque, -# ) -# raise ValueError("Unsupported platform; this must be either 'cpu' or 'gpu'") -# -# -# xla.backend_specific_translations["cpu"][_event_sum2_prim] = partial(_event_sum2_translation, platform="cpu") -# xla.backend_specific_translations["gpu"][_event_sum2_prim] = partial(_event_sum2_translation, platform="gpu") -# -# +# --------------------------- +# event sum kernel 2 +# --------------------------- + + +_event_sum2_prim = core.Primitive("event_sum2") + + +def event_sum2(events, pre_ids, post_ids, post_num, values): + # events + if events.dtype != jnp.bool_: + raise ValueError(f'"events" must be a vector of bool, while we got {events.dtype}') + + # connections + if len(pre_ids) != len(post_ids): + raise ValueError(f'The length of "pre_ids" must be equal to "post_ids", ' + f'while we get: {len(pre_ids)} != {len(post_ids)}') + if pre_ids.dtype != post_ids.dtype: + raise ValueError(f'The dtype of "pre_ids" must be equal to that of "post_ids", ' + f'while we got {(pre_ids.dtype, post_ids.dtype)}') + if pre_ids.dtype not in [jnp.uint32, jnp.uint64]: + raise ValueError(f'The dtype of "post_ids/pre_ids" must be uint32 or uint64, ' + f'while we got {pre_ids.dtype}') + + # output value + values = jnp.asarray([values]) + if values.dtype not in [jnp.float32, jnp.float64]: + raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.') + if values.size not in [1, pre_ids.size]: + raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre_ids) (a vector), ' + f'while we got {values.size} != 1 != {pre_ids.size}') + out = jnp.zeros(post_num, dtype=values.dtype) + values = values.flatten() + + # bind operator + return _event_sum2_prim.bind(events, pre_ids, post_ids, values, out) + + +def _event_sum2_abstract(events, pre_ids, post_ids, value, out): + return out + + +_event_sum2_prim.def_abstract_eval(_event_sum2_abstract) +_event_sum2_prim.def_impl(partial(xla.apply_primitive, _event_sum2_prim)) + + +def _event_sum2_translation(c, events, pre_ids, post_ids, values, out, *, platform="cpu"): + # The conn/post shape + conn_size = np.array(c.get_shape(pre_ids).dimensions()[0], dtype=np.uint32) + post_size = np.array(c.get_shape(out).dimensions()[0], dtype=np.uint32) + _pre_shape = x_shape(np.dtype(np.uint32), (), ()) + _post_shape = x_shape(np.dtype(np.uint32), (), ()) + + # The pre_ids shape + pre_ids_shape = c.get_shape(pre_ids) + Itype = pre_ids_shape.element_type() + assert Itype in [np.uint32, np.uint64] + + # The value shape + values_shape = c.get_shape(values) + Ftype = values_shape.element_type() + assert Ftype in [np.float32, np.float64] + values_dim = values_shape.dimensions() + + # We dispatch a different call depending on the dtype + f_type = b'_f32' if Ftype == np.float32 else b'_f64' + i_type = b'_i32' if Itype == np.uint32 else b'_i64' + + # And then the following is what changes between the GPU and CPU + if platform == "cpu": + v_type = b'_event_sum2_homo' if values_dim[0] == 1 else b'_event_sum2_heter' + return x_ops.CustomCallWithLayout( + c, + platform.encode() + v_type + f_type + i_type, + operands=(x_ops.ConstantLiteral(c, conn_size), + x_ops.ConstantLiteral(c, post_size), + events, + pre_ids, + post_ids, + values), + operand_shapes_with_layout=(_pre_shape, + _post_shape, + c.get_shape(events), + c.get_shape(pre_ids), + c.get_shape(post_ids), + c.get_shape(values)), + shape_with_layout=c.get_shape(out), + ) + elif platform == 'gpu': + if gpu_ops is None: + raise ValueError('Cannot find compiled gpu wheels.') + v_type = b'_event_sum2_homo' if values_dim[0] == 1 else b'_event_sum2_heter' + opaque = gpu_ops.build_event_sum2_descriptor(conn_size, post_size) + return x_ops.CustomCallWithLayout( + c, + platform.encode() + v_type + f_type + i_type, + operands=(events, + pre_ids, + post_ids, + values), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(pre_ids), + c.get_shape(post_ids), + c.get_shape(values)), + shape_with_layout=c.get_shape(out), + opaque=opaque, + ) + raise ValueError("Unsupported platform; this must be either 'cpu' or 'gpu'") + + +xla.backend_specific_translations["cpu"][_event_sum2_prim] = partial(_event_sum2_translation, platform="cpu") +xla.backend_specific_translations["gpu"][_event_sum2_prim] = partial(_event_sum2_translation, platform="gpu") + + # _event_sum3_prim = core.Primitive("event_sum3") # # diff --git a/extensions/brainpylib/tests/test_atomic_prod.py b/extensions/brainpylib/tests/test_atomic_prod.py index 681803070..14c8ecb96 100644 --- a/extensions/brainpylib/tests/test_atomic_prod.py +++ b/extensions/brainpylib/tests/test_atomic_prod.py @@ -18,7 +18,7 @@ def test_heter_values1(self): post_ids = jnp.arange(size, dtype=jnp.uint32) pre_ids = jnp.arange(size, dtype=jnp.uint32) sps = bp.math.asarray(bp.math.random.randint(0, 2, size), - dtype=bp.math.get_dfloat()) + dtype=bp.math.dftype()) a = atomic_prod(sps.value, post_ids, size, pre_ids) print(a) self.assertTrue(jnp.array_equal(a, sps.value)) diff --git a/extensions/brainpylib/tests/test_atomic_sum.py b/extensions/brainpylib/tests/test_atomic_sum.py index 4d69f0021..761492ce0 100644 --- a/extensions/brainpylib/tests/test_atomic_sum.py +++ b/extensions/brainpylib/tests/test_atomic_sum.py @@ -18,7 +18,7 @@ def test_heter_values1(self): post_ids = jnp.arange(size, dtype=jnp.uint32) pre_ids = jnp.arange(size, dtype=jnp.uint32) sps = bp.math.asarray(bp.math.random.randint(0, 2, size), - dtype=bp.math.get_dfloat()) + dtype=bp.math.dftype()) a = atomic_sum(sps.value, post_ids, size, pre_ids) print(a) self.assertTrue(jnp.array_equal(a, sps.value)) diff --git a/extensions/changelog.rst b/extensions/changelog.rst index e06b32368..809eb07e9 100644 --- a/extensions/changelog.rst +++ b/extensions/changelog.rst @@ -1,6 +1,9 @@ Release notes (brainpylib) ########################## +Version 0.0.6 +============= + Version 0.0.5 ============= diff --git a/extensions/lib/event_sum_gpu.cu b/extensions/lib/event_sum_gpu.cu index 5f12c5d28..e0dee75be 100644 --- a/extensions/lib/event_sum_gpu.cu +++ b/extensions/lib/event_sum_gpu.cu @@ -588,55 +588,6 @@ namespace brainpy_lib { } - template - __global__ void event_sum5_heter_kernel(const std::uint32_t max_post_conn, - const std::uint32_t pre_size, - const bool *events, - const I *indices, - const I *indptr, - const F *values, - F *result) { - __shared__ bool shared_event; - __shared__ I shPreStartID[32]; - __shared__ I shPreEndID[32]; - - if (threadIdx.x == 0) { - if (threadIdx.y == 0){ - shared_event = events[0]; - } - } - __syncthreads(); - - const I id = blockIdx.x * 32 + threadIdx.x; - if (id < max_post_conn) { - const unsigned int num_iter = (pre_size + 32 - 1) / 32; - for (unsigned int r = 0; r < num_iter; r++) { - const unsigned int num_event = (r == num_iter - 1) ? ((pre_size - 1) % 32) + 1 : 32; - // assume "max_post_conn" >= num_event - // TODO: fix the bug - if (threadIdx.x < num_event) { - const unsigned int pre_i = (r * 32) + threadIdx.x; - shared_events[threadIdx.x] = events[pre_i]; - if (shared_events[threadIdx.x]) - { - shPreStartID[threadIdx.x] = indptr[pre_i]; - shRowLength[threadIdx.x] = indptr[pre_i + 1] - shPreStartID[threadIdx.x]; - } - } - __syncthreads(); - for (unsigned int j = 0; j < num_event; j++) { - if (shared_events[j]) { - if (id < shRowLength[j]) { - const I syn_i = shPreStartID[j] + id; - const I post_i = indices[syn_i]; - atomicAdd(&result[post_i], values[syn_i]); - } - } - } - } - } - } - } // namespace diff --git a/extensions/lib/gpu_ops.cc b/extensions/lib/gpu_ops.cc index 1c67b5b9b..6894816c9 100644 --- a/extensions/lib/gpu_ops.cc +++ b/extensions/lib/gpu_ops.cc @@ -26,6 +26,17 @@ namespace { dict["gpu_event_sum_heter_f64_i32"] = EncapsulateFunction(gpu_event_sum_heter_f64_i32); dict["gpu_event_sum_heter_f64_i64"] = EncapsulateFunction(gpu_event_sum_heter_f64_i64); + // homogeneous event_sum2 + dict["gpu_event_sum2_homo_f32_i32"] = EncapsulateFunction(gpu_event_sum2_homo_f32_i32); + dict["gpu_event_sum2_homo_f32_i64"] = EncapsulateFunction(gpu_event_sum2_homo_f32_i64); + dict["gpu_event_sum2_homo_f64_i32"] = EncapsulateFunction(gpu_event_sum2_homo_f64_i32); + dict["gpu_event_sum2_homo_f64_i64"] = EncapsulateFunction(gpu_event_sum2_homo_f64_i64); + // heterogeneous event_sum2 + dict["gpu_event_sum2_heter_f32_i32"] = EncapsulateFunction(gpu_event_sum2_heter_f32_i32); + dict["gpu_event_sum2_heter_f32_i64"] = EncapsulateFunction(gpu_event_sum2_heter_f32_i64); + dict["gpu_event_sum2_heter_f64_i32"] = EncapsulateFunction(gpu_event_sum2_heter_f64_i32); + dict["gpu_event_sum2_heter_f64_i64"] = EncapsulateFunction(gpu_event_sum2_heter_f64_i64); + // homogeneous atomic_sum dict["gpu_atomic_sum_homo_f32_i32"] = EncapsulateFunction(gpu_atomic_sum_homo_f32_i32); dict["gpu_atomic_sum_homo_f32_i64"] = EncapsulateFunction(gpu_atomic_sum_homo_f32_i64); @@ -55,6 +66,7 @@ namespace { ) { m.def("registrations", &Registrations); m.def("build_event_sum_descriptor", &build_event_sum_descriptor); + m.def("build_event_sum2_descriptor", &build_event_sum2_descriptor); m.def("build_atomic_sum_descriptor", &build_atomic_sum_descriptor); m.def("build_atomic_prod_descriptor", &build_atomic_prod_descriptor); } diff --git a/extensions/setup_cuda.py b/extensions/setup_cuda.py index 30f944c0a..9750a1454 100644 --- a/extensions/setup_cuda.py +++ b/extensions/setup_cuda.py @@ -1,10 +1,11 @@ import distutils.sysconfig as sysconfig +import glob import os import platform import re import subprocess import sys -import glob + import pybind11 from setuptools import find_packages, setup, Extension from setuptools.command.build_ext import build_ext @@ -46,7 +47,7 @@ def build_extensions(self): #"-DPython_LIBRARIES={}".format(cmake_python_library), #"-DPython_INCLUDE_DIRS={}".format(cmake_python_include_dir), # "-DCMAKE_BUILD_TYPE={}".format("Debug" if self.debug else "Release"), - # "-DCMAKE_PREFIX_PATH={}".format(pybind11.get_cmake_dir()), + "-DCMAKE_PREFIX_PATH={}".format(os.path.dirname(pybind11.get_cmake_dir())), # "-DCMAKE_CUDA_FLAGS={}".format('"-arch=sm_61"') ] if os.environ.get("BRAINPY_CUDA", "no").lower() == "yes": @@ -77,7 +78,7 @@ def build_extension(self, ext): init_py = f.read() __version__ = re.search('__version__ = "(.*)"', init_py).groups()[0] -cuda_version = os.environ.get("JAX_CUDA_VERSION") +cuda_version = os.environ.get("CUDA_VERSION") if cuda_version: __version__ += "+cuda" + cuda_version.replace(".", "") diff --git a/requirements-dev.txt b/requirements-dev.txt index 1d9812faa..41d92e5f8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,9 +4,10 @@ numba matplotlib>=3.4 jaxlib>=0.3.0 scipy>=1.1.0 -networkx brainpylib>=0.0.5 h5py +requests +pillow # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index 0660fc5db..68f8f318e 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -5,6 +5,8 @@ jaxlib>=0.3.0 scipy>=1.1.0 brainpylib>=0.0.5 numba +requests +pillow # document requirements pandoc diff --git a/requirements-win.txt b/requirements-win.txt index fb0b01270..0d38ed128 100644 --- a/requirements-win.txt +++ b/requirements-win.txt @@ -4,9 +4,10 @@ numba h5py matplotlib>=3.4 scipy>=1.1.0 -networkx brainpylib>=0.0.5 jaxlib>=0.3.0 +pillow +requests # test requirements pytest