From f1ea67117e8f102412657c513be13047d70491e9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 Feb 2024 07:16:38 -0800 Subject: [PATCH] Split name_stack out of mlir.ModuleContext. A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext. PiperOrigin-RevId: 608594374 --- jax/_src/custom_derivatives.py | 4 +- jax/_src/interpreters/mlir.py | 87 ++++++++++++----------- jax/_src/interpreters/pxla.py | 12 ++-- jax/_src/lax/control_flow/conditionals.py | 8 +-- jax/_src/lax/control_flow/loops.py | 21 +++--- jax/_src/lax/lax.py | 6 +- jax/_src/lax/slicing.py | 7 +- jax/_src/lax/windowed_reductions.py | 4 +- jax/_src/maps.py | 20 +++--- jax/_src/pjit.py | 6 +- jax/experimental/export/_export.py | 4 +- jax/experimental/shard_map.py | 6 +- jax/interpreters/mlir.py | 2 +- 13 files changed, 95 insertions(+), 92 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 8f60191941eb..2752604444dc 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -437,8 +437,8 @@ def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk, args_ = map(mlir.wrap_singleton_ir_values, args) consts = mlir._ir_consts(call_jaxpr.consts) out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr, - ctx.tokens_in, consts, *args_, - dim_var_values=ctx.dim_var_values) + ctx.name_stack, ctx.tokens_in, consts, + *args_, dim_var_values=ctx.dim_var_values) ctx.set_tokens_out(tokens) return out mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 92e421458b9d..af2629abaaa6 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -340,24 +340,23 @@ def _token_constant_handler(val): # Source locations def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str: - if file_name in caches.canonical_name_cache: - return caches.canonical_name_cache[file_name] + canonical_file_name = caches.canonical_name_cache.get(file_name, None) + if canonical_file_name is not None: + return canonical_file_name - source_file = file_name pattern = config.hlo_source_file_canonicalization_regex.value if pattern: - source_file = re.sub(pattern, '', source_file) - - caches.canonical_name_cache[file_name] = source_file - return source_file + file_name = re.sub(pattern, '', file_name) + caches.canonical_name_cache[file_name] = file_name + return file_name def _is_user_file(ctx: ModuleContext, file_name: str) -> bool: - if file_name in ctx.traceback_caches.is_user_file_cache: - return ctx.traceback_caches.is_user_file_cache[file_name] - - result = source_info_util.is_user_filename(file_name) - ctx.traceback_caches.is_user_file_cache[file_name] = result - return result + is_user = ctx.traceback_caches.is_user_file_cache.get(file_name, None) + if is_user is not None: + return is_user + out = source_info_util.is_user_filename(file_name) + ctx.traceback_caches.is_user_file_cache[file_name] = out + return out def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: """Converts a full traceback to a callsite() MLIR location.""" @@ -386,12 +385,12 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: if len(frame_locs) >= frames_limit: break - if len(frame_locs) == 0: + n = len(frame_locs) + if n == 0: return ir.Location.unknown() + elif n == 1: + return frame_locs[0] else: - if len(frame_locs) == 1: - return frame_locs[0] - return ir.Location.callsite(frame_locs[0], frame_locs[1:]) def _source_info_to_location( @@ -589,7 +588,6 @@ class ModuleContext: backend_or_name: str | xb.XlaBackend | None platforms: Sequence[str] axis_context: AxisContext - name_stack: source_info_util.NameStack keepalives: list[Any] channel_iterator: Iterator[int] host_callbacks: list[Any] @@ -614,7 +612,6 @@ def __init__( backend_or_name: str | xb.XlaBackend | None, platforms: Sequence[str], axis_context: AxisContext, - name_stack: source_info_util.NameStack, keepalives: list[Any], channel_iterator: Iterator[int], host_callbacks: list[Any], @@ -635,7 +632,6 @@ def __init__( self.backend_or_name = backend_or_name self.platforms = platforms self.axis_context = axis_context - self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None else cached_primitive_lowerings) self.traceback_caches = (TracebackCaches() if traceback_caches is None @@ -683,6 +679,7 @@ def replace(self, **kw): return dataclasses.replace(self, **kw) class LoweringRuleContext: """Per-rule context information for MLIR lowering.""" module_context: ModuleContext + name_stack: source_info_util.NameStack primitive: core.Primitive | None avals_in: Sequence[core.AbstractValue] avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None. @@ -947,7 +944,6 @@ def lower_jaxpr_to_module( ctx = ModuleContext(backend_or_name=backend_or_name, platforms=platforms, axis_context=axis_context, - name_stack=name_stack, keepalives=keepalives, channel_iterator=channel_iter, host_callbacks=host_callbacks, @@ -964,7 +960,9 @@ def lower_jaxpr_to_module( attrs["mhlo.num_partitions"] = i32_attr(num_partitions) replace_tokens_with_dummy = lowering_parameters.replace_tokens_with_dummy lower_jaxpr_to_fun( - ctx, "main", jaxpr, ordered_effects, public=True, + ctx, "main", jaxpr, ordered_effects, + name_stack=name_stack, + public=True, create_tokens=replace_tokens_with_dummy, replace_tokens_with_dummy=replace_tokens_with_dummy, num_output_tokens=0, @@ -1105,6 +1103,7 @@ def lower_jaxpr_to_fun( name: str, jaxpr: core.ClosedJaxpr, effects: Sequence[core.Effect], + name_stack: source_info_util.NameStack, *, create_tokens: bool = False, public: bool = False, @@ -1376,7 +1375,7 @@ def aval_to_types(aval): dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens]) # A lowering context just for function body entry/exit code. entry_lowering_ctx = LoweringRuleContext( - module_context=ctx, primitive=None, + module_context=ctx, name_stack=name_stack, primitive=None, avals_in=[], avals_out=None, tokens_in=TokenSet.create([]), tokens_out=None, axis_size_env=None, dim_var_values=dim_var_values) @@ -1403,10 +1402,10 @@ def aval_to_types(aval): args.append([hlo.create_token()]) else: args.append(arg) - callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name)) + callee_name_stack = name_stack.extend(util.wrap_name(name, api_name)) consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts] out_vals, tokens_out = jaxpr_subcomp( - ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, tokens_in, + ctx, jaxpr.jaxpr, callee_name_stack, tokens_in, consts, *args, dim_var_values=dim_var_values) outs = [] if create_tokens: @@ -1496,6 +1495,7 @@ def _emit_lowering_rule_as_fun(lowering_rule, return func_op def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, + name_stack: source_info_util.NameStack, tokens: TokenSet, consts: Sequence[Sequence[ir.Value]], *args: Sequence[ir.Value], @@ -1536,6 +1536,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None env: dict[core.Var, tuple[ir.Value, ...]] = {} + assert isinstance(name_stack, source_info_util.NameStack), type(name_stack) assert len(args) == len(jaxpr.invars), (jaxpr, args) assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts @@ -1545,9 +1546,8 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None last_used = core.last_used(jaxpr) for eqn in jaxpr.eqns: in_nodes = map(read, eqn.invars) - assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) source_info = eqn.source_info.replace( - name_stack=ctx.name_stack + eqn.source_info.name_stack) + name_stack=name_stack + eqn.source_info.name_stack) loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) @@ -1569,12 +1569,12 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None elif eqn.primitive in xla._translations: default_rule = xla_fallback_lowering(eqn.primitive) - eqn_ctx = ctx.replace(name_stack=source_info.name_stack) effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) tokens_in = tokens.subset(effects) avals_in = map(aval, eqn.invars) rule_ctx = LoweringRuleContext( - module_context=eqn_ctx, primitive=eqn.primitive, + module_context=ctx, primitive=eqn.primitive, + name_stack=source_info.name_stack, avals_in=avals_in, avals_out=map(aval, eqn.outvars), tokens_in=tokens_in, tokens_out=None, dim_var_values=dim_var_values) @@ -1781,15 +1781,16 @@ def f_lowered(ctx, *args, **params): # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? out, tokens = jaxpr_subcomp( - ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts), - *map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values) + ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in, + _ir_consts(consts), *map(wrap_singleton_ir_values, args), + dim_var_values=ctx.dim_var_values) ctx.set_tokens_out(tokens) return out return f_lowered -def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, +def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=None, result_names=None): if not call_jaxpr.consts and arg_names is result_names is None: # Cacheable. @@ -1798,12 +1799,12 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, func_op = ctx.cached_primitive_lowerings[key] except KeyError: func_op = lower_jaxpr_to_fun( - ctx, fn_name, call_jaxpr, effects, arg_names=arg_names, + ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, result_names=result_names) ctx.cached_primitive_lowerings[key] = func_op else: func_op = lower_jaxpr_to_fun( - ctx, fn_name, call_jaxpr, effects, arg_names=arg_names, + ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, result_names=result_names) return func_op @@ -1825,12 +1826,12 @@ def check_backend_matches(inner_backend: str | None, f"inner-jit backend specification {inner_backend}.") -def _call_lowering(fn_name, stack_name, call_jaxpr, backend, - ctx: ModuleContext, avals_in, - avals_out, tokens_in, *args, - dim_var_values: Sequence[ir.Value], - arg_names=None, result_names=None): - del stack_name, avals_in +def call_lowering(fn_name, name_stack, call_jaxpr, backend, + ctx: ModuleContext, avals_in, + avals_out, tokens_in, *args, + dim_var_values: Sequence[ir.Value], + arg_names=None, result_names=None): + del avals_in if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = pe.close_jaxpr(call_jaxpr) check_backend_matches(backend, ctx.platforms) @@ -1839,7 +1840,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, output_types = [token_type()] * len(effects) + output_types flat_output_types = util.flatten(output_types) symbol_name = _lower_jaxpr_to_fun_cached( - ctx, fn_name, call_jaxpr, effects, arg_names=arg_names, + ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, result_names=result_names).name.value tokens = [tokens_in.get(eff) for eff in effects] args = (*dim_var_values, *tokens, *args) @@ -1853,8 +1854,8 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, def core_call_lowering(ctx: LoweringRuleContext, *args, name, backend=None, call_jaxpr): - out_nodes, tokens = _call_lowering( - name, name, call_jaxpr, backend, ctx.module_context, + out_nodes, tokens = call_lowering( + name, ctx.name_stack, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args, dim_var_values=ctx.dim_var_values) ctx.set_tokens_out(tokens) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 06a166fb1827..fff2276f55cc 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1433,12 +1433,12 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore sub_ctx = ctx.module_context.replace( - axis_context=sharding_impls.ReplicaAxisContext(new_env), - name_stack=ctx.module_context.name_stack.extend( - util.wrap_name(name, 'pmap'))) - sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (), - *in_nodes_sharded, - dim_var_values=ctx.dim_var_values) + axis_context=sharding_impls.ReplicaAxisContext(new_env)) + sharded_outs, _ = mlir.jaxpr_subcomp( + sub_ctx, call_jaxpr, + ctx.name_stack.extend(util.wrap_name(name, 'pmap')), + mlir.TokenSet(), (), *in_nodes_sharded, + dim_var_values=ctx.dim_var_values) out_avals = [v.aval for v in call_jaxpr.outvars] outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard) for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)] diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 651d7d5622e1..2ea62a0d1df1 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -847,16 +847,14 @@ def _cond_lowering(ctx, index, *args, branches, linear): # captures. case_op = hlo.CaseOp(flat_output_types, index=index, num_branches=len(branches)) - name_stack = ctx.module_context.name_stack.extend('cond') + name_stack = ctx.name_stack.extend('cond') for i, jaxpr in enumerate(branches): branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): - sub_ctx = ctx.module_context.replace( - name_stack=name_stack.extend(f'branch_{i}_fun')) consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts] out_vals, tokens_out = mlir.jaxpr_subcomp( - sub_ctx, jaxpr.jaxpr, tokens_in, - consts, *map(mlir.wrap_singleton_ir_values, args), + ctx.module_context, jaxpr.jaxpr, name_stack.extend(f'branch_{i}_fun'), + tokens_in, consts, *map(mlir.wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values) out_tokens = [tokens_out.get(eff) for eff in ordered_effects] out_vals = [*out_tokens, *out_vals] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index ed1bae549bce..18ad27be80e3 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1662,7 +1662,7 @@ def fun(*args): # Loop condition cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types) - name_stack = ctx.module_context.name_stack.extend('while') + name_stack = ctx.name_stack.extend('while') with ir.InsertionPoint(cond_block): flat_cond_args = [ cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) @@ -1671,13 +1671,14 @@ def fun(*args): # Remove tokens from cond args cond_args = cond_args[num_tokens:] x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) - cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond')) cond_consts = [ mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts ] + cond_name_stack = name_stack.extend('cond') ((pred,),), _ = mlir.jaxpr_subcomp( - cond_ctx, + ctx.module_context, cond_jaxpr.jaxpr, + cond_name_stack, mlir.TokenSet(), cond_consts, *(x + z), @@ -1686,6 +1687,7 @@ def fun(*args): if batched: pred_ctx = mlir.LoweringRuleContext( module_context=ctx.module_context, + name_stack=cond_name_stack, primitive=None, avals_in=[pred_aval], avals_out=[pred_aval.update(shape=())], @@ -1710,20 +1712,21 @@ def fun(*args): token_args, body_args = util.split_list(body_args, [num_tokens]) tokens_in = mlir.TokenSet(zip(body_effects, token_args)) x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts]) - body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body')) + body_name_stack = name_stack.extend('body') body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in body_jaxpr.consts] - new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, + new_z, tokens_out = mlir.jaxpr_subcomp( + ctx.module_context, body_jaxpr.jaxpr, body_name_stack, tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values) out_tokens = [tokens_out.get(eff) for eff in body_effects] if batched: - body_pred_ctx = ctx.module_context.replace( - name_stack=name_stack.extend('body_pred')) + body_pred_name_stack = name_stack.extend('body_pred') cond_consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts] ((body_pred,),), _ = mlir.jaxpr_subcomp( - body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), - cond_consts, *(x + z), dim_var_values=ctx.dim_var_values) + ctx.module_context, cond_jaxpr.jaxpr, body_pred_name_stack, + mlir.TokenSet(), cond_consts, *(x + z), + dim_var_values=ctx.dim_var_values) new_z = _map( partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 057bc1a09bc1..6ab420f78922 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3813,11 +3813,11 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] reducer = op.regions[0].blocks.append(*(ir_types + ir_types)) with ir.InsertionPoint(reducer): - reducer_ctx = ctx.module_context.replace( - name_stack=source_info_util.new_name_stack()) + name_stack = source_info_util.new_name_stack() if jaxpr.effects: raise NotImplementedError('Cannot lower effectful `reduce`.') - out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr.jaxpr, mlir.TokenSet(), + out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr, + name_stack, mlir.TokenSet(), jaxpr.consts, *([a] for a in reducer.arguments), dim_var_values=ctx.dim_var_values) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index f775c2705652..48689bbb0795 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -2493,13 +2493,12 @@ def _scatter_lower(ctx, operand, indices, updates, *, scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype)) update = op.update_computation.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(update): - update_ctx = ctx.module_context.replace( - name_stack=source_info_util.new_name_stack()) + name_stack = source_info_util.new_name_stack() if update_jaxpr.effects: raise NotImplementedError('Cannot lower effectful `scatter`.') out_nodes, _ = mlir.jaxpr_subcomp( - update_ctx, update_jaxpr, mlir.TokenSet(), update_consts, - (update.arguments[0],), (update.arguments[1],), + ctx.module_context, update_jaxpr, name_stack, mlir.TokenSet(), + update_consts, (update.arguments[0],), (update.arguments[1],), dim_var_values=ctx.dim_var_values) hlo.return_(util.flatten(out_nodes)) return op.results diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 50d0ce86a441..af85a04d7bb2 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -320,7 +320,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: if jaxpr.effects: raise NotImplementedError('Cannot lower effectful `reduce_window`.') - out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, + out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, ctx.name_stack, mlir.TokenSet(), consts, *([a] for a in reducer.arguments), dim_var_values=ctx.dim_var_values) return util.flatten(out_nodes) @@ -529,6 +529,7 @@ def _select_and_scatter_lower( if select_jaxpr.effects: raise NotImplementedError('Cannot lower effectful `select`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr, + ctx.name_stack, mlir.TokenSet(), select_consts, *([a] for a in select.arguments), dim_var_values=ctx.dim_var_values) @@ -538,6 +539,7 @@ def _select_and_scatter_lower( if scatter_jaxpr.effects: raise NotImplementedError('Cannot lower effectful `scatter`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr, + ctx.name_stack, mlir.TokenSet(), scatter_consts, *([a] for a in scatter.arguments), dim_var_values=ctx.dim_var_values) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 5a42cb2259fb..e0a6f1a27631 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1360,12 +1360,12 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, # them! # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. - sub_ctx = ctx.module_context.replace( - name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap'))) + name_stack = ctx.name_stack.extend(wrap_name(name, 'xmap')) if any(effects.ordered_effects.contains(eff) for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') - tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(), + tiled_outs, _ = mlir.jaxpr_subcomp(ctx.module_context, vectorized_jaxpr, + name_stack, mlir.TokenSet(), const_nodes, *tiled_ins, dim_var_values=ctx.dim_var_values) @@ -1429,14 +1429,13 @@ def add_spmd_axes( # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. - sub_ctx = ctx.module_context.replace( - name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap'))) + name_stack = ctx.name_stack.extend(wrap_name(name, 'xmap')) if any(effects.ordered_effects.contains(eff) for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') - global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, - mlir.TokenSet(), const_nodes, *sharded_global_in_nodes, - dim_var_values=ctx.dim_var_values) + global_out_nodes, _ = mlir.jaxpr_subcomp( + ctx.module_context, vectorized_jaxpr, name_stack, mlir.TokenSet(), + const_nodes, *sharded_global_in_nodes, dim_var_values=ctx.dim_var_values) sharded_global_out_nodes = [ mlir.wrap_with_sharding_op( @@ -1484,13 +1483,14 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, # translation rule just because the extra tuple stuff is a pain. assert isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext) + name_stack = ctx.name_stack.extend(wrap_name(name, 'xmap')) sub_ctx = ctx.module_context.replace( - name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')), axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes)) if any(effects.ordered_effects.contains(eff) for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') - global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, + global_out_nodes, _ = mlir.jaxpr_subcomp( + sub_ctx, vectorized_jaxpr, name_stack, mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes), dim_var_values=ctx.dim_var_values) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0f45f5dd665d..01c6c2b4394a 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1602,9 +1602,9 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. func = mlir.lower_jaxpr_to_fun( - mod_ctx, name, jaxpr, effects, arg_shardings=arg_shardings, - result_shardings=result_shardings, use_sharding_annotations=False, - api_name=api_name) + mod_ctx, name, jaxpr, effects, ctx.name_stack, + arg_shardings=arg_shardings, result_shardings=result_shardings, + use_sharding_annotations=False, api_name=api_name) mod_ctx.cached_primitive_lowerings[key] = func return func diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index cb27ac367594..18f7436c4bfd 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -676,14 +676,14 @@ def is_token(typ, attrs): module_context = mlir.ModuleContext( backend_or_name="cpu", platforms=["cpu"], axis_context=sharding_impls.ShardingContext(0), - name_stack=source_info_util.new_name_stack(), keepalives=[], channel_iterator=itertools.count(1), host_callbacks=[], module=wrapped_module, context=context, lowering_parameters=mlir.LoweringParameters( global_constant_computation=True )) ctx = mlir.LoweringRuleContext( - module_context=module_context, primitive=None, + module_context=module_context, + name_stack=source_info_util.new_name_stack(), primitive=None, avals_in=args_avals_flat, avals_out=None, tokens_in=mlir.TokenSet(), tokens_out=None) # We compute dim_values from the array arguments. diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index db143a4d23f7..6241900fa492 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -548,9 +548,9 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, ) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) with core.extend_axis_env_nd(tuple(mesh.shape.items())): - out_nodes_, tokens_out = mlir._call_lowering( - "shmap_body", (), jaxpr, None, sub_ctx, in_avals_, out_avals_, - ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, + out_nodes_, tokens_out = mlir.call_lowering( + "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, + out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) ctx.set_tokens_out(tokens_out) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 00cf8a905466..7a72a478b807 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -27,7 +27,7 @@ Token as Token, TokenSet as TokenSet, Value as Value, - _call_lowering as _call_lowering, + call_lowering as _call_lowering, _lowerings as _lowerings, _platform_specific_lowerings as _platform_specific_lowerings, aval_to_ir_type as aval_to_ir_type,