Skip to content

Commit

Permalink
Remove use of jax_experimental_name_stack flag in preparation for i…
Browse files Browse the repository at this point in the history
…ts deletion

PiperOrigin-RevId: 487581301
  • Loading branch information
sharadmv authored and Copybara-Service committed Nov 10, 2022
1 parent 32781b9 commit dbc0b1f
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 34 deletions.
3 changes: 1 addition & 2 deletions haiku/_src/jaxpr_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def make_module(*args, **kwargs):
seen=set(),
module=module)

if jax.config.jax_experimental_name_stack:
_name_scopes_to_modules(module)
_name_scopes_to_modules(module)

if include_module_info:
# Add haiku param and state counts for all haiku modules.
Expand Down
4 changes: 0 additions & 4 deletions haiku/_src/jaxpr_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,10 @@ class JaxprInfoTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.prev_profiler_name_scopes = config.profiler_name_scopes(enabled=True)
self.prev_experimental_namestack = jax.config.jax_experimental_name_stack
jax.config.update("jax_experimental_name_stack", True)

def tearDown(self):
super().tearDown()
config.profiler_name_scopes(self.prev_profiler_name_scopes)
jax.config.update("jax_experimental_name_stack",
self.prev_experimental_namestack)

def test_simple_expression(self):

Expand Down
10 changes: 1 addition & 9 deletions haiku/_src/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from haiku._src import base
from haiku._src import config
from haiku._src import data_structures
from haiku._src import stateful
from haiku._src import utils
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -418,18 +417,11 @@ def wrapped(self, *args, **kwargs):
module_name = getattr(self, "module_name", None)
f = functools.partial(unbound_method, self)
f = functools.partial(run_interceptors, f, method_name, self)
if jax.config.jax_experimental_name_stack and module_name:
if module_name:
local_module_name = module_name.split("/")[-1]
f = jax.named_call(f, name=local_module_name)
if method_name != "__call__":
f = jax.named_call(f, name=method_name)
elif module_name:
# TODO(lenamartens): remove this branch once jax_experimental_name_stack
# flag is removed.
cfg = config.get_config()
if cfg.profiler_name_scopes and method_name == "__call__":
local_module_name = module_name.split("/")[-1]
f = stateful.named_call(f, name=local_module_name)

out = f(*args, **kwargs)

Expand Down
20 changes: 1 addition & 19 deletions haiku/_src/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,25 +929,7 @@ def named_call_hidden_outputs(*args, **kwargs):

@functools.wraps(fun)
def wrapper(*args, **kwargs):
if jax.config.jax_experimental_name_stack:
return jax.named_call(fun, name=name)(*args, **kwargs)

side_channel = {"non_jaxtypes": [], "treedef": None}
wrapped_fun = hide_non_jaxtype_outputs(fun, side_channel)
if base.inside_transform():
wrapped_fun = thread_hk_state_in_kwargs(jax.named_call)(wrapped_fun,
name=name)
else:
wrapped_fun = jax.named_call(wrapped_fun, name=name)

jax_types = wrapped_fun(*args, **kwargs)

non_jaxtypes = side_channel["non_jaxtypes"]
out_leaves = [y if x is None else x
for x, y in zip(jax_types, non_jaxtypes)]
out = jax.tree_util.tree_unflatten(side_channel["treedef"], out_leaves)

return out
return jax.named_call(fun, name=name)(*args, **kwargs)
return wrapper


Expand Down

0 comments on commit dbc0b1f

Please sign in to comment.