Skip to content

Commit

Permalink
[tf.data] Explicitly colocate prefetch dataset op with its input as t…
Browse files Browse the repository at this point in the history
…his collocation only happens automatically in graph mode.

PiperOrigin-RevId: 313867950
Change-Id: I88962b96f208b6d9019e0a117715f74efc8fdc67
  • Loading branch information
jsimsa authored and tensorflower-gardener committed May 29, 2020
1 parent 02dc6f8 commit 8be4d61
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def testPrefetchToDeviceGpuWithReInit(self):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)

@combinations.generate(test_base.eager_only_combinations())
def testPrefetchToDevicePlacement(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")

host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))

self.assertEqual(device_dataset._variant_tensor.device,
"/job:localhost/replica:0/task:0/device:GPU:0")


if __name__ == "__main__":
test.main()
16 changes: 10 additions & 6 deletions tensorflow/python/data/ops/dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4347,14 +4347,18 @@ def __init__(self, input_dataset, buffer_size, slack_period=None):
"""
self._input_dataset = input_dataset
if buffer_size is None:
buffer_size = -1 # This is the sentinel for auto-tuning.
buffer_size = AUTOTUNE
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
variant_tensor = gen_dataset_ops.prefetch_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
slack_period=slack_period,
**self._flat_structure)
# pylint: disable=protected-access
# We colocate the prefetch dataset with its input as this collocation only
# happens automatically in graph mode.
with ops.device(input_dataset._variant_tensor.device):
variant_tensor = gen_dataset_ops.prefetch_dataset(
input_dataset._variant_tensor,
buffer_size=self._buffer_size,
slack_period=slack_period,
**self._flat_structure)
super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)


Expand Down

0 comments on commit 8be4d61

Please sign in to comment.