Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.3.6"
__version__ = "2.3.7"


# fundamental supporting modules
Expand Down Expand Up @@ -61,20 +61,23 @@
experimental,
)
from brainpy._src.dyn.base import not_pass_shared
from brainpy._src.dyn.base import (DynamicalSystem,
DynamicalSystemNS,
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
Container as Container,
Sequential as Sequential,
Network as Network,
NeuGroup as NeuGroup,
NeuGroupNS as NeuGroupNS,
SynConn as SynConn,
SynOut as SynOut,
SynSTP as SynSTP,
SynLTP as SynLTP,
TwoEndConn as TwoEndConn,
CondNeuGroup as CondNeuGroup,
Channel as Channel)
from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS,
NeuGroupNS as NeuGroupNS)
from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS,
SynSTPNS as SynSTPNS,
SynConnNS as SynConnNS, )
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share, Delay
Expand Down
14 changes: 14 additions & 0 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,20 @@ def __del__(self):
def clear_input(self):
pass

def __rrshift__(self, other):
"""Support using right shift operator to call modules.

Examples
--------

>>> import brainpy as bp
>>> x = bp.math.random.rand((10, 10))
>>> l = bp.layers.Activation('tanh')
>>> y = x >> l

"""
return self.__call__(other)


class DynamicalSystemNS(DynamicalSystem):
"""Dynamical system without the need of shared parameters passing into ``update()`` function."""
Expand Down
28 changes: 16 additions & 12 deletions brainpy/_src/dyn/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,18 @@ def update(self, x):
x = bm.as_jax(x)

if share.load('fit'):
mean = jnp.mean(x, self.axis)
mean_of_square = jnp.mean(_square(x), self.axis)
if self.axis_name is not None:
mean, mean_of_square = jnp.split(lax.pmean(jnp.concatenate([mean, mean_of_square]),
axis_name=self.axis_name,
axis_index_groups=self.axis_index_groups),
2)
var = jnp.maximum(0., mean_of_square - _square(mean))
self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
mean = jnp.mean(x, self.axis)
mean_of_square = jnp.mean(_square(x), self.axis)
if self.axis_name is not None:
mean, mean_of_square = jnp.split(
lax.pmean(jnp.concatenate([mean, mean_of_square]),
axis_name=self.axis_name,
axis_index_groups=self.axis_index_groups),
2
)
var = jnp.maximum(0., mean_of_square - _square(mean))
self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
else:
mean = self.running_mean.value
var = self.running_var.value
Expand Down Expand Up @@ -488,7 +490,7 @@ def __init__(
self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape))
self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape))

def update(self,x):
def update(self, x):
if x.shape[-len(self.normalized_shape):] != self.normalized_shape:
raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), '
f'but we got {x.shape}')
Expand Down Expand Up @@ -629,6 +631,8 @@ def __init__(
scale_initializer=scale_initializer,
mode=mode,
name=name)


BatchNorm1D = BatchNorm1d
BatchNorm2D = BatchNorm2d
BatchNorm3D = BatchNorm3d
BatchNorm3D = BatchNorm3d
7 changes: 6 additions & 1 deletion brainpy/_src/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from brainpy import check
from brainpy._src.dyn.base import NeuGroupNS
from brainpy._src.dyn.context import share
from brainpy._src.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_
from brainpy._src.initialize import (OneInit,
Uniform,
Initializer,
parameter,
noise as init_noise,
variable_)
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.integrators.sde.generic import sdeint
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/neurons/input_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
from brainpy._src.dyn.context import share
import brainpy.math as bm
from brainpy._src.dyn.base import NeuGroupNS, not_pass_shared
from brainpy._src.dyn.base import NeuGroupNS
from brainpy._src.initialize import Initializer, parameter, variable_
from brainpy.types import Shape, ArrayType

Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/neurons/noise_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
from brainpy._src.dyn.context import share
from brainpy import math as bm, initialize as init
from brainpy._src.dyn.base import NeuGroupNS as NeuGroup, not_pass_shared
from brainpy._src.dyn.base import NeuGroupNS as NeuGroup
from brainpy._src.initialize import Initializer
from brainpy._src.integrators.sde.generic import sdeint
from brainpy.types import ArrayType, Shape
Expand Down
Loading