Skip to content

Commit

Permalink
[Data] Apply limit to Dataset.take() and related methods (ray-proje…
Browse files Browse the repository at this point in the history
…ct#38677)

To improve the efficiency of common Dataset access methods such as Dataset.take(), .take_batch(), .show(), etc., we can apply a Limit before taking rows for these methods to avoid materializing more rows than requested by the user.

Signed-off-by: Scott Lee <sjl@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
scottjlee authored and arvind-chandra committed Aug 31, 2023
1 parent 4338539 commit 8dd6649
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
14 changes: 12 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,10 +2322,12 @@ def take_batch(
``ValueError``: if the dataset is empty.
"""
batch_format = _apply_strict_mode_batch_format(batch_format)
limited_ds = self.limit(batch_size)

try:
res = next(
iter(
self.iter_batches(
limited_ds.iter_batches(
batch_size=batch_size,
prefetch_batches=0,
batch_format=batch_format,
Expand All @@ -2335,6 +2337,9 @@ def take_batch(
except StopIteration:
raise ValueError("The dataset is empty.")
self._synchronize_progress_bar()

# Save the computed stats to the original dataset.
self._plan._snapshot_stats = limited_ds._plan.stats()
return res

@ConsumptionAPI
Expand Down Expand Up @@ -2375,11 +2380,16 @@ def take(self, limit: int = 20) -> List[Dict[str, Any]]:
"records in pandas or numpy batch format."
)
output = []
for row in self.iter_rows():

limited_ds = self.limit(limit)
for row in limited_ds.iter_rows():
output.append(row)
if len(output) >= limit:
break
self._synchronize_progress_bar()

# Save the computed stats to the original dataset.
self._plan._snapshot_stats = limited_ds._plan.stats()
return output

@ConsumptionAPI
Expand Down
48 changes: 39 additions & 9 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,9 +1050,22 @@ def write_datasource(
)

def take(self, limit: int = 20) -> List[Dict[str, Any]]:
"""Call :py:meth:`Dataset.take <ray.data.Dataset.take>` over the stream of
output batches from the pipeline"""
return Dataset.take(self, limit)
"""Replicates the logic of :py:meth:`Dataset.take <ray.data.Dataset.take>`
over the stream of output batches from the pipeline, excluding logic
of applying a `Limit[batch_size]` before taking rows."""
if ray.util.log_once("dataset_take"):
logger.info(
"Tip: Use `take_batch()` instead of `take() / show()` to return "
"records in pandas or numpy batch format."
)

output = []
for row in self.iter_rows():
output.append(row)
if len(output) >= limit:
break
self._synchronize_progress_bar()
return output

def take_all(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""Call :py:meth:`Dataset.take_all <ray.data.Dataset.take_all>` over the stream
Expand All @@ -1062,14 +1075,31 @@ def take_all(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
def take_batch(
self, batch_size: int = 20, *, batch_format: Optional[str] = "default"
) -> DataBatch:
"""Call :py:meth:`Dataset.take_batch <ray.data.Dataset.take_batch>`
over the stream of output batches from the pipeline"""
return Dataset.take_batch(self, batch_size, batch_format=batch_format)
"""Replicates the logic of :py:meth:`Dataset.take_batch <ray.data.Dataset.take_batch>`
over the stream of output batches from the pipeline, excluding logic
of applying a `Limit[batch_size]` before taking rows."""
batch_format = _apply_strict_mode_batch_format(batch_format)
try:
res = next(
iter(
self.iter_batches(
batch_size=batch_size,
prefetch_batches=0,
batch_format=batch_format,
)
)
)
except StopIteration:
raise ValueError("The dataset is empty.")
self._synchronize_progress_bar()
return res

def show(self, limit: int = 20) -> None:
"""Call :py:meth:`Dataset.show <ray.data.Dataset.show>` over the stream of
output batches from the pipeline"""
return Dataset.show(self, limit)
"""Replicates the logic of :py:meth:`Dataset.show <ray.data.Dataset.show>`
over the stream of output batches from the pipeline, excluding logic
of applying a `Limit[batch_size]` before taking rows."""
for row in self.take(limit):
print(row)

def iter_tf_batches(
self,
Expand Down

0 comments on commit 8dd6649

Please sign in to comment.