New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[host_callback] Add support for pmap and for passing the device to tap #5182
Conversation
@@ -590,7 +590,7 @@ def _hashable_index(idx): | |||
# TODO(skye): is there a simpler way to rewrite this using sharding_spec? | |||
def _shard_sharded_device_array_slow_path(x, devices, indices): | |||
candidates = defaultdict(list) | |||
for buf, idx in zip(x.device_buffers, x.indices): | |||
for buf, idx in safe_zip(x.device_buffers, x.indices): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made these changes of zip
-> safe_zip
because while debugging my changes zip
was failing silently.
b6bdb92
to
27da9a6
Compare
What do you think about treating From a user perspective, I guess this might be a little more expected and would avoid needing to change the signature of host callback. On the other hand, |
I thought about adding pmap to the list of transforms but it is not easy at all, because pmap is not really a transformation. It is more like jit than like vmap. One of the test cases I have added is a jvp of pmap of vmap, specifically to explore this. By the time I can intervene the code is transformed (with jvp of vmap). (The same would happen if we had jvp + jit + vmap). Perhaps you are suggesting to just add a |
@gnecula: for the logging usecase, we would like the data in the form @shoyer mentioned. But there is not enough info to get it into this form with the current setup, even with the device information. Consider:
If we switch the
and for the other it is
I do realize that this is not a common way of using |
Unfortunately, I cannot surface precisely in the printing that the user had a |
Right, |
tests/host_callback_test.py
Outdated
return assertMultiLineStrippedEqual(tst, expected, what) | ||
|
||
|
||
# Run all tests with 8 CPU devices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a typo? It looks like it's always 2 CPU devices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A typo indeed, fixed. Thanks!
* Adds support for jit of pmap and pmap of pmap. * Also adds a `tap_with_device` optional argument to `id_print` and `id_tap`, to have the tap function invoked with a device keyword argument. * Added multiple tests involving pmap Issue: google#5134 Fixes: google#5169
tap_with_device
optional argument toid_print
andid_tap
, to have the tap function invoked with a device keyword argument.In presence of
pmap
there will be multiple devices sending data to the host. Each suchtap will be processed separately (and there may be interleaving between devices). To make
the output more understandable, we added an option to pass the device to the tap function.
For backwards-compatibility, the device is passed as a keyword argument to
id_tap
and only ifthe
tap_with_device
optional flag is passed toid_tap
orid_print
.Issue: #5134
Fixes: #5169