diff --git a/brainpy/dyn/layers/__init__.py b/brainpy/dyn/layers/__init__.py index 4f95bda2f..5234225c2 100644 --- a/brainpy/dyn/layers/__init__.py +++ b/brainpy/dyn/layers/__init__.py @@ -6,3 +6,5 @@ from .reservoir import * from .rnncells import * from .conv import * +from .normalization import * +from .pooling import * \ No newline at end of file diff --git a/brainpy/dyn/layers/normalization.py b/brainpy/dyn/layers/normalization.py new file mode 100644 index 000000000..8c7444b23 --- /dev/null +++ b/brainpy/dyn/layers/normalization.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- + +from typing import Union + +import jax.nn +import jax.numpy as jnp +import jax.lax + +import brainpy.math as bm +from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter +from brainpy.dyn.base import DynamicalSystem +from brainpy.modes import Mode, TrainingMode, NormalMode, training, check + +__all__ = [ + 'BatchNorm', + 'BatchNorm1d', + 'BatchNorm2d', + 'BatchNorm3d', + 'GroupNorm', + 'LayerNorm', + 'InstanceNorm', +] + + +class BatchNorm(DynamicalSystem): + """Batch Normalization node. + This layer aims to reduce the internal covariant shift of data. It + normalizes a batch of data by fixing the mean and variance of inputs + on each feature (channel). Most commonly, the first axis of the data + is the batch, and the last is the channel. However, users can specify + the axes to be normalized. + + adapted from jax.example_libraries.stax.BatchNorm + https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm + + Parameters + ---------- + axis: int, tuple, list + axes where the data will be normalized. The feature (channel) axis should be excluded. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + use_bias: bool + whether to translate data in refactoring. Default: True + use_scale: bool + whether to scale data in refactoring. Default: True + beta_init: brainpy.init.Initializer + an initializer generating the original translation matrix + gamma_init: brainpy.init.Initializer + an initializer generating the original scaling matrix + """ + + def __init__(self, + axis: Union[int, tuple, list], + epsilon: float = 1e-5, + use_bias: bool = True, + use_scale: bool = True, + beta_init: Initializer = ZeroInit(), + gamma_init: Initializer = OneInit(), + mode: Mode = training, + name: str = None, + **kwargs): + super(BatchNorm, self).__init__(name=name, mode=mode) + self.epsilon = epsilon + self.bias = use_bias + self.scale = use_scale + self.beta_init = beta_init if use_bias else () + self.gamma_init = gamma_init if use_scale else () + self.axis = (axis,) if jnp.isscalar(axis) else axis + + def _check_input_dim(self, x): + pass + + def update(self, sha, x): + self._check_input_dim(x) + + input_shape = tuple(d for i, d in enumerate(x.shape) if i not in self.axis) + self.beta = parameter(self.beta_init, input_shape) if self.bias else None + self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None + if isinstance(self.mode, TrainingMode): + self.beta = bm.TrainVar(self.beta) + self.gamma = bm.TrainVar(self.gamma) + + ed = tuple(None if i in self.axis else slice(None) for i in range(jnp.ndim(x))) + # output = bm.normalize(x, self.axis, epsilon=self.epsilon) + print(x) + output = jax.nn.standardize(x.value, self.axis, epsilon=self.epsilon) + print(output) + if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed] + if self.bias: return output + self.beta[ed] + if self.scale: return self.gamma[ed] * output + return output + + def reset_state(self, batch_size=None): + pass + + +class BatchNorm1d(BatchNorm): + """1-D batch normalization. + The data should be of `(b, l, c)`, where `b` is the batch dimension, + `l` is the layer dimension, and `c` is the channel dimension, or of + '(b, c)'. + + Parameters + ---------- + axis: int, tuple, list + axes where the data will be normalized. The feature (channel) axis should be excluded. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + use_bias: bool + whether to translate data in refactoring. Default: True + use_scale: bool + whether to scale data in refactoring. Default: True + beta_init: brainpy.init.Initializer + an initializer generating the original translation matrix + gamma_init: brainpy.init.Initializer + an initializer generating the original scaling matrix + """ + def __init__(self, axis=(0, 1), **kwargs): + super(BatchNorm1d, self).__init__(axis=axis, **kwargs) + + def _check_input_dim(self, x): + ndim = len(x.shape) + if ndim != 2 and ndim != 3: + raise ValueError( + "expected 2D or 3D input (got {}D input)".format(ndim) + ) + if ndim == 2 and len(self.axis) == 2: + self.axis = (0,) + + +class BatchNorm2d(BatchNorm): + """2-D batch normalization. + The data should be of `(b, h, w, c)`, where `b` is the batch dimension, + `h` is the height dimension, `w` is the width dimension, and `c` is the + channel dimension. + + Parameters + ---------- + axis: int, tuple, list + axes where the data will be normalized. The feature (channel) axis should be excluded. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + use_bias: bool + whether to translate data in refactoring. Default: True + use_scale: bool + whether to scale data in refactoring. Default: True + beta_init: brainpy.init.Initializer + an initializer generating the original translation matrix + gamma_init: brainpy.init.Initializer + an initializer generating the original scaling matrix + """ + def __init__(self, axis=(0, 1, 2), **kwargs): + super(BatchNorm2d, self).__init__(axis=axis, **kwargs) + + def _check_input_dim(self, x): + ndim = len(x.shape) + if ndim != 4: + raise ValueError( + "expected 4D input (got {}D input)".format(ndim) + ) + + +class BatchNorm3d(BatchNorm): + """3-D batch normalization. + The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension, + `h` is the height dimension, `w` is the width dimension, `d` is the depth + dimension, and `c` is the channel dimension. + + Parameters + ---------- + axis: int, tuple, list + axes where the data will be normalized. The feature (channel) axis should be excluded. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + use_bias: bool + whether to translate data in refactoring. Default: True + use_scale: bool + whether to scale data in refactoring. Default: True + beta_init: brainpy.init.Initializer + an initializer generating the original translation matrix + gamma_init: brainpy.init.Initializer + an initializer generating the original scaling matrix + """ + def __init__(self, axis=(0, 1, 2, 3), **kwargs): + super(BatchNorm3d, self).__init__(axis=axis, **kwargs) + + def _check_input_dim(self, x): + ndim = len(x.shape) + if ndim != 5: + raise ValueError( + "expected 5D input (got {}D input)".format(ndim) + ) + + +class LayerNorm(DynamicalSystem): + """Layer normalization (https://arxiv.org/abs/1607.06450). + + This layer normalizes data on each example, independently of the batch. More + specifically, it normalizes data of shape (b, d1, d2, ..., c) on the axes of + the data dimensions and the channel (d1, d2, ..., c). Different from batch + normalization, gamma and beta are assigned to each position (elementwise + operation) instead of the whole channel. If users want to assign a single + gamma and beta to a whole example/whole channel, please use GroupNorm/ + InstanceNorm. + + Parameters + ---------- + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + use_bias: bool + whether to translate data in refactoring. Default: True + use_scale: bool + whether to scale data in refactoring. Default: True + beta_init: brainpy.init.Initializer + an initializer generating the original translation matrix + gamma_init: brainpy.init.Initializer + an initializer generating the original scaling matrix + axis: int, tuple, list + axes where the data will be normalized. The batch axis should be excluded. + """ + def __init__(self, + epsilon: float = 1e-5, + use_bias: bool = True, + use_scale: bool = True, + beta_init: Initializer = ZeroInit(), + gamma_init: Initializer = OneInit(), + axis: Union[int, tuple] = None, + mode: Mode = training, + name: str = None, + **kwargs): + super(LayerNorm, self).__init__(name=name, mode=mode) + self.epsilon = epsilon + self.bias = use_bias + self.scale = use_scale + self.beta_init = beta_init if use_bias else () + self.gamma_init = gamma_init if use_scale else () + self.axis = (axis,) if jnp.isscalar(axis) else axis + + def default_axis(self, x): + # default: the first axis (batch dim) is excluded + return tuple(i for i in range(1, len(x.shape))) + + def update(self, sha, x): + if self.axis is None: + self.axis = self.default_axis(x) + # todo: what if elementwise_affine = False? + input_shape = tuple(d for i, d in enumerate(x.shape) if i in self.axis) + self.beta = parameter(self.beta_init, input_shape) if self.bias else None + self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None + if isinstance(self.mode, TrainingMode): + self.beta = bm.TrainVar(self.beta) + self.gamma = bm.TrainVar(self.gamma) + + ed = tuple(None if i not in self.axis else slice(None) for i in range(jnp.ndim(x))) + output = bm.normalize(x, self.axis, epsilon=self.epsilon) + if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed] + if self.bias: return output + self.beta[ed] + if self.scale: return self.gamma[ed] * output + return output + + def reset_state(self, batch_size=None): + pass + + +class GroupNorm(DynamicalSystem): + """Group normalization layer. + + This layer divides channels into groups and normalizes the features within each + group. Its computation is also independent of the batch size. The feature size + must be multiple of the group size. + + The shape of the data should be (b, d1, d2, ..., c), where `d` denotes the batch + size and `c` denotes the feature (channel) size. The `d` and `c` axis should be + excluded in parameter `axis`. + + Parameters + ---------- + num_groups: int + the number of groups. It should be a factor of the number of features. + group_size: int + the group size. It should equal to int(num_features / num_groups). + Either `num_groups` or `group_size` should be specified. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + use_bias: bool + whether to translate data in refactoring. Default: True + use_scale: bool + whether to scale data in refactoring. Default: True + beta_init: brainpy.init.Initializer + an initializer generating the original translation matrix + gamma_init: brainpy.init.Initializer + an initializer generating the original scaling matrix + axis: int, tuple, list + axes where the data will be normalized. Besides the batch axis, the channel + axis should be also excluded, since it will be automatically added to `axis`. + """ + def __init__(self, + num_groups: int = None, + group_size: int = None, + epsilon: float = 1e-5, + use_bias: bool = True, + use_scale: bool = True, + beta_init: Initializer = ZeroInit(), + gamma_init: Initializer = OneInit(), + axis: Union[int, tuple] = None, + mode: Mode = training, + name: str = None, + **kwargs): + super(GroupNorm, self).__init__(name=name, mode=mode) + self.num_groups = num_groups + self.group_size = group_size + self.epsilon = epsilon + self.bias = use_bias + self.scale = use_scale + self.beta_init = beta_init if use_bias else () + self.gamma_init = gamma_init if use_scale else () + self.norm_axis = (axis,) if jnp.isscalar(axis) else axis + + def update(self, sha, x): + num_channels = x.shape[-1] + self.ndim = len(x) + + # compute num_groups and group_size + if ((self.num_groups is None and self.group_size is None) or + (self.num_groups is not None and self.group_size is not None)): + raise ValueError('Either `num_groups` or `group_size` should be specified. ' + 'Once one is specified, the other will be automatically ' + 'computed.') + + if self.num_groups is None: + assert self.group_size > 0, '`group_size` should be a positive integer.' + if num_channels % self.group_size != 0: + raise ValueError('The number of channels ({}) is not multiple of the ' + 'group size ({}).'.format(num_channels, self.group_size)) + else: + self.num_groups = num_channels // self.group_size + else: # self.num_groups is not None: + assert self.num_groups > 0, '`num_groups` should be a positive integer.' + if num_channels % self.num_groups != 0: + raise ValueError('The number of channels ({}) is not multiple of the ' + 'number of groups ({}).'.format(num_channels, self.num_groups)) + else: + self.group_size = num_channels // self.num_groups + + # axes for normalization + if self.norm_axis is None: + # default: the first axis (batch dim) and the second-last axis (num_group dim) are excluded + self.norm_axis = tuple(i for i in range(1, len(x.shape) - 1)) + (self.ndim,) + + group_shape = x.shape[:-1] + (self.num_groups, self.group_size) + input_shape = tuple(d for i, d in enumerate(group_shape) if i in self.norm_axis) + self.beta = parameter(self.beta_init, input_shape) if self.bias else None + self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None + if isinstance(self.mode, TrainingMode): + self.beta = bm.TrainVar(self.beta) + self.gamma = bm.TrainVar(self.gamma) + + group_shape = x.shape[:-1] + (self.num_groups, self.group_size) + ff_reshape = x.reshape(group_shape) + ed = tuple(None if i not in self.norm_axis else slice(None) for i in range(jnp.ndim(ff_reshape))) + output = bm.normalize(ff_reshape, self.norm_axis, epsilon=self.epsilon) + if self.bias and self.scale: + output = self.gamma[ed] * output + self.beta[ed] + elif self.bias: + output = output + self.beta[ed] + elif self.scale: + output = self.gamma[ed] * output + return output.reshape(x.shape) + + +class InstanceNorm(GroupNorm): + """Instance normalization layer. + + This layer normalizes the data within each feature. It can be regarded as + a group normalization layer in which `group_size` equals to 1. + + Parameters + ---------- + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + use_bias: bool + whether to translate data in refactoring. Default: True + use_scale: bool + whether to scale data in refactoring. Default: True + beta_init: brainpy.init.Initializer + an initializer generating the original translation matrix + gamma_init: brainpy.init.Initializer + an initializer generating the original scaling matrix + axis: int, tuple, list + axes where the data will be normalized. The batch and channel axes + should be excluded. + """ + def __init__(self, + epsilon: float = 1e-5, + use_bias: bool = True, + use_scale: bool = True, + beta_init: Initializer = ZeroInit(), + gamma_init: Initializer = OneInit(), + axis: Union[int, tuple] = None, + **kwargs): + super(InstanceNorm, self).__init__(group_size=1, epsilon=epsilon, use_bias=use_bias, + use_scale=use_scale, beta_init=beta_init, + gamma_init=gamma_init, axis=axis, **kwargs) \ No newline at end of file diff --git a/brainpy/dyn/layers/pooling.py b/brainpy/dyn/layers/pooling.py new file mode 100644 index 000000000..21d12fa67 --- /dev/null +++ b/brainpy/dyn/layers/pooling.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- + + +import jax.lax +import brainpy.math as bm +from brainpy.dyn.base import DynamicalSystem +from brainpy.modes import Mode, TrainingMode, NormalMode, training, check + +__all__ = [ + 'Pool', + 'MaxPool', + 'AvgPool', + 'MinPool' +] + + +class Pool(DynamicalSystem): + def __init__(self, init_v, reduce_fn, window_shape, strides, padding, + mode: Mode = training, + name: str = None, + **kwargs): + """Pooling functions are implemented using the ReduceWindow XLA op. + + Args: + init_v: scalar + the initial value for the reduction + reduce_fn: callable + a reduce function of the form `(T, T) -> T`. + window_shape: tuple + a shape tuple defining the window to reduce over. + strides: sequence[int] + a sequence of `n` integers, representing the inter-window strides. + padding: str, sequence[int] + either the string `'SAME'`, the string `'VALID'`, or a sequence + of `n` `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + + Returns: + The output of the reduction for each window slice. + """ + super(Pool, self).__init__(name=name, mode=mode) + self.init_v = init_v + self.reduce_fn = reduce_fn + self.window_shape = window_shape + self.strides = strides or (1,) * len(window_shape) + assert len(self.window_shape) == len(self.strides), ( + f"len({self.window_shape}) must equal len({self.strides})") + self.strides = (1,) + self.strides + (1,) + self.dims = (1,) + window_shape + (1,) + self.is_single_input = False + + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}") + assert all([len(x) == 2 for x in padding]), ( + f"each entry in padding {padding} must be length 2") + padding = ((0, 0),) + padding + ((0, 0),) + self.padding = padding + + def update(self, sha, x): + input_shapes = tuple(d for d in x.shape if d is not None) + assert len(input_shapes) == len(self.dims), f"len({len(input_shapes)}) != len({self.dims})" + + # padding_vals = jax.lax.padtype_to_pads(input_shapes, self.dims, self.strides, self.padding) + # ones = (1,) * len(self.dims) + # out_shapes = jax.lax.reduce_window_shape_tuple( + # input_shapes, self.dims, self.strides, padding_vals, ones, ones) + # + # out_shapes = tuple((None,)) + tuple(d for i, d in enumerate(out_shapes) if i != 0) + + y = jax.lax.reduce_window(x, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding) + + return y + + +class AvgPool(Pool): + """Pools the input by taking the average over a window. + + Args: + window_shape: tuple + a shape tuple defining the window to reduce over. + strides: sequence[int] + a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`). + padding: str, sequence[int] + either the string `'SAME'`, the string `'VALID'`, or a sequence + of `n` `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension (default: `'VALID'`). + + Returns: + The average for each window slice. + """ + + def __init__(self, window_shape, strides=None, padding="VALID"): + super(AvgPool, self).__init__( + init_v=0., + reduce_fn=jax.lax.add, + window_shape=window_shape, + strides=strides, + padding=padding + ) + + def update(self, sha, x): + y = jax.lax.reduce_window(x, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding) + y = y / bm.prod(bm.asarray(self.window_shape)) + return y + + +class MaxPool(Pool): + """Pools the input by taking the maximum over a window. + + Args: + window_shape: tuple + a shape tuple defining the window to reduce over. + strides: sequence[int] + a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`). + padding: str, sequence[int] + either the string `'SAME'`, the string `'VALID'`, or a sequence + of `n` `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension (default: `'VALID'`). + + Returns: + The maximum for each window slice. + """ + def __init__(self, window_shape, strides=None, padding="VALID"): + super(MaxPool, self).__init__( + init_v=-bm.inf, + reduce_fn=jax.lax.max, + window_shape=window_shape, + strides=strides, + padding=padding + ) + + +class MinPool(Pool): + """Pools the input by taking the minimum over a window. + + Args: + window_shape: tuple + a shape tuple defining the window to reduce over. + strides: sequence[int] + a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`). + padding: str, sequence[int] + either the string `'SAME'`, the string `'VALID'`, or a sequence + of `n` `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension (default: `'VALID'`). + + Returns: + The minimum for each window slice. + """ + def __init__(self, window_shape, strides=None, padding="VALID"): + super(MinPool, self).__init__( + init_v=bm.inf, + reduce_fn=jax.lax.min, + window_shape=window_shape, + strides=strides, + padding=padding + ) \ No newline at end of file diff --git a/brainpy/dyn/layers/tests/test_normalization.py b/brainpy/dyn/layers/tests/test_normalization.py new file mode 100644 index 000000000..56bc2cca2 --- /dev/null +++ b/brainpy/dyn/layers/tests/test_normalization.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + + +from unittest import TestCase + +import brainpy as bp + + +class TestBatchNorm1d(TestCase): + def test_batchnorm1d1(self): + class BatchNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(BatchNormNet, self).__init__() + self.norm = bp.dyn.layers.BatchNorm1d(axis=(0, 1, 2)) + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((2, 3, 4)) + inputs[0, 0, :] = 2. + inputs[0, 1, 0] = 5. + print(inputs) + model = BatchNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + def test_batchnorm1d2(self): + class BatchNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(BatchNormNet, self).__init__() + self.norm = bp.dyn.layers.BatchNorm1d() + self.dense = bp.dyn.layers.Dense(num_in=4, num_out=4) + + def update(self, shared, x): + x = self.norm(shared, x) + x = self.dense(shared, x) + return x + + inputs = bp.math.ones((2, 4)) + inputs[0, :] = 2. + print(inputs) + model = BatchNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + +class TestBatchNorm2d(TestCase): + def test_batchnorm2d(self): + class BatchNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(BatchNormNet, self).__init__() + self.norm = bp.dyn.layers.BatchNorm2d() + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((10, 32, 32, 3)) + inputs[0, 1, :, :] = 2. + print(inputs) + model = BatchNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + +class TestBatchNorm3d(TestCase): + def test_batchnorm3d(self): + class BatchNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(BatchNormNet, self).__init__() + self.norm = bp.dyn.layers.BatchNorm3d() + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((10, 32, 32, 16, 3)) + print(inputs) + model = BatchNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + +class TestBatchNorm(TestCase): + def test_batchnorm1(self): + class BatchNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(BatchNormNet, self).__init__() + self.norm = bp.dyn.layers.BatchNorm(axis=(0, 2), use_bias=False) # channel axis: 1 + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((2, 3, 4)) + inputs[0, 0, :] = 2. + inputs[0, 1, 0] = 5. + print(inputs) + model = BatchNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + def test_batchnorm2(self): + class BatchNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(BatchNormNet, self).__init__() + self.norm = bp.dyn.layers.BatchNorm(axis=(0, 2)) # channel axis: 1 + self.dense = bp.dyn.layers.Dense(num_in=12, num_out=2) + + def update(self, shared, x): + x = self.norm(shared, x) + x = x.reshape(-1, 12) + x = self.dense(shared, x) + return x + + inputs = bp.math.ones((2, 3, 4)) + inputs[0, 0, :] = 2. + inputs[0, 1, 0] = 5. + # print(inputs) + model = BatchNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + +class TestLayerNorm(TestCase): + def test_layernorm1(self): + class LayerNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(LayerNormNet, self).__init__() + self.norm = bp.dyn.layers.LayerNorm() + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((2, 3, 4)) + inputs[0, 0, :] = 2. + inputs[0, 1, 0] = 5. + print(inputs) + model = LayerNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + def test_layernorm2(self): + class LayerNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(LayerNormNet, self).__init__() + self.norm = bp.dyn.layers.LayerNorm(axis=2) + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((2, 3, 4)) + inputs[0, 0, :] = 2. + inputs[0, 1, 0] = 5. + print(inputs) + model = LayerNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + +class TestInstanceNorm(TestCase): + def test_instancenorm(self): + class InstanceNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(InstanceNormNet, self).__init__() + self.norm = bp.dyn.layers.InstanceNorm() + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((2, 3, 4)) + inputs[0, 0, :] = 2. + inputs[0, 1, 0] = 5. + print(inputs) + model = InstanceNormNet() + shared = {'fit': False} + print(model(shared, inputs)) + + +class TestGroupNorm(TestCase): + def test_groupnorm1(self): + class GroupNormNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(GroupNormNet, self).__init__() + self.norm = bp.dyn.layers.GroupNorm(num_groups=2) + + def update(self, shared, x): + x = self.norm(shared, x) + return x + + inputs = bp.math.ones((2, 3, 4)) + inputs[0, 0, :] = 2. + inputs[0, 1, 0] = 5. + print(inputs) + model = GroupNormNet() + shared = {'fit': False} + print(model(shared, inputs)) \ No newline at end of file diff --git a/brainpy/dyn/layers/tests/test_pooling.py b/brainpy/dyn/layers/tests/test_pooling.py new file mode 100644 index 000000000..45502cfdc --- /dev/null +++ b/brainpy/dyn/layers/tests/test_pooling.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +import random + +import pytest +from unittest import TestCase +import brainpy as bp +import jax.numpy as jnp +import jax +import numpy as np + + +class TestPool(TestCase): + def test_maxpool(self): + class MaxPoolNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(MaxPoolNet, self).__init__() + self.maxpool = bp.dyn.layers.MaxPool((2, 2)) + + def update(self, sha, x): + x = self.maxpool(sha, x) + return x + + x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) + shared = {'fit': False} + net = MaxPoolNet() + y = net(shared, x) + print("out shape: ", y.shape) + expected_y = jnp.array([ + [4., 5.], + [7., 8.], + ]).reshape((1, 2, 2, 1)) + np.testing.assert_allclose(y, expected_y) + + def test_minpool(self): + class MinPoolNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(MinPoolNet, self).__init__() + self.maxpool = bp.dyn.layers.MinPool((2, 2)) + + def update(self, sha, x): + x = self.maxpool(sha, x) + return x + + x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) + shared = {'fit': False} + net = MinPoolNet() + y = net(shared, x) + print("out shape: ", y.shape) + expected_y = jnp.array([ + [0., 1.], + [3., 4.], + ]).reshape((1, 2, 2, 1)) + np.testing.assert_allclose(y, expected_y) + + def test_avgpool(self): + class AvgPoolNet(bp.dyn.DynamicalSystem): + def __init__(self): + super(AvgPoolNet, self).__init__() + self.maxpool = bp.dyn.layers.AvgPool((2, 2)) + + def update(self, sha, x): + x = self.maxpool(sha, x) + return x + + x = jnp.full((1, 3, 3, 1), 2.) + shared = {'fit': False} + net = AvgPoolNet() + y = net(shared, x) + print("out shape: ", y.shape) + np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.)) \ No newline at end of file diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index 12846e0e0..1ac3b2eb0 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -77,12 +77,11 @@ def __init__( gpu_func = None # register OP - _, self.op = brainpylib.register_op(self.name, - cpu_func=cpu_func, - gpu_func=gpu_func, - out_shapes=eval_shape, - apply_cpu_func_to_gpu=apply_cpu_func_to_gpu, - return_primitive=True) + self.op = brainpylib.register_op(self.name, + cpu_func=cpu_func, + gpu_func=gpu_func, + out_shapes=eval_shape, + apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) def __call__(self, *args, **kwargs): args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a, @@ -131,6 +130,7 @@ def register_op( def fixed_op(*inputs): inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs]) - return f(*inputs) + res = f.bind(*inputs) + return res[0] if len(res) == 1 else res return fixed_op diff --git a/brainpy/math/operators/tests/test_op_register.py b/brainpy/math/operators/tests/test_op_register.py index d253cc0fe..f9f99aea6 100644 --- a/brainpy/math/operators/tests/test_op_register.py +++ b/brainpy/math/operators/tests/test_op_register.py @@ -24,6 +24,7 @@ def event_sum_op(outs, ins): event_sum = bm.register_op(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval) +event_sum2 = bm.XLACustomOp(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval) event_sum = bm.jit(event_sum) @@ -83,6 +84,36 @@ def update(self, tdi): self.post.input += self.g * (self.E - self.post.V) +class ExponentialSyn3(bp.dyn.TwoEndConn): + def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., + method='exp_auto'): + super(ExponentialSyn3, self).__init__(pre=pre, post=post, conn=conn) + self.check_pre_attrs('spike') + self.check_post_attrs('input', 'V') + + # parameters + self.E = E + self.tau = tau + self.delay = delay + self.g_max = g_max + self.pre2post = self.conn.require('pre2post') + + # variables + self.g = bm.Variable(bm.zeros(self.post.num)) + + # function + self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) + + def update(self, tdi): + self.g.value = self.integral(self.g, tdi['t'], tdi['dt']) + # Customized operator + # ------------------------------------------------------------------------------------------------------------ + post_val = bm.zeros(self.post.num) + self.g += event_sum2(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max) + # ------------------------------------------------------------------------------------------------------------ + self.post.input += self.g * (self.E - self.post.V) + + class EINet(bp.dyn.Network): def __init__(self, syn_class, scale=1.0, method='exp_auto', ): super(EINet, self).__init__() @@ -111,7 +142,7 @@ def __init__(self, syn_class, scale=1.0, method='exp_auto', ): class TestOpRegister(unittest.TestCase): def test_op(self): - fig, gs = bp.visualize.get_figure(1, 2, 4, 5) + fig, gs = bp.visualize.get_figure(1, 3, 4, 5) net = EINet(ExponentialSyn, scale=1., method='euler') runner = bp.dyn.DSRunner( @@ -133,5 +164,16 @@ def test_op(self): t, _ = runner2.run(100., eval_time=True) print(t) ax = fig.add_subplot(gs[0, 1]) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax, show=True) + bp.visualize.raster_plot(runner2.mon.ts, runner2.mon['E.spike'], ax=ax) + + net3 = EINet(ExponentialSyn3, scale=1., method='euler') + runner3 = bp.dyn.DSRunner( + net3, + inputs=[(net3.E.input, 20.), (net3.I.input, 20.)], + monitors={'E.spike': net3.E.spike}, + ) + t, _ = runner3.run(100., eval_time=True) + print(t) + ax = fig.add_subplot(gs[0, 2]) + bp.visualize.raster_plot(runner3.mon.ts, runner3.mon['E.spike'], ax=ax, show=True) plt.close()