From 289610eb025e6d8fc11867997e74eed1f22f578c Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 25 May 2022 12:02:35 -0700 Subject: [PATCH] Add a public facing `named_scope` function to allow adding to the name stack. --- jax/__init__.py | 1 + jax/_src/api.py | 57 +++++++++++++++- tests/name_stack_test.py | 136 +++++++++++++++++++-------------------- 3 files changed, 122 insertions(+), 72 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 0adfa57a7266..aed59906ec47 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -96,6 +96,7 @@ make_jaxpr as make_jaxpr, mask as mask, named_call as named_call, + named_scope as named_scope, pmap as pmap, process_count as process_count, process_index as process_index, diff --git a/jax/_src/api.py b/jax/_src/api.py index 4b63d513099f..968a60e0b28f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -29,9 +29,9 @@ import threading import weakref import types -from typing import (Any, Callable, Iterable, NamedTuple, Mapping, Optional, - Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable, - List) +from typing import (Any, Callable, Generator, Iterable, NamedTuple, Mapping, + Optional, Sequence, Tuple, TypeVar, Union, overload, Dict, + Hashable, List) from typing_extensions import Literal from warnings import warn @@ -3178,6 +3178,57 @@ def named_call_f(*args, **kwargs): return named_call_f +@contextmanager +def named_scope( + name: str, + ) -> Generator[None, None, None]: + """A context manager that adds a user specified name to the JAX name stack. + + When staging out computations for just-in-time compilation to XLA (or other + backends such as TensorFlow) JAX does not, by default, preserve the names + (or other source metadata) of Python functions it encounters. + This can make debugging the staged out (and/or compiled) representation of + your program complicated because there is limited context information for each + operation being executed. + + ``named_scope`` tells JAX to stage the given function with additional + annotations on the underlying operations. JAX internally keeps track of these + annotations in a name stack. When the staged out program is compiled with XLA + these annotations are preserved and show up in debugging utilities like the + TensorFlow Profiler in TensorBoard. Names are also preserved when staging out + JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`. + + + Args: + name: The prefix to use to name all operations created within the name + scope. + Yields: + Yields ``None``, but enters a context in which `name` will be appended to + the active name stack. + + Examples: + ``named_scope`` can be used as a context manager inside compiled functions: + + >>> import jax + >>> + >>> @jax.jit + ... def layer(w, x): + ... with jax.named_scope("dot_product"): + ... logits = w.dot(x) + ... with jax.named_scope("activation"): + ... return jax.nn.relu(logits) + + It can also be used as a decorator: + + >>> @jax.jit + ... @jax.named_scope("layer") + ... def layer(w, x): + ... logits = w.dot(x) + ... return jax.nn.relu(logits) + """ + with source_info_util.extend_name_stack(name): + yield + def effects_barrier(): """Waits until existing functions have completed any side-effects.""" dispatch.runtime_tokens.block_until_ready() diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 4fa080c39c65..8da7228469f2 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -21,11 +21,9 @@ from jax import linear_util as lu from jax.config import config from jax._src import test_util as jtu -from jax._src import source_info_util from jax._src.lib import xla_client config.parse_flags_with_absl() -extend_name_stack = source_info_util.extend_name_stack def _get_hlo(f): def wrapped(*args, **kwargs): @@ -66,7 +64,7 @@ def f(x): def test_manual_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): return x + 1 jaxpr = jax.make_jaxpr(f)(2).jaxpr @@ -75,9 +73,9 @@ def f(x): def test_nested_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - with extend_name_stack('bar'): + with jax.named_scope('bar'): return x + 1 jaxpr = jax.make_jaxpr(f)(2).jaxpr for eqn in jaxpr.eqns: @@ -86,20 +84,20 @@ def f(x): def test_multiple_name_stack(self): def f(x): - with extend_name_stack('foo'): + with jax.named_scope('foo'): y = x + 1 - with extend_name_stack('bar'): - with extend_name_stack('baz'): + with jax.named_scope('bar'): + with jax.named_scope('baz'): return y + 1 jaxpr = jax.make_jaxpr(f)(2).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'bar/baz') def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): @lu.wrap_init - @extend_name_stack('bar') + @jax.named_scope('bar') def _f(x): return [x + 1] return core.call(_f, x)[0] @@ -111,10 +109,10 @@ def _f(x): self.assertIn('foo/jit(core_call)/bar', hlo_text) def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): @jax.jit - @extend_name_stack('bar') + @jax.named_scope('bar') def _f(x): return x + 1 return _f(x) @@ -127,10 +125,10 @@ def _f(x): self.assertIn('foo/jit(_f)/bar', hlo_text) def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') @jax.pmap def f(x): - with extend_name_stack('bar'): + with jax.named_scope('bar'): return x + 1 jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') @@ -142,27 +140,27 @@ class NameStackTransformationTest(_EnableNameStackTestCase): def test_vmap_should_transform_name_stack(self): @jax.vmap def f(x): - with extend_name_stack('foo'): + with jax.named_scope('foo'): return x + 1 jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)') def test_vmap_should_transform_inner_name_stacks(self): - @extend_name_stack('foo') + @jax.named_scope('foo') @jax.vmap def f(x): - with extend_name_stack('bar'): - with extend_name_stack('baz'): + with jax.named_scope('bar'): + with jax.named_scope('baz'): return x + 1 jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/vmap(bar)/vmap(baz)') def test_vmap_should_apply_to_call_jaxpr(self): - @extend_name_stack('foo') + @jax.named_scope('foo') @jax.vmap def f(x): @jax.jit - @extend_name_stack('bar') + @jax.named_scope('bar') def _f(x): return x + 1 return _f(x) @@ -176,10 +174,10 @@ def _f(x): def test_jvp_should_transform_stacks(self): def f(x): - with extend_name_stack('bar'): - with extend_name_stack('baz'): + with jax.named_scope('bar'): + with jax.named_scope('baz'): return jnp.square(x) - g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,))) + g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x,), (t,))) jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/jvp(bar)/jvp(baz)') @@ -187,10 +185,10 @@ def f(x): def test_jvp_should_apply_to_call_jaxpr(self): @jax.jit def f(x): - with extend_name_stack('bar'): - with extend_name_stack('baz'): + with jax.named_scope('bar'): + with jax.named_scope('baz'): return jnp.square(x) - g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,))) + g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x,), (t,))) jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') self.assertEqual( @@ -203,7 +201,7 @@ def f(x): def test_grad_should_add_jvp_and_transpose_to_name_stack(self): @jax.value_and_grad def f(x): - with extend_name_stack('foo'): + with jax.named_scope('foo'): return 2 * jnp.sin(x) jaxpr = jax.make_jaxpr(f)(1.).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') @@ -219,10 +217,10 @@ def f(x): def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self): @jax.grad - @extend_name_stack('foo') + @jax.named_scope('foo') @jax.jit def f(x): - with extend_name_stack('bar'): + with jax.named_scope('bar'): return jnp.sin(x) jaxpr = jax.make_jaxpr(f)(1.).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') @@ -246,12 +244,12 @@ class NameStackControlFlowTest(_EnableNameStackTestCase): def test_while_loop_body_should_not_have_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('bar') + @jax.named_scope('bar') def body(x): return x + 1 - @extend_name_stack('bar_cond') + @jax.named_scope('bar_cond') def cond(x): return x < 5 return lax.while_loop(cond, body, x) @@ -271,12 +269,12 @@ def cond(x): def test_vmap_of_while_loop_should_transform_name_stack(self): @jax.vmap - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('bar') + @jax.named_scope('bar') def body(x): return x + 1 - @extend_name_stack('bar_cond') + @jax.named_scope('bar_cond') def cond(x): return x < 5 return lax.while_loop(cond, body, x) @@ -295,12 +293,12 @@ def cond(x): def test_jvp_of_while_loop_transforms_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('bar') + @jax.named_scope('bar') def body(x): return x + 1. - @extend_name_stack('bar_cond') + @jax.named_scope('bar_cond') def cond(x): return x < 5. return lax.while_loop(cond, body, x) @@ -320,12 +318,12 @@ def cond(x): def test_vmap_of_jvp_of_while_loop_transforms_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('bar') + @jax.named_scope('bar') def body(x): return x + 1. - @extend_name_stack('bar_cond') + @jax.named_scope('bar_cond') def cond(x): return x < 5. return lax.while_loop(cond, body, x) @@ -349,12 +347,12 @@ def cond(x): def test_cond_body_should_not_have_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x, y): - @extend_name_stack('true') + @jax.named_scope('true') def true_fn(x): return x + 1 - @extend_name_stack('false') + @jax.named_scope('false') def false_fn(x): return x - 1 return lax.cond(y, true_fn, false_fn, x) @@ -375,13 +373,13 @@ def false_fn(x): def test_vmap_of_cond_should_transform_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') @functools.partial(jax.vmap, in_axes=(0, None)) def f(x, y): - @extend_name_stack('true') + @jax.named_scope('true') def true_fn(x): return x + 1 - @extend_name_stack('false') + @jax.named_scope('false') def false_fn(x): return x - 1 return lax.cond(y, true_fn, false_fn, x) @@ -402,12 +400,12 @@ def false_fn(x): def test_jvp_of_cond_transforms_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x, y): - @extend_name_stack('true') + @jax.named_scope('true') def true_fn(x): return x + 1 - @extend_name_stack('false') + @jax.named_scope('false') def false_fn(x): return x - 1 return lax.cond(y, true_fn, false_fn, x) @@ -429,12 +427,12 @@ def false_fn(x): def test_vmap_of_jvp_of_cond_transforms_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x, y): - @extend_name_stack('true') + @jax.named_scope('true') def true_fn(x): return x + 1 - @extend_name_stack('false') + @jax.named_scope('false') def false_fn(x): return x - 1 return lax.cond(y, true_fn, false_fn, x) @@ -460,12 +458,12 @@ def false_fn(x): def test_grad_of_cond_transforms_name_stack(self): @jax.grad - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x, y): - @extend_name_stack('true') + @jax.named_scope('true') def true_fn(x): return x * x * 2. - @extend_name_stack('false') + @jax.named_scope('false') def false_fn(x): return x / jnp.square(x) return lax.cond(y, true_fn, false_fn, x) @@ -492,12 +490,12 @@ def test_vmap_of_grad_of_cond_transforms_name_stack(self): @functools.partial(jax.vmap, in_axes=(0, None)) @jax.grad - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x, y): - @extend_name_stack('true') + @jax.named_scope('true') def true_fn(x): return x * x * 2. - @extend_name_stack('false') + @jax.named_scope('false') def false_fn(x): return x / x / 2. return lax.cond(y, true_fn, false_fn, x) @@ -522,9 +520,9 @@ def false_fn(x): def test_scan_body_should_not_have_name_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('scan_body') + @jax.named_scope('scan_body') def body(carry, x): return carry + x, carry + x return lax.scan(body, x, jnp.arange(5.)) @@ -540,9 +538,9 @@ def body(carry, x): def test_vmap_of_scan_should_transform_stack(self): @jax.vmap - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('scan_body') + @jax.named_scope('scan_body') def body(carry, x): return carry + x, carry + x return lax.scan(body, x, jnp.arange(8.)) @@ -557,9 +555,9 @@ def body(carry, x): def test_jvp_of_scan_should_transform_stack(self): - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('scan_body') + @jax.named_scope('scan_body') def body(carry, x): return carry + x, carry + x return lax.scan(body, x, jnp.arange(8.)) @@ -576,9 +574,9 @@ def body(carry, x): def test_grad_of_scan_should_transform_stack(self): @jax.value_and_grad - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('scan_body') + @jax.named_scope('scan_body') def body(carry, x): return 2 * carry * x, carry + x return lax.scan(body, x, jnp.arange(8.))[0] @@ -598,9 +596,9 @@ def test_vmap_of_grad_of_scan_should_transform_stack(self): @jax.vmap @jax.value_and_grad - @extend_name_stack('foo') + @jax.named_scope('foo') def f(x): - @extend_name_stack('scan_body') + @jax.named_scope('scan_body') def body(carry, x): return carry * x, carry + x return lax.scan(body, x, jnp.arange(8.))[0]