Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[host_callback] Add support for pmap and for passing the device to tap #5182

Merged
merged 1 commit into from Dec 15, 2020

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Dec 13, 2020

  • Adds support for jit of pmap and pmap of pmap.
  • Also add a tap_with_device optional argument to id_print and
    id_tap, to have the tap function invoked with a device keyword argument.

In presence of pmap there will be multiple devices sending data to the host. Each such
tap will be processed separately (and there may be interleaving between devices). To make
the output more understandable, we added an option to pass the device to the tap function.
For backwards-compatibility, the device is passed as a keyword argument to id_tap and only if
the tap_with_device optional flag is passed to id_tap or id_print.

Issue: #5134
Fixes: #5169

@google-cla google-cla bot added the cla: yes label Dec 13, 2020
@gnecula gnecula requested a review from shoyer December 13, 2020 08:49
@gnecula gnecula self-assigned this Dec 13, 2020
@gnecula gnecula added the pull ready Ready for copybara import and testing label Dec 13, 2020
@@ -590,7 +590,7 @@ def _hashable_index(idx):
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
def _shard_sharded_device_array_slow_path(x, devices, indices):
candidates = defaultdict(list)
for buf, idx in zip(x.device_buffers, x.indices):
for buf, idx in safe_zip(x.device_buffers, x.indices):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made these changes of zip -> safe_zip because while debugging my changes zip was failing silently.

@gnecula gnecula force-pushed the print_pmap branch 2 times, most recently from b6bdb92 to 27da9a6 Compare December 13, 2020 11:02
@shoyer
Copy link
Member

shoyer commented Dec 14, 2020

What do you think about treating pmap as just another transformation from the perspective of host_callback? E.g., adding an entry to the sequence of transformations, like ('pmap', {'device': ...})?

From a user perspective, I guess this might be a little more expected and would avoid needing to change the signature of host callback. On the other hand, device is part of how a computation is run (you can't nest pmap with different devices).

@gnecula
Copy link
Collaborator Author

gnecula commented Dec 14, 2020

I thought about adding pmap to the list of transforms but it is not easy at all, because pmap is not really a transformation. It is more like jit than like vmap. One of the test cases I have added is a jvp of pmap of vmap, specifically to explore this. By the time I can intervene the code is transformed (with jvp of vmap). (The same would happen if we had jvp + jit + vmap).

Perhaps you are suggesting to just add a pmap at the end of transforms just as a way to carry the device. I thought about that too, but the device on which a computation runs has significance even without pmap. That is why I think it makes sense to add the device to the host callback API.

@sbodenstein
Copy link

sbodenstein commented Dec 14, 2020

@gnecula: for the logging usecase, we would like the data in the form @shoyer mentioned. But there is not enough info to get it into this form with the current setup, even with the device information. Consider:

@jax.pmap
@jax.vmap
def f(x):
  x = host_callback.id_tap(print, x)
  return x**2

If we switch the vmap and pmap, the output of the function is the same. But in the one case the prints are

[0 0] (('batch', {'batch_dims': (0,)}),)
[1 1] (('batch', {'batch_dims': (0,)}),)

and for the other it is

[0 1] (('batch', {'batch_dims': (0,)}),)
[0 1] (('batch', {'batch_dims': (0,)}),)

I do realize that this is not a common way of using pmap and vmap, and perhaps its not worth worrying about. Other than this sort of scenario, does one have enough info to reconstruct the log as though it ran on a single device?

@gnecula
Copy link
Collaborator Author

gnecula commented Dec 14, 2020

Unfortunately, I cannot surface precisely in the printing that the user had a vmap(pmap(f)), because that information is lost by the time we get to execute the code. What is printed is what actually executes on each device. Probably most users would not understand that vmap(pmap(f)) is identical to pmap(vmap(f), in_axes=1). In other words, pmap is always moved to the top level of the transformation stack, taking care of adjusting in_axes. The new change to printing can be used to experiment and understand that. In fact, I have added a test_pmap_vmap and test_vmap_pmap.

@shoyer
Copy link
Member

shoyer commented Dec 14, 2020

Right, pmap is really just sugar for launching the same program on multiple devices at the same time. I guess this distinction is lost by the time host_callback runs?

return assertMultiLineStrippedEqual(tst, expected, what)


# Run all tests with 8 CPU devices.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a typo? It looks like it's always 2 CPU devices?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A typo indeed, fixed. Thanks!

* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
  `id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap

Issue: google#5134
Fixes: google#5169
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jax.pmap causes multiple id_tap callbacks to run
3 participants