-
Notifications
You must be signed in to change notification settings - Fork 0
/
random.py
117 lines (86 loc) · 2.89 KB
/
random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
r"""Extended utilities for random number generation"""
__all__ = [
'PRNG',
'set_rng',
'get_rng',
]
import jax
from contextlib import contextmanager
from jax.random import KeyArray
from typing import *
from .debug import same_trace
from .tree_util import Namespace
class PRNG(Namespace):
r"""Creates a pseudo-random number generator (PRNG).
This class is a thin wrapper around the :mod:`jax.random` module, and allows to
generate new PRNG keys or sample from distributions without having to split keys
with :func:`jax.random.split` by hand.
Arguments:
seed: An integer seed or PRNG key.
kwargs: Keyword arguments passed to :func:`jax.random.PRNGKey`.
Example:
>>> rng = PRNG(42)
>>> rng.split() # generates a key
Array([2465931498, 3679230171], dtype=uint32)
>>> rng.split(3) # generates a vector of 3 keys
Array([[ 956272045, 3465119146],
[1903583750, 988321301],
[3226638877, 2833683589]], dtype=uint32)
>>> rng.normal((5,))
Array([ 0.5694761 , -1.4582146 , 0.2309113 , -0.03029377, 0.11095619], dtype=float32)
"""
def __init__(self, seed: Union[int, KeyArray], **kwargs):
if isinstance(seed, int):
self.state = jax.random.PRNGKey(seed, **kwargs)
else:
self.state = seed
def __getattr__(self, name: str) -> Any:
attr = getattr(jax.random, name)
if callable(attr):
return lambda *args, **kwargs: attr(self.split(), *args, **kwargs)
else:
return attr
def split(self, num: int = None) -> KeyArray:
r"""
Arguments:
num: The number of keys to generate.
Returns:
A new key if :py:`num=None` and a vector of keys otherwise.
"""
if num is None:
keys = jax.random.split(self.state, num=2)
else:
keys = jax.random.split(self.state, num=num + 1)
assert same_trace(self.state, keys), "the PRNG was initialized and used within different JIT traces."
if num is None:
key, self.state = keys
else:
key, self.state = keys[:-1], keys[-1]
return key
INOX_RNG: PRNG = None
@contextmanager
def set_rng(rng: PRNG):
r"""Sets the PRNG within a context.
See also:
:class:`PRNG` and :func:`get_rng`
Arguments:
rng: A PRNG instance.
Example:
>>> with set_rng(PRNG(0)):
>>> ... a = get_rng().split()
>>> ... b = get_rng().normal((2, 3))
"""
global INOX_RNG
try:
old, INOX_RNG = INOX_RNG, rng
yield
finally:
INOX_RNG = old
def get_rng() -> PRNG:
r"""Returns the context-bound PRNG.
See also:
:class:`PRNG` and :func:`set_rng`
"""
global INOX_RNG
assert INOX_RNG is not None, "no PRNG is set in this context."
return INOX_RNG