diff --git a/jax/_src/core.py b/jax/_src/core.py index 4c9157059ffa..cf7078bb8798 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1787,12 +1787,9 @@ def str_short(self, short_dtypes=False) -> str: _complex = concretization_function_error(complex, True) def primal_dtype_to_tangent_dtype(primal_dtype): - # TODO(frostig,mattjj): determines that all extended dtypes have - # float0 tangent type, which works fine for all our current - # extended dtype applications. We may some day want to delegate - # this decision to the dtype rules. - if (dtypes.issubdtype(primal_dtype, dtypes.extended) or - not dtypes.issubdtype(primal_dtype, np.inexact)): + if dtypes.issubdtype(primal_dtype, dtypes.extended): + return primal_dtype._rules.tangent_dtype(primal_dtype) # type: ignore + elif not dtypes.issubdtype(primal_dtype, np.inexact): return dtypes.float0 else: return primal_dtype diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index c5a743931263..b9999097131d 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -208,7 +208,7 @@ def write_cotangent(prim, v, ct): # assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): - return ct_env.pop(v, Zero(v.aval)) + return ct_env.pop(v, Zero(v.aval.at_least_vspace())) def read_primal(v): if type(v) is Literal: @@ -588,19 +588,7 @@ def zero_jvp(primitive, primals, tangents, **params): deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) def instantiate_zeros(tangent): - if type(tangent) is not Zero: - return tangent - return instantiate_zeros_aval(tangent.aval, tangent) - -# This function seems similar to instantiate_zeros, but it is sometimes used -# to instantiate zero abstract units with a different aval -def instantiate_zeros_aval(aval, tangent): - if type(tangent) is not Zero: - return tangent - assert tangent.aval == aval - if jax.dtypes.issubdtype(aval.dtype, jax.dtypes.extended): - return aval.dtype._rules.make_tangent(aval.shape, aval.dtype) - return zeros_like_aval(aval) + return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent @lu.transformation_with_aux def traceable(in_tree, *primals_and_tangents): @@ -760,7 +748,7 @@ def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals, if symbolic_zeros: cts_out = map(replace_internal_symbolic_zeros, cts_out) else: - cts_out = map(instantiate_zeros_aval, out_avals, cts_out) + cts_out = map(instantiate_zeros, cts_out) cts_in = bwd(*res, *cts_out) cts_in = map(replace_rule_output_symbolic_zeros, cts_in) return [None] * num_res + list(cts_in) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 192ba2648665..ea3931fe9a91 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -706,7 +706,7 @@ def transposed(*args): cts_in = ad.backward_pass( jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out) _, cts_in = split_list(cts_in, [num_res]) - return map(ad.instantiate_zeros_aval, primal_avals, cts_in) + return map(ad.instantiate_zeros, cts_in) return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals) @@ -729,7 +729,7 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear): for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) res = ops[:num_res] - cts = map(ad.instantiate_zeros_aval, branches[0].out_avals, cts) + cts = map(ad.instantiate_zeros, cts) linear_trans = (False,) * num_res + (True,) * len(cts) out = cond_p.bind( diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 04deab04ba4b..6431f7592232 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -740,7 +740,7 @@ def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts, carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ct_carry, ct_ys = split_list(cts, [num_carry]) - ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) + ct_carry = _map(ad.instantiate_zeros, ct_carry) ct_ys_is_zeros = [type(ct_y) is ad.Zero for ct_y in ct_ys] ct_ys = [x for x in ct_ys if type(x) is not ad.Zero] @@ -797,9 +797,8 @@ def transposed(*res1_cbar_bbar_res2): cbar_abar = ad.backward_pass( jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, b_bar + ys_bar) _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) - a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) - c_bar = _map(ad.instantiate_zeros_aval, c_avals, - _map(ad.add_tangents, c_bar, new_c_bar)) + a_bar = _map(ad.instantiate_zeros, a_bar) + c_bar = _map(ad.instantiate_zeros, _map(ad.add_tangents, c_bar, new_c_bar)) return c_bar + a_bar return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_carry_avals + b_ys_avals_stripped + res2_avals) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 4158b4eb53ef..23759876fdee 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -472,23 +472,6 @@ def full(shape, fill_value, dtype): # the outset. return random_wrap(key_data, impl=dtype._impl) - @staticmethod - def make_tangent(shape, dtype): - physical_shape = (*shape, *dtype._impl.key_shape) - def not_implemented(name): - def func(*args): - raise NotImplementedError(f"Cannot call {name} on tangent of PRNG key.") - return func - impl = PRNGImpl( - key_shape=dtype._impl.key_shape, - seed=not_implemented('seed'), - split=not_implemented('split'), - random_bits=not_implemented('random_bits'), - fold_in=not_implemented('fold_in'), - name=f"{dtype._impl.name}_tangent", - tag=f"{dtype._impl.tag}_t") - return random_wrap(jnp.zeros(physical_shape, dtype='uint32'), impl=impl) - @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32')) @@ -610,19 +593,8 @@ def device_put_replicated(val, aval, sharding, devices): physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype._impl) - -class KeyTangentTy(dtypes.ExtendedDType): - """A dtype to use for the tangent of a PRNGKey""" - _impl: PRNGImpl - type = dtypes.prng_key - - @property - def _rules(self): - raise ValueError("Cannot perform operations on the tangent of a PRNGKey.") - - @property - def name(self) -> str: - return f'key_tangent<{self._impl.tag}>' + def tangent_dtype(_): + return dtypes.float0 class KeyTy(dtypes.ExtendedDType): diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 40ebe7eea303..8bb41eddaf85 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1299,17 +1299,8 @@ def _aval_is_empty(aval) -> bool: return math.prod(aval.shape) == 0 def _instantiate_zeros(tan, arg): - """Turn special ad.zero tangents into arrays of 0s for sending to host. - Args: - tan: the tangent. - arg: the argument for which we need to instantiate the tangent - - Returns: tan if it is not ad.Zero, otherwise a 0 array of appropriate type - and shape - """ - if type(tan) is not ad.Zero: - return tan - return ad.instantiate_zeros_aval(tan.aval, tan) + del arg + return ad.instantiate_zeros(tan) def _outside_call_jvp_rule(primals, tangents, **params): assert "has_token" not in params diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 6810f7bb5d31..9147bba0074d 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -338,6 +338,9 @@ def _arg_jax_to_tf(arg_jax): # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! + if getattr(arg_jax, 'dtype', None) == dtypes.float0: + return tf.zeros(shape=arg_jax.shape, + dtype=jax2tf_internal._tf_np_dtype_for_float0) return tf.constant(np.asarray(arg_jax)) args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat)) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 4c0ff4e2a412..f88e411fa22a 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -43,7 +43,6 @@ f_jvp_traceable as f_jvp_traceable, get_primitive_transpose as get_primitive_transpose, instantiate_zeros as instantiate_zeros, - instantiate_zeros_aval as instantiate_zeros_aval, is_undefined_primal as is_undefined_primal, jvp as jvp, jvp_jaxpr as jvp_jaxpr, diff --git a/tests/random_test.py b/tests/random_test.py index 88f604f2f9f2..cfa0ecb45f9b 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1125,8 +1125,8 @@ def f(_, state): def _f_fwd(_, state): return state, None def _f_bwd(_, state_bar): - assert state_bar[1].dtype.name == "key" # key tangent type - return state_bar + assert state_bar[1].dtype == dtypes.float0 # key tangent type + return state_bar[0], state_bar f.defvjp(_f_fwd, _f_bwd) state = (8.0, jax.random.key(123)) result = jax.grad(lambda theta: f(theta, state)[0])(3.0) @@ -1139,9 +1139,9 @@ def f(_, state): def _f_fwd(_, state): return tree_util.tree_map(lambda x: x.value, state), None def _f_bwd(_, state_bar): - self.assertTrue(dtypes.issubdtype(state_bar[1].dtype, dtypes.prng_key)) + self.assertTrue(state_bar[1].dtype == dtypes.float0) self.assertIsInstance(state_bar[1], jax.custom_derivatives.SymbolicZero) - return state_bar + return state_bar[0], state_bar f.defvjp(_f_fwd, _f_bwd, symbolic_zeros=True) state = (8.0, jax.random.key(123)) result = jax.grad(lambda theta: f(theta, state)[0])(3.0)