Skip to content

Commit

Permalink
revise logic for tangent types of extended dtypes
Browse files Browse the repository at this point in the history
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change

We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
  • Loading branch information
mattjj committed Dec 20, 2023
1 parent 35b8fdc commit ec7d28c
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 73 deletions.
9 changes: 3 additions & 6 deletions jax/_src/core.py
Expand Up @@ -1787,12 +1787,9 @@ def str_short(self, short_dtypes=False) -> str:
_complex = concretization_function_error(complex, True)

def primal_dtype_to_tangent_dtype(primal_dtype):
# TODO(frostig,mattjj): determines that all extended dtypes have
# float0 tangent type, which works fine for all our current
# extended dtype applications. We may some day want to delegate
# this decision to the dtype rules.
if (dtypes.issubdtype(primal_dtype, dtypes.extended) or
not dtypes.issubdtype(primal_dtype, np.inexact)):
if dtypes.issubdtype(primal_dtype, dtypes.extended):
return primal_dtype._rules.tangent_dtype(primal_dtype) # type: ignore
elif not dtypes.issubdtype(primal_dtype, np.inexact):
return dtypes.float0
else:
return primal_dtype
Expand Down
18 changes: 3 additions & 15 deletions jax/_src/interpreters/ad.py
Expand Up @@ -208,7 +208,7 @@ def write_cotangent(prim, v, ct):
# assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)

def read_cotangent(v):
return ct_env.pop(v, Zero(v.aval))
return ct_env.pop(v, Zero(v.aval.at_least_vspace()))

def read_primal(v):
if type(v) is Literal:
Expand Down Expand Up @@ -588,19 +588,7 @@ def zero_jvp(primitive, primals, tangents, **params):
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))

def instantiate_zeros(tangent):
if type(tangent) is not Zero:
return tangent
return instantiate_zeros_aval(tangent.aval, tangent)

# This function seems similar to instantiate_zeros, but it is sometimes used
# to instantiate zero abstract units with a different aval
def instantiate_zeros_aval(aval, tangent):
if type(tangent) is not Zero:
return tangent
assert tangent.aval == aval
if jax.dtypes.issubdtype(aval.dtype, jax.dtypes.extended):
return aval.dtype._rules.make_tangent(aval.shape, aval.dtype)
return zeros_like_aval(aval)
return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent

@lu.transformation_with_aux
def traceable(in_tree, *primals_and_tangents):
Expand Down Expand Up @@ -760,7 +748,7 @@ def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals,
if symbolic_zeros:
cts_out = map(replace_internal_symbolic_zeros, cts_out)
else:
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
cts_out = map(instantiate_zeros, cts_out)
cts_in = bwd(*res, *cts_out)
cts_in = map(replace_rule_output_symbolic_zeros, cts_in)
return [None] * num_res + list(cts_in)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Expand Up @@ -706,7 +706,7 @@ def transposed(*args):
cts_in = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return map(ad.instantiate_zeros, cts_in)

return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)

Expand All @@ -729,7 +729,7 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear):
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))

res = ops[:num_res]
cts = map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
cts = map(ad.instantiate_zeros, cts)
linear_trans = (False,) * num_res + (True,) * len(cts)

out = cond_p.bind(
Expand Down
7 changes: 3 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -740,7 +740,7 @@ def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,

carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ct_carry, ct_ys = split_list(cts, [num_carry])
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
ct_carry = _map(ad.instantiate_zeros, ct_carry)
ct_ys_is_zeros = [type(ct_y) is ad.Zero for ct_y in ct_ys]
ct_ys = [x for x in ct_ys if type(x) is not ad.Zero]

Expand Down Expand Up @@ -797,9 +797,8 @@ def transposed(*res1_cbar_bbar_res2):
cbar_abar = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, b_bar + ys_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
_map(ad.add_tangents, c_bar, new_c_bar))
a_bar = _map(ad.instantiate_zeros, a_bar)
c_bar = _map(ad.instantiate_zeros, _map(ad.add_tangents, c_bar, new_c_bar))
return c_bar + a_bar
return _make_closed_jaxpr(transposed,
res1_avals + c_avals + b_carry_avals + b_ys_avals_stripped + res2_avals)
Expand Down
32 changes: 2 additions & 30 deletions jax/_src/prng.py
Expand Up @@ -472,23 +472,6 @@ def full(shape, fill_value, dtype):
# the outset.
return random_wrap(key_data, impl=dtype._impl)

@staticmethod
def make_tangent(shape, dtype):
physical_shape = (*shape, *dtype._impl.key_shape)
def not_implemented(name):
def func(*args):
raise NotImplementedError(f"Cannot call {name} on tangent of PRNG key.")
return func
impl = PRNGImpl(
key_shape=dtype._impl.key_shape,
seed=not_implemented('seed'),
split=not_implemented('split'),
random_bits=not_implemented('random_bits'),
fold_in=not_implemented('fold_in'),
name=f"{dtype._impl.name}_tangent",
tag=f"{dtype._impl.tag}_t")
return random_wrap(jnp.zeros(physical_shape, dtype='uint32'), impl=impl)

@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32'))
Expand Down Expand Up @@ -610,19 +593,8 @@ def device_put_replicated(val, aval, sharding, devices):
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
return random_wrap(physical_result, impl=aval.dtype._impl)


class KeyTangentTy(dtypes.ExtendedDType):
"""A dtype to use for the tangent of a PRNGKey"""
_impl: PRNGImpl
type = dtypes.prng_key

@property
def _rules(self):
raise ValueError("Cannot perform operations on the tangent of a PRNGKey.")

@property
def name(self) -> str:
return f'key_tangent<{self._impl.tag}>'
def tangent_dtype(_):
return dtypes.float0


class KeyTy(dtypes.ExtendedDType):
Expand Down
13 changes: 2 additions & 11 deletions jax/experimental/host_callback.py
Expand Up @@ -1299,17 +1299,8 @@ def _aval_is_empty(aval) -> bool:
return math.prod(aval.shape) == 0

def _instantiate_zeros(tan, arg):
"""Turn special ad.zero tangents into arrays of 0s for sending to host.
Args:
tan: the tangent.
arg: the argument for which we need to instantiate the tangent
Returns: tan if it is not ad.Zero, otherwise a 0 array of appropriate type
and shape
"""
if type(tan) is not ad.Zero:
return tan
return ad.instantiate_zeros_aval(tan.aval, tan)
del arg
return ad.instantiate_zeros(tan)

def _outside_call_jvp_rule(primals, tangents, **params):
assert "has_token" not in params
Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/jax2tf/call_tf.py
Expand Up @@ -338,6 +338,9 @@ def _arg_jax_to_tf(arg_jax):
# The following avoids copies to the host on CPU, always for Array
# and even for ndarray if they are sufficiently aligned.
# TODO(necula): on TPU this copies to the host!
if getattr(arg_jax, 'dtype', None) == dtypes.float0:
return tf.zeros(shape=arg_jax.shape,
dtype=jax2tf_internal._tf_np_dtype_for_float0)
return tf.constant(np.asarray(arg_jax))

args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
Expand Down
1 change: 0 additions & 1 deletion jax/interpreters/ad.py
Expand Up @@ -43,7 +43,6 @@
f_jvp_traceable as f_jvp_traceable,
get_primitive_transpose as get_primitive_transpose,
instantiate_zeros as instantiate_zeros,
instantiate_zeros_aval as instantiate_zeros_aval,
is_undefined_primal as is_undefined_primal,
jvp as jvp,
jvp_jaxpr as jvp_jaxpr,
Expand Down
8 changes: 4 additions & 4 deletions tests/random_test.py
Expand Up @@ -1125,8 +1125,8 @@ def f(_, state):
def _f_fwd(_, state):
return state, None
def _f_bwd(_, state_bar):
assert state_bar[1].dtype.name == "key<fry_t>" # key tangent type
return state_bar
assert state_bar[1].dtype == dtypes.float0 # key tangent type
return state_bar[0], state_bar
f.defvjp(_f_fwd, _f_bwd)
state = (8.0, jax.random.key(123))
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
Expand All @@ -1139,9 +1139,9 @@ def f(_, state):
def _f_fwd(_, state):
return tree_util.tree_map(lambda x: x.value, state), None
def _f_bwd(_, state_bar):
self.assertTrue(dtypes.issubdtype(state_bar[1].dtype, dtypes.prng_key))
self.assertTrue(state_bar[1].dtype == dtypes.float0)
self.assertIsInstance(state_bar[1], jax.custom_derivatives.SymbolicZero)
return state_bar
return state_bar[0], state_bar
f.defvjp(_f_fwd, _f_bwd, symbolic_zeros=True)
state = (8.0, jax.random.key(123))
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
Expand Down

0 comments on commit ec7d28c

Please sign in to comment.