Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jax2tf] First draft of testing the QR conversion. #3775

Merged
merged 6 commits into from
Jul 21, 2020

Conversation

bchetioui
Copy link
Collaborator

QR decomposition is off by over 1e-6 in some instances, requiring custom
atol and rtol values in testing code. There is an odd problem in which
experimental compilation fails for complex types, although they are in
principle supported.

@hawkinsp
Copy link
Member

QR decomposition is off by over 1e-6 in some instances, requiring custom
atol and rtol values in testing code. There is an odd problem in which
experimental compilation fails for complex types, although they are in
principle supported.

Complex QR would not work if compiled with XLA in TensorFlow because there is no XLA-level implementation. It works in JAX on CPU and GPU because we call LAPACK and Cusolver via custom-calls. It does not work in JAX on TPU for the same reason (#1274): we don't have an implementation yet.

Also note that even for real QR decompositions you will call the HLO implementation when compiling with TF, which is probably much slower than the LAPACK/Cusolver implementations. Arguably the right fix for that is to upstream our linear algebra implementations to XLA so TF/XLA can call them also.

self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=1e-5, rtol=1e-5)
else:
# TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good strategy to keep track of TODOs that are discussed in comments.

jax/experimental/jax2tf/tests/primitives_test.py Outdated Show resolved Hide resolved
QR decomposition is off by over 1e-6 in some instances, requiring custom
atol and rtol values in testing code. There is an odd problem in which
experimental compilation fails for complex types, although they are in
principle supported.
@gnecula
Copy link
Collaborator

gnecula commented Jul 21, 2020

I tested this internally before merging.

@gnecula gnecula merged commit 9773432 into google:master Jul 21, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
* [jax2tf] First draft of testing the QR conversion.

QR decomposition is off by over 1e-6 in some instances, requiring custom
atol and rtol values in testing code. There is an odd problem in which
experimental compilation fails for complex types, although they are in
principle supported.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
* [jax2tf] First draft of testing the QR conversion.

QR decomposition is off by over 1e-6 in some instances, requiring custom
atol and rtol values in testing code. There is an odd problem in which
experimental compilation fails for complex types, although they are in
principle supported.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
* [jax2tf] First draft of testing the QR conversion.

QR decomposition is off by over 1e-6 in some instances, requiring custom
atol and rtol values in testing code. There is an odd problem in which
experimental compilation fails for complex types, although they are in
principle supported.
@bchetioui bchetioui deleted the test_qr branch August 31, 2020 07:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants