Skip to content

Commit

Permalink
Split name_stack out of mlir.ModuleContext.
Browse files Browse the repository at this point in the history
A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext.

PiperOrigin-RevId: 608594374
  • Loading branch information
hawkinsp authored and jax authors committed Feb 20, 2024
1 parent 2165611 commit f1ea671
Show file tree
Hide file tree
Showing 13 changed files with 95 additions and 92 deletions.
4 changes: 2 additions & 2 deletions jax/_src/custom_derivatives.py
Expand Up @@ -437,8 +437,8 @@ def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
args_ = map(mlir.wrap_singleton_ir_values, args)
consts = mlir._ir_consts(call_jaxpr.consts)
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
ctx.tokens_in, consts, *args_,
dim_var_values=ctx.dim_var_values)
ctx.name_stack, ctx.tokens_in, consts,
*args_, dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation)
Expand Down
87 changes: 44 additions & 43 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -340,24 +340,23 @@ def _token_constant_handler(val):
# Source locations

def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
if file_name in caches.canonical_name_cache:
return caches.canonical_name_cache[file_name]
canonical_file_name = caches.canonical_name_cache.get(file_name, None)
if canonical_file_name is not None:
return canonical_file_name

source_file = file_name
pattern = config.hlo_source_file_canonicalization_regex.value
if pattern:
source_file = re.sub(pattern, '', source_file)

caches.canonical_name_cache[file_name] = source_file
return source_file
file_name = re.sub(pattern, '', file_name)
caches.canonical_name_cache[file_name] = file_name
return file_name

def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
if file_name in ctx.traceback_caches.is_user_file_cache:
return ctx.traceback_caches.is_user_file_cache[file_name]

result = source_info_util.is_user_filename(file_name)
ctx.traceback_caches.is_user_file_cache[file_name] = result
return result
is_user = ctx.traceback_caches.is_user_file_cache.get(file_name, None)
if is_user is not None:
return is_user
out = source_info_util.is_user_filename(file_name)
ctx.traceback_caches.is_user_file_cache[file_name] = out
return out

def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
"""Converts a full traceback to a callsite() MLIR location."""
Expand Down Expand Up @@ -386,12 +385,12 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
if len(frame_locs) >= frames_limit:
break

if len(frame_locs) == 0:
n = len(frame_locs)
if n == 0:
return ir.Location.unknown()
elif n == 1:
return frame_locs[0]
else:
if len(frame_locs) == 1:
return frame_locs[0]

return ir.Location.callsite(frame_locs[0], frame_locs[1:])

def _source_info_to_location(
Expand Down Expand Up @@ -589,7 +588,6 @@ class ModuleContext:
backend_or_name: str | xb.XlaBackend | None
platforms: Sequence[str]
axis_context: AxisContext
name_stack: source_info_util.NameStack
keepalives: list[Any]
channel_iterator: Iterator[int]
host_callbacks: list[Any]
Expand All @@ -614,7 +612,6 @@ def __init__(
backend_or_name: str | xb.XlaBackend | None,
platforms: Sequence[str],
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
keepalives: list[Any],
channel_iterator: Iterator[int],
host_callbacks: list[Any],
Expand All @@ -635,7 +632,6 @@ def __init__(
self.backend_or_name = backend_or_name
self.platforms = platforms
self.axis_context = axis_context
self.name_stack = name_stack
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
else cached_primitive_lowerings)
self.traceback_caches = (TracebackCaches() if traceback_caches is None
Expand Down Expand Up @@ -683,6 +679,7 @@ def replace(self, **kw): return dataclasses.replace(self, **kw)
class LoweringRuleContext:
"""Per-rule context information for MLIR lowering."""
module_context: ModuleContext
name_stack: source_info_util.NameStack
primitive: core.Primitive | None
avals_in: Sequence[core.AbstractValue]
avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None.
Expand Down Expand Up @@ -947,7 +944,6 @@ def lower_jaxpr_to_module(

ctx = ModuleContext(backend_or_name=backend_or_name,
platforms=platforms, axis_context=axis_context,
name_stack=name_stack,
keepalives=keepalives,
channel_iterator=channel_iter,
host_callbacks=host_callbacks,
Expand All @@ -964,7 +960,9 @@ def lower_jaxpr_to_module(
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
replace_tokens_with_dummy = lowering_parameters.replace_tokens_with_dummy
lower_jaxpr_to_fun(
ctx, "main", jaxpr, ordered_effects, public=True,
ctx, "main", jaxpr, ordered_effects,
name_stack=name_stack,
public=True,
create_tokens=replace_tokens_with_dummy,
replace_tokens_with_dummy=replace_tokens_with_dummy,
num_output_tokens=0,
Expand Down Expand Up @@ -1105,6 +1103,7 @@ def lower_jaxpr_to_fun(
name: str,
jaxpr: core.ClosedJaxpr,
effects: Sequence[core.Effect],
name_stack: source_info_util.NameStack,
*,
create_tokens: bool = False,
public: bool = False,
Expand Down Expand Up @@ -1376,7 +1375,7 @@ def aval_to_types(aval):
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
# A lowering context just for function body entry/exit code.
entry_lowering_ctx = LoweringRuleContext(
module_context=ctx, primitive=None,
module_context=ctx, name_stack=name_stack, primitive=None,
avals_in=[], avals_out=None,
tokens_in=TokenSet.create([]), tokens_out=None,
axis_size_env=None, dim_var_values=dim_var_values)
Expand All @@ -1403,10 +1402,10 @@ def aval_to_types(aval):
args.append([hlo.create_token()])
else:
args.append(arg)
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = jaxpr_subcomp(
ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, tokens_in,
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
consts, *args, dim_var_values=dim_var_values)
outs = []
if create_tokens:
Expand Down Expand Up @@ -1496,6 +1495,7 @@ def _emit_lowering_rule_as_fun(lowering_rule,
return func_op

def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
name_stack: source_info_util.NameStack,
tokens: TokenSet,
consts: Sequence[Sequence[ir.Value]],
*args: Sequence[ir.Value],
Expand Down Expand Up @@ -1536,6 +1536,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None

env: dict[core.Var, tuple[ir.Value, ...]] = {}

assert isinstance(name_stack, source_info_util.NameStack), type(name_stack)
assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
Expand All @@ -1545,9 +1546,8 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
name_stack=name_stack + eqn.source_info.name_stack)
loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_override_lowering_rule(eqn.primitive)
Expand All @@ -1569,12 +1569,12 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None
elif eqn.primitive in xla._translations:
default_rule = xla_fallback_lowering(eqn.primitive)

eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
rule_ctx = LoweringRuleContext(
module_context=eqn_ctx, primitive=eqn.primitive,
module_context=ctx, primitive=eqn.primitive,
name_stack=source_info.name_stack,
avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values)
Expand Down Expand Up @@ -1781,15 +1781,16 @@ def f_lowered(ctx, *args, **params):
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?

out, tokens = jaxpr_subcomp(
ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts),
*map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values)
ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in,
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out

return f_lowered


def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack,
arg_names=None, result_names=None):
if not call_jaxpr.consts and arg_names is result_names is None:
# Cacheable.
Expand All @@ -1798,12 +1799,12 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
func_op = ctx.cached_primitive_lowerings[key]
except KeyError:
func_op = lower_jaxpr_to_fun(
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
result_names=result_names)
ctx.cached_primitive_lowerings[key] = func_op
else:
func_op = lower_jaxpr_to_fun(
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
result_names=result_names)
return func_op

Expand All @@ -1825,12 +1826,12 @@ def check_backend_matches(inner_backend: str | None,
f"inner-jit backend specification {inner_backend}.")


def _call_lowering(fn_name, stack_name, call_jaxpr, backend,
ctx: ModuleContext, avals_in,
avals_out, tokens_in, *args,
dim_var_values: Sequence[ir.Value],
arg_names=None, result_names=None):
del stack_name, avals_in
def call_lowering(fn_name, name_stack, call_jaxpr, backend,
ctx: ModuleContext, avals_in,
avals_out, tokens_in, *args,
dim_var_values: Sequence[ir.Value],
arg_names=None, result_names=None):
del avals_in
if isinstance(call_jaxpr, core.Jaxpr):
call_jaxpr = pe.close_jaxpr(call_jaxpr)
check_backend_matches(backend, ctx.platforms)
Expand All @@ -1839,7 +1840,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend,
output_types = [token_type()] * len(effects) + output_types
flat_output_types = util.flatten(output_types)
symbol_name = _lower_jaxpr_to_fun_cached(
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
result_names=result_names).name.value
tokens = [tokens_in.get(eff) for eff in effects]
args = (*dim_var_values, *tokens, *args)
Expand All @@ -1853,8 +1854,8 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend,

def core_call_lowering(ctx: LoweringRuleContext,
*args, name, backend=None, call_jaxpr):
out_nodes, tokens = _call_lowering(
name, name, call_jaxpr, backend, ctx.module_context,
out_nodes, tokens = call_lowering(
name, ctx.name_stack, call_jaxpr, backend, ctx.module_context,
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -1433,12 +1433,12 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,

with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
sub_ctx = ctx.module_context.replace(
axis_context=sharding_impls.ReplicaAxisContext(new_env),
name_stack=ctx.module_context.name_stack.extend(
util.wrap_name(name, 'pmap')))
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
*in_nodes_sharded,
dim_var_values=ctx.dim_var_values)
axis_context=sharding_impls.ReplicaAxisContext(new_env))
sharded_outs, _ = mlir.jaxpr_subcomp(
sub_ctx, call_jaxpr,
ctx.name_stack.extend(util.wrap_name(name, 'pmap')),
mlir.TokenSet(), (), *in_nodes_sharded,
dim_var_values=ctx.dim_var_values)
out_avals = [v.aval for v in call_jaxpr.outvars]
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard)
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/lax/control_flow/conditionals.py
Expand Up @@ -847,16 +847,14 @@ def _cond_lowering(ctx, index, *args, branches, linear):
# captures.
case_op = hlo.CaseOp(flat_output_types, index=index,
num_branches=len(branches))
name_stack = ctx.module_context.name_stack.extend('cond')
name_stack = ctx.name_stack.extend('cond')
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
sub_ctx = ctx.module_context.replace(
name_stack=name_stack.extend(f'branch_{i}_fun'))
consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
consts, *map(mlir.wrap_singleton_ir_values, args),
ctx.module_context, jaxpr.jaxpr, name_stack.extend(f'branch_{i}_fun'),
tokens_in, consts, *map(mlir.wrap_singleton_ir_values, args),
dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
Expand Down
21 changes: 12 additions & 9 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -1662,7 +1662,7 @@ def fun(*args):

# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
name_stack = ctx.module_context.name_stack.extend('while')
name_stack = ctx.name_stack.extend('while')
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
Expand All @@ -1671,13 +1671,14 @@ def fun(*args):
# Remove tokens from cond args
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond'))
cond_consts = [
mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts
]
cond_name_stack = name_stack.extend('cond')
((pred,),), _ = mlir.jaxpr_subcomp(
cond_ctx,
ctx.module_context,
cond_jaxpr.jaxpr,
cond_name_stack,
mlir.TokenSet(),
cond_consts,
*(x + z),
Expand All @@ -1686,6 +1687,7 @@ def fun(*args):
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
name_stack=cond_name_stack,
primitive=None,
avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())],
Expand All @@ -1710,20 +1712,21 @@ def fun(*args):
token_args, body_args = util.split_list(body_args, [num_tokens])
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body'))
body_name_stack = name_stack.extend('body')
body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
for x in body_jaxpr.consts]
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
new_z, tokens_out = mlir.jaxpr_subcomp(
ctx.module_context, body_jaxpr.jaxpr, body_name_stack,
tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=name_stack.extend('body_pred'))
body_pred_name_stack = name_stack.extend('body_pred')
cond_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
for x in cond_jaxpr.consts]
((body_pred,),), _ = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
cond_consts, *(x + z), dim_var_values=ctx.dim_var_values)
ctx.module_context, cond_jaxpr.jaxpr, body_pred_name_stack,
mlir.TokenSet(), cond_consts, *(x + z),
dim_var_values=ctx.dim_var_values)
new_z = _map(
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lax/lax.py
Expand Up @@ -3813,11 +3813,11 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
with ir.InsertionPoint(reducer):
reducer_ctx = ctx.module_context.replace(
name_stack=source_info_util.new_name_stack())
name_stack = source_info_util.new_name_stack()
if jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `reduce`.')
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr.jaxpr, mlir.TokenSet(),
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
name_stack, mlir.TokenSet(),
jaxpr.consts,
*([a] for a in reducer.arguments),
dim_var_values=ctx.dim_var_values)
Expand Down
7 changes: 3 additions & 4 deletions jax/_src/lax/slicing.py
Expand Up @@ -2493,13 +2493,12 @@ def _scatter_lower(ctx, operand, indices, updates, *,
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype))
update = op.update_computation.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(update):
update_ctx = ctx.module_context.replace(
name_stack=source_info_util.new_name_stack())
name_stack = source_info_util.new_name_stack()
if update_jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `scatter`.')
out_nodes, _ = mlir.jaxpr_subcomp(
update_ctx, update_jaxpr, mlir.TokenSet(), update_consts,
(update.arguments[0],), (update.arguments[1],),
ctx.module_context, update_jaxpr, name_stack, mlir.TokenSet(),
update_consts, (update.arguments[0],), (update.arguments[1],),
dim_var_values=ctx.dim_var_values)
hlo.return_(util.flatten(out_nodes))
return op.results
Expand Down

0 comments on commit f1ea671

Please sign in to comment.