Skip to content

Commit

Permalink
[shard-map] improve error message when a custom_vjp bwd has extra psum
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jan 2, 2024
1 parent e6c8901 commit 12e57de
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
36 changes: 26 additions & 10 deletions jax/experimental/shard_map.py
Expand Up @@ -716,7 +716,9 @@ def process_map(self, map_primitive, fun, tracers, params):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("custom_jvp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
"https://github.com/google/jax/issues")
raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
Expand All @@ -732,7 +734,9 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
"https://github.com/google/jax/issues")
raise NotImplementedError(msg)
del prim, fwd, bwd, out_trees, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
Expand Down Expand Up @@ -897,7 +901,8 @@ def _standard_check(prim, mesh, *in_rep, **__):
if in_rep_ and not in_rep_[:-1] == in_rep_[1:]:
raise Exception(f"Primitive {prim} requires argument replication types "
f"to match, but got {in_rep}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
return in_rep_[0] if in_rep_ else None

def register_standard_collective(prim):
Expand All @@ -911,7 +916,8 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params):
raise Exception(f"Collective {prim} must be applied to a device-varying "
f"replication type, but got {x_rep} for collective acting "
f"over axis name {axis_name}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
return x_rep

def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params):
Expand Down Expand Up @@ -965,7 +971,8 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
raise Exception("Collective psum must be applied to a device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues, and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
return [r | set(axes) for r in in_rep]
register_norewrite(psum2_p)
Expand All @@ -979,7 +986,8 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups):
"non-device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues, and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
return [r - set(axes) for r in in_rep]
register_norewrite(pbroadcast_p)
Expand Down Expand Up @@ -1065,7 +1073,9 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
if not carry_rep_in == carry_rep_out:
raise Exception("Scan carry input and output got mismatched replication "
f"types {carry_rep_in} and {carry_rep_out}. Please open an "
"issue at https://github.com/google/jax/issues")
"issue at https://github.com/google/jax/issues, and as a "
"temporary workaround pass the check_rep=False argument to "
"shard_map")
return out_rep

@register_rewrite(control_flow.loops.scan_p)
Expand Down Expand Up @@ -1114,7 +1124,9 @@ def _custom_vjp_call_jaxpr_rewrite(
mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("Please open an issue at https://github.com/google/jax/issues and as"
" a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)

fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep)
Expand Down Expand Up @@ -1677,7 +1689,9 @@ def post_process_call(self, call_primitive, out_tracers, params):

def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("Please open an issue at https://github.com/google/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
Expand All @@ -1696,7 +1710,9 @@ def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("Please open an issue at https://github.com/google/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
Expand Down
19 changes: 19 additions & 0 deletions tests/shard_map_test.py
Expand Up @@ -1333,6 +1333,25 @@ def f(*args):
with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"):
shard_f(jnp.ones((8, 8)), jnp.ones((8, 8)))

def test_custom_vjp_replication_error_message_hint(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('i',))

@jax.custom_vjp
def f(x):
return jax.lax.psum(x, 'i')
def f_fwd(x):
return f(x), None
def f_bwd(_, g):
return jax.lax.psum(g, 'i'),
f.defvjp(f_fwd, f_bwd)

@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
def g(x):
return f(f(x))

with self.assertRaisesRegex(Exception, r"check_rep=False"):
jax.grad(lambda x: g(x).sum())(jnp.ones(4))


class FunSpec(NamedTuple):
name: str
Expand Down

0 comments on commit 12e57de

Please sign in to comment.