Skip to content

Commit

Permalink
fix bugs in truncated_normal; add TruncatedNormal init. (#575)
Browse files Browse the repository at this point in the history
* fix bugs in  truncated_normal; add TruncatedNormal init.

* fix line delimiter bug.

* Add TruncatedNormal initializer to initialize.rst
  • Loading branch information
charlielam0615 committed Jan 2, 2024
1 parent 586cb0c commit d0988a0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
45 changes: 45 additions & 0 deletions brainpy/_src/initialize/random_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

__all__ = [
'Normal',
'TruncatedNormal',
'Uniform',
'VarianceScaling',
'KaimingUniform',
Expand Down Expand Up @@ -122,6 +123,50 @@ def __repr__(self):
return f'{self.__class__.__name__}(scale={self.scale}, rng={self.rng})'


class TruncatedNormal(_InterLayerInitializer):
"""Initialize weights with truncated normal distribution.
Parameters
----------
loc : float, ndarray
Mean ("centre") of the distribution before truncating. Note that
the mean of the truncated distribution will not be exactly equal
to ``loc``.
scale : float
The standard deviation of the normal distribution before truncating.
lower : float, ndarray
A float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with ``upper``.
upper : float, ndarray
A float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with ``lower``.
"""

def __init__(self, loc=0., scale=1., lower=None, upper=None, seed=None):
super(TruncatedNormal, self).__init__()
assert scale > 0, '`scale` must be positive.'
self.scale = scale
self.loc = loc
self.lower = lower
self.upper = upper
self.rng = bm.random.default_rng(seed, clone=False)

def __call__(self, shape, dtype=None):
shape = _format_shape(shape)
weights = self.rng.truncated_normal(
size=shape,
scale=self.scale,
lower=self.lower,
upper=self.upper,
loc=self.loc
)
return bm.asarray(weights, dtype=dtype)

def __repr__(self):
return f'{self.__class__.__name__}(loc={self.loc}, scale={self.scale}, lower={self.lower}, upper={self.upper}, rng={self.rng})'


class Gamma(_InterLayerInitializer):
"""Initialize weights with Gamma distribution.
Expand Down
4 changes: 4 additions & 0 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2858,6 +2858,10 @@ def truncated_normal(lower, upper, size=None, loc=0., scale=1., dtype=float, key
upper : float, ndarray
A float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with ``lower``.
loc : float, ndarray
Mean ("centre") of the distribution before truncating. Note that
the mean of the truncated distribution will not be exactly equal
to ``loc``.
size : optional, list of int, tuple of int
A tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
Expand Down
1 change: 1 addition & 0 deletions docs/apis/initialize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Random Initializers

Normal
Uniform
TruncatedNormal
VarianceScaling
KaimingUniform
KaimingNormal
Expand Down

0 comments on commit d0988a0

Please sign in to comment.