From 1b79caa6bd87fcb5f1c0058427522c6ec22f6469 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 28 Oct 2021 11:06:58 -0700 Subject: [PATCH] Add separate mechanism for threading name stacks to the lowering --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/api.py | 13 +- jax/_src/config.py | 6 +- jax/_src/custom_derivatives.py | 2 +- jax/_src/dispatch.py | 2 +- jax/_src/lax/control_flow.py | 26 +- jax/_src/lax/lax.py | 4 +- jax/_src/source_info_util.py | 114 +++++- jax/_src/util.py | 16 +- jax/core.py | 50 ++- jax/experimental/djax.py | 5 +- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/loops.py | 6 +- jax/experimental/maps.py | 2 +- jax/experimental/pjit.py | 2 +- jax/interpreters/ad.py | 89 +++-- jax/interpreters/batching.py | 10 +- jax/interpreters/mlir.py | 11 +- jax/interpreters/partial_eval.py | 47 ++- jax/interpreters/pxla.py | 8 +- jax/interpreters/sharded_jit.py | 8 +- jax/interpreters/xla.py | 34 +- tests/api_test.py | 8 +- tests/name_stack_test.py | 612 ++++++++++++++++++++++++++++++ 24 files changed, 937 insertions(+), 142 deletions(-) create mode 100644 tests/name_stack_test.py diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 5a58352da2e1..b3f19afeefd5 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -353,7 +353,7 @@ def transposed(*args): primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars] - in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, consts, dummy_args, + in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts, cell.treedef = tree_flatten(in_cts_) return in_cts diff --git a/jax/_src/api.py b/jax/_src/api.py index a582902c543f..2427ab37b12d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -56,10 +56,11 @@ Partial, PyTreeDef, all_leaves) from jax._src.tree_util import broadcast_prefix from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list, - extend_name_stack, wrap_name, cache, wraps, + extend_name_stack, new_name_stack, wrap_name, cache, wraps, HashableFunction) from jax._src import device_array from jax._src import dispatch +from jax._src import source_info_util from jax._src.lib import jax_jit from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc @@ -895,9 +896,8 @@ def computation_maker(*args, **kwargs): should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100) xla_args, donated_invars = xla._xla_callable_args( c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars) - ctx = xla.TranslationContext( - c, backend, axis_env_, - extend_name_stack(wrap_name(fun_name, "xla_computation"))) + name_stack = new_name_stack(wrap_name(fun_name, "xla_computation")) + ctx = xla.TranslationContext(c, backend, axis_env_, name_stack) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) build_out_tuple = partial(xc.ops.Tuple, c, out_nodes) if out_parts is not None: @@ -2615,7 +2615,7 @@ def transposed_fun(consts, out_cotangent): dummies = [ad.UndefinedPrimal(a) for a in in_avals] in_cotangents = map( ad.instantiate_zeros, - ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents)) + ad.backward_pass(jaxpr, reduce_axes, True, consts, dummies, out_cotangents)) return tree_unflatten(in_tree, in_cotangents) # Ensure that transposed_fun is a PyTree @@ -3197,6 +3197,9 @@ def named_call( _, in_tree = tree_flatten(()) + if config.jax_experimental_name_stack: + return source_info_util.extend_name_stack(name)(fun) + @functools.wraps(fun) def named_call_f(*args, **kwargs): lu_f = lu.wrap_init(lambda: fun(*args, **kwargs)) diff --git a/jax/_src/config.py b/jax/_src/config.py index ecfeae9f24e1..e4a189b1abfb 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -140,7 +140,6 @@ def config_with_absl(self): for name, val in self.values.items(): flag_type, meta_args, meta_kwargs = self.meta[name] absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) - app.call_after_init(lambda: self.complete_absl_config(absl_flags)) def complete_absl_config(self, absl_flags): @@ -688,6 +687,11 @@ def _update_disable_jit_thread_local(val): help=('Enables experimental features for staging out computations with ' 'dynamic shapes.')) +config.define_bool_state( + name='jax_experimental_name_stack', + default=False, + help='Enable using the context manager-based name stack.') + # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. config.define_bool_state( diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e60110e66e8a..18ed2abca077 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -407,7 +407,7 @@ def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr, jvp_jaxpr_thunk, num_consts): del jvp_jaxpr_thunk, num_consts return ad.backward_pass( - fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts) + fun_jaxpr.jaxpr, reduce_axes, False, fun_jaxpr.consts, args, cts) ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose def custom_jvp_jaxpr_custom_partial_eval_rule( diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 7ec760c32cfa..2cb1892a72e9 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -250,7 +250,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, # pass long arg lists as tuple for TPU tuple_args = len(abstract_args) > 100 axis_env = xla.AxisEnv(nreps, (), ()) - name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit')) + name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module: Union[str, xc.XlaComputation] module_name = f"jit_{fun.__name__}" diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index a0006f8cc6be..50c89e6042ed 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -347,8 +347,9 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr, cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry)) cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))] x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) + name_stack = extend_name_stack(ctx.name_stack, 'while') cond_ctx = ctx.replace(builder=cond_c, - name_stack=extend_name_stack(ctx.name_stack, 'cond')) + name_stack=extend_name_stack(name_stack, 'cond')) pred, = xla.jaxpr_subcomp( cond_ctx, cond_jaxpr.jaxpr, _map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts), @@ -365,14 +366,14 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr, body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))] x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) body_ctx = ctx.replace(builder=body_c, - name_stack=extend_name_stack(ctx.name_stack, 'body')) + name_stack=extend_name_stack(name_stack, 'body')) new_z = xla.jaxpr_subcomp( body_ctx, body_jaxpr.jaxpr, _map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts), *(y + z)) if batched: body_pred_ctx = body_ctx.replace( - name_stack=extend_name_stack(ctx.name_stack, 'body_pred')) + name_stack=extend_name_stack(name_stack, 'body_pred')) body_pred, = xla.jaxpr_subcomp( body_pred_ctx, cond_jaxpr.jaxpr, _map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts), @@ -1201,9 +1202,11 @@ def _cond_partial_eval(trace, *tracers, branches, linear): linear_2 = (False,) * num_res + linear params = dict(branches=branches_2, linear=linear_2) + name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] + source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, - source_info_util.current()) + source) for t in out_tracers: t.recipe = eqn return out_tracers @@ -1297,7 +1300,7 @@ def transposed(*args): res, cts_out = split_list(args, [num_res]) primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] cts_in = ad.backward_pass( - jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out) + jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out) _, cts_in = split_list(cts_in, [num_res]) return _map(ad.instantiate_zeros_aval, primal_avals, cts_in) @@ -1924,9 +1927,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])] other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]] in_pvals_1 = invariant_pvals + other_pvals - jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr( - lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1, - instantiate=[True] * (num_carry + num_ys) + [False] * num_res) + with source_info_util.reset_name_stack(): + jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr( + lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1, + instantiate=[True] * (num_carry + num_ys) + [False] * num_res) jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ()) num_consts_1 = num_consts + len(consts_1) # any now-known residuals are intensive, so we want to revise jaxpr_2 to take @@ -1990,6 +1994,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) for pv, const in zip(out_pvs, out_consts)] + name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] + source = source_info_util.current().replace(name_stack=name_stack) linear_2 = ([False] * len(int_res_tracers) + [lin or not uk for uk, lin in zip(unknowns, linear)] + [False] * len(ext_res_tracers)) @@ -1999,7 +2005,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, num_consts=num_consts_2, num_carry=num_carry, linear=tuple(linear_2), unroll=unroll), - source_info_util.current()) + source) for t in out_tracers: t.recipe = eqn return out_tracers @@ -2068,7 +2074,7 @@ def transposed(*res1_cbar_bbar_res2): res1_cbar_bbar_res2, [num_res1, num_c, num_b]) primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) - cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts, + cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, b_bar) _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c93eecf9594d..4ad857fee0a5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -50,7 +50,7 @@ import jax._src.pretty_printer as pp from jax._src import util from jax._src.util import (cache, safe_zip, prod, safe_map, canonicalize_axis, - split_list) + split_list, new_name_stack) from jax.tree_util import tree_map import jax._src.lib from jax._src.lib import pytree @@ -3424,7 +3424,7 @@ def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True): subc = xc.XlaBuilder("reduction_computation") assert len(consts) == 0, "Reduction computations cannot have constants" args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)] - ctx = xla.TranslationContext(subc, platform, axis_env, '') + ctx = xla.TranslationContext(subc, platform, axis_env, new_name_stack()) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args) if singleton: return subc.build(out_nodes[0]) diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index 48b4e9883593..b18e2af745d8 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -13,12 +13,13 @@ # limitations under the License. import contextlib +import dataclasses import functools import itertools import os.path import threading import types -from typing import Optional, Iterator, NamedTuple +from typing import Optional, Iterator, NamedTuple, Union, Tuple import jax.version from jax._src.lib import xla_client, xla_extension_version @@ -40,15 +41,66 @@ class Frame(NamedTuple): def register_exclusion(path): _exclude_paths.append(path) +class Scope(NamedTuple): + name: str + + def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]: + return (self.name, *stack) + +class Transform(NamedTuple): + name: str + + def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]: + return tuple(map(lambda x: f'{self.name}({x})', stack)) + +@dataclasses.dataclass(frozen=True) +class NameStack: + stack: Tuple[Union[Scope, Transform], ...] = () + + def extend(self, name: Union[Tuple[str, ...], str]) -> 'NameStack': + if not isinstance(name, tuple): + name = (name,) + scopes = tuple(map(Scope, name)) + return NameStack(self.stack + scopes) + + def wrap_name(self, name: str) -> str: + if not self.stack: + return name + return f'{str(self)}/{name}' + + def transform(self, transform_name: str) -> 'NameStack': + return NameStack((*self.stack, Transform(transform_name))) + + def __getitem__(self, idx) -> 'NameStack': + return NameStack(self.stack[idx]) + + def __len__(self): + return len(self.stack) + + def __add__(self, other: 'NameStack') -> 'NameStack': + return NameStack(self.stack + other.stack) + + def __radd__(self, other: 'NameStack') -> 'NameStack': + return NameStack(other.stack + self.stack) + + def __str__(self) -> str: + scope: Tuple[str, ...] = () + for elem in self.stack[::-1]: + scope = elem.wrap(scope) + return '/'.join(scope) + class SourceInfo(NamedTuple): traceback: Optional[Traceback] + name_stack: NameStack - def replace(self, *, traceback: Optional[Traceback] = None) -> 'SourceInfo': + def replace(self, *, traceback: Optional[Traceback] = None, + name_stack: Optional[NameStack] = None) -> 'SourceInfo': traceback = traceback or self.traceback - return self._replace(traceback=traceback) + name_stack = self.name_stack if name_stack is None else name_stack + return self._replace(traceback=traceback, name_stack=name_stack) def new_source_info() -> SourceInfo: - return SourceInfo(None) + return SourceInfo(None, NameStack()) def is_user_filename(filename: str) -> bool: """Heuristic that guesses the identity of the user's code in a stack trace.""" @@ -97,11 +149,10 @@ def __init__(self): _source_info_context = _SourceInfoContext() def current() -> SourceInfo: - context = _source_info_context.context - if not context.traceback: - return context.replace(traceback=xla_client.Traceback.get_traceback()) - return context - + source_info = _source_info_context.context + if not source_info.traceback: + source_info = source_info.replace(traceback=xla_client.Traceback.get_traceback()) + return source_info class JaxStackTraceBeforeTransformation(Exception): pass @@ -118,9 +169,10 @@ def has_user_context(e): return False @contextlib.contextmanager -def user_context(c: Optional[Traceback]): +def user_context(c: Optional[Traceback], *, name_stack: Optional[NameStack] = None): prev = _source_info_context.context - _source_info_context.context = _source_info_context.context.replace(traceback=c) + _source_info_context.context = _source_info_context.context.replace( + traceback=c, name_stack=name_stack) filtered_tb = None try: yield @@ -141,3 +193,43 @@ def user_context(c: Optional[Traceback]): finally: _source_info_context.context = prev del filtered_tb + +def current_name_stack() -> NameStack: + return _source_info_context.context.name_stack + +@contextlib.contextmanager +def extend_name_stack(name: str) -> Iterator[NameStack]: + prev_context = _source_info_context.context + curr_name_stack = prev_context.name_stack + new_context = prev_context.replace(name_stack=curr_name_stack.extend(name)) + _source_info_context.context = new_context + try: + yield _source_info_context.context.name_stack + finally: + _source_info_context.context = prev_context + +@contextlib.contextmanager +def set_name_stack(name_stack: NameStack) -> Iterator[None]: + prev_context = _source_info_context.context + new_context = prev_context.replace(name_stack=name_stack) + _source_info_context.context = new_context + try: + yield + finally: + _source_info_context.context = prev_context + +@contextlib.contextmanager +def reset_name_stack() -> Iterator[None]: + with set_name_stack(NameStack()): + yield + +@contextlib.contextmanager +def transform_name_stack(name: str) -> Iterator[NameStack]: + prev_context = _source_info_context.context + curr_name_stack = prev_context.name_stack + new_context = prev_context.replace(name_stack=curr_name_stack.transform(name)) + _source_info_context.context = new_context + try: + yield _source_info_context.context.name_stack + finally: + _source_info_context.context = prev_context diff --git a/jax/_src/util.py b/jax/_src/util.py index ef54dd767103..7ea747a3ecd4 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -277,7 +277,21 @@ def get_module_functions(module): def wrap_name(name, transform_name): return transform_name + '(' + name + ')' -def extend_name_stack(stack, name=''): +def new_name_stack(name: str = ''): + if config.jax_experimental_name_stack: + from jax._src import source_info_util + name_stack = source_info_util.NameStack() + if name: + name_stack = name_stack.extend(name) + return name_stack + return name + '/' + +def extend_name_stack(stack, name: str): + if config.jax_experimental_name_stack: + from jax._src import source_info_util + assert isinstance(stack, source_info_util.NameStack), stack + return stack.extend(name) + assert isinstance(stack, str) return stack + name + '/' def canonicalize_axis(axis, num_dims) -> int: diff --git a/jax/core.py b/jax/core.py index 625a97cb4ddd..cf835458e43f 100644 --- a/jax/core.py +++ b/jax/core.py @@ -81,10 +81,11 @@ def __str__(self): __repr__ = __str__ def pretty_print(self, *, source_info=False, print_shapes=True, - custom_pp_eqn_rules=True, **kw): + custom_pp_eqn_rules=True, name_stack=False, **kw): doc = pp_jaxpr(self, JaxprPpContext(), source_info=source_info, print_shapes=print_shapes, - custom_pp_eqn_rules=custom_pp_eqn_rules) + custom_pp_eqn_rules=custom_pp_eqn_rules, + name_stack=name_stack) return doc.format(**kw) def _repr_pretty_(self, p, cycle): @@ -141,9 +142,10 @@ def map_jaxpr(self, f): def __str__(self): return str(self.jaxpr) def __repr__(self): return repr(self.jaxpr) - def pretty_print(self, *, source_info=False, print_shapes=True, **kw): + def pretty_print(self, *, source_info=False, print_shapes=True, + name_stack=False, **kw): return pp_jaxpr(self.jaxpr, JaxprPpContext(), source_info=source_info, - print_shapes=print_shapes).format(**kw) + print_shapes=print_shapes, name_stack=name_stack).format(**kw) def _repr_pretty_(self, p, cycle): @@ -333,7 +335,8 @@ def write(v, val): map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - with source_info_util.user_context(eqn.source_info.traceback): + name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack + with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params) if eqn.primitive.multiple_results: map(write, eqn.outvars, ans) @@ -2272,50 +2275,52 @@ def pp_vars(vs: Sequence[Any], context: JaxprPpContext, [pp.text(pp_var(v, context)) for v in vs]) )) -def pp_kv_pair(k:str, v: Any, context: JaxprPpContext) -> pp.Doc: +def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc: if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v): - pp_v = pp_jaxprs(v, context) + pp_v = pp_jaxprs(v, context, name_stack=name_stack) elif isinstance(v, Jaxpr): - pp_v = pp_jaxpr(v, context) + pp_v = pp_jaxpr(v, context, name_stack=name_stack) elif isinstance(v, ClosedJaxpr): - pp_v = pp_jaxpr(v.jaxpr, context) + pp_v = pp_jaxpr(v.jaxpr, context, name_stack=name_stack) else: pp_v = pp.text(str(v)) return pp.text(f'{k}=') + pp_v -def pp_kv_pairs(kv_pairs, context: JaxprPpContext) -> pp.Doc: +def pp_kv_pairs(kv_pairs, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc: if not kv_pairs: return pp.nil() return pp.group( pp.nest(2, pp.concat([ pp.text("["), pp.brk(""), - pp.join(pp.brk(), [pp_kv_pair(k, v, context) for k, v in kv_pairs]) + pp.join(pp.brk(), [pp_kv_pair(k, v, context, name_stack=name_stack) for k, v in kv_pairs]) ])) + pp.brk("") + pp.text("]") ) def pp_eqn(eqn, context: JaxprPpContext, *, print_shapes=True, - source_info=False, custom_pp_eqn_rules=True) -> pp.Doc: + source_info=False, custom_pp_eqn_rules=True, name_stack=False) -> pp.Doc: lhs = pp_vars(eqn.outvars, context, print_shapes=print_shapes) annotation = (source_info_util.summarize(eqn.source_info) if source_info else None) rule = pp_eqn_rules.get(eqn.primitive) + name_stack_annotation = f'[{eqn.source_info.name_stack}]' if name_stack else None if rule and custom_pp_eqn_rules: rhs = rule(eqn, context) else: - rhs = [pp.text(eqn.primitive.name), - pp_kv_pairs(sorted(eqn.params.items()), context), + rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation), + pp_kv_pairs(sorted(eqn.params.items()), context, name_stack=name_stack), pp.text(" ") + pp_vars(eqn.invars, context)] return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs]) CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext], Sequence[pp.Doc]] pp_eqn_rules: Dict[Primitive, CustomPpEqnRule] = {} def pp_eqns(eqns, context: JaxprPpContext, *, print_shapes=True, - source_info=False, custom_pp_eqn_rules=True + source_info=False, custom_pp_eqn_rules=True, name_stack=False, ) -> pp.Doc: return pp.join( pp.brk("; "), [pp_eqn(e, context, print_shapes=print_shapes, source_info=source_info, + name_stack=name_stack, custom_pp_eqn_rules=custom_pp_eqn_rules) for e in eqns]) def _compact_eqn_should_include(k: str, v: Any) -> bool: @@ -2349,23 +2354,25 @@ def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext, *, def pp_jaxpr(jaxpr, context: JaxprPpContext, *, print_shapes=True, - source_info=False, custom_pp_eqn_rules=True) -> pp.Doc: + source_info=False, custom_pp_eqn_rules=True, name_stack=False) -> pp.Doc: eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, print_shapes=print_shapes, source_info=source_info, - custom_pp_eqn_rules=custom_pp_eqn_rules) + custom_pp_eqn_rules=custom_pp_eqn_rules, + name_stack=name_stack) return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, print_shapes=print_shapes) -def pp_jaxprs(jaxprs, context: JaxprPpContext) -> pp.Doc: +def pp_jaxprs(jaxprs, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc: jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] return pp.group(pp.nest(2, pp.concat([ pp.text('('), pp.brk(""), - pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context), jaxprs))] + pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, name_stack=name_stack), jaxprs))] )) + pp.brk("") + pp.text(')') ) def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext, - print_shapes=True, source_info: bool = False) -> pp.Doc: + print_shapes=True, source_info: bool = False, + name_stack: bool = False) -> pp.Doc: lo = max(lo, 0) hi = max(lo, min(hi, len(jaxpr.eqns))) eqns = jaxpr.eqns[lo:hi] @@ -2377,7 +2384,8 @@ def eqns_fn(): if lo != 0: pps.append(pp.text('...')) pps.extend(map((lambda e: pp_eqn(e, context, print_shapes=print_shapes, - source_info=source_info)), eqns)) + source_info=source_info, + name_stack=name_stack)), eqns)) if hi != len(jaxpr.eqns): pps.append(pp.text('...')) return pp.join(pp.brk("; "), pps) diff --git a/jax/experimental/djax.py b/jax/experimental/djax.py index fe88991d808d..9561567b087b 100644 --- a/jax/experimental/djax.py +++ b/jax/experimental/djax.py @@ -24,7 +24,7 @@ from jax.core import Var, Literal, Atom, Tracer from jax._src import util from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list, - tuple_delete) + tuple_delete, new_name_stack) import jax._src.pretty_printer as pp map = safe_map @@ -806,7 +806,8 @@ def fun(*args): operands_ = it.chain.from_iterable([*dims.values(), *operands]) platform = "cpu" # TODO: don't hardwire in the CPU translation. - ctx = xla.TranslationContext(c, platform, xla.AxisEnv(1, (), ()), '') + ctx = xla.TranslationContext(c, platform, xla.AxisEnv(1, (), ()), + new_name_stack()) outs = xla.jaxpr_subcomp(ctx, jaxpr, xla._xla_consts(c, consts), *operands_) return util.unflatten(outs, [aval_to_num_buffers(aval) for aval in out_avals]) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index e98b86cf1c9d..e1485f71fe78 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -261,7 +261,7 @@ def convert(fun: Callable, """ api._check_callable(fun) fun_name = getattr(fun, "__name__", "unknown") - name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf")) + name_stack = util.wrap_name(fun_name, "jax2tf") + "/" def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal: # TODO: is there a better way to check if we are inside a transformation? if not core.trace_state_clean() and not _thread_local_state.inside_call_tf: diff --git a/jax/experimental/loops.py b/jax/experimental/loops.py index e6aa11c5a2e3..d8941942d9c4 100644 --- a/jax/experimental/loops.py +++ b/jax/experimental/loops.py @@ -117,6 +117,7 @@ def loop_body(i, acc_arr): from jax import tree_util from jax.errors import UnexpectedTracerError from jax.interpreters import partial_eval as pe +from jax._src import source_info_util from jax._src.util import safe_map @@ -291,10 +292,11 @@ def start_subtrace(self): """Starts a nested trace, returns the Trace object.""" # TODO: This follows the __enter__ part of core.new_main. level = core.thread_local_state.trace_state.trace_stack.next_level() - main = core.MainTrace(level, pe.JaxprTrace) + name_stack = source_info_util.current_name_stack() + main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack) core.thread_local_state.trace_state.trace_stack.push(main) self._count_subtraces += 1 - return pe.JaxprTrace(main, core.cur_sublevel()) + return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack) def end_subtrace(self): # TODO: This follows the __exit__ part of core.new_main diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 0f2049841251..68bbda577f09 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -910,7 +910,7 @@ def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes) all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts fun = lu.hashable_partial( lu.wrap_init(ad.backward_pass), - call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys())) + call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()), False) fun, nz_arg_cts = ad.nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 01e96ec6e73b..385418db8975 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -846,7 +846,7 @@ def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if not type(mz) is ty) body = lu.wrap_init(ad.closed_backward_pass) - body = lu.hashable_partial(body, jaxpr, reduce_axes) + body = lu.hashable_partial(body, jaxpr, reduce_axes, False) primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in)) body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 385707a47671..f0f68fef0dc5 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import functools from functools import partial import itertools as it @@ -40,19 +40,22 @@ def identity(x): return x -def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any: +def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, + transform_stack=True) -> Any: if not has_aux: - return jvpfun(jvp_subtrace(fun), instantiate) + return jvpfun(jvp_subtrace(fun), instantiate, transform_stack) else: fun, aux = jvp_subtrace_aux(fun) - return jvpfun(fun, instantiate), aux + return jvpfun(fun, instantiate, transform_stack), aux @lu.transformation -def jvpfun(instantiate, primals, tangents): +def jvpfun(instantiate, transform_stack, primals, tangents): tangents = [Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents] - with core.new_main(JVPTrace) as main: + ctx = (source_info_util.transform_name_stack('jvp') if transform_stack + else contextlib.nullcontext()) + with core.new_main(JVPTrace) as main, ctx: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: @@ -120,7 +123,7 @@ def vjp(traceable, primals, has_aux=False, reduce_axes=()): def unbound_vjp(pvals, jaxpr, consts, *cts): cts = tuple(map(ignore_consts, cts, pvals)) dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars] - arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts) + arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts) return map(instantiate_zeros, arg_cts) # Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward @@ -162,7 +165,7 @@ def recast_to_float0(primal, tangent): return tangent # NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will) -def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents_in): +def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) @@ -207,36 +210,40 @@ def write_primal(v, val): map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} - map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) - for eqn in jaxpr.eqns[::-1]: - # FIXME: Some invars correspond to tangents - invals = map(read_primal, eqn.invars) - if eqn.primitive.multiple_results: - cts_in = map(read_cotangent, eqn.outvars) - else: - cts_in, = map(read_cotangent, eqn.outvars) - with source_info_util.user_context(eqn.source_info.traceback): - if eqn.primitive.call_primitive or eqn.primitive.map_primitive: - cts_in_avals = [v.aval for v in eqn.outvars] - params = dict(eqn.params) - call_jaxpr = params.pop('call_jaxpr') - cts_out = get_primitive_transpose(eqn.primitive)( - params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes) - elif eqn.primitive in reducing_transposes: - cts_out = reducing_transposes[eqn.primitive]( - reduce_axes, cts_in, *invals, **eqn.params) + ctx = (source_info_util.transform_name_stack('transpose') if transform_stack + else contextlib.nullcontext()) + with ctx: + map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) + for eqn in jaxpr.eqns[::-1]: + # FIXME: Some invars correspond to tangents + invals = map(read_primal, eqn.invars) + if eqn.primitive.multiple_results: + cts_in = map(read_cotangent, eqn.outvars) else: - cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, - **eqn.params) - cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out - # FIXME: Some invars correspond to primals! - map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) + cts_in, = map(read_cotangent, eqn.outvars) + name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack + with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): + if eqn.primitive.call_primitive or eqn.primitive.map_primitive: + cts_in_avals = [v.aval for v in eqn.outvars] + params = dict(eqn.params) + call_jaxpr = params.pop('call_jaxpr') + cts_out = get_primitive_transpose(eqn.primitive)( + params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes) + elif eqn.primitive in reducing_transposes: + cts_out = reducing_transposes[eqn.primitive]( + reduce_axes, cts_in, *invals, **eqn.params) + else: + cts_out = get_primitive_transpose(eqn.primitive)( + cts_in, *invals, **eqn.params) + cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out + # FIXME: Some invars correspond to primals! + map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out -def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, primals_in, cotangents_in): - return backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals_in, cotangents_in) +def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack, primals_in, cotangents_in): + return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts, primals_in, cotangents_in) class UndefinedPrimal: @@ -297,7 +304,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) nonzero_tangents, tangent_tree_def = tree_flatten(tangents) nz_tangents = [type(t) is not Zero for t in tangents] - if 'name' in params: + if 'name' in params and not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'jvp')) f_jvp = jvp_subtrace(f, self.main) f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp) @@ -547,9 +554,12 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents): def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts - fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes) + fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) - new_params = dict(params, name=wrap_name(params['name'], 'transpose')) + if config.jax_experimental_name_stack: + new_params = params + else: + new_params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, map(is_undefined_primal, args), @@ -575,7 +585,7 @@ def do_transpose(primals_in, cotangents_in): residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):] # Now that we have a purely linear jaxpr, we can transpose it cotangents_out = backward_pass( - tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in) + tangent_jaxpr.jaxpr, reduce_axes, False, (), primals_in + residuals, cotangents_in) # backward_pass will return cotangents computed for all invars, but some of them # are residuals appended by partial eval, so we need to skip those before we return. return cotangents_out[:len(primals_in)] @@ -594,7 +604,7 @@ def nonzero_outputs(*args, **kwargs): def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts - fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes) + fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). @@ -642,7 +652,8 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate): def _jvp_jaxpr(jaxpr, nonzeros, instantiate): assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) - f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros) + f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), + nonzeros) tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index b30bd4206686..4f1c54e16f71 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -22,6 +22,7 @@ from jax.config import config from jax import core from jax.core import raise_to_shaped, Trace, Tracer +from jax._src import source_info_util from jax._src.tree_util import tree_unflatten, tree_flatten from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p, Zero) @@ -29,7 +30,6 @@ from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, cache) -from jax._src import source_info_util from jax.interpreters import partial_eval as pe map = safe_map @@ -205,7 +205,10 @@ def process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results - params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) + if config.jax_experimental_name_stack: + params = dict(params, name=params.get('name', f.__name__)) + else: + params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) @@ -372,7 +375,8 @@ def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size, def _batch_outer(axis_name, axis_size, in_dims, main_type, *in_vals): with core.new_main(main_type, axis_name=axis_name) as main: with core.extend_axis_env(axis_name, axis_size, main): - outs = yield (main, in_dims, *in_vals), {} + with source_info_util.transform_name_stack('vmap'): + outs = yield (main, in_dims, *in_vals), {} del main yield outs diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 4c3b0768281e..7022cba505e4 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -266,8 +266,8 @@ def _device_array_constant_handler(val, canonicalize_types): def _source_info_to_location( primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, - name_stack: str = "") -> ir.Location: - eqn_str = name_stack + core.str_eqn_compact(primitive.name, params) + name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location: + eqn_str = str(name_stack) + core.str_eqn_compact(primitive.name, params) frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() @@ -280,6 +280,7 @@ def _source_info_to_location( # Translation rules +NameStack = Union[str, source_info_util.NameStack] def make_ir_context() -> ir.Context: """Creates an MLIR context suitable for JAX IR.""" @@ -334,7 +335,7 @@ class ModuleContext: symbol_table: ir.SymbolTable platform: str axis_context: AxisContext - name_stack: str + name_stack: NameStack # Cached primitive lowerings. cached_primitive_lowerings: Dict[Any, builtin.FuncOp] @@ -344,7 +345,7 @@ def axis_env(self) -> xla.AxisEnv: return self.axis_context.axis_env def __init__( - self, platform: str, axis_context: AxisContext, name_stack: str, + self, platform: str, axis_context: AxisContext, name_stack: NameStack, context: Optional[ir.Context] = None, module: Optional[ir.Module] = None, ip: Optional[ir.InsertionPoint] = None, @@ -412,7 +413,7 @@ def flatten_lowering_ir_args( def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_context: AxisContext, - name_stack: str, donated_args: Sequence[bool], + name_stack: NameStack, donated_args: Sequence[bool], replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 55697615ec00..e60f7406ab3f 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -97,6 +97,11 @@ def merge_with_known(self, val: core.Value) -> core.Value: class JaxprTrace(Trace): + + def __init__(self, *args, name_stack: source_info_util.NameStack): + super().__init__(*args) + self.name_stack = name_stack + def pure(self, val) -> 'JaxprTracer': return self.new_const(val) @@ -163,7 +168,8 @@ def default_process_primitive(self, primitive, tracers, params): tracers = map(self.instantiate_const, tracers) avals = [t.aval for t in tracers] out_aval = primitive.abstract_eval(*avals, **params) - source = source_info_util.current() + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) if primitive.multiple_results: out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) for aval in out_aval] @@ -213,11 +219,11 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): # The outputs of the staged-out call are Tracers with the new eqn as recipe. out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) for a in out_avals] + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), - out_tracers, primitive, staged_params, - source_info_util.current()) + out_tracers, primitive, staged_params, source) for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) def process_map(self, primitive, f: lu.WrappedFun, tracers, params): @@ -305,8 +311,9 @@ def todo(out): update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) new_params = update_params(params, [], len(in_tracers)) new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) - eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, - source_info_util.current()) + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) + eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, source) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -342,8 +349,10 @@ def todo(out): for d, a in zip(staged_out_axes, out_avals_mapped)] out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) for a in out_avals] + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, - primitive, staged_params, source_info_util.current()) + primitive, staged_params, source) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -355,6 +364,9 @@ def out_axes_transform(out_axes): return out, (todo, out_axes_transform) + def _current_truncated_name_stack(self): + return source_info_util.current_name_stack()[len(self.name_stack):] + def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal], app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]], instantiate: bool): @@ -394,11 +406,13 @@ def jvp_jaxpr_thunk(): converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) return converted_jaxpr, (*consts, *env) + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style, dict(fun_jaxpr=closed_jaxpr, jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts) + len(env)), - source_info_util.current()) + source) for t in out_tracers: t.recipe = eqn return out_tracers @@ -434,12 +448,14 @@ def fwd_jaxpr_thunk(): converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) return converted_jaxpr, (*consts, *env) + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style, dict(fun_jaxpr=closed_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, num_consts=len(consts) + len(env), bwd=bwd, out_trees=out_trees), - source_info_util.current()) + source) for t in out_tracers: t.recipe = eqn return out_tracers @@ -551,7 +567,8 @@ def trace_to_jaxpr( returned jaxpr takes as inputs the known residual values followed by values of the originally unknown inputs. """ - with core.new_main(JaxprTrace) as main: + current_name_stack = source_info_util.current_name_stack() + with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: fun = trace_to_subjaxpr(fun, main, instantiate) jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) assert not env @@ -565,7 +582,7 @@ def trace_to_subjaxpr_nounits( main: core.MainTrace, instantiate: Union[bool, Sequence[bool]], in_pvals: Sequence[PartialVal]): assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals - trace = JaxprTrace(main, core.cur_sublevel()) + trace = main.with_cur_sublevel() in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] @@ -1464,11 +1481,13 @@ def process_call(self, call_primitive, f, tracers, params): dim_tracers = _get_tracers_only_in_shapes(tracers) in_avals = _tracers_to_avals(dim_tracers + tracers) keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers) + name_stack = source_info_util.current_name_stack() with core.new_sublevel(): jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, in_avals, keep_inputs=keep_inputs) if params.get('inline', False): - return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers) + with source_info_util.set_name_stack(name_stack): + return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers) source_info = source_info_util.current() env = {v: t for v, t in zip((*jaxpr.constvars, *jaxpr.invars), (*consts, *dim_tracers, *tracers)) @@ -1695,7 +1714,7 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs frame = JaxprStackFrame() - with extend_jaxpr_stack(main, frame): + with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): trace = DynamicJaxprTrace(main, core.cur_sublevel()) in_tracers = _avals_to_tracers(trace, in_avals) in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] @@ -1835,7 +1854,7 @@ def _get_tracers_in_shapes(seen: Set[TracerId], in_tracers: Sequence[Tracer] def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]], pvals: Sequence[PartialVal]): assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals - trace = JaxprTrace(main, core.cur_sublevel()) + trace = main.with_cur_sublevel() in_tracers = map(trace.new_arg, pvals) ans = yield in_tracers, {} assert isinstance(ans, (list, tuple)), ( diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 7a33fde392de..ceb11f70e461 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -54,7 +54,7 @@ from jax._src import source_info_util from jax._src import util from jax._src.util import (unzip3, prod, safe_map, safe_zip, - extend_name_stack, wrap_name, assert_unreachable, + extend_name_stack, new_name_stack, wrap_name, assert_unreachable, tuple_insert, tuple_delete, distributed_debug_log) from jax.errors import JAXTypeError from jax._src import dispatch @@ -1038,7 +1038,7 @@ def lower_parallel_callable( axis_env = xla.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) - name_stack = extend_name_stack(wrap_name(name, 'pmap')) + name_stack = new_name_stack(wrap_name(name, 'pmap')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) replicated_args = [axis is None for axis in in_axes] module: Union[str, xc.XlaComputation] @@ -2145,7 +2145,7 @@ def lower_mesh_computation( in_is_gda: Sequence[bool]): assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) - name_stack = extend_name_stack(wrap_name(fun_name, api_name)) + name_stack = new_name_stack(wrap_name(fun_name, api_name)) global_axis_sizes = mesh.shape @@ -2236,7 +2236,7 @@ def lower_mesh_computation( partitions_are_protos=partitions_proto) return MeshComputation( - name_stack, module, donated_invars, mesh=mesh, global_in_avals=global_in_avals, + str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes, spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_gda=in_is_gda) diff --git a/jax/interpreters/sharded_jit.py b/jax/interpreters/sharded_jit.py index be9eaa2a5514..5a3d2da59e9c 100644 --- a/jax/interpreters/sharded_jit.py +++ b/jax/interpreters/sharded_jit.py @@ -35,7 +35,7 @@ _ensure_index_tuple) import jax._src.util as util from jax.tree_util import tree_flatten, tree_unflatten -from jax._src.util import (extend_name_stack, wrap_name, wraps, safe_map, +from jax._src.util import (new_name_stack, wrap_name, wraps, safe_map, safe_zip, HashableFunction) from jax._src.config import config @@ -149,7 +149,7 @@ def _sharded_callable( xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) ctx = xla.TranslationContext( - c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit"))) + c, platform, axis_env, new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) @@ -202,7 +202,7 @@ def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, sub_ctx = ctx.replace( builder=subc, - name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) + name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) @@ -234,7 +234,7 @@ def _sharded_jit_lowering(ctx, *in_nodes, args.append(ns) sub_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) + name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 28b3c5f4013f..0d980511ad7a 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -42,7 +42,7 @@ Literal, str_eqn_compact, abstract_token) import jax._src.pretty_printer as pp from jax._src import util -from jax._src.util import (prod, extend_name_stack, wrap_name, +from jax._src.util import (prod, extend_name_stack, new_name_stack, wrap_name, safe_zip, safe_map, partition_list) from jax._src.lib import xla_client as xc from jax.interpreters import partial_eval as pe @@ -101,11 +101,15 @@ def _get_canonical_source_file(frame: source_info_util.Frame): def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, - name_stack: str = "", + name_stack: Union[str, source_info_util.NameStack] = "", ) -> xc.OpMetadata: - eqn_str = name_stack + str_eqn_compact(primitive.name, params) + if config.jax_experimental_name_stack: + eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params) + else: + assert isinstance(name_stack, str) + eqn_str = name_stack + str_eqn_compact(primitive.name, params) tracebacks[eqn_str] = source_info.traceback - frame = source_info_util.user_frame(source_info) if source_info else None + frame = source_info_util.user_frame(source_info) return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, @@ -438,7 +442,7 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv', xla_args, _ = _xla_callable_args(c, avals, tuple_args=False, filter_tokens=False) ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env, - name_stack="") + name_stack=new_name_stack()) ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params) if prim.multiple_results: ans = xops.Tuple(c, ans) @@ -551,7 +555,7 @@ class TranslationContext: # with a specific platform in mind. platform: Optional[str] axis_env: AxisEnv - name_stack: str + name_stack: Union[str, source_info_util.NameStack] def replace(self, **kw): return dataclasses.replace(self, **kw) @@ -581,9 +585,15 @@ def write(v, node): _partitionmap(write, jaxpr.constvars, consts) _partitionmap(write, jaxpr.invars, args) for eqn in jaxpr.eqns: + if config.jax_experimental_name_stack: + assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) + source_info = eqn.source_info.replace( + name_stack=ctx.name_stack + eqn.source_info.name_stack) + else: + source_info = eqn.source_info op_metadata = make_op_metadata( eqn.primitive, eqn.params, name_stack=ctx.name_stack, - source_info=eqn.source_info) + source_info=source_info) ctx.builder.set_op_metadata(op_metadata) in_nodes = _flatmap(read, eqn.invars) if (ctx.platform is not None and @@ -596,7 +606,9 @@ def write(v, node): f"XLA translation rule for primitive '{eqn.primitive.name}' not found") with source_info_util.user_context(eqn.source_info.traceback): - ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars), + eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if + config.jax_experimental_name_stack else ctx) + ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars), *in_nodes, **eqn.params) assert isinstance(ans, collections.abc.Sequence), (ans, eqn) @@ -755,8 +767,8 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args): @profiler.annotate_function def lower_jaxpr_to_xla_module( fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv, - name_stack: str, tuple_args: bool, donated_invars: Sequence[bool], - replicated_args: Optional[Sequence[bool]], + name_stack: Union[source_info_util.NameStack, str], tuple_args: bool, + donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]], arg_partitions: Optional[Any], out_partitions: Optional[Any], partitions_are_protos: bool = False @@ -1042,7 +1054,7 @@ def f_with_avals(c, avals, xla_args, params): wrapped_fun = _tuple_output(wrapped_fun) with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)): jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) - ctx = TranslationContext(c, backend, axis_env, '') + ctx = TranslationContext(c, backend, axis_env, new_name_stack()) outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args) if (multiple_results or any(len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)): diff --git a/tests/api_test.py b/tests/api_test.py index 749fa96071ee..81f1cbf9ce70 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7311,7 +7311,13 @@ def f(x): return my_test_function(x) c = jax.xla_computation(f)(2) - self.assertIn("my_test_function", c.as_hlo_text()) + if config.jax_experimental_name_stack: + print_opts = xla_client._xla.HloPrintOptions.short_parsable() + print_opts.print_metadata = True + hlo_text = c.as_hlo_module().to_string(print_opts) + else: + hlo_text = c.as_hlo_text() + self.assertIn("my_test_function", hlo_text) def test_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py new file mode 100644 index 000000000000..d10d2aeff54a --- /dev/null +++ b/tests/name_stack_test.py @@ -0,0 +1,612 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +import jax.numpy as jnp +from jax import core +from jax import lax +from jax import linear_util as lu +from jax.config import config +from jax._src import test_util as jtu +from jax._src import source_info_util +from jax._src.lib import xla_client + +config.parse_flags_with_absl() +extend_name_stack = source_info_util.extend_name_stack + +def _get_hlo(f): + def wrapped(*args, **kwargs): + c = jax.xla_computation(f)(*args, **kwargs) + print_opts = xla_client._xla.HloPrintOptions.short_parsable() + print_opts.print_metadata = True + return c.as_hlo_module().to_string(print_opts) + return wrapped + +class _EnableNameStackTestCase(jtu.JaxTestCase): + + def setUp(self): + self.cfg = config._read("jax_experimental_name_stack") + config.update("jax_experimental_name_stack", True) + + def tearDown(self): + config.update("jax_experimental_name_stack", self.cfg) + + +class NameStackTest(_EnableNameStackTestCase): + + def test_trivial_name_stack(self): + + def f(x): + return x + 1 + jaxpr = jax.make_jaxpr(f)(2).jaxpr + for eqn in jaxpr.eqns: + self.assertEqual(str(eqn.source_info.name_stack), '') + + def test_name_call_name_stack(self): + + @jax.named_call + def f(x): + return x + 1 + jaxpr = jax.make_jaxpr(f)(2).jaxpr + for eqn in jaxpr.eqns: + self.assertEqual(str(eqn.source_info.name_stack), 'f') + + def test_manual_name_stack(self): + + @extend_name_stack('foo') + def f(x): + return x + 1 + jaxpr = jax.make_jaxpr(f)(2).jaxpr + for eqn in jaxpr.eqns: + self.assertEqual(str(eqn.source_info.name_stack), 'foo') + + def test_nested_name_stack(self): + + @extend_name_stack('foo') + def f(x): + with extend_name_stack('bar'): + return x + 1 + jaxpr = jax.make_jaxpr(f)(2).jaxpr + for eqn in jaxpr.eqns: + self.assertEqual(str(eqn.source_info.name_stack), 'foo/bar') + + def test_multiple_name_stack(self): + + def f(x): + with extend_name_stack('foo'): + y = x + 1 + with extend_name_stack('bar'): + with extend_name_stack('baz'): + return y + 1 + jaxpr = jax.make_jaxpr(f)(2).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'bar/baz') + + def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self): + @extend_name_stack('foo') + def f(x): + @lu.wrap_init + @extend_name_stack('bar') + def _f(x): + return [x + 1] + return core.call(_f, x)[0] + + jaxpr = jax.make_jaxpr(f)(2).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') + + hlo_text = _get_hlo(f)(2) + self.assertIn('foo/core_call/bar', hlo_text) + + def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self): + @extend_name_stack('foo') + def f(x): + @jax.jit + @extend_name_stack('bar') + def _f(x): + return x + 1 + return _f(x) + + jaxpr = jax.make_jaxpr(f)(2).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') + + hlo_text = _get_hlo(f)(2) + self.assertIn('foo/jit(_f)/bar', hlo_text) + + def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self): + @extend_name_stack('foo') + @jax.pmap + def f(x): + with extend_name_stack('bar'): + return x + 1 + jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') + + +class NameStackTransformationTest(_EnableNameStackTestCase): + + def test_vmap_should_transform_name_stack(self): + @jax.vmap + def f(x): + with extend_name_stack('foo'): + return x + 1 + jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)') + + def test_vmap_should_transform_inner_name_stacks(self): + @extend_name_stack('foo') + @jax.vmap + def f(x): + with extend_name_stack('bar'): + with extend_name_stack('baz'): + return x + 1 + jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/vmap(bar)/vmap(baz)') + + def test_vmap_should_apply_to_call_jaxpr(self): + @extend_name_stack('foo') + @jax.vmap + def f(x): + @jax.jit + @extend_name_stack('bar') + def _f(x): + return x + 1 + return _f(x) + + jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') + + hlo_text = _get_hlo(f)(jnp.ones(2)) + self.assertIn('foo/vmap(jit(_f))/vmap(bar)', hlo_text) + + def test_jvp_should_transform_stacks(self): + def f(x): + with extend_name_stack('bar'): + with extend_name_stack('baz'): + return jnp.square(x) + g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,))) + jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), + 'foo/jvp(bar)/jvp(baz)') + + def test_jvp_should_apply_to_call_jaxpr(self): + @jax.jit + def f(x): + with extend_name_stack('bar'): + with extend_name_stack('baz'): + return jnp.square(x) + g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,))) + jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual( + str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), + 'bar/baz') + + hlo_text = _get_hlo(g)(1., 1.) + self.assertIn('foo/jvp(jit(f))/jvp(bar)', hlo_text) + + def test_grad_should_add_jvp_and_transpose_to_name_stack(self): + @jax.grad + def f(x): + with extend_name_stack('foo'): + return jnp.sin(x) + jaxpr = jax.make_jaxpr(f)(1.).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack), + 'transpose(jvp(foo))') + + hlo_text = _get_hlo(f)(1.) + self.assertIn('jvp(foo)/sin', hlo_text) + self.assertIn('jvp(foo)/cos', hlo_text) + self.assertIn('transpose(jvp(foo))/mul', hlo_text) + + def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self): + @jax.grad + @extend_name_stack('foo') + @jax.jit + def f(x): + with extend_name_stack('bar'): + return jnp.sin(x) + jaxpr = jax.make_jaxpr(f)(1.).jaxpr + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'transpose(jvp(foo))') + self.assertEqual(str( + jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') + self.assertEqual(str( + jaxpr.eqns[0].params['call_jaxpr'].eqns[1].source_info.name_stack), 'bar') + self.assertEqual(str( + jaxpr.eqns[1].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') + + hlo_text = _get_hlo(f)(1.) + self.assertIn('jvp(foo)/jvp(jit(f))/jvp(bar)/sin', hlo_text) + self.assertIn('jvp(foo)/jvp(jit(f))/jvp(bar)/cos', hlo_text) + self.assertIn( + 'transpose(jvp(foo))/transpose(jvp(jit(f)))/transpose(jvp(bar))/mul', + hlo_text) + + +class NameStackControlFlowTest(_EnableNameStackTestCase): + + def test_while_loop_body_should_not_have_name_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('bar') + def body(x): + return x + 1 + @extend_name_stack('bar_cond') + def cond(x): + return x < 5 + return lax.while_loop(cond, body, x) + jaxpr = jax.make_jaxpr(f)(0) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str( + jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack), + 'bar') + self.assertEqual(str( + jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack), + 'bar_cond') + + hlo_text = _get_hlo(f)(1.) + self.assertIn('foo/while/body/bar', hlo_text) + self.assertIn('foo/while/cond/bar_cond', hlo_text) + + def test_vmap_of_while_loop_should_transform_name_stack(self): + + @jax.vmap + @extend_name_stack('foo') + def f(x): + @extend_name_stack('bar') + def body(x): + return x + 1 + @extend_name_stack('bar_cond') + def cond(x): + return x < 5 + return lax.while_loop(cond, body, x) + jaxpr = jax.make_jaxpr(f)(jnp.arange(2)) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)') + self.assertEqual(str( + jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack), + 'bar') + self.assertEqual(str( + jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack), + 'bar_cond') + + hlo_text = _get_hlo(f)(jnp.arange(2.)) + self.assertIn('vmap(foo)/vmap(while)/vmap(body)/vmap(bar)', hlo_text) + self.assertIn('vmap(foo)/vmap(while)/vmap(cond)/vmap(bar_cond)', hlo_text) + + def test_jvp_of_while_loop_transforms_name_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('bar') + def body(x): + return x + 1. + @extend_name_stack('bar_cond') + def cond(x): + return x < 5. + return lax.while_loop(cond, body, x) + g = lambda x, t: jax.jvp(f, (x,), (t,)) + jaxpr = jax.make_jaxpr(g)(1., 1.) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str( + jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack), + 'bar') + self.assertEqual(str( + jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack), + 'bar_cond') + + hlo_text = _get_hlo(g)(1., 1.) + self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(bar)', hlo_text) + self.assertIn('jvp(foo)/jvp(while)/jvp(cond)/jvp(bar_cond)', hlo_text) + + def test_vmap_of_jvp_of_while_loop_transforms_name_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('bar') + def body(x): + return x + 1. + @extend_name_stack('bar_cond') + def cond(x): + return x < 5. + return lax.while_loop(cond, body, x) + g = jax.vmap(lambda x, t: jax.jvp(f, (x,), (t,))) + jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2)) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))') + self.assertEqual(str( + jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack), + 'bar') + self.assertEqual(str( + jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack), + 'bar_cond') + + hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2)) + self.assertIn( + 'vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(body))/vmap(jvp(bar))', + hlo_text) + self.assertIn( + 'vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(cond))/vmap(jvp(bar_cond))', + hlo_text) + + def test_cond_body_should_not_have_name_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('true') + def true_fn(x): + return x + 1 + @extend_name_stack('false') + def false_fn(x): + return x - 1 + return lax.cond(True, true_fn, false_fn, x) + jaxpr = jax.make_jaxpr(f)(0) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack), + 'false') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack), + 'true') + + hlo_text = _get_hlo(f)(1.) + self.assertIn('foo/cond/branch_0_fun/false/sub', hlo_text) + self.assertIn('foo/cond/branch_1_fun/true/add', hlo_text) + + def test_vmap_of_cond_should_transform_name_stack(self): + + @extend_name_stack('foo') + @jax.vmap + def f(x): + @extend_name_stack('true') + def true_fn(x): + return x + 1 + @extend_name_stack('false') + def false_fn(x): + return x - 1 + return lax.cond(True, true_fn, false_fn, x) + jaxpr = jax.make_jaxpr(f)(jnp.arange(2)) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack), + 'false') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack), + 'true') + + hlo_text = _get_hlo(f)(jnp.arange(2.)) + self.assertIn('foo/vmap(cond)/vmap(branch_0_fun)/vmap(false)/sub', hlo_text) + self.assertIn('foo/vmap(cond)/vmap(branch_1_fun)/vmap(true)/add', hlo_text) + + def test_jvp_of_cond_transforms_name_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('true') + def true_fn(x): + return x + 1 + @extend_name_stack('false') + def false_fn(x): + return x - 1 + return lax.cond(True, true_fn, false_fn, x) + g = lambda x, t: jax.jvp(f, (x,), (t,)) + jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2)) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack), + 'false') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack), + 'true') + + hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2)) + self.assertIn('jvp(foo)/jvp(cond)/jvp(branch_0_fun)/jvp(false)/sub', hlo_text) + self.assertIn('jvp(foo)/jvp(cond)/jvp(branch_1_fun)/jvp(true)/add', hlo_text) + + def test_vmap_of_jvp_of_cond_transforms_name_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('true') + def true_fn(x): + return x + 1 + @extend_name_stack('false') + def false_fn(x): + return x - 1 + return lax.cond(True, true_fn, false_fn, x) + g = jax.vmap(lambda x, t: jax.jvp(f, (x,), (t,))) + jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2)) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack), + 'false') + self.assertEqual(str( + jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack), + 'true') + + hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2)) + self.assertIn( + 'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_0_fun))/vmap(jvp(false))/sub', + hlo_text) + self.assertIn( + 'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_1_fun))/vmap(jvp(true))/add', + hlo_text) + + def test_grad_of_cond_transforms_name_stack(self): + + @jax.grad + @extend_name_stack('foo') + def f(x): + @extend_name_stack('true') + def true_fn(x): + return x * 2. + @extend_name_stack('false') + def false_fn(x): + return x / 2. + return lax.cond(True, true_fn, false_fn, x) + jaxpr = jax.make_jaxpr(f)(1.) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), + 'transpose(jvp(foo))') + + hlo_text = _get_hlo(f)(1.) + self.assertIn( + 'jvp(foo)/jvp(cond)/jvp(branch_0_fun)/jvp(false)/div', + hlo_text) + self.assertIn( + 'jvp(foo)/jvp(cond)/jvp(branch_1_fun)/jvp(true)/mul', + hlo_text) + self.assertIn( + 'transpose(jvp(foo))/transpose(jvp(cond))/transpose(jvp(branch_0_fun))/transpose(jvp(false))/div', + hlo_text) + self.assertIn( + 'transpose(jvp(foo))/transpose(jvp(cond))/transpose(jvp(branch_1_fun))/transpose(jvp(true))/mul', + hlo_text) + + def test_vmap_of_grad_of_cond_transforms_name_stack(self): + + @jax.vmap + @jax.grad + @extend_name_stack('foo') + def f(x): + @extend_name_stack('true') + def true_fn(x): + return x * 2. + @extend_name_stack('false') + def false_fn(x): + return x / 2. + return lax.cond(True, true_fn, false_fn, x) + jaxpr = jax.make_jaxpr(f)(jnp.arange(2.)) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))') + self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), + 'vmap(transpose(jvp(foo)))') + + hlo_text = _get_hlo(f)(jnp.arange(2.)) + self.assertIn( + 'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_0_fun))/vmap(jvp(false))/div', + hlo_text) + self.assertIn( + 'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_1_fun))/vmap(jvp(true))/mul', + hlo_text) + self.assertIn( + 'vmap(transpose(jvp(foo)))/vmap(transpose(jvp(cond)))/vmap(transpose(jvp(branch_0_fun)))/vmap(transpose(jvp(false)))/div', + hlo_text) + self.assertIn( + 'vmap(transpose(jvp(foo)))/vmap(transpose(jvp(cond)))/vmap(transpose(jvp(branch_1_fun)))/vmap(transpose(jvp(true)))/mul', + hlo_text) + + def test_scan_body_should_not_have_name_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('scan_body') + def body(carry, x): + return carry + x, carry + x + return lax.scan(body, x, jnp.arange(5.)) + jaxpr = jax.make_jaxpr(f)(1.) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') + self.assertEqual(str( + jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack), + 'scan_body') + + hlo_text = _get_hlo(f)(1.) + self.assertIn('foo/while/body/scan_body', hlo_text) + + def test_vmap_of_scan_should_transform_stack(self): + + @jax.vmap + @extend_name_stack('foo') + def f(x): + @extend_name_stack('scan_body') + def body(carry, x): + return carry + x, carry + x + return lax.scan(body, x, jnp.arange(8.)) + jaxpr = jax.make_jaxpr(f)(jnp.arange(2.)) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)') + self.assertEqual(str( + jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack), + 'scan_body') + + hlo_text = _get_hlo(f)(jnp.arange(2.)) + self.assertIn('vmap(foo)/vmap(while)/vmap(body)/vmap(scan_body)/add', hlo_text) + + def test_jvp_of_scan_should_transform_stack(self): + + @extend_name_stack('foo') + def f(x): + @extend_name_stack('scan_body') + def body(carry, x): + return carry + x, carry + x + return lax.scan(body, x, jnp.arange(8.)) + g = lambda x, t: jax.jvp(f, (x,), (t,)) + jaxpr = jax.make_jaxpr(g)(1., 1.) + self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str( + jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack), + 'scan_body') + + hlo_text = _get_hlo(g)(1., 1.) + self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(scan_body)/add', hlo_text) + + def test_grad_of_scan_should_transform_stack(self): + + @jax.grad + @extend_name_stack('foo') + def f(x): + @extend_name_stack('scan_body') + def body(carry, x): + return carry * x, carry + x + return lax.scan(body, x, jnp.arange(8.))[0] + jaxpr = jax.make_jaxpr(f)(1.) + self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)') + self.assertEqual(str(jaxpr.eqns[3].source_info.name_stack), + 'transpose(jvp(foo))') + self.assertEqual(str( + jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack), + 'scan_body') + + hlo_text = _get_hlo(f)(1.) + self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(scan_body)/mul', hlo_text) + self.assertIn('transpose(jvp(foo))/transpose(jvp(while))/transpose(jvp(body))/transpose(jvp(scan_body))/mul', hlo_text) + + def test_vmap_of_grad_of_scan_should_transform_stack(self): + + @jax.vmap + @jax.grad + @extend_name_stack('foo') + def f(x): + @extend_name_stack('scan_body') + def body(carry, x): + return carry * x, carry + x + return lax.scan(body, x, jnp.arange(8.))[0] + jaxpr = jax.make_jaxpr(f)(jnp.arange(2.)) + self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'vmap(jvp(foo))') + self.assertEqual(str(jaxpr.eqns[3].source_info.name_stack), + 'vmap(transpose(jvp(foo)))') + self.assertEqual(str( + jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack), + 'scan_body') + + hlo_text = _get_hlo(f)(jnp.arange(2.)) + self.assertIn('vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(body))/vmap(jvp(scan_body))/mul', hlo_text) + self.assertIn('vmap(transpose(jvp(foo)))/vmap(transpose(jvp(while)))/vmap(transpose(jvp(body)))/vmap(transpose(jvp(scan_body)))/mul', hlo_text) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())