Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax backend #68

Merged
merged 17 commits into from
May 22, 2019
Merged

Jax backend #68

merged 17 commits into from
May 22, 2019

Conversation

chaserileyroberts
Copy link
Contributor

No description provided.

This was referenced May 21, 2019
Copy link
Contributor

@stavros11 stavros11 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a JAX expert but looks good to me. Some small observations:

Because of the bug you mentioned, it requires the latest version of jax to work, but that's fine since this is already on their pip.

Also, there are still some rtol issues in the tests. Locally, sometimes they pass and sometimes I get some failures probably because we have random numbers without fixed seed. This is a jax issue though, since there are some numerical differences between jax and numpy dot for large matrices. I don't think this will be significant in any application.

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented May 22, 2019

Yeah that's a good observation. I think we should change those tests to be deterministic. Will do that in the next PR.

@chaserileyroberts chaserileyroberts merged commit 6bb90a3 into master May 22, 2019
@chaserileyroberts chaserileyroberts deleted the jax_backend branch May 22, 2019 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants