Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of np.ldexp and np.frexp #1529

Merged
merged 1 commit into from
Apr 1, 2020

Conversation

@WindQAQ WindQAQ changed the title Implementation np.ldexp and np.frexp Implementation of np.ldexp and np.frexp Oct 20, 2019
Copy link
Member

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! It looks great, but I'd suggest some minor improvements for extreme values.

@@ -138,6 +138,7 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None,
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
op_record("frexp", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please:
(a) test a larger range of values? I believe rand_default() returns normally-distributed random numbers with a mean of zero and a standard deviation of 1; for this function in particular I would expect it to be useful to check a range of exponent values.
(b) test inf and nan values? I believe the current implementation does the wrong thing for inf.

(You might need to extend or refactor one of the random generators in test_util.py.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response. Sure. Will do that.

By the way, I wonder how to enable float64 for a single test case. I could enable it via config.update("jax_enable_x64", True) but it seems that it will affect all testcases. In the case of ldexp, because the max value for float32 and float64 is different, for very large exponent, one will return inf, and the other will return a real value. Therefore, I'd like to test some values with float64.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX's unit tests are run twice on Travis-CI: with and without float64 enabled. You can simply check if FLAGS.jax_enable_x64 is true to known if float64 is enabled (and if some test specifically needs float64 enabled or not, feel free to skip it otherwise)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Will update it soon!

@@ -504,6 +504,24 @@ def exp2(x):
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))


@_wraps(onp.ldexp)
def ldexp(x1, x2):
x1, x2 = _promote_shapes("ldexp",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally both of these functions are implemented by bit-tricks. It's possible the bit-trick implementations are faster. This implementation is however fine, it's just an observation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bit trick is indeed the underlying implementation of C/C++, but I wonder how to access bit representation of floating point in JAX. Seems that there is currently no equivalent function of np.ndarray.view to reinterpret the data type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have implemented it with bitwise operations. Mainly based on golang's implementation.

https://github.com/golang/go/blob/master/src/math/ldexp.go
https://github.com/golang/go/blob/master/src/math/frexp.go

@WindQAQ
Copy link
Contributor Author

WindQAQ commented Jan 9, 2020

Hi @hawkinsp, could you take a look at the PR again? Thank you!

@hawkinsp
Copy link
Member

hawkinsp commented Apr 1, 2020

I updated this PR for some changes to Github head and made few small cleanups.

@hawkinsp
Copy link
Member

hawkinsp commented Apr 1, 2020

Thanks for the PR! Sorry it took us so long to merge it.

@hawkinsp hawkinsp merged commit 8c4a938 into google:master Apr 1, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Apr 2, 2020
Co-authored-by: Peter Hawkins <phawkins@google.com>
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Apr 13, 2020
Co-authored-by: Peter Hawkins <phawkins@google.com>
mattjj added a commit that referenced this pull request May 27, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jun 11, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants