-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
random.py
250 lines (211 loc) · 9.07 KB
/
random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for pseudo-random number generation.
The :mod:`jax.random` package provides a number of routines for deterministic
generation of sequences of pseudorandom numbers.
Basic usage
-----------
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.PRNGKey(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
PRNG Keys
---------
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
be passed as a first argument.
The random state is described by two unsigned 32-bit integers that we call a **key**,
usually generated by the :py:func:`jax.random.PRNGKey` function::
>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
This key can then be used in any of JAX's random number generation routines::
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
Note that using a key does not modify it, so reusing the same key will lead to the same result::
>>> random.uniform(key)
Array(0.41845703, dtype=float32)
If you need a new random number, you can use :meth:`jax.random.split` to generate new subkeys::
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)
Advanced
--------
Design and Context
==================
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
See `docs/jep/263-prng.md <https://github.com/google/jax/blob/main/docs/jep/263-prng.md>`_
for more details.
To summarize, among other requirements, the JAX PRNG aims to:
1. ensure reproducibility,
2. parallelize well, both in terms of vectorization (generating array values)
and multi-replica, multi-core computation. In particular it should not use
sequencing constraints between random function calls.
Advanced RNG configuration
==========================
JAX provides several PRNG implementations (controlled by the
`jax_default_prng_impl` flag).
- **default**
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
- "rbg" uses ThreeFry for splitting, and XLA RBG for data generation.
- "unsafe_rbg" exists only for demonstration purposes, using RBG both for
splitting (using an untested made up algorithm) and generating.
The random streams generated by these experimental implementations haven't
been subject to any empirical randomness testing (e.g. Big Crush). The
random bits generated may change between JAX versions.
The possible reasons not use the default RNG are:
1. it may be slow to compile (specifically for Google Cloud TPUs)
2. it's slower to execute on TPUs
3. it doesn't support efficient automatic sharding / partitioning
Here is a short summary:
.. table::
:widths: auto
================================= ======== ========= === ========== ===== ============
Property Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
================================= ======== ========= === ========== ===== ============
Fastest on TPU ✅ ✅ ✅ ✅
efficiently shardable (w/ pjit) ✅ ✅ ✅
identical across shardings ✅ ✅ ✅ ✅
identical across CPU/GPU/TPU ✅ ✅
identical across JAX/XLA versions ✅ ✅
================================= ======== ========= === ========== ===== ============
(*): with jax_threefry_partitionable=1 set
(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set
The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less
robust/studied hash function for random value generation (but not for
`jax.random.split` or `jax.random.fold_in`), "unsafe_rbg" additionally uses less
robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
less safe in the sense that the quality of random streams it generates from
different keys is less well understood.
For more about jax_threefry_partitionable, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
"""
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.random import (
PRNGKey as PRNGKey,
ball as ball,
bernoulli as bernoulli,
beta as beta,
bits as bits,
categorical as categorical,
cauchy as cauchy,
chisquare as chisquare,
choice as choice,
default_prng_impl as _deprecated_default_prng_impl,
dirichlet as dirichlet,
double_sided_maxwell as double_sided_maxwell,
exponential as exponential,
f as f,
fold_in as fold_in,
gamma as gamma,
generalized_normal as generalized_normal,
geometric as geometric,
gumbel as gumbel,
key as key,
key_data as key_data,
key_impl as key_impl,
laplace as laplace,
logistic as logistic,
loggamma as loggamma,
lognormal as lognormal,
maxwell as maxwell,
multivariate_normal as multivariate_normal,
normal as normal,
orthogonal as orthogonal,
pareto as pareto,
permutation as permutation,
poisson as poisson,
rademacher as rademacher,
randint as randint,
random_gamma_p as random_gamma_p,
rayleigh as rayleigh,
rbg_key as _deprecated_rbg_key,
shuffle as shuffle,
split as split,
t as t,
threefry2x32_key as _deprecated_threefry2x32_key,
triangular as triangular,
truncated_normal as truncated_normal,
uniform as uniform,
unsafe_rbg_key as _deprecated_unsafe_rbg_key,
wald as wald,
weibull_min as weibull_min,
wrap_key_data as wrap_key_data,
)
from jax._src.prng import (
threefry_2x32 as _deprecated_threefry_2x32,
threefry2x32_p as _deprecated_threefry2x32_p,
)
# Deprecations
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray
_deprecations = {
# Added September 13, 2023:
"PRNGKeyArray": (
"jax.random.PRNGKeyArray is deprecated. Use jax.Array for annotations, and "
"jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of "
"typed prng keys.", _PRNGKeyArray
),
"KeyArray": (
"jax.random.KeyArray is deprecated. Use jax.Array for annotations, and "
"jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of "
"typed prng keys.", _PRNGKeyArray
),
# Added September 21, 2023
"threefry2x32_key": (
"jax.random.threefry2x32_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'threefry2x32')", _deprecated_threefry2x32_key),
"rbg_key": (
"jax.random.rbg_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'rbg')", _deprecated_rbg_key),
"unsafe_rbg_key": (
"jax.random.unsafe_rbg_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'unsafe_rbg')", _deprecated_unsafe_rbg_key),
# Added October 18, 2023
"threefry_2x32": ( # Note: this has been raising a FutureWarning since 2021
"jax.random.threefry_2x32 is deprecated. Use jax.extend.random.threefry_2x32.",
_deprecated_threefry_2x32,
),
"threefry2x32_p": (
"jax.random.threefry2x32_p is deprecated. Use jax.extend.random.threefry2x32_p.",
_deprecated_threefry2x32_p,
),
# Added October 19. 2023
"default_prng_impl": (
"jax.random.default_prng_impl is deprecated. Typical uses can be replaced by "
"jax.random.key_impl(key), jax.eval_shape(jax.random.key, 0).dtype, or similar.",
_deprecated_default_prng_impl,
),
}
import typing
if typing.TYPE_CHECKING:
PRNGKeyArray = typing.Any
KeyArray = typing.Any
default_prng_impl = _deprecated_default_prng_impl
threefry_2x32 = _deprecated_threefry_2x32
threefry2x32_p = _deprecated_threefry2x32_p
threefry2x32_key = _deprecated_threefry2x32_key
rbg_key = _deprecated_rbg_key
unsafe_rbg_key = _deprecated_unsafe_rbg_key
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing