-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
prng.py
1447 lines (1169 loc) · 49.1 KB
/
prng.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from collections.abc import Iterator, Sequence
from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, NamedTuple
import numpy as np
import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax._src import api_util
from jax._src import api
from jax._src import basearray
from jax._src import config as config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import sharding_specs
from jax._src import tree_util as tree_util_internal
from jax._src import typing
from jax._src.api import jit, vmap
from jax._src.dtypes import float0
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir import ir
from jax._src.lib import gpu_prng
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.array_methods import (
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
NamedSharding, PmapSharding, GSPMDSharding, XLACompatibleSharding)
from jax._src.typing import Array
from jax._src.util import safe_map, safe_zip
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Device = xc.Device
Shard = Any # TODO(jakevdp): fix circular imports and import Shard
Shape = tuple[int, ...]
UINT_DTYPES = {
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
# -- PRNG implementation interface
class PRNGImpl(NamedTuple):
"""Specifies PRNG key shape and operations.
A PRNG implementation is determined by a key type ``K`` and a
collection of functions that operate on such keys. The key type
``K`` is an array type with element type uint32 and shape specified
by ``key_shape``. The type signature of each operations is::
seed :: int[] -> K
fold_in :: K -> int[] -> K
split[shape] :: K -> K[*shape]
random_bits[shape, bit_width] :: K -> uint<bit_width>[*shape]
A PRNG implementation is adapted to an array-like object of keys
``K`` by the ``PRNGKeyArray`` class, which should be created via the
``random_seed`` function.
"""
key_shape: Shape
seed: Callable
split: Callable
random_bits: Callable
fold_in: Callable
name: str = '<unnamed>'
tag: str = '?'
def __hash__(self) -> int:
return hash(self.tag)
def __str__(self) -> str:
return self.tag
def pprint(self):
ty = self.__class__.__name__
return (pp.text(f"{ty} [{self.tag}] {{{self.name}}}:") +
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
]))))
prngs = {}
def register_prng(impl: PRNGImpl):
if impl.name in prngs:
raise ValueError(f'PRNG with name {impl.name} already registered: {impl}')
prngs[impl.name] = impl
# -- PRNG key arrays
def _check_prng_key_data(impl, key_data: typing.Array):
ndim = len(impl.key_shape)
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
f"to have ndim, shape, and dtype attributes. Got {key_data}")
if key_data.ndim < 1:
raise TypeError("JAX encountered invalid PRNG key data: expected "
f"key_data.ndim >= 1; got ndim={key_data.ndim}")
if key_data.shape[-ndim:] != impl.key_shape:
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.shape to "
f"end with {impl.key_shape}; got shape={key_data.shape} for {impl=}")
if key_data.dtype not in [np.uint32, float0]:
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
f"got dtype={key_data.dtype}")
class PRNGKeyArrayMeta(abc.ABCMeta):
"""Metaclass for overriding PRNGKeyArray isinstance checks."""
def __instancecheck__(cls, instance):
try:
return (isinstance(instance.aval, core.ShapedArray) and
type(instance.aval.dtype) is KeyTy)
except AttributeError:
return super().__instancecheck__(instance)
class PRNGKeyArray(jax.Array, metaclass=PRNGKeyArrayMeta):
"""An array whose elements are PRNG keys"""
@abc.abstractmethod
def unsafe_buffer_pointer(self) -> int: ...
@abc.abstractmethod
def block_until_ready(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def copy_to_host_async(self) -> None: ...
@property
@abc.abstractmethod
def shape(self) -> tuple[int, ...]: ...
@property
@abc.abstractmethod
def ndim(self) -> int: ...
@property
@abc.abstractmethod
def size(self) -> int: ...
@property
@abc.abstractmethod
def dtype(self): ...
@property
@abc.abstractmethod
def itemsize(self): ...
@property
@abc.abstractmethod
def sharding(self): ...
@property
@abc.abstractmethod
def at(self) -> _IndexUpdateHelper: ... # type: ignore[override]
@abc.abstractmethod
def __len__(self) -> int: ...
@abc.abstractmethod
def __iter__(self) -> Iterator[PRNGKeyArray]: ...
@abc.abstractmethod
def reshape(self, *args, order='C') -> PRNGKeyArray: ...
@property
@abc.abstractmethod
def T(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def __getitem__(self, _) -> PRNGKeyArray: ...
@abc.abstractmethod
def ravel(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def squeeze(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def swapaxes(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def take(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def transpose(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def flatten(self, *_, **__) -> PRNGKeyArray: ...
@property
@abc.abstractmethod
def is_fully_addressable(self) -> bool: ...
@property
@abc.abstractmethod
def is_fully_replicated(self) -> bool: ...
@abc.abstractmethod
def device(self) -> Device: ...
@abc.abstractmethod
def devices(self) -> set[Device]: ...
@abc.abstractmethod
def delete(self) -> None: ...
@abc.abstractmethod
def is_deleted(self) -> bool: ...
@abc.abstractmethod
def on_device_size_in_bytes(self) -> int: ...
@property
@abc.abstractmethod
def addressable_shards(self) -> list[Shard]: ...
@property
@abc.abstractmethod
def global_shards(self) -> list[Shard]: ...
@abc.abstractmethod
def addressable_data(self, index: int) -> PRNGKeyArray: ...
# TODO(jakevdp): potentially add tolist(), tobytes(),
# device_buffer, device_buffers, __cuda_interface__()
class PRNGKeyArrayImpl(PRNGKeyArray):
"""An array of PRNG keys backed by an RNG implementation.
This class lifts the definition of a PRNG, provided in the form of a
``PRNGImpl``, into an array-like pytree class. Instances of this
class behave like an array whose base elements are keys, hiding the
fact that keys are typically arrays (of ``uint32`` dtype) themselves.
PRNGKeyArrays are also restricted relative to JAX arrays in that
they do not expose arithmetic operations. They instead expose
wrapper methods around the PRNG implementation functions (``split``,
``random_bits``, ``fold_in``).
"""
_impl: PRNGImpl
_base_array: typing.Array
def __init__(self, impl, key_data: Any):
assert not isinstance(key_data, core.Tracer)
_check_prng_key_data(impl, key_data)
self._impl = impl
self._base_array = key_data
def block_until_ready(self):
_ = self._base_array.block_until_ready()
return self
def copy_to_host_async(self):
_ = self._base_array.copy_to_host_async()
@property
def aval(self):
return keys_shaped_array(self._impl, self.shape)
@property
def shape(self):
return base_arr_shape_to_keys_shape(self._impl, self._base_array.shape)
@property
def size(self):
return math.prod(self.shape)
@property
def ndim(self):
return len(self.shape)
@property
def dtype(self):
return KeyTy(self._impl)
@property
def itemsize(self):
return self.dtype.itemsize
_device = property(op.attrgetter('_base_array._device'))
_committed = property(op.attrgetter('_base_array._committed'))
device = property(op.attrgetter('_base_array.device')) # type: ignore[assignment]
devices = property(op.attrgetter('_base_array.devices')) # type: ignore[assignment]
is_fully_addressable = property(op.attrgetter('_base_array.is_fully_addressable')) # type: ignore[assignment]
is_fully_replicated = property(op.attrgetter('_base_array.is_fully_replicated')) # type: ignore[assignment]
delete = property(op.attrgetter('_base_array.delete')) # type: ignore[assignment]
is_deleted = property(op.attrgetter('_base_array.is_deleted')) # type: ignore[assignment]
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))
@property
def addressable_shards(self) -> list[Shard]:
return [
type(s)(
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self._impl, s._data),
)
for s in self._base_array.addressable_shards
]
@property
def global_shards(self) -> list[Shard]:
return [
type(s)(
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self._impl, s._data),
)
for s in self._base_array.global_shards
]
@property
def sharding(self):
phys_sharding = self._base_array.sharding
return KeyTyRules.logical_op_sharding(self.aval, phys_sharding)
def _is_scalar(self):
base_ndim = len(self._impl.key_shape)
return self._base_array.ndim == base_ndim
def __len__(self):
if self._is_scalar():
raise TypeError('len() of unsized object')
return len(self._base_array)
def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
if self._is_scalar():
raise TypeError('iteration over a 0-d key array')
# TODO(frostig): we may want to avoid iteration by slicing because
# a very common use of iteration is `k1, k2 = split(key)`, and
# slicing/indexing may be trickier to track for linearity checking
# purposes. Maybe we can:
# * introduce an unpack primitive+traceable (also allow direct use)
# * unpack upfront into shape[0] many keyarray slices
# * return iter over these unpacked slices
# Whatever we do, we'll want to do it by overriding
# ShapedArray._iter when the element type is KeyTy...
return (PRNGKeyArrayImpl(self._impl, k) for k in iter(self._base_array))
def __repr__(self):
return (f'Array({self.shape}, dtype={self.dtype.name}) overlaying:\n'
f'{self._base_array}')
def pprint(self):
pp_keys = pp.text('shape = ') + pp.text(str(self.shape))
pp_impl = pp.text('impl = ') + self._impl.pprint()
return str(pp.group(
pp.text('PRNGKeyArray:') +
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
def copy(self):
return self.__class__(self._impl, self._base_array.copy())
__hash__ = None # type: ignore[assignment]
__array_priority__ = 100
# Overwritten immediately below
@property
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]
@property
def T(self) -> PRNGKeyArray: assert False
def __getitem__(self, _) -> PRNGKeyArray: assert False
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
def ravel(self, *_, **__) -> PRNGKeyArray: assert False
def reshape(self, *_, **__) -> PRNGKeyArray: assert False
def squeeze(self, *_, **__) -> PRNGKeyArray: assert False
def swapaxes(self, *_, **__) -> PRNGKeyArray: assert False
def take(self, *_, **__) -> PRNGKeyArray: assert False
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
_set_array_base_attributes(PRNGKeyArrayImpl, include=[
*(f"__{op}__" for op in _array_operators),
'at', 'flatten', 'ravel', 'reshape',
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
basearray.Array.register(PRNGKeyArrayImpl)
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
def prngkeyarrayimpl_flatten(x):
return (x._base_array,), x._impl
def prngkeyarrayimpl_unflatten(impl, children):
base_array, = children
return PRNGKeyArrayImpl(impl, base_array)
tree_util_internal.dispatch_registry.register_node(
PRNGKeyArrayImpl, prngkeyarrayimpl_flatten, prngkeyarrayimpl_unflatten)
# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArrayImpl:
return random_seed(seed, impl=impl)
def keys_shaped_array(impl, shape):
return core.ShapedArray(shape, KeyTy(impl))
def base_arr_shape_to_keys_shape(impl, base_arr_shape):
base_ndim = len(impl.key_shape)
return base_arr_shape[:-base_ndim]
def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
if dispatch.is_single_device_sharding(sharding):
return sharding
elif isinstance(sharding, PmapSharding):
key_shape = aval.dtype._impl.key_shape
trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape)
phys_sharding_spec = sharding_specs.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
elif isinstance(sharding, NamedSharding):
key_shape = aval.dtype._impl.key_shape
trailing_spec = [None] * len(key_shape)
return NamedSharding(
sharding.mesh,
PartitionSpec(*sharding.spec, *trailing_spec))
elif is_sharding_from_xla:
return sharding
else:
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
return GSPMDSharding(
sharding._device_assignment,
KeyTyRules.physical_hlo_sharding(aval, hlos))
class KeyTyRules:
@staticmethod
def full(shape, fill_value, dtype):
physical_shape = (*shape, *dtype._impl.key_shape)
if hasattr(fill_value, 'dtype') and jnp.issubdtype(fill_value.dtype, dtypes.prng_key):
key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape)
else:
key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32'))
# TODO(frostig,mattjj,vanderplas,lenamartens): consider this consumed from
# the outset.
return random_wrap(key_data, impl=dtype._impl)
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32'))
@staticmethod
def physical_const(val) -> Array:
return val._base_array
@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
op_sharding_proto = hlo_sharding.to_proto() # type: ignore
new_op_sharding = op_sharding_proto.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
suffix = [tad.pop()] if op_sharding_proto.replicate_on_last_tile_dim else []
tad.extend([1] * len(key_shape) + suffix)
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)
@staticmethod
def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
if dispatch.is_single_device_sharding(phys_sharding):
return phys_sharding
elif isinstance(phys_sharding, PmapSharding):
key_shape = aval.dtype._impl.key_shape
logical_sharding_spec = sharding_specs.ShardingSpec(
sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)],
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=phys_sharding.devices,
sharding_spec=logical_sharding_spec)
elif isinstance(phys_sharding, NamedSharding):
key_shape = aval.dtype._impl.key_shape
return pxla.create_mesh_pspec_sharding(
phys_sharding.mesh,
PartitionSpec(*phys_sharding.spec[:-len(key_shape)]))
else:
key_shape = aval.dtype._impl.key_shape
phys_op_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape)).to_proto()
logical_op_sharding = phys_op_sharding.clone()
tad = list(logical_op_sharding.tile_assignment_dimensions)
tad = tad[:-len(key_shape)]
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return PRNGKeyArrayImpl(aval.dtype._impl, buf)
return handler
@staticmethod
def local_sharded_result_handler(aval, sharding, indices):
phys_aval = core.physical_aval(aval)
key_shape = aval.dtype._impl.key_shape
phys_handler_maker = pxla.local_result_handlers[core.ShapedArray]
# set up a grounded sharding (with a grounded sharding spec)
if isinstance(sharding, (PmapSharding, NamedSharding)):
phys_sharding = make_key_array_phys_sharding(
aval, sharding, is_sharding_from_xla=False)
else:
assert False, f'impossible sharding {sharding} in local sharded result handler'
# set up grounded indices
trailing_inds = [slice(None)] * len(key_shape)
phys_indices = [(*inds, *trailing_inds) for inds in indices]
# make a physical handler
phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices)
# set up a handler that calls the physical one and wraps back up
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_sharding = make_key_array_phys_sharding(
aval, out_sharding, is_out_sharding_from_xla)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
return handler
@staticmethod
def make_sharded_array(aval, sharding, arrays, committed):
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_arrays = [random_unwrap(arr) for arr in arrays]
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
phys_result = phys_handler(phys_arrays)
return PRNGKeyArrayImpl(aval.dtype._impl, phys_result)
@staticmethod
def device_put_sharded(vals, aval, sharding, devices):
physical_aval = core.physical_aval(aval)
physical_buffers = tree_util.tree_map(random_unwrap, vals)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices))
return random_wrap(physical_result, impl=aval.dtype._impl)
@staticmethod
def device_put_replicated(val, aval, sharding, devices):
physical_aval = core.physical_aval(aval)
assert len(xla.aval_to_xla_shapes(physical_aval)) == 1
physical_buf = random_unwrap(val)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
return random_wrap(physical_result, impl=aval.dtype._impl)
@staticmethod
def tangent_dtype(_):
return dtypes.float0
# TODO(mattjj,frostig): even though the key dtype shouldn't appear in
# tangents, our ad.replace_float0s in custom_jvp/vjp means passing in zeros
# like the primal to user rules
@staticmethod
def zero(_):
return np.zeros((), dtypes.float0)
@staticmethod
def convert_from(key_dtype, other_dtype) -> bool:
return False
@staticmethod
def convert_to(other_dtype, key_dtype) -> bool:
return False
class KeyTy(dtypes.ExtendedDType):
_impl: PRNGImpl # TODO(mattjj,frostig): protocol really
_rules = KeyTyRules
type = dtypes.prng_key
def __init__(self, impl):
self._impl = impl
@property
def name(self) -> str:
return f'key<{self._impl.tag}>'
@property
def itemsize(self) -> int:
return math.prod(self._impl.key_shape) * np.dtype('uint32').itemsize
def __repr__(self) -> str:
return self.name
def __eq__(self, other):
return type(other) is KeyTy and self._impl == other._impl
def __hash__(self) -> int:
return hash((self.__class__, self._impl))
core.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding):
arr = x._base_array
phys_sharding = make_key_array_phys_sharding(
x.aval, sharding, is_sharding_from_xla=False)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
def key_array_constant_handler(x):
arr = x._base_array
return mlir.get_constant_handler(type(arr))(arr)
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
# -- primitives
def iterated_vmap_unary(n, f):
for _ in range(n):
f = api.vmap(f)
return f
# TODO(frostig): Revise the following two functions? These basically
# undo the singleton dimensions added by `batching.defbroadcasting`.
# It works, but introduces some possibly-redundant squeezes. Can we
# borrow from other broadcasting primitives instead?
def squeeze_vmap(f, left):
def squeeze_vmap_f(x, y):
if left:
x = jnp.squeeze(x, axis=0)
axes = (None, 0)
else:
y = jnp.squeeze(y, axis=0)
axes = (0, None)
return api.vmap(f, in_axes=axes, out_axes=0)(x, y)
return squeeze_vmap_f
def iterated_vmap_binary_bcast(shape1, shape2, f):
ndim1, ndim2 = len(shape1), len(shape2)
if ndim1 == ndim2 == 0:
return f
if 0 in [ndim1, ndim2]:
if ndim1 == 0:
return lambda x, y: iterated_vmap_unary(ndim2, lambda y: f(x, y))(y)
else:
return lambda x, y: iterated_vmap_unary(ndim1, lambda x: f(x, y))(x)
assert len(shape1) == len(shape2)
for sz1, sz2 in reversed(zip(shape1, shape2)):
if sz1 == sz2:
f = api.vmap(f, out_axes=0)
else:
assert sz1 == 1 or sz2 == 1, (sz1, sz2)
f = squeeze_vmap(f, sz1 == 1)
return f
def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArrayImpl:
# Avoid overflow error in X32 mode by first converting ints to int64.
# This breaks JIT invariance for large ints, but supports the common
# use-case of instantiating with Python hashes in X32 mode.
if isinstance(seeds, int):
seeds_arr = jnp.asarray(np.int64(seeds))
else:
seeds_arr = jnp.asarray(seeds)
if config.random_seed_offset.value:
seeds_arr += config.random_seed_offset.value
return random_seed_p.bind(seeds_arr, impl=impl)
random_seed_p = core.Primitive('random_seed')
ad.defjvp_zero(random_seed_p)
batching.defvectorized(random_seed_p)
@random_seed_p.def_abstract_eval
def random_seed_abstract_eval(seeds_aval, *, impl):
return keys_shaped_array(impl, seeds_aval.shape)
@random_seed_p.def_impl
def random_seed_impl(seeds, *, impl):
base_arr = random_seed_impl_base(seeds, impl=impl)
return PRNGKeyArrayImpl(impl, base_arr)
def random_seed_impl_base(seeds, *, impl):
seed = iterated_vmap_unary(np.ndim(seeds), impl.seed)
return seed(seeds)
def random_seed_lowering(ctx, seeds, *, impl):
aval, = ctx.avals_in
seed = iterated_vmap_unary(aval.ndim, impl.seed)
seed_lowering = mlir.lower_fun(seed, multiple_results=False)
return mlir.delegate_lowering(
ctx, seed_lowering, seeds,
avals_out=map(core.physical_aval, ctx.avals_out))
mlir.register_lowering(random_seed_p, random_seed_lowering)
def random_split(keys, shape: Shape):
return random_split_p.bind(keys, shape=shape)
random_split_p = core.Primitive('random_split')
ad.defjvp_zero(random_split_p)
batching.defvectorized(random_split_p)
@random_split_p.def_abstract_eval
def random_split_abstract_eval(keys_aval, *, shape):
return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape))
@random_split_p.def_impl
def random_split_impl(keys, *, shape):
base_arr = random_split_impl_base(
keys._impl, keys._base_array, keys.ndim, shape=shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)
def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
return split(base_arr)
def random_split_lowering(ctx, keys, *, shape):
aval, = ctx.avals_in
impl = aval.dtype._impl
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, shape))
split_lowering = mlir.lower_fun(split, multiple_results=False)
return mlir.delegate_lowering(
ctx, split_lowering, keys,
avals_in=[core.physical_aval(aval)],
avals_out=map(core.physical_aval, ctx.avals_out))
mlir.register_lowering(random_split_p, random_split_lowering)
def random_fold_in(keys, msgs):
return random_fold_in_p.bind(keys, jnp.asarray(msgs))
random_fold_in_p = core.Primitive('random_fold_in')
ad.defjvp_zero(random_fold_in_p)
batching.defbroadcasting(random_fold_in_p)
@random_fold_in_p.def_abstract_eval
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
shape = lax_internal.broadcasting_shape_rule(
'random_fold_in', keys_aval, msgs_aval)
named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval)
return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape)
@random_fold_in_p.def_impl
def random_fold_in_impl(keys, msgs):
base_arr = random_fold_in_impl_base(
keys._impl, keys._base_array, msgs, keys.shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
fold_in = iterated_vmap_binary_bcast(
keys_shape, np.shape(msgs), impl.fold_in)
return fold_in(base_arr, msgs)
def random_fold_in_lowering(ctx, keys, msgs):
keys_aval, msgs_aval = ctx.avals_in
impl = keys_aval.dtype._impl
fold_in = iterated_vmap_binary_bcast(
keys_aval.shape, msgs_aval.shape, impl.fold_in)
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
return mlir.delegate_lowering(
ctx, fold_in_lowering, keys, msgs,
avals_in=[core.physical_aval(keys_aval), msgs_aval],
avals_out=map(core.physical_aval, ctx.avals_out))
mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
def random_bits(keys, bit_width, shape):
shape = core.as_named_shape(shape)
for name, size in shape.named_items:
# TODO(frostig,mattjj,apaszke): Is this real_size check necessary,
# and is it meant to raise a user-facing ValueError? Should it be
# an `assert` (or RuntimeError) instead? Why do we check it in
# calls to `random_bits` instead of a more common paralleism path?
real_size = lax.psum(1, name)
if real_size != size:
raise ValueError(f"The shape of axis {name} was specified as {size}, "
f"but it really is {real_size}")
axis_index = lax.axis_index(name)
keys = random_fold_in(keys, axis_index)
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional)
random_bits_p = core.Primitive('random_bits')
ad.defjvp_zero(random_bits_p)
batching.defvectorized(random_bits_p)
@random_bits_p.def_abstract_eval
def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
out_shape = (*keys_aval.shape, *shape)
out_dtype = dtypes.dtype(f'uint{bit_width}')
return core.ShapedArray(out_shape, out_dtype)
@random_bits_p.def_impl
def random_bits_impl(keys, *, bit_width, shape):
return random_bits_impl_base(keys._impl, keys._base_array, keys.ndim,
bit_width=bit_width, shape=shape)
def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):
bits = iterated_vmap_unary(
keys_ndim, lambda k: impl.random_bits(k, bit_width, shape))
return bits(base_arr)
def random_bits_lowering(ctx, keys, *, bit_width, shape):
aval, = ctx.avals_in
impl = aval.dtype._impl
bits = iterated_vmap_unary(
aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
ctx_new = ctx.replace(avals_in=[core.physical_aval(aval)])
out = bits_lowering(ctx_new, keys)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
mlir.register_lowering(random_bits_p, random_bits_lowering)
# The following wrap/unwrap primitives are at least a stopgap for
# backwards compatibility, namely when `config.jax_enable_custom_prng`
# is False. We need to convert key arrays to and from underlying
# uint32 base array, and we may need to do so under a jit. For
# example, we want to support:
#
# keys = jax.jit(random.split)(key)
#
# where `key` and `keys` are both acceptably old-style uint32 arrays
# so long as enable_custom_prng is False. The way we handle this is
# that `random.split` adapts the input/output by converting to/from
# key arrays across its call to `random_split`. So we rely on these
# wrap/unwrap casting primitives to allow that conversion under jit.
#
# We may want to keep both around for testing and debugging escape
# hatches. We can rename them `unsafe` for emphasis, and/or issue a
# warning on entry to the traceable.
#
# TODO(frostig): Consider removal once we always enable_custom_prng.
def random_wrap(base_arr, *, impl):
_check_prng_key_data(impl, base_arr)
return random_wrap_p.bind(base_arr, impl=impl)
random_wrap_p = core.Primitive('random_wrap')
ad.defjvp_zero(random_wrap_p)
@random_wrap_p.def_abstract_eval
def random_wrap_abstract_eval(base_arr_aval, *, impl):
shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape)
return keys_shaped_array(impl, shape)
@random_wrap_p.def_impl
def random_wrap_impl(base_arr, *, impl):
return PRNGKeyArrayImpl(impl, base_arr)
def random_wrap_lowering(ctx, base_arr, *, impl):
return [base_arr]
def random_wrap_batch_rule(batched_args, batch_dims, *, impl):
x, = batched_args
d, = batch_dims
x = batching.bdim_at_front(x, d, 1)
return random_wrap(x, impl=impl), 0
mlir.register_lowering(random_wrap_p, random_wrap_lowering)
batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
def random_unwrap(keys):
if not jnp.issubdtype(keys.dtype, dtypes.prng_key):
raise TypeError(f'random_unwrap takes key array operand, got {keys.dtype=}')
return random_unwrap_p.bind(keys)
random_unwrap_p = core.Primitive('random_unwrap')
ad.defjvp_zero(random_unwrap_p)
batching.defvectorized(random_unwrap_p)
@random_unwrap_p.def_abstract_eval
def random_unwrap_abstract_eval(keys_aval):
return core.physical_aval(keys_aval)
@random_unwrap_p.def_impl
def random_unwrap_impl(keys):
return keys._base_array
def random_unwrap_lowering(ctx, keys):
return [keys]
mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
# -- threefry2x32 PRNG implementation
def _is_threefry_prng_key(key: typing.Array) -> bool:
try:
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False
def threefry_seed(seed: typing.Array) -> typing.Array:
"""Create a single raw threefry PRNG key from an integer seed.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
Returns:
The PRNG key contents, modeled as an array of shape (2,) and dtype
uint32. The key is constructed from a 64-bit seed by effectively
bit-casting to a pair of uint32 values (or from a 32-bit seed by
first padding out with zeros).
"""
return _threefry_seed(seed)
@partial(jit, inline=True)
def _threefry_seed(seed: typing.Array) -> typing.Array:
if seed.shape:
raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
if not np.issubdtype(seed.dtype, np.integer):
raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
convert = lambda k: lax.expand_dims(lax.convert_element_type(k, np.uint32), [0])
k1 = convert(
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
with config.numpy_dtype_promotion('standard'):
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
# inputs. We should avoid this.
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
return lax.concatenate([k1, k2], 0)
def _make_rotate_left(dtype):
if not jnp.issubdtype(dtype, np.integer):
raise TypeError("_rotate_left only accepts integer dtypes.")
nbits = np.array(jnp.iinfo(dtype).bits, dtype)
def _rotate_left(x, d):
if lax.dtype(d) != dtype:
d = lax.convert_element_type(d, dtype)
if lax.dtype(x) != dtype:
x = lax.convert_element_type(x, dtype)
return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
return _rotate_left
### hash function and split
def _threefry2x32_abstract_eval(*args):
if any(a.dtype != jnp.uint32 for a in args):
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
.format(args))
if all(isinstance(arg, core.ShapedArray) for arg in args):
shape = lax_internal.broadcasting_shape_rule(*args)
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
else:
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
return (aval,) * 2
rotate_left = _make_rotate_left(np.uint32)