Skip to content

Commit

Permalink
Add a public facing named_scope function to allow adding to the nam…
Browse files Browse the repository at this point in the history
…e stack.
  • Loading branch information
sharadmv committed Jun 9, 2022
1 parent 3e699dd commit 289610e
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 72 deletions.
1 change: 1 addition & 0 deletions jax/__init__.py
Expand Up @@ -96,6 +96,7 @@
make_jaxpr as make_jaxpr,
mask as mask,
named_call as named_call,
named_scope as named_scope,
pmap as pmap,
process_count as process_count,
process_index as process_index,
Expand Down
57 changes: 54 additions & 3 deletions jax/_src/api.py
Expand Up @@ -29,9 +29,9 @@
import threading
import weakref
import types
from typing import (Any, Callable, Iterable, NamedTuple, Mapping, Optional,
Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable,
List)
from typing import (Any, Callable, Generator, Iterable, NamedTuple, Mapping,
Optional, Sequence, Tuple, TypeVar, Union, overload, Dict,
Hashable, List)
from typing_extensions import Literal
from warnings import warn

Expand Down Expand Up @@ -3178,6 +3178,57 @@ def named_call_f(*args, **kwargs):

return named_call_f

@contextmanager
def named_scope(
name: str,
) -> Generator[None, None, None]:
"""A context manager that adds a user specified name to the JAX name stack.
When staging out computations for just-in-time compilation to XLA (or other
backends such as TensorFlow) JAX does not, by default, preserve the names
(or other source metadata) of Python functions it encounters.
This can make debugging the staged out (and/or compiled) representation of
your program complicated because there is limited context information for each
operation being executed.
``named_scope`` tells JAX to stage the given function with additional
annotations on the underlying operations. JAX internally keeps track of these
annotations in a name stack. When the staged out program is compiled with XLA
these annotations are preserved and show up in debugging utilities like the
TensorFlow Profiler in TensorBoard. Names are also preserved when staging out
JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`.
Args:
name: The prefix to use to name all operations created within the name
scope.
Yields:
Yields ``None``, but enters a context in which `name` will be appended to
the active name stack.
Examples:
``named_scope`` can be used as a context manager inside compiled functions:
>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
... with jax.named_scope("dot_product"):
... logits = w.dot(x)
... with jax.named_scope("activation"):
... return jax.nn.relu(logits)
It can also be used as a decorator:
>>> @jax.jit
... @jax.named_scope("layer")
... def layer(w, x):
... logits = w.dot(x)
... return jax.nn.relu(logits)
"""
with source_info_util.extend_name_stack(name):
yield

def effects_barrier():
"""Waits until existing functions have completed any side-effects."""
dispatch.runtime_tokens.block_until_ready()
Expand Down

0 comments on commit 289610e

Please sign in to comment.