-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of np.ldexp and np.frexp #1529
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! It looks great, but I'd suggest some minor improvements for extreme values.
tests/lax_numpy_test.py
Outdated
@@ -138,6 +138,7 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None, | |||
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []), | |||
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default(), []), | |||
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]), | |||
op_record("frexp", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please:
(a) test a larger range of values? I believe rand_default()
returns normally-distributed random numbers with a mean of zero and a standard deviation of 1; for this function in particular I would expect it to be useful to check a range of exponent values.
(b) test inf and nan values? I believe the current implementation does the wrong thing for inf.
(You might need to extend or refactor one of the random generators in test_util.py
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late response. Sure. Will do that.
By the way, I wonder how to enable float64 for a single test case. I could enable it via config.update("jax_enable_x64", True)
but it seems that it will affect all testcases. In the case of ldexp
, because the max value for float32 and float64 is different, for very large exponent, one will return inf, and the other will return a real value. Therefore, I'd like to test some values with float64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX's unit tests are run twice on Travis-CI: with and without float64 enabled. You can simply check if FLAGS.jax_enable_x64
is true to known if float64 is enabled (and if some test specifically needs float64 enabled or not, feel free to skip it otherwise)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Will update it soon!
jax/numpy/lax_numpy.py
Outdated
@@ -504,6 +504,24 @@ def exp2(x): | |||
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x)) | |||
|
|||
|
|||
@_wraps(onp.ldexp) | |||
def ldexp(x1, x2): | |||
x1, x2 = _promote_shapes("ldexp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally both of these functions are implemented by bit-tricks. It's possible the bit-trick implementations are faster. This implementation is however fine, it's just an observation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit trick is indeed the underlying implementation of C/C++, but I wonder how to access bit representation of floating point in JAX. Seems that there is currently no equivalent function of np.ndarray.view to reinterpret the data type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I have implemented it with bitwise operations. Mainly based on golang's implementation.
https://github.com/golang/go/blob/master/src/math/ldexp.go
https://github.com/golang/go/blob/master/src/math/frexp.go
Hi @hawkinsp, could you take a look at the PR again? Thank you! |
97a8d1a
to
300b85f
Compare
I updated this PR for some changes to Github head and made few small cleanups. |
Thanks for the PR! Sorry it took us so long to merge it. |
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Peter Hawkins <phawkins@google.com>
As per #70, add
ldexp
and its inverse functionfrexp
.Related links:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.ldexp.html#numpy.ldexp
https://docs.scipy.org/doc/numpy/reference/generated/numpy.frexp.html#numpy.frexp
https://en.cppreference.com/w/cpp/numeric/math/frexp