New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid re-flattening in jit() and pmap() when no donate_argnums are present. #3945
Conversation
Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW if you were seeing this overhead when using Haiku you should find google-deepmind/dm-haiku@81f6f7a significantly improves the performance for you (tree_flatten
on params/state is now O(1)
).
Thank you! |
@tomhennigan This PR fails some internal tests. Perhaps you can take a look and see what the problem is. |
The failure is in BufferDonationTest.test_pmap_donate_argnums_invalidates_input (only on TPU and GPU):
|
…nt. (google#3945)" This reverts commit 4e873f4. See comments in google#3945 about the failure.
I see, I hadn't noticed that donate_argnums can be an integer too (which is
the case in that test), so the condition can't be using the implicit
boolean value.
…On Tue, 4 Aug 2020, 12:10 George Necula, ***@***.***> wrote:
The failure is in
BufferDonationTest.test_pmap_donate_argnums_invalidates_input:
File "/build/work/978904e12079446cd55d2263ccdfe9103bb3/google3/runfiles/google3/third_party/py/jax/tests/api_test.py", line 3391, in test_pmap_donate_argnums_invalidates_input
self.assertDeleted(x)
File "/build/work/978904e12079446cd55d2263ccdfe9103bb3/google3/runfiles/google3/third_party/py/jax/tests/api_test.py", line 3404, in <lambda>
assertDeleted = lambda self, x: self._assertDeleted(x, True)
File "/build/work/978904e12079446cd55d2263ccdfe9103bb3/google3/runfiles/google3/third_party/py/jax/tests/api_test.py", line 3412, in _assertDeleted
self.assertEqual(buffer.is_deleted(), deleted)
File "<embedded stdlib>/unittest/case.py", line 829, in assertEqual
assertion_func(first, second, msg=msg)
File "<embedded stdlib>/unittest/case.py", line 822, in _baseAssertEqual
raise self.failureException(msg)
AssertionError: False != True
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#3945 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAINVQYFJZLCC4LVRNLHLWDR67NAHANCNFSM4PTSJQXA>
.
|
PR resent #3955. |
Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.