Skip to content

Commit

Permalink
[shard-map] replace jaxpr interpreters with final-style-xform-of-eval…
Browse files Browse the repository at this point in the history
…-jaxpr
  • Loading branch information
mattjj committed Nov 30, 2023
1 parent e624610 commit b8f758e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 77 deletions.
134 changes: 57 additions & 77 deletions jax/experimental/shard_map.py
Expand Up @@ -27,13 +27,14 @@
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec, Mesh
from jax._src import core
from jax._src import dtypes
from jax._src import ad_util
from jax._src import array
from jax._src import callback
from jax._src import core
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import ops
from jax._src import pjit
Expand All @@ -42,15 +43,15 @@
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src import array
from jax._src.core import Tracer
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
control_flow)
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists, split_list, subs_list2)
merge_lists, split_list, subs_list2,
weakref_lru_cache)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -430,73 +431,6 @@ def _shard_map_staging(
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging


Val = Any

# TODO(mattjj): caching
def _replication_rewrite_match(
mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]],
out_rep_dst: Sequence[set[AxisName]],
) -> core.ClosedJaxpr:
f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep))
f = _match_rep(f, mesh, out_rep_dst)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts)

@lu.transformation
def _match_rep(mesh: Mesh, out_rep_dst: Sequence[set[AxisName]], *args):
out_vals, out_reps = yield args, {}
_check_reps2(mesh, out_rep_dst, out_reps)
out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)]
yield out_vals


def _rep_rewrite(
mesh: Mesh, jaxpr_: core.ClosedJaxpr,
in_rep: Sequence[set[AxisName]], *args: Val,
) -> tuple[tuple[Val], tuple[set[AxisName]]]:
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts

env: dict[core.Var, tuple[Val, set[AxisName]]] = {}

def read(x: core.Atom) -> tuple[Val, set[AxisName]]:
return env[x] if isinstance(x, core.Var) else (x.val, set(mesh.axis_names))

def write(v: core.Var, val: Val, rep: set[AxisName]) -> None:
env[v] = (val, rep)

map(write, jaxpr.constvars, consts, [set(mesh.axis_names)] * len(consts))
map(write, jaxpr.invars, args, in_rep)
for e in jaxpr.eqns:
rule = _rewrite_rules.get(e.primitive, partial(_rule_missing, e.primitive))
in_vals, in_reps = unzip2(map(read, e.invars))
out_vals, out_reps = rule(mesh, in_reps, *in_vals, **e.params)
map(write, e.outvars, out_vals, out_reps)
out_vals, out_reps = unzip2(map(read, jaxpr.outvars))
return out_vals, out_reps

def _rule_missing(prim: core.Primitive, *_, **__):
raise NotImplementedError(
f"No replication rule for {prim}. As a workaround, pass the "
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
"issue at https://github.com/google/jax/issues")

def _replication_rewrite_nomatch(
mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]],
) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]:
f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep))
f, out_rep = _grab_out_rep(f)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts), list(out_rep())

@lu.transformation_with_aux
def _grab_out_rep(*args):
yield (yield args, {})


def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
Expand Down Expand Up @@ -574,6 +508,12 @@ def write(v: core.Var, val: RepType) -> None:
def _valid_repeats(mesh: Mesh, rep: RepType, dst: AxisNames) -> bool:
return rep is None or set(_unmentioned(mesh, dst)).issubset(rep)

def _rule_missing(prim: core.Primitive, *_, **__):
raise NotImplementedError(
f"No replication rule for {prim}. As a workaround, pass the "
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
"issue at https://github.com/google/jax/issues")

# Lowering

def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
Expand Down Expand Up @@ -1638,6 +1578,8 @@ def _get_devices(p, backend):

### Rewrite!

Val = Any

class RewriteTracer(core.Tracer):
rep: set[AxisName]
val: Val
Expand Down Expand Up @@ -1736,9 +1678,14 @@ def post_process_custom_vjp_call(self, out_tracers, _):

# TODO process_axis_index

@lu.transformation
def _efficient_transpose_rewrite(mesh, in_names, out_names_thunk, *args):
def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
in_reps = map(partial(_in_names_to_rep, mesh), in_names)
out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()]
fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps)
return _match_rep(fun, mesh, out_reps_src, out_reps_dst)

@lu.transformation_with_aux
def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args):
lvl = core.dynamic_level()
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main:
t = main.with_cur_sublevel()
Expand All @@ -1747,10 +1694,43 @@ def _efficient_transpose_rewrite(mesh, in_names, out_names_thunk, *args):
out_tracers = map(t.full_raise, ans)
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers)
del main, t, in_tracers, out_tracers, ans
out_rep_dst = [frozenset(_unmentioned(mesh, n)) for n in out_names_thunk()]
out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)]
yield out_vals
yield out_vals, out_reps

@lu.transformation
def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args):
outs = yield args, {}
out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_
out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_
_check_reps2(mesh, out_reps_dst, out_reps_src)
outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)]
yield outs

# TODO(mattjj): caching
def _replication_rewrite_match(
mesh: Mesh,
jaxpr: core.ClosedJaxpr,
in_rep: Sequence[set[AxisName]],
out_rep_dst: Sequence[set[AxisName]],
) -> core.ClosedJaxpr:
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
f = _match_rep(f, mesh, out_rep, out_rep_dst)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts)

# TODO(mattjj): caching
def _replication_rewrite_nomatch(
mesh: Mesh,
jaxpr: core.ClosedJaxpr,
in_rep: Sequence[set[AxisName]],
) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]:
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts), out_rep()

@lu.transformation_with_aux
def _rewrite_subtrace(main, in_reps, *in_vals):
Expand Down
6 changes: 6 additions & 0 deletions tests/shard_map_test.py
Expand Up @@ -1280,6 +1280,12 @@ def f(x, y):
y = f(a, b) # don't crash
self.assertAllClose(y, a @ b, check_dtypes=False, atol=1e-2, rtol=1e-2)

def test_custom_jvp_inside_jit(self):
mesh = jtu.create_global_mesh((4,), ('batch',))
x = shard_map(jax.jit(jax.nn.relu),
mesh=mesh, in_specs=P('batch'),
out_specs=P('batch'))(jnp.arange(16.)) # don't crash


class FunSpec(NamedTuple):
name: str
Expand Down

0 comments on commit b8f758e

Please sign in to comment.