forked from google/jax
-
Notifications
You must be signed in to change notification settings - Fork 2
/
metadata_test.py
125 lines (103 loc) · 4.16 KB
/
metadata_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import unittest
from absl.testing import absltest
from jax._src import test_util as jtu
import jax
from jax._src import config as jax_config
from jax._src.lib.mlir import ir
from jax import numpy as jnp
from jax import config
config.parse_flags_with_absl()
def module_to_string(module: ir.Module) -> str:
output = io.StringIO()
module.operation.print(file=output, enable_debug_info=True,
print_generic_op_form=False)
return output.getvalue()
class MetadataTest(jtu.JaxTestCase):
def test_jit_metadata(self):
hlo = module_to_string(jax.jit(jnp.sin).lower(1.).compiler_ir())
self.assertRegex(hlo, r'loc\("jit\(sin\)/jit\(main\)/sin"')
def foo(x):
return jnp.sin(x)
hlo = module_to_string(jax.jit(foo).lower(1.).compiler_ir())
self.assertRegex(hlo, r'loc\("jit\(foo\)/jit\(main\)/sin"')
@unittest.skip("TODO") # TODO(jekbradbury)
def test_nested_jit_metadata(self):
@jax.jit
def foo(x):
return jnp.sin(x)
def bar(x):
return jnp.cos(foo(x))
_ = bar(1.)
assert self.op_types[-2] == 'sin'
assert self.op_names[-2] == 'jit(foo)/sin'
assert self.op_types[-1] == 'cos'
assert self.op_names[-1] == 'cos'
_ = jax.jit(bar)(1.)
assert self.op_types[-3] == 'xla_call'
assert self.op_names[-3] == 'jit(bar)/xla_call[ backend=None\n' \
' device=None\n' \
' name=foo ]'
assert self.op_types[-2] == 'sin'
assert self.op_names[-2] == 'jit(bar)/jit(foo)/sin'
assert self.op_types[-1] == 'cos'
assert self.op_names[-1] == 'jit(bar)/cos'
def test_grad_jit_metadata(self):
@jax.jit
def foo(x):
return jnp.sin(x)
hlo = module_to_string(jax.jit(jax.grad(foo)).lower(1.).compiler_ir())
self.assertRegex(hlo, r'loc\(".*jvp\(jit\(foo\)\)/cos"')
self.assertRegex(hlo, r'loc\(".*transpose\(jvp\(jit\(foo\)\)\)/mul"')
def test_cond_metadata(self):
def true_fun(x):
return jnp.sin(x)
def false_fun(x):
return jnp.cos(x)
def f(which, x):
return jax.lax.cond(which, x, true_fun, x, false_fun)
hlo = module_to_string(jax.jit(f).lower(True, 1.).compiler_ir())
self.assertRegex(hlo, r'loc\(".*cond\[linear=\(False, False\)\]"')
self.assertRegex(hlo, r'loc\(".*cond/branch_0_fun/cos"')
self.assertRegex(hlo, r'loc\(".*cond/branch_1_fun/sin"')
def test_argmax(self):
def f(x):
return jnp.argmax(x)
hlo = module_to_string(jax.jit(f).lower(jnp.arange(8.0)).compiler_ir())
self.assertNotRegex(hlo, r'<.* at 0x[0-9a-fA-F]+>')
def test_source_file_prefix_removal(self):
def make_hlo():
return module_to_string(
jax.jit(jnp.sin).lower(jnp.arange(8.0)).compiler_ir()
)
# Sanity check
self.assertRegex(make_hlo(), r"[/\\]+tests[/\\]+metadata_test.py")
with jax_config.hlo_source_file_canonicalization_regex(r".*[\\/]+tests[/\\]+"):
hlo = make_hlo()
self.assertIn("metadata_test.py", hlo)
self.assertNotRegex(hlo, r"tests[/\\]+")
self.assertNotRegex(hlo, r"[/\\]+metadata_test.py")
with jax_config.hlo_source_file_canonicalization_regex("no_match_xxx"):
hlo = make_hlo()
self.assertRegex(hlo, r"[/\\]+tests[/\\]+metadata_test.py")
with jax_config.hlo_source_file_canonicalization_regex(".*"):
hlo = make_hlo()
self.assertNotIn("test.py", hlo)
with jax_config.hlo_source_file_canonicalization_regex("test"):
hlo = make_hlo()
self.assertRegex(hlo, r"[/\\]+s[/\\]+metadata_.py")
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())