-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
random.py
2643 lines (2197 loc) · 98 KB
/
random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from functools import partial
import math
from operator import index
import typing
from typing import Union
import warnings
import numpy as np
import jax.numpy as jnp
from jax import lax
from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import prng
from jax._src import xla_bridge
from jax._src.api import jit, vmap
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import canonicalize_axis
RealArray = ArrayLike
IntegerArray = ArrayLike
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeInt = DTypeLike
DTypeLikeUInt = DTypeLike
DTypeLikeFloat = DTypeLike
Shape = Sequence[int]
PRNGImpl = prng.PRNGImpl
KeyArray = Array
KeyArrayLike = ArrayLike
UINT_DTYPES = prng.UINT_DTYPES
### utilities
_lax_const = lax_internal._const
def _isnan(x: ArrayLike) -> Array:
return lax.ne(x, x)
def _check_prng_key(name: str, key: KeyArrayLike, *,
allow_batched: bool = False) -> tuple[KeyArray, bool]:
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
wrapped_key = key
wrapped = False
elif _arraylike(key):
# Call random_wrap here to surface errors for invalid keys.
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
wrapped = True
if config.legacy_prng_key.value == 'error':
raise ValueError(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".')
elif config.legacy_prng_key.value == 'warn':
warnings.warn(
'Legacy uint32 key array passed as key to jax.random function. '
'Please create keys using jax.random.key(). If use of a raw key array '
'was intended, set jax_legacy_prng_key="allow".', stacklevel=2)
elif config.enable_custom_prng.value:
# TODO(jakevdp): possibly remove this warning condition.
warnings.warn(
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')
if (not allow_batched) and wrapped_key.ndim:
raise ValueError(f"{name} accepts a single key, but was given a key array of"
f" shape {np.shape(key)} != (). Use jax.vmap for batching.")
return wrapped_key, wrapped
def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if config.enable_custom_prng.value:
return key
else:
return prng.random_unwrap(key) if was_wrapped else key
def _random_bits(key: KeyArray, bit_width: int, shape: Shape) -> Array:
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
return prng.random_bits(key, bit_width=bit_width, shape=shape)
# TODO(frostig,vanderplas): remove from public API altogether, or at
# least change to return after asserting presence in `prng.prngs`
def default_prng_impl():
"""Get the default PRNG implementation.
The default implementation is determined by ``config.jax_default_prng_impl``,
which specifies it by name.
"""
impl_name = config.default_prng_impl.value
assert impl_name in prng.prngs, impl_name
return prng.prngs[impl_name]
### key operations
# Wrapper around prng.PRNGImpl meant to hide its attributes from the
# public API.
# TODO(frostig,vanderplas): consider hiding all the attributes of
# PRNGImpl and directly returning it.
class PRNGSpec:
"""Specifies a PRNG key implementation."""
__slots__ = ['_impl']
_impl: PRNGImpl
def __init__(self, impl):
self._impl = impl
def __repr__(self) -> str:
return f"PRNGSpec({self._impl.name!r})"
def __str__(self) -> str:
return str(self._impl)
def __hash__(self) -> int:
return hash(self._impl)
def __eq__(self, other) -> bool:
return isinstance(other, PRNGSpec) and self._impl == other._impl
# TODO(frostig,vanderplas): remove PRNGImpl from this union when it's
# no longer in the public API because `default_prng_impl` is gone
PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl]
def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl:
if impl_spec is None:
return default_prng_impl()
if type(impl_spec) is PRNGImpl:
# TODO(frostig,vanderplas): remove this case once we remove
# default_prng_impl (and thus PRNGImpl) from the public API and
# PRNGImpl from jex. We won't need to handle these then, and we
# can remove them from the input type annotation above as well.
return impl_spec
if type(impl_spec) is PRNGSpec:
return impl_spec._impl
if type(impl_spec) is str:
if impl_spec in prng.prngs:
return prng.prngs[impl_spec]
keys_fmt = ', '.join(f'"{s}"' for s in prng.prngs.keys())
raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". '
f'Did you mean one of: {keys_fmt}?')
t = type(impl_spec)
raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.')
def _key(ctor_name: str, seed: int | ArrayLike,
impl_spec: PRNGSpecDesc | None) -> KeyArray:
impl = resolve_prng_impl(impl_spec)
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given a PRNG key.")
if np.ndim(seed):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given an array of "
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
return prng.random_seed(seed, impl=impl)
def key(seed: int | ArrayLike, *,
impl: PRNGSpecDesc | None = None) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array containing a key, whose dtype indicates
the default PRNG implementation, as determined by the optional
``impl`` argument or, otherwise, by the ``jax_default_prng_impl``
config flag at the time when this function is called.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
impl: optional string specifying the PRNG implementation (e.g.
``'threefry2x32'``)
Returns:
A scalar PRNG key array, consumable by random functions as well as ``split``
and ``fold_in``.
"""
return _key('key', seed, impl)
def PRNGKey(seed: int | ArrayLike, *,
impl: PRNGSpecDesc | None = None) -> KeyArray:
"""Create a legacy PRNG key given an integer seed.
This function produces old-style legacy PRNG keys, which are arrays
of dtype ``uint32``. For more, see the note in the `PRNG keys
<https://jax.readthedocs.io/en/latest/jax.random.html#prng-keys>`_
section. When possible, :func:`jax.random.key` is recommended for
use instead.
The resulting key does not carry a PRNG implementation. The returned
key matches the implementation given by the optional ``impl``
argument or, otherwise, determined by the ``jax_default_prng_impl``
config flag. Callers must ensure that same implementation is set as
the default when passing this key as an argument to other functions
(such as ``jax.random.split`` and ``jax.random.normal``).
Args:
seed: a 64- or 32-bit integer used as the value of the key.
impl: optional string specifying the PRNG implementation (e.g.
``'threefry2x32'``)
Returns:
A PRNG key, consumable by random functions as well as ``split``
and ``fold_in``.
"""
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
data: a 32-bit integer representing data to be folded into the key.
Returns:
A new PRNG key that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
key, wrapped = _check_prng_key("fold_in", key)
if np.ndim(data):
raise TypeError("fold_in accepts a scalar, but was given an array of"
f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
key_out = prng.random_fold_in(key, jnp.uint32(data))
return _return_prng_keys(wrapped, key_out)
def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
# to always enable_custom_prng
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if key.ndim:
raise TypeError("split accepts a single key, but was given a key array of "
f"shape {key.shape} != (). Use jax.vmap for batching.")
shape = tuple(num) if isinstance(num, Sequence) else (num,)
return prng.random_split(key, shape=shape)
def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
num: optional, a positive integer (or tuple of integers) indicating
the number (or shape) of keys to produce. Defaults to 2.
Returns:
An array-like object of `num` new PRNG keys.
"""
typed_key, wrapped = _check_prng_key("split", key)
return _return_prng_keys(wrapped, _split(typed_key, num))
def _key_impl(keys: KeyArray) -> PRNGImpl:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
return keys_dtype._impl
def key_impl(keys: KeyArrayLike) -> PRNGSpec:
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
return PRNGSpec(_key_impl(typed_keys))
def _key_data(keys: KeyArray) -> Array:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
return prng.random_unwrap(keys)
def key_data(keys: KeyArrayLike) -> Array:
"""Recover the bits of key data underlying a PRNG key array."""
keys, _ = _check_prng_key("key_data", keys, allow_batched=True)
return _key_data(keys)
def wrap_key_data(key_bits_array: Array, *,
impl: PRNGSpecDesc | None = None):
"""Wrap an array of key data bits into a PRNG key array.
Args:
key_bits_array: a ``uint32`` array with trailing shape corresponding to
the key shape of the PRNG implementation specified by ``impl``.
impl: optional, specifies a PRNG implementation, as in ``random.key``.
Returns:
A PRNG key array, whose dtype is a subdtype of ``jax.dtypes.prng_key``
corresponding to ``impl``, and whose shape equals the leading shape
of ``key_bits_array.shape`` up to the key bit dimensions.
"""
impl_obj = resolve_prng_impl(impl)
return prng.random_wrap(key_bits_array, impl=impl_obj)
### random samplers
def _check_shape(name: str, shape: Shape, *param_shapes) -> None:
if param_shapes:
shape_ = lax.broadcast_shapes(shape, *param_shapes) # type: ignore
if shape != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
raise ValueError(msg.format(name, shape_, shape))
def bits(key: KeyArrayLike,
shape: Shape = (),
dtype: DTypeLikeUInt | None = None) -> Array:
"""Sample uniform bits in the form of unsigned integers.
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ``()``.
dtype: optional, an unsigned integer dtype for the returned values (default
``uint64`` if ``jax_enable_x64`` is true, otherwise ``uint32``).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("bits", key)
if dtype is None:
dtype = dtypes.canonicalize_dtype(jnp.uint)
else:
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.unsignedinteger):
raise ValueError("dtype argument to `bits` must be an unsigned int dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
bit_width = dtype.itemsize * 8
return _random_bits(key, bit_width, shape)
def uniform(key: KeyArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float,
minval: RealArray = 0.,
maxval: RealArray = 1.) -> Array:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
minval: optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
maxval: optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("uniform", key)
dtypes.check_user_dtype_supported(dtype)
shape = core.canonicalize_shape(shape)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
return _uniform(key, shape, dtype, minval, maxval)
@partial(jit, static_argnums=(1, 2))
def _uniform(key, shape, dtype, minval, maxval) -> Array:
_check_shape("uniform", shape)
if not jnp.issubdtype(dtype, np.floating):
raise TypeError("uniform only accepts floating point dtypes.")
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
minval = lax.broadcast_to_rank(minval, len(shape))
maxval = lax.broadcast_to_rank(maxval, len(shape))
finfo = jnp.finfo(dtype)
nbits, nmant = finfo.bits, finfo.nmant
if nbits not in (8, 16, 32, 64):
raise TypeError(
f"uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot {dtype}."
)
rng_bits = nbits
if nmant < 8:
rng_bits = 8
bits = _random_bits(key, rng_bits, shape)
uint_dtype = UINT_DTYPES[nbits]
if rng_bits != nbits:
bits = lax.convert_element_type(bits, uint_dtype)
# The strategy here is to randomize only the mantissa bits with an exponent of
# 1 (after applying the bias), then shift and scale to the desired range. The
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
# equivalent float representations, which might not be true on all platforms.
float_bits = lax.bitwise_or(
lax.shift_right_logical(bits, np.array(rng_bits - nmant, uint_dtype)),
np.array(1.0, dtype).view(uint_dtype),
)
floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
return lax.max(
minval,
lax.reshape(floats * (maxval - minval) + minval, shape))
def randint(key: KeyArrayLike,
shape: Shape,
minval: IntegerArray,
maxval: IntegerArray,
dtype: DTypeLikeInt = int) -> Array:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
key: a PRNG key used as the random key.
shape: a tuple of nonnegative integers representing the shape.
minval: int or array of ints broadcast-compatible with ``shape``, a minimum
(inclusive) value for the range.
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
(exclusive) value for the range.
dtype: optional, an int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("randint", key)
dtypes.check_user_dtype_supported(dtype)
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
return _randint(key, shape, minval, maxval, dtype)
@partial(jit, static_argnums=(1, 4))
def _randint(key, shape, minval, maxval, dtype) -> Array:
_check_shape("randint", shape, np.shape(minval), np.shape(maxval))
if not jnp.issubdtype(dtype, np.integer):
raise TypeError(f"randint only accepts integer dtypes, got {dtype}")
check_arraylike("randint", minval, maxval)
minval = jnp.asarray(minval)
maxval = jnp.asarray(maxval)
if not jnp.issubdtype(minval.dtype, np.integer):
minval = minval.astype(int)
if not jnp.issubdtype(maxval.dtype, np.integer):
maxval = maxval.astype(int)
# Flag where maxval is greater than the maximum value of dtype
# in order to handle cases like randint(key, shape, 0, 256, 'uint8')
maxval_out_of_range = lax.gt(
maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), maxval.dtype))
minval = _convert_and_clip_integer(minval, dtype)
maxval = _convert_and_clip_integer(maxval, dtype)
minval = lax.broadcast_to_rank(minval, len(shape))
maxval = lax.broadcast_to_rank(maxval, len(shape))
nbits = jnp.iinfo(dtype).bits
if nbits not in (8, 16, 32, 64):
raise TypeError(f"randint only accepts 8-, 16-, 32-, or 64-bit dtypes, got {dtype}")
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
# We generate double the number of random bits required by the dtype so as to
# reduce that bias.
k1, k2 = _split(key)
rbits = lambda key: _random_bits(key, nbits, shape)
higher_bits, lower_bits = rbits(k1), rbits(k2)
unsigned_dtype = UINT_DTYPES[nbits]
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
# Ensure that span=1 when maxval <= minval, so minval is always returned;
# https://github.com/jax-ml/jax/issues/222
span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
# When maxval is out of range, the span has to be one larger.
# If span is already the maximum representable value, this will wrap to zero,
# causing remainders below to have no effect, which is the correct semantics.
span = lax.select(
maxval_out_of_range & (maxval > minval),
lax.add(span, _lax_const(span, 1)),
span)
# To compute a remainder operation on an integer that might have twice as many
# bits as we can represent in the native unsigned dtype, we compute a
# multiplier equal to 2**nbits % span. To avoid overflow, we use the identity:
# (a * b) % N = [(a % N) * (b % N)] % N
multiplier = lax.rem(_lax_const(span, 2 ** (nbits // 2)), span)
multiplier = lax.rem(lax.mul(multiplier, multiplier), span)
random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier),
lax.rem(lower_bits, span))
random_offset = lax.rem(random_offset, span)
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
def permutation(key: KeyArrayLike,
x: int | ArrayLike,
axis: int = 0,
independent: bool = False) -> Array:
"""Returns a randomly permuted array or range.
Args:
key: a PRNG key used as the random key.
x: int or array. If x is an integer, randomly shuffle np.arange(x).
If x is an array, randomly shuffle its elements.
axis: int, optional. The axis which x is shuffled along. Default is 0.
independent: bool, optional. If set to True, each individual vector along
the given axis is shuffled independently. Default is False.
Returns:
A shuffled version of x or array range
"""
key, _ = _check_prng_key("permutation", key)
check_arraylike("permutation", x)
axis = canonicalize_axis(axis, np.ndim(x) or 1)
if not np.ndim(x):
if not np.issubdtype(lax.dtype(x), np.integer):
raise TypeError("x must be an integer or at least 1-dimensional")
r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()')
return _shuffle(key, jnp.arange(r), axis)
if independent or np.ndim(x) == 1:
return _shuffle(key, x, axis)
ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr]
return jnp.take(x, ind, axis, unique_indices=True)
@partial(jit, static_argnums=(2,))
def _shuffle(key, x, axis) -> Array:
# On parallel architectures, Fisher-Yates is more expensive than doing
# multiple sorts. This algorithm is based on one developed and analyzed by
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
# may have collisions. If we repeat the process, using fresh 32bit keys for
# each sort, then whenever all pairs of elements have been assigned distinct
# keys at some iteration (or equivalently when the strings formed by
# concatenating the successive keys for each element are all distinct) then we
# are guaranteed to have a perfect sample (assuming that either the sort is
# stable or that any bias is not value-dependent). Since checking uniqueness
# at runtime may be expensive, we use a heuristic static stop criterion
# developed by tjablin@. See tensorflow/compiler/tf2xla/random_ops.cc for more
# info, and for the original implementation of this algorithm. See also
# Section 2 of http://people.csail.mit.edu/costis/6896sp11/lec5s.pdf for
# another analysis (where the keys are generated one bit at a time).
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
uint32max = jnp.iinfo(np.uint32).max
if not core.is_constant_dim(x.size):
raise NotImplementedError(
"shape polymorphism for `permutation` or `shuffle`"
f" for arrays of non-constant size: {x.size}")
num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max)))
for _ in range(num_rounds):
key, subkey = _split(key)
sort_keys = _random_bits(subkey, 32, x.shape)
_, x = lax.sort_key_val(sort_keys, x, axis)
return x
def choice(key: KeyArrayLike,
a: int | ArrayLike,
shape: Shape = (),
replace: bool = True,
p: RealArray | None = None,
axis: int = 0) -> Array:
"""Generates a random sample from a given array.
.. warning::
If ``p`` has fewer non-zero elements than the requested number of samples,
as specified in ``shape``, and ``replace=False``, the output of this
function is ill-defined. Please make sure to use appropriate inputs.
Args:
key: a PRNG key used as the random key.
a : array or int. If an ndarray, a random sample is generated from
its elements. If an int, the random sample is generated as if a were
arange(a).
shape : tuple of ints, optional. Output shape. If the given shape is,
e.g., ``(m, n)``, then ``m * n`` samples are drawn. Default is (),
in which case a single value is returned.
replace : boolean. Whether the sample is with or without replacement.
Default is True.
p : 1-D array-like, The probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all
entries in a.
axis: int, optional. The axis along which the selection is performed.
The default, 0, selects by row.
Returns:
An array of shape `shape` containing samples from `a`.
"""
key, _ = _check_prng_key("choice", key)
if not isinstance(shape, Sequence):
raise TypeError("shape argument of jax.random.choice must be a sequence, "
f"got {shape}")
check_arraylike("choice", a)
arr = jnp.asarray(a)
if arr.ndim == 0:
n_inputs = core.concrete_or_error(int, a, "The error occurred in jax.random.choice()")
else:
axis = canonicalize_axis(axis, arr.ndim)
n_inputs = arr.shape[axis]
n_draws = math.prod(shape)
if n_draws == 0:
return jnp.zeros(shape, dtype=arr.dtype)
if n_inputs <= 0:
raise ValueError("a must be greater than 0 unless no samples are taken")
if not replace and n_draws > n_inputs:
raise ValueError(
f"Cannot take a larger sample (size {n_draws}) than "
f"population (size {n_inputs}) when 'replace=False'")
if p is None:
if replace:
ind = randint(key, shape, 0, n_inputs)
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
else:
slices = (slice(None),) * axis + (slice(n_draws),)
result = permutation(key, n_inputs if arr.ndim == 0 else arr, axis)[slices]
else:
check_arraylike("choice", p)
p_arr, = promote_dtypes_inexact(p)
if p_arr.shape != (n_inputs,):
raise ValueError(
"p must be None or a 1D vector with the same size as a.shape[axis]. "
f"p has shape {p_arr.shape} and a.shape[axis] is {n_inputs}.")
if replace:
p_cuml = jnp.cumsum(p_arr)
r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype))
ind = jnp.searchsorted(p_cuml, r).astype(int)
else:
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
ind = jnp.argsort(g)[:n_draws]
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
return result.reshape(shape if arr.ndim == 0 else
arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:])
def normal(key: KeyArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample standard normal random values with given shape and float dtype.
The values are returned according to the probability density function:
.. math::
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
on the domain :math:`-\infty < x < \infty`
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("normal", key)
shape = core.canonicalize_shape(shape)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
return _normal(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _normal(key, shape, dtype) -> Array:
if dtypes.issubdtype(dtype, np.complexfloating):
sqrt2 = np.array(np.sqrt(2), dtype)
key_re, key_im = _split(key)
real_dtype = np.array(0, dtype).real.dtype
_re = _normal_real(key_re, shape, real_dtype).astype(dtype)
_im = _normal_real(key_im, shape, real_dtype).astype(dtype)
return (_re + 1j * _im) / sqrt2
else:
return _normal_real(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _normal_real(key, shape, dtype) -> Array:
_check_shape("normal", shape)
lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype)
hi = np.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi)
return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u))
def multivariate_normal(key: KeyArrayLike,
mean: RealArray,
cov: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat | None = None,
method: str = 'cholesky') -> Array:
r"""Sample multivariate normal random values with given mean and covariance.
The values are returned according to the probability density function:
.. math::
f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}
where :math:`k` is the dimension, :math:`\mu` is the mean (given by ``mean``) and
:math:`\Sigma` is the covariance matrix (given by ``cov``).
Args:
key: a PRNG key used as the random key.
mean: a mean vector of shape ``(..., n)``.
cov: a positive definite covariance matrix of shape ``(..., n, n)``. The
batch shape ``...`` must be broadcast-compatible with that of ``mean``.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis. Must be broadcast-compatible with ``mean.shape[:-1]`` and
``cov.shape[:-2]``. The default (None) produces a result batch shape by
broadcasting together the batch shapes of ``mean`` and ``cov``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
method: optional, a method to compute the factor of ``cov``.
Must be one of 'svd', 'eigh', and 'cholesky'. Default 'cholesky'. For
singular covariance matrices, use 'svd' or 'eigh'.
Returns:
A random array with the specified dtype and shape given by
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
"""
key, _ = _check_prng_key("multivariate_normal", key)
dtypes.check_user_dtype_supported(dtype)
mean, cov = promote_dtypes_inexact(mean, cov)
if method not in {'svd', 'eigh', 'cholesky'}:
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
if dtype is None:
dtype = mean.dtype
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `multivariate_normal` must be a float "
f"dtype, got {dtype}")
if shape is not None:
shape = core.canonicalize_shape(shape)
return _multivariate_normal(key, mean, cov, shape, dtype, method)
@partial(jit, static_argnums=(3, 4, 5))
def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
if not np.ndim(mean) >= 1:
msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}"
raise ValueError(msg.format(np.ndim(mean)))
if not np.ndim(cov) >= 2:
msg = "multivariate_normal requires cov.ndim >= 2, got cov.ndim == {}"
raise ValueError(msg.format(np.ndim(cov)))
n = mean.shape[-1]
if np.shape(cov)[-2:] != (n, n):
msg = ("multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
"but got cov.shape == {shape}.")
raise ValueError(msg.format(n=n, shape=np.shape(cov)))
if shape is None:
shape = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
else:
_check_shape("normal", shape, mean.shape[:-1], cov.shape[:-2])
if method == 'svd':
(u, s, _) = svd(cov)
factor = u * jnp.sqrt(s[..., None, :])
elif method == 'eigh':
(w, v) = eigh(cov)
factor = v * jnp.sqrt(w[..., None, :])
else: # 'cholesky'
factor = cholesky(cov)
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
with config.numpy_rank_promotion('allow'):
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return result
def truncated_normal(key: KeyArrayLike,
lower: RealArray,
upper: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample truncated standard normal random values with given shape and dtype.
The values are returned according to the probability density function:
.. math::
f(x) \propto e^{-x^2/2}
on the domain :math:`\rm{lower} < x < \rm{upper}`.
Args:
key: a PRNG key used as the random key.
lower: a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with ``upper``.
upper: a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with ``lower``.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``lower`` and ``upper``. The
default (None) produces a result shape by broadcasting ``lower`` and
``upper``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
Returns values in the open interval ``(lower, upper)``.
"""
if shape is not None:
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("truncated_normal", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
return _truncated_normal(key, lower, upper, shape, dtype)
@partial(jit, static_argnums=(3, 4))
def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(lower), np.shape(upper))
else:
_check_shape("truncated_normal", shape, np.shape(lower), np.shape(upper))
sqrt2 = np.array(np.sqrt(2), dtype)
lower = lax.convert_element_type(lower, dtype)
upper = lax.convert_element_type(upper, dtype)
a = lax.erf(lower / sqrt2)
b = lax.erf(upper / sqrt2)
if not jnp.issubdtype(dtype, np.floating):
raise TypeError("truncated_normal only accepts floating point dtypes.")
u = uniform(key, shape, dtype, minval=a, maxval=b)
out = sqrt2 * lax.erf_inv(u)
# Clamp the value to the open interval (lower, upper) to make sure that
# rounding (or if we chose `a` for `u`) doesn't push us outside of the range.
return jnp.clip(
out,
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
def bernoulli(key: KeyArrayLike,
p: RealArray = np.float32(0.5),
shape: Shape | None = None) -> Array:
r"""Sample Bernoulli random values with given shape and mean.
The values are distributed according to the probability mass function:
.. math::
f(k; p) = p^k(1 - p)^{1 - k}
where :math:`k \in \{0, 1\}` and :math:`0 \le p \le 1`.
Args:
key: a PRNG key used as the random key.
p: optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with ``shape``. Default 0.5.
shape: optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with ``p.shape``. The default (None)
produces a result shape equal to ``p.shape``.
Returns:
A random array with boolean dtype and shape given by ``shape`` if ``shape``
is not None, or else ``p.shape``.
"""
if shape is not None:
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("bernoulli", key)
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
if not jnp.issubdtype(dtype, np.floating):
msg = "bernoulli probability `p` must have a floating dtype, got {}."
raise TypeError(msg.format(dtype))
p = lax.convert_element_type(p, dtype)
return _bernoulli(key, p, shape)
@partial(jit, static_argnums=(2,))
def _bernoulli(key, p, shape) -> Array:
if shape is None:
# TODO: Use the named part of `p` as well
shape = np.shape(p)
else:
_check_shape("bernoulli", shape, np.shape(p))
return uniform(key, shape, lax.dtype(p)) < p
def beta(key: KeyArrayLike,
a: RealArray,
b: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Beta random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1}
on the domain :math:`0 \le x \le 1`.
Args:
key: a PRNG key used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the first parameter "alpha".
b: a float or array of floats broadcast-compatible with ``shape``
representing the second parameter "beta".
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``a`` and ``b``. The default
(None) produces a result shape by broadcasting ``a`` and ``b``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
"""
key, _ = _check_prng_key("beta", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `beta` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _beta(key, a, b, shape, dtype)
def _beta(key, a, b, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(a), np.shape(b))
else:
_check_shape("beta", shape, np.shape(a), np.shape(b))
a = lax.convert_element_type(a, dtype)
b = lax.convert_element_type(b, dtype)
key_a, key_b = _split(key)
a = jnp.broadcast_to(a, shape)
b = jnp.broadcast_to(b, shape)
log_gamma_a = loggamma(key_a, a, shape, dtype)
log_gamma_b = loggamma(key_b, b, shape, dtype)
# Compute gamma_a / (gamma_a + gamma_b) without losing precision.
log_max = lax.max(log_gamma_a, log_gamma_b)
gamma_a_scaled = jnp.exp(log_gamma_a - log_max)
gamma_b_scaled = jnp.exp(log_gamma_b - log_max)
return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled)
def cauchy(key: KeyArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Cauchy random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) \propto \frac{1}{x^2 + 1}
on the domain :math:`-\infty < x < \infty`
Args:
key: a PRNG key used as the random key.