-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
random_test.py
1459 lines (1219 loc) · 57.1 KB
/
random_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import enum
from functools import partial
import math
from unittest import skipIf
from typing import Any, NamedTuple
import zlib
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax import numpy as jnp
from jax import random
from jax import tree_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import vmap
from jax.interpreters import xla
from jax._src import random as jax_random
from jax._src import prng as prng_internal
config.parse_flags_with_absl()
PRNG_IMPLS = list(prng_internal.prngs.items())
class OnX64(enum.Enum):
ALSO = enum.auto()
SKIP = enum.auto()
ONLY = enum.auto()
class RandomValuesCase(NamedTuple):
name: str
prng_impl: str
shape: tuple[int, ...]
dtype: Any
params: dict
expected: np.ndarray
on_x64: OnX64 = OnX64.ALSO
atol: float | None = None
rtol: float | None = None
def _testname(self):
if self.dtype is None:
shape_dtype = str(self.shape)
else:
shape_dtype = jtu.format_shape_dtype_string(self.shape, self.dtype)
name = f"_{self.name}_{self.prng_impl}_{shape_dtype}"
if self.params:
fmt = lambda x: str(x).replace(' ', '').replace('\n', '')
name += "_" + "_".join(f"{k}={fmt(v)}" for k, v in self.params.items())
return name
def _seed(self):
# Generate a deterministic unique 32-bit seed given the name and prng impl
return zlib.adler32((self.name + self.prng_impl).encode())
_RANDOM_VALUES_CASES = [
# TODO(jakevdp) add coverage for other distributions.
RandomValuesCase("bernoulli", "threefry2x32", (5,), None, {'p': 0.5},
np.array([False, True, True, True, False]), on_x64=OnX64.SKIP),
RandomValuesCase("bernoulli", "rbg", (5,), None, {'p': 0.5},
np.array([True, True, True, True, True]), on_x64=OnX64.SKIP),
RandomValuesCase("beta", "threefry2x32", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.13259 , 0.824893, 0.948363, 0.964155, 0.235448], dtype='float32')),
RandomValuesCase("beta", "rbg", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.93215 , 0.833959, 0.121902, 0.270003, 0.429541], dtype='float32')),
# TODO(frostig,jakevdp) add coverage for non-threefry bits
RandomValuesCase("bits", "threefry2x32", (5,), np.uint8, {},
np.array([10, 158, 82, 54, 158], dtype='uint8')),
RandomValuesCase("bits", "threefry2x32", (5,), np.uint16, {},
np.array([6738, 38161, 50695, 57337, 61600], dtype='uint16')),
RandomValuesCase("bits", "threefry2x32", (5,), np.uint32, {},
np.array([1978747883, 4134381225, 3628107870, 689687174, 2788938207], dtype='uint32')),
RandomValuesCase("bits", "threefry2x32", (5,), np.uint64, {},
np.array([17649965731882839947, 1415307058040849897, 8282622628079774249,
14024425113645909402, 2012979996110532418], dtype='uint64'),
on_x64=OnX64.ONLY),
RandomValuesCase("cauchy", "threefry2x32", (5,), np.float32, {},
np.array([ -0.088416, -10.169713, 3.49677, -1.18056, 0.34556], dtype='float32'), rtol=1E-5),
RandomValuesCase("cauchy", "rbg", (5,), np.float32, {},
np.array([0.008389, 0.108793, -0.031826, -0.01876, 0.963218], dtype='float32')),
RandomValuesCase("dirichlet", "threefry2x32", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.003128, 0.009694, 0.987178], [0.025938, 0.479091, 0.494971]], dtype='float32')),
RandomValuesCase("dirichlet", "rbg", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.080742, 0.525493, 0.393765], [0.006837, 0.804796, 0.188366]], dtype='float32')),
RandomValuesCase("double_sided_maxwell", "threefry2x32", (5,), np.float32, {"loc": 1, "scale": 2},
np.array([-2.408914, -3.370437, 3.235352, -0.907734, -1.708732], dtype='float32'), on_x64=OnX64.SKIP),
RandomValuesCase("double_sided_maxwell", "rbg", (5,), np.float32, {"loc": 1, "scale": 2},
np.array([4.957495, 3.003086, 5.33935, 2.942878, -1.203524], dtype='float32'), on_x64=OnX64.SKIP),
RandomValuesCase("exponential", "threefry2x32", (5,), np.float32, {},
np.array([0.526067, 0.043046, 0.039932, 0.46427 , 0.123886], dtype='float32')),
RandomValuesCase("exponential", "rbg", (5,), np.float32, {},
np.array([0.231303, 0.684814, 0.017181, 0.089552, 0.345087], dtype='float32')),
RandomValuesCase("gamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([0.824221, 1.724476, 0.502882, 5.386132, 0.685543], dtype='float32')),
RandomValuesCase("gamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([0.994946, 0.519941, 1.754347, 0.479223, 1.16932 ], dtype='float32')),
RandomValuesCase("gumbel", "threefry2x32", (5,), np.float32, {},
np.array([2.06701, 0.911726, 0.145736, 0.185427, -0.00711], dtype='float32')),
RandomValuesCase("gumbel", "rbg", (5,), np.float32, {},
np.array([-0.099308, -1.123809, 1.007618, -0.077968, 3.421349], dtype='float32')),
RandomValuesCase("laplace", "threefry2x32", (5,), np.float32, {},
np.array([0.578939, -0.204902, 0.555733, 0.911053, -0.96456], dtype='float32')),
RandomValuesCase("laplace", "rbg", (5,), np.float32, {},
np.array([-2.970422, 1.925082, -0.757887, -4.444797, 0.561983], dtype='float32')),
RandomValuesCase("loggamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([ 0.240559, -3.575443, -0.450946, -2.161372, -2.943277], dtype='float32')),
RandomValuesCase("loggamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([-0.107021, -0.809968, -0.25546 , -1.212273, -1.946579], dtype='float32')),
RandomValuesCase("logistic", "threefry2x32", (5,), np.float32, {},
np.array([0.19611, -1.709053, -0.274093, -0.208322, -1.675489], dtype='float32')),
RandomValuesCase("logistic", "rbg", (5,), np.float32, {},
np.array([-0.234923, -0.545184, 0.700992, -0.708609, -1.474884], dtype='float32')),
RandomValuesCase("maxwell", "threefry2x32", (5,), np.float32, {},
np.array([3.070779, 0.908479, 1.521317, 0.875551, 1.306137], dtype='float32')),
RandomValuesCase("maxwell", "rbg", (5,), np.float32, {},
np.array([2.048746, 0.470027, 1.053105, 1.01969, 2.710645], dtype='float32')),
RandomValuesCase("multivariate_normal", "threefry2x32", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
np.array([[ 1.067826, 1.215599, 0.234166], [-0.237534, 1.32591, 1.413987]], dtype='float32'), on_x64=OnX64.SKIP),
RandomValuesCase("multivariate_normal", "rbg", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
np.array([[-0.036897, 0.770969, 0.756959], [1.755091, 2.350553, 0.627142]], dtype='float32'), on_x64=OnX64.SKIP),
RandomValuesCase("normal", "threefry2x32", (5,), np.float32, {},
np.array([-1.173234, -1.511662, 0.070593, -0.099764, 1.052845], dtype='float32')),
RandomValuesCase("normal", "rbg", (5,), np.float32, {},
np.array([-0.479658, 0.565747, -1.065106, 0.997962, -1.478002], dtype='float32')),
RandomValuesCase("pareto", "threefry2x32", (5,), np.float32, {"b": 0.5},
np.array([2.751398, 1.281863, 87.85448, 1.254542, 2.824487], dtype='float32')),
RandomValuesCase("pareto", "rbg", (5,), np.float32, {"b": 0.5},
np.array([1.241914, 1.521864, 5.615384, 1911.502, 1.816702], dtype='float32')),
RandomValuesCase("poisson", "threefry2x32", (5,), np.int32, {"lam": 5},
np.array([7, 3, 6, 11, 6], dtype='int32')),
# Note: poisson not implemented for rbg sampler.
RandomValuesCase("rademacher", "threefry2x32", (5,), np.int32, {},
np.array([-1, -1, -1, -1, 1], dtype='int32'), on_x64=OnX64.SKIP),
RandomValuesCase("rademacher", "rbg", (5,), np.int32, {},
np.array([1, 1, 1, -1, -1], dtype='int32'), on_x64=OnX64.SKIP),
RandomValuesCase("randint", "threefry2x32", (5,), np.int32, {"minval": 0, "maxval": 10},
np.array([0, 5, 7, 7, 5], dtype='int32')),
RandomValuesCase("randint", "rbg", (5,), np.int32, {"minval": 0, "maxval": 10},
np.array([7, 1, 8, 5, 8], dtype='int32')),
RandomValuesCase("truncated_normal", "threefry2x32", (5,), np.float32, {"lower": 0, "upper": 2},
np.array([0.582807, 1.709771, 0.159513, 0.861376, 0.36148], dtype='float32')),
RandomValuesCase("truncated_normal", "rbg", (5,), np.float32, {"lower": 0, "upper": 2},
np.array([0.770068, 1.516464, 0.710406, 0.762801, 1.305324], dtype='float32')),
RandomValuesCase("uniform", "threefry2x32", (5,), np.float32, {},
np.array([0.298671, 0.073213, 0.873356, 0.260549, 0.412797], dtype='float32')),
RandomValuesCase("uniform", "rbg", (5,), np.float32, {},
np.array([0.477161, 0.706508, 0.656261, 0.432547, 0.057772], dtype='float32')),
RandomValuesCase("weibull_min", "threefry2x32", (5,), np.float32, {"scale": 1, "concentration": 1},
np.array([1.605863, 0.841809, 0.224218, 0.4826 , 0.027901], dtype='float32')),
RandomValuesCase("weibull_min", "rbg", (5,), np.float32, {"scale": 1, "concentration": 1},
np.array([1.370903, 0.086532, 0.061688, 3.407599, 0.215077], dtype='float32')),
]
KEY_CTORS = [random.key, random.PRNGKey]
@jtu.with_config(jax_legacy_prng_key='allow')
class PrngTest(jtu.JaxTestCase):
def check_key_has_impl(self, key, impl):
if jnp.issubdtype(key.dtype, dtypes.prng_key):
self.assertIs(key._impl, impl)
else:
self.assertEqual(key.dtype, jnp.dtype('uint32'))
self.assertEqual(key.shape, impl.key_shape)
def test_config_prngs_registered(self):
# TODO(frostig): pull these string values somehow from the
# jax_default_prng_impl config enum state definition directly,
# rather than copying manually here?
self.assertIn('threefry2x32', prng_internal.prngs)
self.assertIn('rbg', prng_internal.prngs)
self.assertIn('unsafe_rbg', prng_internal.prngs)
def testThreefry2x32(self):
# We test the hash by comparing to known values provided in the test code of
# the original reference implementation of Threefry. For the values, see
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
def result_to_hex(result):
return tuple(hex(x.copy()).rstrip("L") for x in result)
expected = ("0x6b200159", "0x99ba4efe")
result = prng_internal.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0x1cb996fc", "0xbb002be7")
u32_max = np.iinfo(np.uint32).max
result = prng_internal.threefry_2x32(np.uint32([u32_max, u32_max]), np.uint32([u32_max, u32_max]))
self.assertEqual(expected, result_to_hex(result))
expected = ("0xc4923a9c", "0x483df7a0")
result = prng_internal.threefry_2x32(
np.uint32([0x13198a2e, 0x03707344]),
np.uint32([0x243f6a88, 0x85a308d3]))
self.assertEqual(expected, result_to_hex(result))
def testThreefry2x32Large(self):
n = 10000000
result = prng_internal.threefry_2x32(
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.concatenate([
jnp.full((n,), 0x243f6a88, jnp.uint32),
jnp.full((n,), 0x85a308d3, jnp.uint32)
]))
np.testing.assert_equal(result[:n], np.full((n,), 0xc4923a9c, dtype=np.uint32))
np.testing.assert_equal(result[n:], np.full((n,), 0x483df7a0, dtype=np.uint32))
def testThreefry2x32Empty(self):
# Regression test for an op-by-op crash for empty arrays in CUDA mode.
with jax.disable_jit():
result = prng_internal.threefry_2x32(
(np.uint32(0x13198a2e), np.uint32(0x03707344)),
jnp.ones((10, 0,), jnp.uint32))
np.testing.assert_equal(result, np.zeros((10, 0,), dtype=np.uint32))
def testNoOpByOpUnderHash(self):
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
_ = prng_internal.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32))
finally:
xla.apply_primitive = apply_primitive
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testRngRandomBits(self, make_key):
# Test specific outputs to ensure consistent random values between JAX versions.
def random_bits(key, width, shape):
# TODO(frostig): Use random.bits, as in:
#
# def random_bits(key, width, shape):
# dtype = jnp.dtype(f'uint{width}')
# return jax.random.bits(key, shape, dtype)
#
# Doing so doesn't work in width 64 at present due to
# normalization in random.bits.
key, _ = jax_random._check_prng_key(key)
return jax_random._random_bits(key, width, shape)
key = make_key(1701)
bits8 = random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8)
bits16 = random_bits(key, 16, (3,))
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16)
bits32 = random_bits(key, 32, (3,))
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
self.assertArraysEqual(bits32, expected32)
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random_bits(key, 64, (3,))
if config.enable_x64.value:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
else:
expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32)
self.assertArraysEqual(bits64, expected64)
@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS],
make_key=KEY_CTORS)
def testRngRandomBitsShapeDtype(self, prng_name, make_key):
# Like testRngRandomBits, but only meant to exercise random_bits
# on every PRNG implementation. Instead of values, only checks
# that shapes/dtypes are as expected.
def random_bits(key, width, shape):
dtype = jnp.dtype(f'uint{width}')
return jax.random.bits(key, shape, dtype)
with jax.default_prng_impl(prng_name):
key = make_key(1701)
bits8 = random_bits(key, 8, (3,))
self.assertEqual(bits8.shape, (3,))
self.assertEqual(bits8.dtype, np.dtype('uint8'))
bits16 = random_bits(key, 16, (3,))
self.assertEqual(bits16.shape, (3,))
self.assertEqual(bits16.dtype, np.dtype('uint16'))
bits32 = random_bits(key, 32, (3,))
self.assertEqual(bits32.shape, (3,))
self.assertEqual(bits32.dtype, np.dtype('uint32'))
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random_bits(key, 64, (3,))
expected_dtype = np.dtype('uint64' if config.enable_x64.value else 'uint32')
self.assertEqual(bits64.shape, (3,))
self.assertEqual(bits64.dtype, expected_dtype)
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testRngRandomBitsViewProperty(self, make_key):
# TODO: add 64-bit if it ever supports this property.
# TODO: will this property hold across endian-ness?
def random_bits(key, width, shape):
dtype = jnp.dtype(f'uint{width}')
return jax.random.bits(key, shape, dtype)
N = 10
key = make_key(1701)
nbits = [8, 16, 32]
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])
@jtu.sample_product(case=_RANDOM_VALUES_CASES, make_key=KEY_CTORS)
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testRandomDistributionValues(self, case, make_key):
"""
Tests values output by various distributions. This will catch any
unintentional changes to the implementations that could result in
different random sequences.
Any refactoring of random distributions that leads to non-trivial
differences in this test should follow the procedure outlined at
https://jax.readthedocs.io/en/latest/api_compatibility.html#numerics-and-randomness
This includes:
* Announcing the change in the CHANGELOG.md
* Considering adding a flag that reverts the new behavior, made
available for a deprecation window's amount of time.
"""
if config.enable_x64.value:
self.skipTest("test produces different values when jax_enable_x64=True")
if not config.enable_x64.value:
self.skipTest("test only valid when jax_enable_x64=True")
with jax.default_prng_impl(case.prng_impl):
func = getattr(random, case.name)
key = make_key(case._seed())
if case.dtype:
actual = func(key, **case.params, shape=case.shape, dtype=case.dtype)
else:
actual = func(key, **case.params, shape=case.shape)
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def testPRNGValues(self, make_key):
# Test to ensure consistent random values between JAX versions
k = make_key(0)
self.assertEqual(random.randint(k, (3, 3), 0, 8).dtype,
dtypes.canonicalize_dtype(jnp.int_))
if config.enable_x64.value:
self.assertAllClose(
random.randint(k, (3, 3), 0, 8, dtype='int64'),
np.array([[7, 2, 6],
[2, 1, 0],
[6, 7, 7]], dtype='int64'))
self.assertAllClose(
random.randint(k, (3, 3), 0, 8, dtype='int32'),
np.array([[2, 1, 3],
[6, 1, 5],
[6, 3, 4]], dtype='int32'))
self.assertAllClose(
random.key_data(random.split(k, 4)),
np.array([[2285895361, 1501764800],
[1518642379, 4090693311],
[ 433833334, 4221794875],
[ 839183663, 3740430601]], dtype='uint32'))
self.assertAllClose(
random.key_data(random.fold_in(k, 4)),
np.array([2285895361, 433833334], dtype='uint32'))
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_random_seed_offset(self, make_key):
k1 = make_key(17)
with config.random_seed_offset(3):
k2 = make_key(17)
eq = k1 == k2 if k2.ndim == 0 else all(k1 == k2)
self.assertFalse(eq)
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_random_bits_error(self, make_key):
msg = 'dtype argument .* must be an unsigned int dtype'
with self.assertRaisesRegex(ValueError, msg):
random.bits(make_key(0), (3, 4), np.dtype('int8'))
with self.assertRaisesRegex(ValueError, msg):
random.bits(make_key(0), (3, 4), np.dtype('float16'))
@skipIf(not config.threefry_partitionable.value, 'enable after upgrade')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_threefry_split_fold_in_symmetry(self, make_key):
with jax.default_prng_impl('threefry2x32'):
key = make_key(72)
f1, f2, f3 = (random.fold_in(key, i) for i in range(3))
s1, s2, s3 = random.split(key, 3)
f1, f2, f3 = map(random.key_data, [f1, f2, f3])
s1, s2, s3 = map(random.key_data, [s1, s2, s3])
self.assertArraysEqual(f1, s1)
self.assertArraysEqual(f2, s2)
self.assertArraysEqual(f3, s3)
@skipIf(not config.threefry_partitionable.value, 'enable after upgrade')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
# See https://github.com/google/jax/issues/7708
with jax.default_prng_impl('threefry2x32'):
key = make_key(72)
f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')),
in_axes=(None, 0), axis_name='batch')(key, jnp.ones(3))
s1, s2, s3 = random.split(key, 3)
f1, f2, f3 = map(random.key_data, [f1, f2, f3])
s1, s2, s3 = map(random.key_data, [s1, s2, s3])
self.assertArraysEqual(f1, s1)
self.assertArraysEqual(f2, s2)
self.assertArraysEqual(f3, s3)
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
def test_loggamma_nan_corner_case(self):
# regression test for https://github.com/google/jax/issues/17922
# This particular key previously led to NaN output.
# If the underlying implementation ever changes, this test will no longer
# exercise this corner case, so we compare to a particular output value
# rather than just checking for lack of NaNs.
expected = jnp.float32(-4.595436)
key = random.wrap_key_data(
jnp.array([3200590325, 713258242], dtype='uint32'))
actual = random.loggamma(key, 0.0, dtype='float32')
rtol = 1E-4 if jtu.test_device_matches(["tpu"]) else 1E-6
self.assertAllClose(expected, actual, rtol=rtol)
@parameterized.parameters([params
for d in [
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]},
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.enable_x64.value else [0, 4294967295]},
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]},
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.enable_x64.value else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.enable_x64.value else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.enable_x64.value else [0, 2147483547]},
]
for params in [dict(**d, make_key=ctor) for ctor in KEY_CTORS]
])
def test_prng_seeds_and_keys(self, seed, typ, jit, key, make_key):
seed = typ(seed)
if jit:
maker = lambda k: random.key_data(jax.jit(make_key)(k))
else:
maker = lambda k: random.key_data(make_key(k))
if (jit and typ is int and not config.enable_x64.value and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
# We expect an error to be raised.
# NOTE: we check 'if jit' because some people rely on builtin int seeds
# (e.g. from PRNGKey(hash("altair is best plotting library"))) outside jit
# First check with no cache entry (note lambda above).
with self.assertRaises(OverflowError):
maker(seed)
# Then populate a cache entry.
maker(typ(0)).block_until_ready()
# Then check now that we have a cache entry.
with self.assertRaises(OverflowError):
maker(seed)
else:
# Otherwise we expect no error.
actual = maker(seed)
expected = jnp.array(key, dtype=jnp.uint32)
self.assertArraysEqual(actual, expected)
@parameterized.parameters([
{'make_key': ctor, 'name': name, 'impl': impl}
for ctor in KEY_CTORS
for name, impl in PRNG_IMPLS])
def test_default_prng_selection(self, make_key, name, impl):
with jax.default_prng_impl(name):
self.assertIs(jax_random.default_prng_impl(), impl)
key = make_key(42)
self.check_key_has_impl(key, impl)
k1, k2 = random.split(key, 2)
self.check_key_has_impl(k1, impl)
self.check_key_has_impl(k2, impl)
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
for ctor in KEY_CTORS
for name, impl in PRNG_IMPLS])
def test_key_construction_with_explicit_impl_name(self, make_key, name, impl):
key = make_key(42, impl=name)
self.check_key_has_impl(key, impl)
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_isinstance(self, make_key):
key = make_key(0)
self.assertIsInstance(key, jax.Array)
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_key_output_vjp(self, make_key):
# See https://github.com/google/jax/issues/14856
def f(seed): return make_key(seed)
jax.vjp(f, 1) # doesn't crash
def test_legacy_prng_key_flag(self):
raw_key = jnp.zeros(2, dtype='uint32')
invalid_key = jnp.zeros(1, dtype='float32')
msg = "Legacy uint32 key array passed as key to jax.random function."
with jax.legacy_prng_key('allow'):
# TODO(jakevdp): remove when enable_custom_prng no longer issues warnings
with jax.enable_custom_prng(False):
with self.assertNoWarnings():
random.uniform(raw_key)
with jax.legacy_prng_key('warn'):
with self.assertWarnsRegex(UserWarning, msg):
random.uniform(raw_key)
with jax.legacy_prng_key('error'):
with self.assertRaisesRegex(ValueError, msg):
random.uniform(raw_key)
# Invalid key error should take precedence.
with self.assertRaisesRegex(TypeError, "JAX encountered invalid PRNG key data"):
random.uniform(invalid_key)
class ThreefryPrngTest(jtu.JaxTestCase):
@parameterized.parameters([{'make_key': ctor} for ctor in [
partial(random.PRNGKey, impl='threefry2x32'),
partial(random.key, impl='threefry2x32')]])
def test_seed_no_implicit_transfers(self, make_key):
# See https://github.com/google/jax/issues/15613
with jax.transfer_guard('disallow'):
make_key(jax.device_put(42)) # doesn't crash
class KeyArrayTest(jtu.JaxTestCase):
# Key arrays involve:
# * a Python key array type, backed by an underlying uint32 "base" array,
# * an abstract shaped array with key element type,
# * primitives that return or operate on such shaped arrays,
# * compiler lowerings,
# * a device-side data representation...
# Test it all!
#
# A handful of these tests follow CustomElementTypesTest in
# lax_tests.py as an example. If you add a test here (e.g. testing
# lowering of a key-dtyped shaped array), consider whether it
# might also be a more general test of opaque element types. If
# so, add a corresponding test to CustomElementTypesTest as well.
def assertKeysEqual(self, key1, key2):
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(random.key_data(key1), random.key_data(key2))
def test_construction(self):
key = random.key(42)
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
def test_issubdtype(self):
key = random.key(42)
self.assertTrue(jnp.issubdtype(key.dtype, key.dtype))
self.assertTrue(jnp.issubdtype(key.dtype, dtypes.prng_key))
self.assertTrue(jnp.issubdtype(key.dtype, dtypes.extended))
self.assertTrue(jnp.issubdtype(key.dtype, np.generic))
self.assertFalse(jnp.issubdtype(key.dtype, np.integer))
self.assertFalse(jnp.issubdtype(key.dtype, np.number))
with self.assertRaisesRegex(TypeError, "Cannot interpret"):
jnp.issubdtype(key, dtypes.prng_key)
@skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag')
def test_construction_upgrade_flag(self):
key = random.PRNGKey(42)
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
def make_keys(self, *shape, seed=28):
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
return jax.vmap(random.key)(seeds).reshape(shape)
def test_key_as_seed(self):
key = self.make_keys()
with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"):
random.PRNGKey(key)
with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"):
random.key(key)
def test_non_scalar_seed(self):
seed_arr = np.arange(4)
with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"):
random.PRNGKey(seed_arr)
with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"):
random.key(seed_arr)
def test_non_integer_seed(self):
seed = np.pi
with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"):
random.PRNGKey(seed)
with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"):
random.key(seed)
def test_dtype_property(self):
k1, k2 = self.make_keys(), self.make_keys()
self.assertEqual(k1.dtype, k2.dtype)
k3, k4 = random.split(k1, 2)
self.assertEqual(k1.dtype, k3.dtype)
self.assertEqual(k3.dtype, k4.dtype)
g = []
def f(k):
g.append(k.dtype)
return random.split(k)
_ = jax.jit(f)(k1)
self.assertEqual(g[0], k1.dtype)
self.assertEqual(g[0], k2.dtype)
def test_key_dtype_attributes(self):
key = self.make_keys()
key_raw = random.key_data(key)
self.assertStartsWith(key.dtype.name, "key")
self.assertEqual(key.size * key.dtype.itemsize,
key_raw.size * key_raw.dtype.itemsize)
def test_key_attributes(self):
key = self.make_keys()
self.assertEqual(key.itemsize, key.dtype.itemsize)
self.assertEqual(key.size, math.prod(key.shape))
self.assertEqual(key.ndim, len(key.shape))
def test_key_copy(self):
key = self.make_keys()
self.assertKeysEqual(key, key.copy())
self.assertKeysEqual(key, copy.copy(key))
self.assertKeysEqual(key, copy.deepcopy(key))
self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key))
def test_isinstance(self):
@jax.jit
def f(k):
self.assertIsInstance(k, prng_internal.PRNGKeyArray)
return k
k1 = self.make_keys()
k2 = f(k1)
self.assertIsInstance(k1, prng_internal.PRNGKeyArray)
self.assertIsInstance(k2, prng_internal.PRNGKeyArray)
def test_cpp_dispatch_normal(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key array as an argument.
@jax.jit
def f(key):
return jax.random.normal(key)
key = self.make_keys()
with jtu.count_pjit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
def test_cpp_dispatch_split(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
# function with a key arrays as inputs and as outputs.
@jax.jit
def f(key):
return jax.random.split(key)
key = self.make_keys()
with jtu.count_pjit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
def test_cpp_dispatch_aot_normal(self):
# Ensure we stay on the C++ dispatch path when calling an
# AOT-compiled function with a key array as an argument.
key = self.make_keys()
f = jax.jit(lambda key: jax.random.normal(key)).lower(key).compile()
with jtu.count_aot_jit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
def test_cpp_dispatch_aot_split(self):
# Ensure we stay on the C++ dispatch path when calling an
# AOT-compiled function with a key arrays as inputs and as
# outputs.
key = self.make_keys()
f = jax.jit(lambda key: jax.random.split(key)).lower(key).compile()
with jtu.count_aot_jit_cpp_cache_miss() as count:
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
# -- prng primitives
def test_random_wrap_vmap(self):
f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl)
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
keys = jax.vmap(f, in_axes=0)(base_arr)
self.assertIsInstance(keys, prng_internal.PRNGKeyArray)
self.assertEqual(keys.shape, (3,))
keys = jax.vmap(f, in_axes=1)(base_arr.T)
self.assertIsInstance(keys, prng_internal.PRNGKeyArray)
self.assertEqual(keys.shape, (3,))
@jtu.sample_product(use_internal=[False, True])
def test_random_unwrap(self, use_internal):
unwrap = prng_internal.random_unwrap if use_internal else random.key_data
def f(k): return unwrap(k)
k = self.make_keys(3, 4)
out = f(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.jit(f)(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.vmap(f)(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.vmap(jax.jit(f))(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
if not use_internal:
return
x = jnp.arange(12, dtype=np.dtype('uint32')).reshape(3, 4)
self.assertRaisesRegex(
TypeError, 'random_unwrap takes key array operand, got .*',
lambda: f(x))
self.assertRaisesRegex(
TypeError, 'random_unwrap takes key array operand, got .*',
lambda: jax.jit(f)(x))
self.assertRaisesRegex(
TypeError, 'random_unwrap takes key array operand, got .*',
lambda: jax.vmap(f)(x))
def test_eval_shape_keys_in(self):
def f(key):
return prng_internal.random_bits(key, bit_width=32, shape=(5,))
out = jax.eval_shape(f, self.make_keys())
self.assertEqual(out.shape, (5,))
self.assertEqual(out.dtype, np.dtype('uint32'))
def f(key):
return prng_internal.random_bits(key, bit_width=16, shape=(5,))
out = jax.eval_shape(f, self.make_keys())
self.assertEqual(out.shape, (5,))
self.assertEqual(out.dtype, np.dtype('uint16'))
def test_eval_shape_keys_out(self):
def f(seed):
return self.make_keys(seed=seed)
out = jax.eval_shape(f, 28)
self.assertEqual(out.shape, ())
# TODO(frostig): check dtype too when available
def test_eval_shape_keys_in_out(self):
def f(key):
return random.split(key)
out = jax.eval_shape(f, self.make_keys())
self.assertEqual(out.shape, (2,))
# TODO(frostig): check dtype too when available
def test_vmap(self):
ks = self.make_keys(3, 4, 5)
ys = jax.vmap(jax.jit(lambda k: k.T))(ks)
self.assertEqual(ys.shape, (3, 5, 4))
# -- dtype-polymorphic operation (esp. lowerings)
def test_scan_jaxpr(self):
ks = self.make_keys(3, 4, 5)
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
# { lambda ; a:key<fry>[3,4,5]. let
# b:key<fry>[3,5,4] = scan[
# jaxpr={ lambda ; c:key<fry>[4,5]. let
# d:key<fry>[5,4] = transpose[permutation=(1, 0)] c
# in (d,) }
# ] a
# in (b,) }
self.assertLen(jaxpr.invars, 1)
a, = jaxpr.invars
self.assertIsInstance(a.aval, core.ShapedArray)
self.assertEqual(a.aval.shape, (3, 4, 5))
self.assertIs(type(a.aval.dtype), prng_internal.KeyTy)
self.assertLen(jaxpr.eqns, 1)
e, = jaxpr.eqns
self.assertLen(e.outvars, 1)
b, = e.outvars
self.assertIsInstance(b.aval, core.ShapedArray)
self.assertEqual(b.aval.shape, (3, 5, 4))
self.assertIs(type(b.aval.dtype), prng_internal.KeyTy)
def test_scan_lowering(self):
ks = self.make_keys(3, 4)
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
_, out = jax.jit(f)(ks) # doesn't crash
self.assertIsInstance(out, prng_internal.PRNGKeyArray)
self.assertEqual(out.shape, (3, 4))
def test_slice(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 4))
def test_dynamic_slice(self):
ks = self.make_keys(3, 4)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_slice_in_dim, slice_size=2))(ks, index)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 4))
def test_dynamic_update_slice(self):
ks = self.make_keys(3, 4)
k = self.make_keys(1, 4)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_update_slice_in_dim, axis=0))(ks, k, index)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 4))
def test_transpose(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x.T)(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (4, 3))
def test_gather(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (4,))
ks = self.make_keys(3, 4, 5)
ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (4, 5))
ys = jax.jit(lambda x: x[1, 2:4])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 5))
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (2,))
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 2, 1))
def test_select(self):
ks = self.make_keys(3, 2)
cs = jnp.array([True, False, False, True, False, True]).reshape(3, 2)
ys = jax.jit(lax.select)(cs, ks, ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 2))
def test_select_scalar_cond(self):
# regression test for https://github.com/google/jax/issues/16422
ks = self.make_keys(3)
ys = lax.select(True, ks, ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (3,))
def test_vmap_of_cond(self):
# See https://github.com/google/jax/issues/15869
def f(x):
keys = self.make_keys(*x.shape)
return lax.select(x, keys, keys)
x = jnp.array([True, False, False])
f(x) # doesn't crash
def test_device_put(self):
device = jax.devices()[0]
keys = self.make_keys(4)
keys_on_device = jax.device_put(keys, device)
self.assertKeysEqual(keys, keys_on_device)
def test_device_put_sharded(self):
devices = jax.devices()
keys = self.make_keys(len(devices))
keys_on_device = jax.device_put_sharded(list(keys), devices)
self.assertKeysEqual(keys, keys_on_device)
def test_device_put_replicated(self):
devices = jax.devices()
key = self.make_keys()
keys_on_device = jax.device_put_replicated(key, devices)
self.assertKeysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device)
def test_make_array_from_callback(self):
devices = jax.devices()
shape = (len(devices),)
mesh = jtu.create_global_mesh((len(devices),), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
def callback(index):
i = jnp.arange(len(devices))[index[0]]
return jax.vmap(random.key)(i)
result = jax.make_array_from_callback(shape, sharding, callback)
expected = jax.vmap(random.key)(jnp.arange(len(devices)))
self.assertKeysEqual(result, expected)
def test_make_array_from_single_device_arrays(self):
devices = jax.devices()
shape = (len(devices),)
mesh = jtu.create_global_mesh((len(devices),), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
keys = random.split(random.key(0), len(devices))
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
self.assertKeysEqual(result, keys)
def test_key_array_custom_jvp(self):
def f_raw(x, key):
return x * random.normal(key, ())
f = jax.custom_jvp(f_raw)
@f.defjvp
def f_jvp(primals, tangents):
nonlocal key_dot
x, key = primals
x_dot, key_dot = tangents
rand = random.normal(key, ())
tangent_out = x_dot * rand
primal_out = x * rand
return primal_out, tangent_out
key_dot = None
key = self.make_keys()
default_result = jax.grad(f_raw)(0.0, key)
custom_result = jax.grad(f)(0.0, key)
self.assertAllClose(default_result, custom_result)
self.assertIsInstance(key_dot, prng_internal.PRNGKeyArray)
self.assertArraysEqual(random.key_data(key_dot), np.uint32(0))