diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8dfbc086def3..d552c528a872 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -640,21 +640,28 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, args = map(partial(_unmatch_spec, mesh), in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main: + fun, out_rep = _shmap_subtrace(fun, main, in_rep) with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main): - t = main.with_cur_sublevel() - in_tracers = map(partial(ShardMapTracer, t), in_rep, args) - ans = fun.call_wrapped(*in_tracers) - out_tracers = map(t.full_raise, ans) - outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers) - del main, t, in_tracers, ans, out_tracers - out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_] + outs = fun.call_wrapped(*args) + del main + out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_rep: - _check_reps(mesh, out_names_thunk(), out_rep) - return map(partial(_match_spec, mesh, check_rep), out_rep, out_names_thunk(), - outs_) + _check_reps(mesh, out_names_thunk(), out_rep()) + return map(partial(_match_spec, mesh, check_rep), + out_rep(), out_names_thunk(), outs) core.EvalTrace.process_shard_map = _shard_map_impl +@lu.transformation_with_aux +def _shmap_subtrace(main, in_rep, *in_vals): + t = main.with_cur_sublevel() + in_tracers = map(partial(ShardMapTracer, t), in_rep, in_vals) + ans = yield in_tracers, {} + out_tracers = map(t.full_raise, ans) + outs, out_rep = unzip2((t.val, t.rep) for t in out_tracers) + del t, in_tracers, ans, out_tracers + yield outs, out_rep + def _names_to_pspec(names: AxisNames) -> PartitionSpec: ndmin = max(names) + 1 if names else 0 return PartitionSpec(*(names.get(i) for i in range(ndmin))) @@ -747,33 +754,35 @@ def process_map(self, map_primitive, fun, tracers, params): "a feature request at https://github.com/google/jax/issues !") def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - raise NotImplementedError( - "Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet " - "supported. " - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + # Since ShardMapTrace is only used as a base main, we can drop the jvp. + if symbolic_zeros: + msg = "Please open an issue at https://github.com/google/jax/issues !" + raise NotImplementedError(msg) + del prim, jvp, symbolic_zeros + in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) + fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) + with core.new_sublevel(): + out_vals = fun.call_wrapped(*in_vals) + return map(partial(ShardMapTracer, self), out_rep(), out_vals) def post_process_custom_jvp_call(self, out_tracers, _): - raise NotImplementedError( - "Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet " - "supported. " - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + assert False # unreachable def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): - raise NotImplementedError( - "Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet " - "supported. " - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + # Since ShardMapTrace is only used as a base main, we can drop the jvp. + if symbolic_zeros: + msg = "Please open an issue at https://github.com/google/jax/issues !" + raise NotImplementedError(msg) + del prim, fwd, bwd, out_trees, symbolic_zeros + in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) + fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) + with core.new_sublevel(): + out_vals = fun.call_wrapped(*in_vals) + return map(partial(ShardMapTracer, self), out_rep(), out_vals) def post_process_custom_vjp_call(self, out_tracers, _): - raise NotImplementedError( - "Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet " - "supported. " - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + assert False # unreachable def process_axis_index(self, frame): with core.eval_context(), jax.disable_jit(False): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 80ffa1f9a94a..9e9f10469422 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -667,7 +667,7 @@ def body(c, _): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) - def test_eager_notimplemented_error_message_custom_jvp(self): + def test_eager_custom_jvp_basic(self): @jax.custom_jvp def foo(x): return 2. * x @@ -675,32 +675,32 @@ def foo(x): @foo.defjvp def foo_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents - return foo(x), 2. * x_dot + return foo(x), 3. * x_dot mesh = jtu.create_global_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) - x = jnp.arange(4.) - with self.assertRaisesRegex(NotImplementedError, 'custom_jvp'): - g(x) + y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) + self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) + self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) - def test_eager_notimplemented_error_message_custom_vjp(self): + def test_eager_custom_vjp_basic(self): @jax.custom_vjp def foo(x): return 2. * x def foo_fwd(x): - return x, None + return foo(x), None def foo_bwd(_, y_bar): - return 2. * y_bar, + return 3. * y_bar, foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_global_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) - x = jnp.arange(4.) - with self.assertRaisesRegex(NotImplementedError, 'custom_vjp'): - g(x) + y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) + self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) + self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @parameterized.parameters([True, False]) def test_axis_index_basic(self, jit): @@ -930,7 +930,8 @@ def f(x): y = f(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) - def test_rewrite_process_custom_jvp_call(self): + @parameterized.parameters([True, False]) + def test_rewrite_process_custom_jvp_call(self, jit): @jax.custom_jvp def foo(x): return 2. * x @@ -943,16 +944,19 @@ def foo_jvp(primals, tangents): mesh = jtu.create_global_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) - x = jnp.arange(4.) + if jit: + g = jax.jit(g) - y = jax.jit(g)(x) + x = jnp.arange(4.) + y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) - y2, y_dot = jax.jvp(jax.jit(g), (x,), (3 * x,)) + y2, y_dot = jax.jvp(g, (x,), (3 * x,)) self.assertAllClose(y2, 2 * x * x, check_dtypes=True) self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True) - def test_rewrite_process_custom_vjp_call(self): + @parameterized.parameters([True, False]) + def test_rewrite_process_custom_vjp_call(self, jit): @jax.custom_vjp def foo(x): return 2. * x @@ -968,16 +972,19 @@ def foo_bwd(_, y_bar): mesh = jtu.create_global_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) + if jit: + g = jax.jit(g) x = jnp.arange(4.) - y = jax.jit(g)(x) + y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) - y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x) + y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x) self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) - def test_rewrite_process_custom_vjp_call_match_more_replicated(self): + @parameterized.parameters([True, False]) + def test_rewrite_process_custom_vjp_call_match_more_replicated(self, jit): @jax.custom_vjp def foo(x): return 2. * x @@ -993,16 +1000,19 @@ def foo_bwd(_, y_bar): mesh = jtu.create_global_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) - x = jnp.arange(4.) + if jit: + g = jax.jit(g) - y = jax.jit(g)(x) + x = jnp.arange(4.) + y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) - y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x) + y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x) self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True) - def test_rewrite_process_custom_vjp_call_match_less_replicated(self): + @parameterized.parameters([True, False]) + def test_rewrite_process_custom_vjp_call_match_less_replicated(self, jit): @jax.custom_vjp def foo(x, y): del y @@ -1019,18 +1029,22 @@ def foo_bwd(y, _): mesh = jtu.create_global_mesh((4,), ('x',)) g = shard_map(lambda x, y: foo(x, y) * y, mesh, in_specs=(P(), P('x')), out_specs=P('x')) + if jit: + g = jax.jit(g) + x = jnp.arange(4.) y = jnp.arange(4 * 4.) - z = jax.jit(g)(x, y) + z = g(x, y) self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False) - z_, x_bar = jax.value_and_grad(lambda x, y: jax.jit(g)(x, y).sum())(x, y) + z_, x_bar = jax.value_and_grad(lambda x, y: g(x, y).sum())(x, y) self.assertAllClose(z.sum(), z_, check_dtypes=False) self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0), check_dtypes=False) - def test_rewrite_custom_vjp_call_jaxpr(self): + @parameterized.parameters([True, False]) + def test_rewrite_custom_vjp_call_jaxpr(self, jit): @jax.custom_vjp def foo(x): return 2. * x @@ -1050,12 +1064,14 @@ def foo_scan(x): mesh = jtu.create_global_mesh((4,), ('x',)) g = shard_map(lambda x: foo_scan(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) + if jit: + g = jax.jit(g) x = jnp.arange(4.) - y = jax.jit(g)(x) + y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) - y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x) + y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x) self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) @@ -1270,6 +1286,7 @@ class FunSpec(NamedTuple): lambda r1, r2: r1 & r2, lambda x1, x2: (x1.shape and x2.shape and x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0])), + FunSpec('relu', 1, lambda x: jax.nn.relu(x + 1) - 1, lambda r: r), ] input_shapes = [