From 0b95e2affa142b715eb1b62838abda507f8598b3 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 5 Oct 2022 12:30:35 +0800 Subject: [PATCH 1/3] Add `PoissonInput` --- brainpy/__init__.py | 15 +- brainpy/dyn/layers/conv.py | 2 +- brainpy/dyn/layers/normalization.py | 2 +- brainpy/dyn/layers/nvar.py | 4 +- brainpy/dyn/layers/pooling.py | 2 +- brainpy/dyn/neurons/biological_models.py | 10 +- brainpy/dyn/neurons/input_groups.py | 2 + brainpy/dyn/neurons/reduced_models.py | 277 ++++++++++++----------- brainpy/dyn/synapses/abstract_models.py | 72 +++++- brainpy/modes.py | 6 +- 10 files changed, 242 insertions(+), 150 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 7a365dc86..63321dd03 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.2.3" +__version__ = "2.2.3.1" try: @@ -8,10 +8,15 @@ del jaxlib except ModuleNotFoundError: raise ModuleNotFoundError( - 'Please install jaxlib. See ' - 'https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax ' - 'for installation instructions.' - ) from None + ''' + +Please install jaxlib. See + +https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax + +for installation instructions. + + ''') from None # fundamental modules diff --git a/brainpy/dyn/layers/conv.py b/brainpy/dyn/layers/conv.py index 513077b22..26760e49c 100644 --- a/brainpy/dyn/layers/conv.py +++ b/brainpy/dyn/layers/conv.py @@ -6,7 +6,7 @@ import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem from brainpy.initialize import XavierNormal, ZeroInit, parameter -from brainpy.modes import Mode, TrainingMode, NormalMode, training, check +from brainpy.modes import Mode, TrainingMode, training __all__ = [ 'GeneralConv', diff --git a/brainpy/dyn/layers/normalization.py b/brainpy/dyn/layers/normalization.py index 8c7444b23..a996f5ba9 100644 --- a/brainpy/dyn/layers/normalization.py +++ b/brainpy/dyn/layers/normalization.py @@ -9,7 +9,7 @@ import brainpy.math as bm from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter from brainpy.dyn.base import DynamicalSystem -from brainpy.modes import Mode, TrainingMode, NormalMode, training, check +from brainpy.modes import Mode, TrainingMode, NormalMode, training, check_mode __all__ = [ 'BatchNorm', diff --git a/brainpy/dyn/layers/nvar.py b/brainpy/dyn/layers/nvar.py index 553dbaab1..43bd3b4e1 100644 --- a/brainpy/dyn/layers/nvar.py +++ b/brainpy/dyn/layers/nvar.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem -from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check +from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check_mode from brainpy.tools.checking import (check_integer, check_sequence) __all__ = [ @@ -73,7 +73,7 @@ def __init__( name: str = None, ): super(NVAR, self).__init__(mode=mode, name=name) - check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__) + check_mode(self.mode, (BatchingMode, NormalMode), self.__class__.__name__) # parameters order = tuple() if order is None else order diff --git a/brainpy/dyn/layers/pooling.py b/brainpy/dyn/layers/pooling.py index 21d12fa67..4eb85985d 100644 --- a/brainpy/dyn/layers/pooling.py +++ b/brainpy/dyn/layers/pooling.py @@ -4,7 +4,7 @@ import jax.lax import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem -from brainpy.modes import Mode, TrainingMode, NormalMode, training, check +from brainpy.modes import Mode, training __all__ = [ 'Pool', diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index a691e54f7..fdd7e6a27 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -8,7 +8,7 @@ from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.integrators.sde import sdeint -from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check +from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check_mode from brainpy.tools.checking import check_initializer from brainpy.types import Shape, Array @@ -219,7 +219,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__) + check_mode(self.mode, (BatchingMode, NormalMode), self.__class__.__name__) # parameters self.ENa = parameter(ENa, self.varshape, allow_none=False) @@ -427,7 +427,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (BatchingMode, NormalMode), self.__class__) + check_mode(self.mode, (BatchingMode, NormalMode), self.__class__) # params self.V_Ca = parameter(V_Ca, self.varshape, allow_none=False) @@ -685,7 +685,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (NormalMode, BatchingMode), self.__class__) + check_mode(self.mode, (NormalMode, BatchingMode), self.__class__) # conductance parameters self.gAHP = parameter(gAHP, self.varshape, allow_none=False) @@ -994,7 +994,7 @@ def __init__( ): # initialization super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode) - check(self.mode, (BatchingMode, NormalMode), self.__class__) + check_mode(self.mode, (BatchingMode, NormalMode), self.__class__) # parameters self.ENa = parameter(ENa, self.varshape, allow_none=False) diff --git a/brainpy/dyn/neurons/input_groups.py b/brainpy/dyn/neurons/input_groups.py index 413ac0597..7eba286ff 100644 --- a/brainpy/dyn/neurons/input_groups.py +++ b/brainpy/dyn/neurons/input_groups.py @@ -11,6 +11,7 @@ from brainpy.modes import Mode, BatchingMode, normal from brainpy.types import Shape, Array + __all__ = [ 'InputGroup', 'OutputGroup', @@ -205,3 +206,4 @@ def reset(self, batch_size=None): def reset_state(self, batch_size=None): self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + diff --git a/brainpy/dyn/neurons/reduced_models.py b/brainpy/dyn/neurons/reduced_models.py index 517191d3c..f21dffcc7 100644 --- a/brainpy/dyn/neurons/reduced_models.py +++ b/brainpy/dyn/neurons/reduced_models.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable +from typing import Union, Callable, Optional +from functools import partial from jax.lax import stop_gradient import brainpy.math as bm from brainpy.dyn.base import NeuGroup from brainpy.initialize import (ZeroInit, OneInit, Initializer, - parameter, variable, noise as init_noise) + parameter, variable, variable2, noise as init_noise) from brainpy.integrators import sdeint, odeint, JointEq -from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode, normal, check +from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode, normal, check_mode from brainpy.tools.checking import check_initializer, check_callable from brainpy.types import Shape, Array @@ -87,7 +88,7 @@ def __init__( mode=mode, keep_size=keep_size, name=name) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) @@ -100,8 +101,8 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable(self._V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(self._V_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) # integral if self.noise is None: @@ -113,8 +114,8 @@ def derivative(self, V, t, I_ext): return (-V + self.V_rest + self.R * I_ext) / self.tau def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) def update(self, tdi, x=None): if x is not None: self.input += x @@ -191,11 +192,11 @@ def __init__( V_th: Union[float, Array, Initializer, Callable] = 20., R: Union[float, Array, Initializer, Callable] = 1., tau: Union[float, Array, Initializer, Callable] = 10., - tau_ref: Union[float, Array, Initializer, Callable] = None, + tau_ref: Optional[Union[float, Array, Initializer, Callable]] = None, V_initializer: Union[Initializer, Callable, Array] = ZeroInit(), - noise: Union[float, Array, Initializer, Callable] = None, + noise: Optional[Union[float, Array, Initializer, Callable]] = None, method: str = 'exp_auto', - name: str = None, + name: Optional[str] = None, # training parameter mode: Mode = normal, @@ -206,7 +207,7 @@ def __init__( name=name, keep_size=keep_size, mode=mode) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) @@ -223,13 +224,13 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable(self._V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(self._V_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) if self.tau_ref is not None: - self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) - self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -241,13 +242,13 @@ def derivative(self, V, t, I_ext): return (-V + self.V_rest + self.R * I_ext) / self.tau def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) - self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt @@ -419,7 +420,7 @@ def __init__( name=name, mode=mode, keep_size=keep_size, ) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) @@ -437,13 +438,13 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable(V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(V_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) - self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) if self.tau_ref is not None: - self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -452,13 +453,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) - self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) @@ -541,6 +542,7 @@ class AdExIF(NeuGroup): R 1 \ Membrane resistance. tau 10 ms Membrane time constant. Compute by R * C. tau_w 30 ms Time constant of the adaptation current. + tau_ref 0. ms Refractory time. ============= ============== ======== ======================================================================================================================== **Model Variables** @@ -552,6 +554,7 @@ class AdExIF(NeuGroup): w 0 Adaptation current. input 0 External and synaptic input current. spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. t_last_spike -1e7 Last spike time stamp. ================== ================= ========================================================= @@ -575,32 +578,34 @@ def __init__( b: Union[float, Array, Initializer, Callable] = 1., tau: Union[float, Array, Initializer, Callable] = 10., tau_w: Union[float, Array, Initializer, Callable] = 30., + tau_ref: Optional[Union[float, Array, Initializer, Callable]] = 30., R: Union[float, Array, Initializer, Callable] = 1., V_initializer: Union[Initializer, Callable, Array] = ZeroInit(), w_initializer: Union[Initializer, Callable, Array] = ZeroInit(), - noise: Union[float, Array, Initializer, Callable] = None, + noise: Optional[Union[float, Array, Initializer, Callable]] = None, method: str = 'exp_auto', keep_size: bool = False, mode: Mode = normal, - name: str = None + name: Optional[str] = None ): super(AdExIF, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode, ) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) self.V_reset = parameter(V_reset, self.varshape, allow_none=False) self.V_th = parameter(V_th, self.varshape, allow_none=False) self.V_T = parameter(V_T, self.varshape, allow_none=False) - self.delta_T = parameter(delta_T, self.varshape, allow_none=False) self.a = parameter(a, self.varshape, allow_none=False) self.b = parameter(b, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) self.tau = parameter(tau, self.varshape, allow_none=False) self.tau_w = parameter(tau_w, self.varshape, allow_none=False) - self.R = parameter(R, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.delta_T = parameter(delta_T, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=2) # initializers @@ -610,12 +615,15 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = variable(V_initializer, mode, self.varshape) - self.w = variable(w_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(V_initializer, self.varshape, mode) + self.w = variable2(w_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(mode, BatchingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) - + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + if self.tau_ref is not None: + self.refractory = variable2(partial(bm.zeros, dtype=bool), self.varshape, mode) + self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e8, self.varshape, mode) + # functions if self.noise is None: self.integral = odeint(method=method, f=self.derivative) @@ -623,11 +631,16 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.w.value = variable(self._w_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.w.value = variable2(self._w_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=(bm.dftype() + if isinstance(self.mode, TrainingMode) + else bool)), + self.varshape, batch_size) + if self.tau_ref is not None: + self.refractory.value = variable2(partial(bm.zeros, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size) def dV(self, V, t, w, I_ext): exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) @@ -646,10 +659,16 @@ def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt if x is not None: self.input += x V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt) + if self.tau_ref is not None: + refractory = (t - self.t_last_spike) <= self.tau_ref + V = bm.where(refractory, self.V.value, V) spike = V >= self.V_th self.V.value = bm.where(spike, self.V_reset, V) self.w.value = bm.where(spike, w + self.b, w) self.spike.value = spike + if self.tau_ref is not None: + self.refractory.value = bm.logical_or(refractory, spike) + self.t_last_spike.value = bm.where(spike, t, self.t_last_spike) def clear_input(self): self.input[:] = 0. @@ -745,7 +764,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) @@ -763,13 +782,13 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable(V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(V_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) - self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) if self.tau_ref is not None: - self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -778,13 +797,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) - self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau @@ -914,7 +933,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode, ) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) @@ -935,12 +954,12 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = variable(V_initializer, mode, self.varshape) - self.w = variable(w_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(V_initializer, self.varshape, mode) + self.w = variable2(w_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) - self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -949,12 +968,12 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.w.value = variable(self._w_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.w.value = variable2(self._w_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) - self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, w, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau @@ -1098,7 +1117,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # params self.V_rest = parameter(V_rest, self.varshape, allow_none=False) @@ -1129,13 +1148,13 @@ def __init__( self._Vth_initializer = Vth_initializer # variables - self.I1 = variable(I1_initializer, mode, self.varshape) - self.I2 = variable(I2_initializer, mode, self.varshape) - self.V_th = variable(Vth_initializer, mode, self.varshape) - self.V = variable(V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.I1 = variable2(I1_initializer, self.varshape, mode) + self.I2 = variable2(I2_initializer, self.varshape, mode) + self.V_th = variable2(Vth_initializer, self.varshape, mode) + self.V = variable2(V_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) # integral if self.noise is None: @@ -1144,13 +1163,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.I1.value = variable(self._I1_initializer, batch_size, self.varshape) - self.I2.value = variable(self._I2_initializer, batch_size, self.varshape) - self.V_th.value = variable(self._Vth_initializer, batch_size, self.varshape) - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.I1.value = variable2(self._I1_initializer, self.varshape, batch_size) + self.I2.value = variable2(self._I2_initializer, self.varshape, batch_size) + self.V_th.value = variable2(self._Vth_initializer, self.varshape, batch_size) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dI1(self, I1, t): return - self.k1 * I1 @@ -1263,7 +1282,7 @@ def __init__( size=size, keep_size=keep_size, mode=mode) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) @@ -1284,14 +1303,14 @@ def __init__( self._a_initializer = a_initializer # variables - self.a = variable(a_initializer, mode, self.varshape) - self.V = variable(V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.a = variable2(a_initializer, self.varshape, mode) + self.V = variable2(V_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) if self.tau_ref is not None: - self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) - self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -1310,14 +1329,14 @@ def derivative(self): return JointEq([self.dV, self.da]) def reset_state(self, batch_size=None): - self.a.value = variable(self._a_initializer, batch_size, self.varshape) - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.a.value = variable2(self._a_initializer, self.varshape, batch_size) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) - self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt @@ -1455,7 +1474,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # params self.a = parameter(a, self.varshape, allow_none=False) @@ -1474,14 +1493,14 @@ def __init__( self._u_initializer = u_initializer # variables - self.u = variable(u_initializer, mode, self.varshape) - self.V = variable(V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.u = variable2(u_initializer, self.varshape, mode) + self.V = variable2(V_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) if self.tau_ref is not None: - self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) - self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # functions if self.noise is None: @@ -1490,14 +1509,14 @@ def __init__( self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.u.value = variable(self._u_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.u.value = variable2(self._u_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) - self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, u, I_ext): dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext @@ -1685,7 +1704,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.a = parameter(a, self.varshape, allow_none=False) @@ -1708,12 +1727,12 @@ def __init__( self._z_initializer = z_initializer # variables - self.V = variable(self._V_initializer, mode, self.varshape) - self.y = variable(self._y_initializer, mode, self.varshape) - self.z = variable(self._z_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(self._V_initializer, self.varshape, mode) + self.y = variable2(self._y_initializer, self.varshape, mode) + self.z = variable2(self._z_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) # integral if self.noise is None: @@ -1722,12 +1741,12 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.z.value = variable(self._z_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.y.value = variable2(self._y_initializer, self.varshape, batch_size) + self.z.value = variable2(self._z_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dV(self, V, t, y, z, I_ext): return y - self.a * V * V * V + self.b * V * V - z + I_ext @@ -1864,7 +1883,7 @@ def __init__( keep_size=keep_size, name=name, mode=mode) - check(self.mode, (TrainingMode, NormalMode), self.__class__) + check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters self.a = parameter(a, self.varshape, allow_none=False) @@ -1881,11 +1900,11 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = variable(self._V_initializer, mode, self.varshape) - self.w = variable(self._w_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + self.V = variable2(self._V_initializer, self.varshape, mode) + self.w = variable2(self._w_initializer, self.varshape, mode) + self.input = variable2(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) + self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) # integral if self.noise is None: @@ -1894,11 +1913,11 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.w.value = variable(self._w_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.V.value = variable2(self._V_initializer, self.varshape, batch_size) + self.w.value = variable2(self._w_initializer, self.varshape, batch_size) + self.input.value = variable2(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) + self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dV(self, V, t, w, I_ext): return V - V * V * V / 3 - w + I_ext diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 711d05be9..17c7b84aa 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -3,14 +3,15 @@ from typing import Union, Dict, Callable, Optional from jax import vmap -from jax.lax import stop_gradient +from jax.lax import stop_gradient, cond import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One -from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn +from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn, SynConn from brainpy.initialize import Initializer, variable from brainpy.integrators import odeint, JointEq -from brainpy.modes import Mode, BatchingMode, normal +from brainpy.tools.checking import check_integer, check_float +from brainpy.modes import Mode, BatchingMode, normal, NormalMode, check_mode from brainpy.types import Array from ..synouts import CUBA, MgBlock @@ -20,6 +21,7 @@ 'DualExponential', 'Alpha', 'NMDA', + 'PoissonInput', ] @@ -882,3 +884,67 @@ def update(self, tdi, pre_spike=None): # output return self.output(post_vs) + + +class PoissonInput(SynConn): + """Poisson Input to the given `Variable`. + + Adds independent Poisson input to a target variable. For large + numbers of inputs, this is much more efficient than creating a + `PoissonGroup`. The synaptic events are generated randomly during the + simulation and are not preloaded and stored in memory. All the inputs must + target the same variable, have the same frequency and same synaptic weight. + All neurons in the target variable receive independent realizations of + Poisson spike trains. + + Parameters + ---------- + target_var: Variable + The variable that is targeted by this input. + num_input: int + The number of inputs. + freq: float + The frequency of each of the inputs. Must be a scalar. + weight: float + The synaptic weight. Must be a scalar. + """ + + def __init__( + self, + target_var: bm.Variable, + num_input: int, + freq: Union[int, float], + weight: Union[int, float], + seed: Optional[int] = None, + mode: Mode = normal, + name: str = None + ): + from ..neurons.input_groups import InputGroup, OutputGroup + super(PoissonInput, self).__init__(InputGroup(1), OutputGroup(1), name=name, mode=mode) + + # check data + if not isinstance(target_var, bm.Variable): + raise TypeError(f'"target_var" must be an instance of Variable. ' + f'But we got {type(target_var)}: {target_var}') + check_integer(num_input, 'num_input', min_bound=1) + check_float(freq, 'freq', min_bound=0., allow_int=True) + check_float(weight, 'weight', allow_int=True) + check_mode(mode, NormalMode, name=self.__class__.__name__) + + # parameters + self.target_var = target_var + self.num_input = num_input + self.freq = freq + self.weight = weight + self.seed = seed + self.rng = bm.random.RandomState(self.seed) + + def update(self, tdi): + p = self.freq * tdi.dt / 1e3 + a = self.num_input * p + b = self.num_input * (1 - p) + inp = bm.cond((a > 5) * (b > 5), + lambda _: self.rng.normal(a, b * p, self.target_var.shape), + lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape), + None) + self.target_var += inp diff --git a/brainpy/modes.py b/brainpy/modes.py index ceee2740c..d672cff1f 100644 --- a/brainpy/modes.py +++ b/brainpy/modes.py @@ -13,7 +13,7 @@ 'batching', 'training', - 'check', + 'check_mode', ] @@ -42,14 +42,14 @@ class TrainingMode(BatchingMode): training = TrainingMode() -def check(mode, supported_modes, name=''): +def check_mode(mode, supported_modes, name=''): """Check whether the used mode is in the list of the supported models. Parameters ---------- mode: Mode The mode used. - supported_modes: list of type, tuple of type + supported_modes: type, list of type, tuple of type The list of all types to support. name: Any The name. From 4417069e57673f502de50379b26deb5623c4dc20 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 5 Oct 2022 12:30:49 +0800 Subject: [PATCH 2/3] update installation setup --- setup.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index ab529cf28..5dcc2b7a2 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ README = f.read() # require users to install jaxlib before installing brainpy on Windows platform +requirements = ['numpy>=1.15', 'jax>=0.3.0', 'tqdm'] if sys.platform.startswith('win32') or sys.platform.startswith('cygwin'): try: import jaxlib @@ -52,8 +53,10 @@ ---------------------------------------------------------------------- ''') from None +else: + requirements.append('jaxlib>=0.3.0') - +# installation packages packages = find_packages() if 'docs' in packages: packages.remove('docs') @@ -71,25 +74,20 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.7', - install_requires=[ - 'numpy>=1.15', - 'jax>=0.3.0', - 'jaxlib>=0.3.0', - 'tqdm', - ], + install_requires=requirements, url='https://github.com/PKU-NIP-Lab/BrainPy', project_urls={ "Bug Tracker": "https://github.com/PKU-NIP-Lab/BrainPy/issues", "Documentation": "https://brainpy.readthedocs.io/", "Source Code": "https://github.com/PKU-NIP-Lab/BrainPy", }, - keywords='computational neuroscience, ' - 'brain-inspired computation, ' - 'dynamical systems, ' - 'differential equations, ' - 'brain modeling, ' - 'brain dynamics modeling, ' - 'brain dynamics programming', + keywords=('computational neuroscience, ' + 'brain-inspired computation, ' + 'dynamical systems, ' + 'differential equations, ' + 'brain modeling, ' + 'brain dynamics modeling, ' + 'brain dynamics programming'), classifiers=[ 'Natural Language :: English', 'Operating System :: OS Independent', From 9e5fe038f28660ab4970f4fd0ca5c15936f66010 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 5 Oct 2022 13:26:09 +0800 Subject: [PATCH 3/3] update installation setup --- brainpy/__init__.py | 24 ++++++++++++--- docs/quickstart/installation.rst | 50 ++++++++++++++++++++------------ setup.py | 21 +------------- 3 files changed, 53 insertions(+), 42 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 63321dd03..f2ed909ce 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -10,11 +10,27 @@ raise ModuleNotFoundError( ''' -Please install jaxlib. See - -https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax +BrainPy needs jaxlib, please install jaxlib. -for installation instructions. +1. If you are using Windows system, install jaxlib through + + >>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html + +2. If you are using macOS platform, install jaxlib through + + >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html + +3. If you are using Linux platform, install jaxlib through + + >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html + +4. If you are using Linux + CUDA platform, install jaxlib through + + >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14", "jaxlib=0.3.14". + +More detail installation instruction, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax ''') from None diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index d0aeebcfa..fb220b731 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -89,18 +89,18 @@ Linux & MacOS ^^^^^^^^^^^^^ Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or -later) platforms. The provided binary releases of JAX for Linux and macOS +later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS systems are available at - for CPU: https://storage.googleapis.com/jax-releases/jax_releases.html - for GPU: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -To install a CPU-only version of JAX, you can run +If you want to install a CPU-only version of `jax` and `jaxlib`, you can run .. code-block:: bash - pip install --upgrade "jax[cpu]" + pip install --upgrade "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html If you want to install JAX with both CPU and NVidia GPU support, you must first install `CUDA`_ and `CuDNN`_, if they have not already been installed. Next, run @@ -109,7 +109,9 @@ If you want to install JAX with both CPU and NVidia GPU support, you must first pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -Alternatively, you can download the preferred release ".whl" file for jaxlib, and install it via ``pip``: + +Alternatively, you can download the preferred release ".whl" file for jaxlib +from the above release links, and install it via ``pip``: .. code-block:: bash @@ -117,20 +119,33 @@ Alternatively, you can download the preferred release ".whl" file for jaxlib, an pip install jax==0.3.14 -Note that the versions of `jaxlib` and `jax` should be consistent. +.. note:: + + Note that the versions of `jaxlib` and `jax` should be consistent. + + For example, if you are using `jax==0.3.14`, you would better install `jax==0.3.14`. + Windows ^^^^^^^ -For **Windows** users, JAX can be installed by the following methods: +For **Windows** users, `jax` and `jaxlib` can be installed from the community supports. +Specifically, you can install `jax` and `jaxlib` through: + +.. code-block:: bash + + pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html + +If you are using GPU, you can install GPU-versioned wheels through: + +.. code-block:: bash -- **Method 1**: There are several communities support JAX for Windows, please refer - to the github link for more details: https://github.com/cloudhan/jax-windows-builder . - Simply speaking, the provided binary releases of JAX for Windows - are available at https://whls.blob.core.windows.net/unstable/index.html . + pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html - You can download the preferred release ".whl" file, and install it via ``pip``: +Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by +downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html . +Then install it via ``pip``: .. code-block:: bash @@ -138,13 +153,13 @@ For **Windows** users, JAX can be installed by the following methods: pip install jax==0.3.14 -- **Method 2**: For Windows 10+ system, you can use `Windows Subsystem for Linux (WSL)`_. - The installation guide can be found in `WSL Installation Guide for Windows 10`_. - Then, you can install JAX in WSL just like the installation step in Linux/MacOs. - - -- **Method 3**: You can also `build JAX from source`_. +WSL +^^^ +Moreover, for Windows 10+ system, we recommend using `Windows Subsystem for Linux (WSL)`_. +The installation guide can be found in +`WSL Installation Guide for Windows 10/11 `_. +Then, you can install JAX in WSL just like the installation step in Linux/MacOs. Dependency 3: brainpylib @@ -194,7 +209,6 @@ packages: .. _Matplotlib: https://matplotlib.org/ .. _JAX: https://github.com/google/jax .. _Windows Subsystem for Linux (WSL): https://docs.microsoft.com/en-us/windows/wsl/about -.. _WSL Installation Guide for Windows 10: https://docs.microsoft.com/en-us/windows/wsl/install-win10 .. _build JAX from source: https://jax.readthedocs.io/en/latest/developer.html .. _SymPy: https://github.com/sympy/sympy .. _Numba: https://numba.pydata.org/ diff --git a/setup.py b/setup.py index 5dcc2b7a2..a84c87566 100644 --- a/setup.py +++ b/setup.py @@ -37,25 +37,6 @@ with io.open(os.path.join(here, 'README.md'), 'r', encoding='utf-8') as f: README = f.read() -# require users to install jaxlib before installing brainpy on Windows platform -requirements = ['numpy>=1.15', 'jax>=0.3.0', 'tqdm'] -if sys.platform.startswith('win32') or sys.platform.startswith('cygwin'): - try: - import jaxlib - except ModuleNotFoundError: - raise ModuleNotFoundError(''' - ----------------------------------------------------------------------- - We detect that your are using Windows platform. - Please manually install "jaxlib" before installing "brainpy". - See https://whls.blob.core.windows.net/unstable/index.html - for jaxlib's Windows wheels. ----------------------------------------------------------------------- - -''') from None -else: - requirements.append('jaxlib>=0.3.0') - # installation packages packages = find_packages() if 'docs' in packages: @@ -74,7 +55,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.7', - install_requires=requirements, + install_requires=['numpy>=1.15', 'jax>=0.3.0', 'tqdm'], url='https://github.com/PKU-NIP-Lab/BrainPy', project_urls={ "Bug Tracker": "https://github.com/PKU-NIP-Lab/BrainPy/issues",