Skip to content

Commit

Permalink
revise xla.device_put device logic (#2907)
Browse files Browse the repository at this point in the history
* revise xla.device_put device logic, fixes #2905

* remove test of behavior we don't want

Previously, we were testing that for a DeviceArray x, writing
jax.device_put(x) would evaluate to a DeviceArray *on the default
device*. Instead, we should be happy with just returning the same
DeviceArray without any movement.
  • Loading branch information
mattjj committed May 1, 2020
1 parent a4deae3 commit e06bde8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
17 changes: 12 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import defaultdict
import itertools as it
import operator as op
from typing import Any, Callable, Dict, Sequence, Type
from typing import Any, Callable, Dict, Sequence, Type, Optional

from absl import logging
import numpy as onp
Expand Down Expand Up @@ -920,16 +920,23 @@ def _device_put_device_array(x, device):
return _force(x).device_buffer
device_put_handlers[DeviceArray] = _device_put_device_array

def _copy_device_array_to_device(x, device):
if is_device_constant(x):
def _copy_device_array_to_device(x: DeviceArray, device: Optional[xc.Device]):
if device is None:
# no copying to be done because there's no target specified
return x
elif is_device_constant(x):
# create a new DeviceArray with the same lazy expr, no copying
return DeviceArray(x.aval, device, x._lazy_expr, DeviceConstant(device))
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
if device is None or x.device_buffer.device() == device:
# source and target platforms are the same
if x.device_buffer.device() == device:
# no copying to be done because source equals target
return x
else:
# move the buffer with a device-to-device copy
moved_buf = x.device_buffer.copy_to_device(device)
else:
# Buffers from different XLA backends are passed through the host.
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
return DeviceArray(x.aval, device, x._lazy_expr, moved_buf)
Expand Down
3 changes: 0 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,6 @@ def test_device_put_across_platforms(self):
x = api.device_put(val, device=cpu_device)
self.assertEqual(x.device_buffer.device(), cpu_device)

y = api.device_put(x)
self.assertEqual(y.device_buffer.device(), default_device)

def test_jit_on_all_devices(self):
# Verifies we can run the same computation on every device present, even
# if they are, for example, different models of GPU.
Expand Down
18 changes: 18 additions & 0 deletions tests/multibackend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,24 @@ def my_sin(x): return np.sin(x)
result4 = api.jit(my_sin, backend="cpu")(2)
self.assertEqual(result4.device_buffer.device(), cpus[0])

@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_indexing(self):
# https://github.com/google/jax/issues/2905
cpus = api.devices("cpu")

x = api.device_put(onp.ones(2), cpus[0])
y = x[0]
self.assertEqual(y.device_buffer.device(), cpus[0])

@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_sum(self):
# https://github.com/google/jax/issues/2905
cpus = api.devices("cpu")

x = api.device_put(onp.ones(2), cpus[0])
y = x.sum()
self.assertEqual(y.device_buffer.device(), cpus[0])


if __name__ == "__main__":
absltest.main()

0 comments on commit e06bde8

Please sign in to comment.