Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid re-flattening in jit() and pmap() when no donate_argnums are present. #3945

Merged
merged 1 commit into from Aug 3, 2020

Conversation

adria-p
Copy link

@adria-p adria-p commented Aug 3, 2020

Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.

Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.
@google-cla google-cla bot added the cla: yes label Aug 3, 2020
@tomhennigan tomhennigan self-requested a review August 3, 2020 21:54
Copy link
Member

@tomhennigan tomhennigan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW if you were seeing this overhead when using Haiku you should find google-deepmind/dm-haiku@81f6f7a significantly improves the performance for you (tree_flatten on params/state is now O(1)).

@mattjj mattjj merged commit 4e873f4 into google:master Aug 3, 2020
@mattjj
Copy link
Member

mattjj commented Aug 3, 2020

Thank you!

@gnecula
Copy link
Collaborator

gnecula commented Aug 4, 2020

@tomhennigan This PR fails some internal tests. Perhaps you can take a look and see what the problem is.
https://test.corp.google.com/ui#id=OCL:324772011:BASE:324772057:1596532059027:eaa9fd87

@gnecula
Copy link
Collaborator

gnecula commented Aug 4, 2020

The failure is in BufferDonationTest.test_pmap_donate_argnums_invalidates_input (only on TPU and GPU):

  File "/build/work/978904e12079446cd55d2263ccdfe9103bb3/google3/runfiles/google3/third_party/py/jax/tests/api_test.py", line 3391, in test_pmap_donate_argnums_invalidates_input
    self.assertDeleted(x)
  File "/build/work/978904e12079446cd55d2263ccdfe9103bb3/google3/runfiles/google3/third_party/py/jax/tests/api_test.py", line 3404, in <lambda>
    assertDeleted = lambda self, x: self._assertDeleted(x, True)
  File "/build/work/978904e12079446cd55d2263ccdfe9103bb3/google3/runfiles/google3/third_party/py/jax/tests/api_test.py", line 3412, in _assertDeleted
    self.assertEqual(buffer.is_deleted(), deleted)
  File "<embedded stdlib>/unittest/case.py", line 829, in assertEqual
    assertion_func(first, second, msg=msg)
  File "<embedded stdlib>/unittest/case.py", line 822, in _baseAssertEqual
    raise self.failureException(msg)
AssertionError: False != True

gnecula added a commit to gnecula/jax that referenced this pull request Aug 4, 2020
…nt. (google#3945)"

This reverts commit 4e873f4.

See comments in google#3945 about the failure.
gnecula added a commit that referenced this pull request Aug 4, 2020
…nt. (#3945)" (#3953)

This reverts commit 4e873f4.

See comments in #3945 about the failure.
@adria-p
Copy link
Author

adria-p commented Aug 4, 2020 via email

@adria-p
Copy link
Author

adria-p commented Aug 4, 2020

PR resent #3955.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants