Skip to content

Commit

Permalink
Cache the creation of ClosedJaxpr in pjit_transpose which if not cach…
Browse files Browse the repository at this point in the history
…ed breaks the compilation cache.

PiperOrigin-RevId: 504304311
  • Loading branch information
yashk2810 authored and jax authors committed Jan 24, 2023
1 parent bbccf55 commit b621373
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,14 @@ def _pjit_partial_eval_custom_params_updater(
_pjit_partial_eval_custom_params_updater)


@lu.cache
def _pjit_transpose_trace(fun, in_avals, api_name):
transpose_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
fun, in_avals, debug_info=pe.debug_info_final(fun, api_name))
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
return transpose_jaxpr


def _pjit_transpose(reduce_axes, cts_in, *primals_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
Expand All @@ -1669,19 +1677,15 @@ def prune_type(ty, xs, maybe_zeros):
*prune_type(ad.UndefinedPrimal, in_positional_semantics, primals_in),
*prune_type(ad.Zero, (out_positional_semantics,) * len(cts_in), cts_in)
)
global_cts_in_avals = [core.raise_to_shaped(core.get_aval(ct))
for ct in primals_and_nz_cts_in]
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
for ct in primals_and_nz_cts_in)
if not config.jax_array:
global_cts_in_avals = local_to_global(
global_cts_in_avals = tuple(local_to_global(
transpose_in_positional_semantics, global_cts_in_avals,
transpose_in_shardings, resource_env.physical_mesh)
transpose_in_shardings, resource_env.physical_mesh))

api_name = 'jit' if resource_env is None else 'pjit'
transpose_jaxpr, global_cts_out_avals, consts = pe.trace_to_jaxpr_dynamic(
body, global_cts_in_avals, debug_info=pe.debug_info_final(body, api_name))
# TODO(apaszke): Creating ClosedJaxpr by hand will break compilation cache!
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
del consts
transpose_jaxpr = _pjit_transpose_trace(body, global_cts_in_avals, api_name)
cts_out_treedef = cts_out_treedef_thunk()
transpose_out_shardings = prune_type(
ad.Zero,
Expand Down

0 comments on commit b621373

Please sign in to comment.