From 12e57dea3f3fb3a29ab5041901b435c3d82ecfc0 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 2 Jan 2024 13:26:40 -0800 Subject: [PATCH] [shard-map] improve error message when a custom_vjp bwd has extra psum --- jax/experimental/shard_map.py | 36 +++++++++++++++++++++++++---------- tests/shard_map_test.py | 19 ++++++++++++++++++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 0ad5de4a3540..6f4a37844b3b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -716,7 +716,9 @@ def process_map(self, map_primitive, fun, tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. if symbolic_zeros: - msg = "Please open an issue at https://github.com/google/jax/issues !" + msg = ("custom_jvp symbolic_zeros support with shard_map is not " + "implemented; please open an issue at " + "https://github.com/google/jax/issues") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -732,7 +734,9 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. if symbolic_zeros: - msg = "Please open an issue at https://github.com/google/jax/issues !" + msg = ("custom_vjp symbolic_zeros support with shard_map is not " + "implemented; please open an issue at " + "https://github.com/google/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -897,7 +901,8 @@ def _standard_check(prim, mesh, *in_rep, **__): if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: raise Exception(f"Primitive {prim} requires argument replication types " f"to match, but got {in_rep}. Please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/google/jax/issues and as a temporary " + "workaround pass the check_rep=False argument to shard_map") return in_rep_[0] if in_rep_ else None def register_standard_collective(prim): @@ -911,7 +916,8 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): raise Exception(f"Collective {prim} must be applied to a device-varying " f"replication type, but got {x_rep} for collective acting " f"over axis name {axis_name}. Please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/google/jax/issues and as a temporary " + "workaround pass the check_rep=False argument to shard_map") return x_rep def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): @@ -965,7 +971,8 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups): raise Exception("Collective psum must be applied to a device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/google/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r | set(axes) for r in in_rep] register_norewrite(psum2_p) @@ -979,7 +986,8 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): "non-device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/google/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r - set(axes) for r in in_rep] register_norewrite(pbroadcast_p) @@ -1065,7 +1073,9 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): if not carry_rep_in == carry_rep_out: raise Exception("Scan carry input and output got mismatched replication " f"types {carry_rep_in} and {carry_rep_out}. Please open an " - "issue at https://github.com/google/jax/issues") + "issue at https://github.com/google/jax/issues, and as a " + "temporary workaround pass the check_rep=False argument to " + "shard_map") return out_rep @register_rewrite(control_flow.loops.scan_p) @@ -1114,7 +1124,9 @@ def _custom_vjp_call_jaxpr_rewrite( mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, symbolic_zeros): if symbolic_zeros: - msg = "Please open an issue at https://github.com/google/jax/issues !" + msg = ("Please open an issue at https://github.com/google/jax/issues and as" + " a temporary workaround pass the check_rep=False argument to " + "shard_map") raise NotImplementedError(msg) fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep) @@ -1677,7 +1689,9 @@ def post_process_call(self, call_primitive, out_tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: - msg = "Please open an issue at https://github.com/google/jax/issues !" + msg = ("Please open an issue at https://github.com/google/jax/issues and " + "as a temporary workaround pass the check_rep=False argument to " + "shard_map") raise NotImplementedError(msg) in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) @@ -1696,7 +1710,9 @@ def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: - msg = "Please open an issue at https://github.com/google/jax/issues !" + msg = ("Please open an issue at https://github.com/google/jax/issues and " + "as a temporary workaround pass the check_rep=False argument to " + "shard_map") raise NotImplementedError(msg) in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 1cac76ffd20c..00d16306386d 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1333,6 +1333,25 @@ def f(*args): with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"): shard_f(jnp.ones((8, 8)), jnp.ones((8, 8))) + def test_custom_vjp_replication_error_message_hint(self): + mesh = Mesh(np.array(jax.devices()[:4]), ('i',)) + + @jax.custom_vjp + def f(x): + return jax.lax.psum(x, 'i') + def f_fwd(x): + return f(x), None + def f_bwd(_, g): + return jax.lax.psum(g, 'i'), + f.defvjp(f_fwd, f_bwd) + + @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) + def g(x): + return f(f(x)) + + with self.assertRaisesRegex(Exception, r"check_rep=False"): + jax.grad(lambda x: g(x).sum())(jnp.ones(4)) + class FunSpec(NamedTuple): name: str