Skip to content

Commit

Permalink
[shard_map] implement eager custom_jvp / custom_vjp
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 29, 2023
1 parent 2ccdfa6 commit 7589c2b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 58 deletions.
69 changes: 39 additions & 30 deletions jax/experimental/shard_map.py
Expand Up @@ -640,21 +640,28 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
args = map(partial(_unmatch_spec, mesh), in_names, args)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main:
fun, out_rep = _shmap_subtrace(fun, main, in_rep)
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main):
t = main.with_cur_sublevel()
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(t.full_raise, ans)
outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers)
del main, t, in_tracers, ans, out_tracers
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
outs = fun.call_wrapped(*args)
del main
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs]
_check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types
if check_rep:
_check_reps(mesh, out_names_thunk(), out_rep)
return map(partial(_match_spec, mesh, check_rep), out_rep, out_names_thunk(),
outs_)
_check_reps(mesh, out_names_thunk(), out_rep())
return map(partial(_match_spec, mesh, check_rep),
out_rep(), out_names_thunk(), outs)
core.EvalTrace.process_shard_map = _shard_map_impl

@lu.transformation_with_aux
def _shmap_subtrace(main, in_rep, *in_vals):
t = main.with_cur_sublevel()
in_tracers = map(partial(ShardMapTracer, t), in_rep, in_vals)
ans = yield in_tracers, {}
out_tracers = map(t.full_raise, ans)
outs, out_rep = unzip2((t.val, t.rep) for t in out_tracers)
del t, in_tracers, ans, out_tracers
yield outs, out_rep

def _names_to_pspec(names: AxisNames) -> PartitionSpec:
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
Expand Down Expand Up @@ -747,33 +754,35 @@ def process_map(self, map_primitive, fun, tracers, params):
"a feature request at https://github.com/google/jax/issues !")

def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
raise NotImplementedError(
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
fun, out_rep = _shmap_subtrace(fun, self.main, in_rep)
with core.new_sublevel():
out_vals = fun.call_wrapped(*in_vals)
return map(partial(ShardMapTracer, self), out_rep(), out_vals)

def post_process_custom_jvp_call(self, out_tracers, _):
raise NotImplementedError(
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
assert False # unreachable

def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
raise NotImplementedError(
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
raise NotImplementedError(msg)
del prim, fwd, bwd, out_trees, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
fun, out_rep = _shmap_subtrace(fun, self.main, in_rep)
with core.new_sublevel():
out_vals = fun.call_wrapped(*in_vals)
return map(partial(ShardMapTracer, self), out_rep(), out_vals)

def post_process_custom_vjp_call(self, out_tracers, _):
raise NotImplementedError(
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
assert False # unreachable

def process_axis_index(self, frame):
with core.eval_context(), jax.disable_jit(False):
Expand Down
73 changes: 45 additions & 28 deletions tests/shard_map_test.py
Expand Up @@ -667,40 +667,40 @@ def body(c, _):
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x)

def test_eager_notimplemented_error_message_custom_jvp(self):
def test_eager_custom_jvp_basic(self):
@jax.custom_jvp
def foo(x):
return 2. * x

@foo.defjvp
def foo_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return foo(x), 2. * x_dot
return foo(x), 3. * x_dot

mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
with self.assertRaisesRegex(NotImplementedError, 'custom_jvp'):
g(x)
y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.))
self.assertAllClose(y, (2. * jnp.arange(4.)).sum())
self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False)

def test_eager_notimplemented_error_message_custom_vjp(self):
def test_eager_custom_vjp_basic(self):
@jax.custom_vjp
def foo(x):
return 2. * x

def foo_fwd(x):
return x, None
return foo(x), None

def foo_bwd(_, y_bar):
return 2. * y_bar,
return 3. * y_bar,

foo.defvjp(foo_fwd, foo_bwd)

mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
with self.assertRaisesRegex(NotImplementedError, 'custom_vjp'):
g(x)
y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.))
self.assertAllClose(y, (2. * jnp.arange(4.)).sum())
self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False)

@parameterized.parameters([True, False])
def test_axis_index_basic(self, jit):
Expand Down Expand Up @@ -930,7 +930,8 @@ def f(x):
y = f(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)

def test_rewrite_process_custom_jvp_call(self):
@parameterized.parameters([True, False])
def test_rewrite_process_custom_jvp_call(self, jit):
@jax.custom_jvp
def foo(x):
return 2. * x
Expand All @@ -943,16 +944,19 @@ def foo_jvp(primals, tangents):
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
if jit:
g = jax.jit(g)

y = jax.jit(g)(x)
x = jnp.arange(4.)
y = g(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)

y2, y_dot = jax.jvp(jax.jit(g), (x,), (3 * x,))
y2, y_dot = jax.jvp(g, (x,), (3 * x,))
self.assertAllClose(y2, 2 * x * x, check_dtypes=True)
self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True)

def test_rewrite_process_custom_vjp_call(self):
@parameterized.parameters([True, False])
def test_rewrite_process_custom_vjp_call(self, jit):
@jax.custom_vjp
def foo(x):
return 2. * x
Expand All @@ -968,16 +972,19 @@ def foo_bwd(_, y_bar):
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
if jit:
g = jax.jit(g)

x = jnp.arange(4.)
y = jax.jit(g)(x)
y = g(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)

y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)

def test_rewrite_process_custom_vjp_call_match_more_replicated(self):
@parameterized.parameters([True, False])
def test_rewrite_process_custom_vjp_call_match_more_replicated(self, jit):
@jax.custom_vjp
def foo(x):
return 2. * x
Expand All @@ -993,16 +1000,19 @@ def foo_bwd(_, y_bar):
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
if jit:
g = jax.jit(g)

y = jax.jit(g)(x)
x = jnp.arange(4.)
y = g(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)

y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True)

def test_rewrite_process_custom_vjp_call_match_less_replicated(self):
@parameterized.parameters([True, False])
def test_rewrite_process_custom_vjp_call_match_less_replicated(self, jit):
@jax.custom_vjp
def foo(x, y):
del y
Expand All @@ -1019,18 +1029,22 @@ def foo_bwd(y, _):
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x, y: foo(x, y) * y, mesh,
in_specs=(P(), P('x')), out_specs=P('x'))
if jit:
g = jax.jit(g)

x = jnp.arange(4.)
y = jnp.arange(4 * 4.)

z = jax.jit(g)(x, y)
z = g(x, y)
self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False)

z_, x_bar = jax.value_and_grad(lambda x, y: jax.jit(g)(x, y).sum())(x, y)
z_, x_bar = jax.value_and_grad(lambda x, y: g(x, y).sum())(x, y)
self.assertAllClose(z.sum(), z_, check_dtypes=False)
self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0),
check_dtypes=False)

def test_rewrite_custom_vjp_call_jaxpr(self):
@parameterized.parameters([True, False])
def test_rewrite_custom_vjp_call_jaxpr(self, jit):
@jax.custom_vjp
def foo(x):
return 2. * x
Expand All @@ -1050,12 +1064,14 @@ def foo_scan(x):
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo_scan(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
if jit:
g = jax.jit(g)

x = jnp.arange(4.)
y = jax.jit(g)(x)
y = g(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)

y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)

Expand Down Expand Up @@ -1270,6 +1286,7 @@ class FunSpec(NamedTuple):
lambda r1, r2: r1 & r2,
lambda x1, x2: (x1.shape and x2.shape and
x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0])),
FunSpec('relu', 1, lambda x: jax.nn.relu(x + 1) - 1, lambda r: r),
]

input_shapes = [
Expand Down

0 comments on commit 7589c2b

Please sign in to comment.