-
Notifications
You must be signed in to change notification settings - Fork 46
/
test_rng.py
89 lines (71 loc) · 2.81 KB
/
test_rng.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
""" Test random number generator.
Notes:
- Compilation is slow (about 1 minute).
- Got several other compilation issues related to dtype handling, as dtype
is currently a string. We may need to set a better support for strings
in Myia.
- Exceptions are not correctly handled with @myia decorator.
Many "illegal" errors are raised, including illegal primitives,
illegal registered types, and other things. I can fix that,
but ultimately we will still need to correctly support strings
(so that exceptions could be correctly printed in backend).
"""
import numpy as np
from myia import myia
from myia.rng_mrg import (
MyiaRandomState,
myia_increment_state,
myia_random_state,
myia_uniform,
)
def pyth_random():
state = myia_random_state('float64')
v1, state = myia_uniform(state)
v2, state = myia_uniform(state)
v3, state = myia_uniform(state)
v4, state = myia_uniform(state)
v5, state = myia_uniform(state)
v6, state = myia_uniform(state)
v7, state = myia_uniform(state)
v8, state = myia_uniform(state)
v9, state = myia_uniform(state)
v10, _ = myia_uniform(state)
return v1, v2, v3, v4, v5, v6, v7, v8, v9, v10
@myia
def myia_random():
return pyth_random()
@myia
def run_myia_random_state():
return myia_random_state('float32')
@myia
def run_myia_increment_state(s):
return myia_increment_state(s)
def test_myia_random_state():
state = run_myia_random_state()
assert isinstance(state, MyiaRandomState)
assert isinstance(state.rstate, tuple)
assert state.rstate == (np.int32(12345),) * 6
def test_myia_increment_state():
state_1 = run_myia_random_state()
state_2 = run_myia_increment_state(state_1)
state_3 = run_myia_increment_state(state_2)
state_4 = run_myia_increment_state(state_3)
assert state_2.rstate == tuple(np.int32(val) for val in (
336690377, 597094797, 1245771585, 85196284, 523477687, 2094976052))
assert state_3.rstate == tuple(np.int32(val) for val in (
502033783, 1322587635, 1964121530, 1949818481, 1607232546, 1462898381))
assert state_4.rstate == tuple(np.int32(val) for val in (
739421137, 1475938232, 730262207, 1630192198, 324551134, 795289868))
def test_myia_random_generation():
# Hide runtime warnings about overflow in integer operations.
err_orig = np.seterr(all='ignore')
pyth_res = pyth_random()
myia_res = myia_random()
assert pyth_res == myia_res
assert pyth_res == tuple(np.float64(val) for val in (
0.7353244530968368, 0.6142074400559068, 0.11007806099951267,
0.6487741703167558, 0.36619443260133266, 0.10882294131442904,
0.5330547927878797, 0.9783797566778958, 0.9151237849146128,
0.8509745532646775))
# Get numpy back to its original warning config.
np.seterr(**err_orig)