Skip to content

Commit

Permalink
Use XLA atan2 for complex atan
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 382831891
  • Loading branch information
majnemer authored and jax authors committed Jul 2, 2021
1 parent 6b51f5c commit 5f11bf5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Expand Up @@ -2432,7 +2432,7 @@ def acos_translation_rule(x):
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))

def atan_translation_rule(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
if jax.lib._xla_extension_version < 26 and dtypes.issubdtype(_dtype(x), np.complexfloating):
return mul(_const(x, -1j), atanh(mul(_const(x, 1j), x)))
else:
return atan2(x, _const(x, 1))
Expand All @@ -2442,7 +2442,7 @@ def atan_translation_rule(x):
multiple_results=False))
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))

atan2_p = standard_naryop([_float, _float], 'atan2')
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
ad.defjvp(atan2_p,
lambda g, x, y: g * (y / (square(x) + square(y))),
lambda g, x, y: g * -x / (square(x) + square(y)))
Expand Down
16 changes: 15 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1106,7 +1106,21 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.AbstractValue],
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(lax.atan_translation_rule,
multiple_results=False)

tf_impl[lax.atan2_p] = tf.math.atan2
def _atan2(y, x, **kwargs):
if x.dtype.is_complex or y.dtype.is_complex:
complex_component_dtype = {
tf.complex64: tf.float32,
tf.complex128: tf.float64
}.get(y.dtype)
zero = tf.constant(0, complex_component_dtype)
one = tf.constant(1, complex_component_dtype)
i = tf.complex(zero, one)
return -i * tf.math.log((x + i * y)/tf.math.sqrt(x * x + y * y))
else:
return tf.math.atan2(y, x)


tf_impl[lax.atan2_p] = _atan2
tf_impl[lax.acosh_p] = tf.math.acosh
tf_impl[lax.atanh_p] = tf.math.atanh
tf_impl[lax.asinh_p] = tf.math.asinh
Expand Down

0 comments on commit 5f11bf5

Please sign in to comment.