Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type mismatch in jet rule for abs #3807

Merged
merged 1 commit into from
Jul 21, 2020
Merged

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jul 20, 2020

Currently, in x64 mode, jet fails on lax.abs for float32 input. This PR fixes the issue.

I opted not to add a test for the moment, because doing it well would require a larger restructure of jet_test.py, and this is the only tested jet rule for which this is an issue.

Repro:

from jax.config import config
config.update("jax_enable_x64", True)

import numpy as np
from jax.experimental.jet import jet
from jax import lax

x = np.zeros((2, 3), dtype='float32')
jet(lax.abs, (x,), ([x],))
Traceback (most recent call last):
  File "tmp.py", line 9, in <module>
    jet(lax.abs, (x,), ([x],))
  File "/Users/vanderplas/github/google/jax/jax/experimental/jet.py", line 56, in jet
    out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
  File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/vanderplas/github/google/jax/jax/lax/lax.py", line 255, in abs
    return abs_p.bind(x)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 276, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/Users/vanderplas/github/google/jax/jax/experimental/jet.py", line 126, in process_primitive
    primal_out, terms_out = rule(primals_in, series_in, **params)
  File "/Users/vanderplas/github/google/jax/jax/experimental/jet.py", line 506, in _abs_taylor_rule
    negs = lax.select(lax.lt(x, 0.0), lax.full_like(x, -1), lax.full_like(x, 1.0))
  File "/Users/vanderplas/github/google/jax/jax/lax/lax.py", line 366, in lt
    return lt_p.bind(x, y)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 273, in bind
    return self.impl(*args, **kwargs)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 224, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 240, in xla_primitive_callable
    aval_out = prim.abstract_eval(*avals, **params)
  File "/Users/vanderplas/github/google/jax/jax/lax/lax.py", line 1816, in standard_abstract_eval
    return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
  File "/Users/vanderplas/github/google/jax/jax/lax/lax.py", line 1857, in naryop_dtype_rule
    _check_same_dtypes(name, False, *aval_dtypes)
  File "/Users/vanderplas/github/google/jax/jax/lax/lax.py", line 5587, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
TypeError: lt requires arguments to have the same dtypes, got float32, float64.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Jul 20, 2020

(side note to @mattjj: the weakly-typed constants we talked about this afternoon would be a better fix for this kind of issue in the long run)

Copy link
Member

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@jakevdp jakevdp merged commit 71f80a5 into google:master Jul 21, 2020
@jekbradbury
Copy link
Contributor

I'm a little confused; 0.0 should already be treated as weakly typed (coerced to float32 when interacting with float32 JAX values and to float64 when interacting with float64 JAX values). Why is that not enough here?

NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 21, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants