Skip to content

Commit

Permalink
fix bugs in truncated_normal (#585)
Browse files Browse the repository at this point in the history
* fix bugs in `truncated_normal`

* Update random.py

use `lax.nextafter` to get small values.

* Revert "Update random.py"

This reverts commit a59be4b.

* Update random.py

use `lax.nextafter` for `minval` and `maxval`
  • Loading branch information
charlielam0615 committed Jan 4, 2024
1 parent 43b933e commit 5871c9c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,9 @@ def truncated_normal(self, lower, upper, size=None, loc=0., scale=1., dtype=floa
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
key = self.split_key() if key is None else _formalize_key(key)
out = jr.uniform(key, size, dtype, minval=2 * l - 1, maxval=2 * u - 1)
out = jr.uniform(key, size, dtype,
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype)))

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
Expand Down

0 comments on commit 5871c9c

Please sign in to comment.