Skip to content

Commit

Permalink
fix bugs in brainpy.math.random.truncated_normal (#574)
Browse files Browse the repository at this point in the history
* [math] fix bugs in `brainpy.math.random.truncated_normal`

* fix requirements

* fix

* fix init bug

* fix test

* update conv doc

* [random] change the algorithm of `truncated_normal` sampling method
  • Loading branch information
chaoming0625 committed Jan 2, 2024
1 parent 8320edc commit 256cb27
Show file tree
Hide file tree
Showing 9 changed files with 2,790 additions and 250 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -103,7 +102,6 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.6.post4"
__version__ = "2.4.6.post5"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
39 changes: 14 additions & 25 deletions brainpy/_src/dnn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from jax import lax

from brainpy import math as bm, tools, check
from brainpy import math as bm, tools
from brainpy._src.dnn.base import Layer
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
from brainpy._src.dnn.base import Layer

__all__ = [
'Conv1d', 'Conv2d', 'Conv3d',
Expand Down Expand Up @@ -488,9 +488,7 @@ def __init__(
mode: bm.Mode = None,
name: str = None,
):
super(_GeneralConvTranspose, self).__init__(name=name, mode=mode)

assert self.mode.is_parent_of(bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode)
super().__init__(name=name, mode=mode)

self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
Expand Down Expand Up @@ -586,22 +584,17 @@ def __init__(
"""Initializes the module.
Args:
output_channels: Number of output channels.
kernel_shape: The shape of the kernel. Either an integer or a sequence of
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The shape of the kernel. Either an integer or a sequence of
length 1.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 1. Defaults to 1.
output_shape: Output shape of the spatial dimensions of a transpose
convolution. Can be either an integer or an iterable of integers. If a
`None` value is given, a default shape is automatically calculated.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME``.
Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
with_bias: Whether to add a bias. By default, true.
w_init: Optional weight initialization. By default, truncated normal.
b_init: Optional bias initialization. By default, zeros.
data_format: The data format of the input. Either ``NWC`` or ``NCW``. By
default, ``NWC``.
w_initializer: Optional weight initialization. By default, truncated normal.
b_initializer: Optional bias initialization. By default, zeros.
mask: Optional mask of the weights.
name: The name of the module.
"""
Expand Down Expand Up @@ -648,6 +641,7 @@ def __init__(
"""Initializes the module.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The shape of the kernel. Either an integer or a sequence of
length 2.
Expand Down Expand Up @@ -704,22 +698,17 @@ def __init__(
"""Initializes the module.
Args:
output_channels: Number of output channels.
kernel_shape: The shape of the kernel. Either an integer or a sequence of
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The shape of the kernel. Either an integer or a sequence of
length 3.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 3. Defaults to 1.
output_shape: Output shape of the spatial dimensions of a transpose
convolution. Can be either an integer or an iterable of integers. If a
`None` value is given, a default shape is automatically calculated.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME``.
Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
with_bias: Whether to add a bias. By default, true.
w_init: Optional weight initialization. By default, truncated normal.
b_init: Optional bias initialization. By default, zeros.
data_format: The data format of the input. Either ``NDHWC`` or ``NCDHW``.
By default, ``NDHWC``.
w_initializer: Optional weight initialization. By default, truncated normal.
b_initializer: Optional bias initialization. By default, zeros.
mask: Optional mask of the weights.
name: The name of the module.
"""
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/initialize/random_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __call__(self, shape, dtype=None):
variance = (self.scale / denominator).astype(dtype)
if self.distribution == "truncated_normal":
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
res = self.rng.truncated_normal(-2, 2, shape, dtype) * stddev
res = self.rng.truncated_normal(-2, 2, shape).astype(dtype) * stddev
elif self.distribution == "normal":
res = self.rng.randn(*shape) * jnp.sqrt(variance).astype(dtype)
elif self.distribution == "uniform":
Expand Down

0 comments on commit 256cb27

Please sign in to comment.