Skip to content

Commit

Permalink
Add separate mechanism for threading name stacks to the lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Feb 23, 2022
1 parent e96b91d commit 1b79caa
Show file tree
Hide file tree
Showing 24 changed files with 937 additions and 142 deletions.
2 changes: 1 addition & 1 deletion jax/_src/ad_checkpoint.py
Expand Up @@ -353,7 +353,7 @@ def transposed(*args):
primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, consts, dummy_args,
in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
out_cts)
in_cts, cell.treedef = tree_flatten(in_cts_)
return in_cts
Expand Down
13 changes: 8 additions & 5 deletions jax/_src/api.py
Expand Up @@ -56,10 +56,11 @@
Partial, PyTreeDef, all_leaves)
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, wrap_name, cache, wraps,
extend_name_stack, new_name_stack, wrap_name, cache, wraps,
HashableFunction)
from jax._src import device_array
from jax._src import dispatch
from jax._src import source_info_util
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -895,9 +896,8 @@ def computation_maker(*args, **kwargs):
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
ctx = xla.TranslationContext(
c, backend, axis_env_,
extend_name_stack(wrap_name(fun_name, "xla_computation")))
name_stack = new_name_stack(wrap_name(fun_name, "xla_computation"))
ctx = xla.TranslationContext(c, backend, axis_env_, name_stack)
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
if out_parts is not None:
Expand Down Expand Up @@ -2615,7 +2615,7 @@ def transposed_fun(consts, out_cotangent):
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cotangents = map(
ad.instantiate_zeros,
ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents))
ad.backward_pass(jaxpr, reduce_axes, True, consts, dummies, out_cotangents))
return tree_unflatten(in_tree, in_cotangents)

# Ensure that transposed_fun is a PyTree
Expand Down Expand Up @@ -3197,6 +3197,9 @@ def named_call(

_, in_tree = tree_flatten(())

if config.jax_experimental_name_stack:
return source_info_util.extend_name_stack(name)(fun)

@functools.wraps(fun)
def named_call_f(*args, **kwargs):
lu_f = lu.wrap_init(lambda: fun(*args, **kwargs))
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/config.py
Expand Up @@ -140,7 +140,6 @@ def config_with_absl(self):
for name, val in self.values.items():
flag_type, meta_args, meta_kwargs = self.meta[name]
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)

app.call_after_init(lambda: self.complete_absl_config(absl_flags))

def complete_absl_config(self, absl_flags):
Expand Down Expand Up @@ -688,6 +687,11 @@ def _update_disable_jit_thread_local(val):
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'))

config.define_bool_state(
name='jax_experimental_name_stack',
default=False,
help='Enable using the context manager-based name stack.')

# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
config.define_bool_state(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_derivatives.py
Expand Up @@ -407,7 +407,7 @@ def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
jvp_jaxpr_thunk, num_consts):
del jvp_jaxpr_thunk, num_consts
return ad.backward_pass(
fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts)
fun_jaxpr.jaxpr, reduce_axes, False, fun_jaxpr.consts, args, cts)
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose

def custom_jvp_jaxpr_custom_partial_eval_rule(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/dispatch.py
Expand Up @@ -250,7 +250,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
# pass long arg lists as tuple for TPU
tuple_args = len(abstract_args) > 100
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit'))
name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"jit_{fun.__name__}"
Expand Down
26 changes: 16 additions & 10 deletions jax/_src/lax/control_flow.py
Expand Up @@ -347,8 +347,9 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
name_stack = extend_name_stack(ctx.name_stack, 'while')
cond_ctx = ctx.replace(builder=cond_c,
name_stack=extend_name_stack(ctx.name_stack, 'cond'))
name_stack=extend_name_stack(name_stack, 'cond'))
pred, = xla.jaxpr_subcomp(
cond_ctx, cond_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts),
Expand All @@ -365,14 +366,14 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
body_ctx = ctx.replace(builder=body_c,
name_stack=extend_name_stack(ctx.name_stack, 'body'))
name_stack=extend_name_stack(name_stack, 'body'))
new_z = xla.jaxpr_subcomp(
body_ctx, body_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = body_ctx.replace(
name_stack=extend_name_stack(ctx.name_stack, 'body_pred'))
name_stack=extend_name_stack(name_stack, 'body_pred'))
body_pred, = xla.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts),
Expand Down Expand Up @@ -1201,9 +1202,11 @@ def _cond_partial_eval(trace, *tracers, branches, linear):

linear_2 = (False,) * num_res + linear
params = dict(branches=branches_2, linear=linear_2)
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
source_info_util.current())
source)
for t in out_tracers: t.recipe = eqn
return out_tracers

Expand Down Expand Up @@ -1297,7 +1300,7 @@ def transposed(*args):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
cts_in = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out)
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)

Expand Down Expand Up @@ -1924,9 +1927,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])]
other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
in_pvals_1 = invariant_pvals + other_pvals
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
with source_info_util.reset_name_stack():
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ())
num_consts_1 = num_consts + len(consts_1)
# any now-known residuals are intensive, so we want to revise jaxpr_2 to take
Expand Down Expand Up @@ -1990,6 +1994,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
for pv, const in zip(out_pvs, out_consts)]
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
linear_2 = ([False] * len(int_res_tracers) +
[lin or not uk for uk, lin in zip(unknowns, linear)] +
[False] * len(ext_res_tracers))
Expand All @@ -1999,7 +2005,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
num_consts=num_consts_2,
num_carry=num_carry, linear=tuple(linear_2),
unroll=unroll),
source_info_util.current())
source)
for t in out_tracers: t.recipe = eqn
return out_tracers

Expand Down Expand Up @@ -2068,7 +2074,7 @@ def transposed(*res1_cbar_bbar_res2):
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts,
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
primals, b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Expand Up @@ -50,7 +50,7 @@
import jax._src.pretty_printer as pp
from jax._src import util
from jax._src.util import (cache, safe_zip, prod, safe_map, canonicalize_axis,
split_list)
split_list, new_name_stack)
from jax.tree_util import tree_map
import jax._src.lib
from jax._src.lib import pytree
Expand Down Expand Up @@ -3424,7 +3424,7 @@ def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True):
subc = xc.XlaBuilder("reduction_computation")
assert len(consts) == 0, "Reduction computations cannot have constants"
args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
ctx = xla.TranslationContext(subc, platform, axis_env, '')
ctx = xla.TranslationContext(subc, platform, axis_env, new_name_stack())
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
if singleton:
return subc.build(out_nodes[0])
Expand Down
114 changes: 103 additions & 11 deletions jax/_src/source_info_util.py
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

import contextlib
import dataclasses
import functools
import itertools
import os.path
import threading
import types
from typing import Optional, Iterator, NamedTuple
from typing import Optional, Iterator, NamedTuple, Union, Tuple

import jax.version
from jax._src.lib import xla_client, xla_extension_version
Expand All @@ -40,15 +41,66 @@ class Frame(NamedTuple):
def register_exclusion(path):
_exclude_paths.append(path)

class Scope(NamedTuple):
name: str

def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
return (self.name, *stack)

class Transform(NamedTuple):
name: str

def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(map(lambda x: f'{self.name}({x})', stack))

@dataclasses.dataclass(frozen=True)
class NameStack:
stack: Tuple[Union[Scope, Transform], ...] = ()

def extend(self, name: Union[Tuple[str, ...], str]) -> 'NameStack':
if not isinstance(name, tuple):
name = (name,)
scopes = tuple(map(Scope, name))
return NameStack(self.stack + scopes)

def wrap_name(self, name: str) -> str:
if not self.stack:
return name
return f'{str(self)}/{name}'

def transform(self, transform_name: str) -> 'NameStack':
return NameStack((*self.stack, Transform(transform_name)))

def __getitem__(self, idx) -> 'NameStack':
return NameStack(self.stack[idx])

def __len__(self):
return len(self.stack)

def __add__(self, other: 'NameStack') -> 'NameStack':
return NameStack(self.stack + other.stack)

def __radd__(self, other: 'NameStack') -> 'NameStack':
return NameStack(other.stack + self.stack)

def __str__(self) -> str:
scope: Tuple[str, ...] = ()
for elem in self.stack[::-1]:
scope = elem.wrap(scope)
return '/'.join(scope)

class SourceInfo(NamedTuple):
traceback: Optional[Traceback]
name_stack: NameStack

def replace(self, *, traceback: Optional[Traceback] = None) -> 'SourceInfo':
def replace(self, *, traceback: Optional[Traceback] = None,
name_stack: Optional[NameStack] = None) -> 'SourceInfo':
traceback = traceback or self.traceback
return self._replace(traceback=traceback)
name_stack = self.name_stack if name_stack is None else name_stack
return self._replace(traceback=traceback, name_stack=name_stack)

def new_source_info() -> SourceInfo:
return SourceInfo(None)
return SourceInfo(None, NameStack())

def is_user_filename(filename: str) -> bool:
"""Heuristic that guesses the identity of the user's code in a stack trace."""
Expand Down Expand Up @@ -97,11 +149,10 @@ def __init__(self):
_source_info_context = _SourceInfoContext()

def current() -> SourceInfo:
context = _source_info_context.context
if not context.traceback:
return context.replace(traceback=xla_client.Traceback.get_traceback())
return context

source_info = _source_info_context.context
if not source_info.traceback:
source_info = source_info.replace(traceback=xla_client.Traceback.get_traceback())
return source_info

class JaxStackTraceBeforeTransformation(Exception): pass

Expand All @@ -118,9 +169,10 @@ def has_user_context(e):
return False

@contextlib.contextmanager
def user_context(c: Optional[Traceback]):
def user_context(c: Optional[Traceback], *, name_stack: Optional[NameStack] = None):
prev = _source_info_context.context
_source_info_context.context = _source_info_context.context.replace(traceback=c)
_source_info_context.context = _source_info_context.context.replace(
traceback=c, name_stack=name_stack)
filtered_tb = None
try:
yield
Expand All @@ -141,3 +193,43 @@ def user_context(c: Optional[Traceback]):
finally:
_source_info_context.context = prev
del filtered_tb

def current_name_stack() -> NameStack:
return _source_info_context.context.name_stack

@contextlib.contextmanager
def extend_name_stack(name: str) -> Iterator[NameStack]:
prev_context = _source_info_context.context
curr_name_stack = prev_context.name_stack
new_context = prev_context.replace(name_stack=curr_name_stack.extend(name))
_source_info_context.context = new_context
try:
yield _source_info_context.context.name_stack
finally:
_source_info_context.context = prev_context

@contextlib.contextmanager
def set_name_stack(name_stack: NameStack) -> Iterator[None]:
prev_context = _source_info_context.context
new_context = prev_context.replace(name_stack=name_stack)
_source_info_context.context = new_context
try:
yield
finally:
_source_info_context.context = prev_context

@contextlib.contextmanager
def reset_name_stack() -> Iterator[None]:
with set_name_stack(NameStack()):
yield

@contextlib.contextmanager
def transform_name_stack(name: str) -> Iterator[NameStack]:
prev_context = _source_info_context.context
curr_name_stack = prev_context.name_stack
new_context = prev_context.replace(name_stack=curr_name_stack.transform(name))
_source_info_context.context = new_context
try:
yield _source_info_context.context.name_stack
finally:
_source_info_context.context = prev_context
16 changes: 15 additions & 1 deletion jax/_src/util.py
Expand Up @@ -277,7 +277,21 @@ def get_module_functions(module):
def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'

def extend_name_stack(stack, name=''):
def new_name_stack(name: str = ''):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
name_stack = source_info_util.NameStack()
if name:
name_stack = name_stack.extend(name)
return name_stack
return name + '/'

def extend_name_stack(stack, name: str):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
assert isinstance(stack, source_info_util.NameStack), stack
return stack.extend(name)
assert isinstance(stack, str)
return stack + name + '/'

def canonicalize_axis(axis, num_dims) -> int:
Expand Down

0 comments on commit 1b79caa

Please sign in to comment.