Skip to content

Conversation

allisonwang-db
Copy link
Contributor

This PR:

  • Implements the partitions method in DataSourceReader that uses the num_shards parameter from the IterableDataset to read from the a streaming dataset across multiple workers.
  • Supports non-streaming mode: load_dataset(..., streaming=False)
  • Supports dataset configuration: load_dataset(path, name=config_name, ...)
  • Changed the return type of the data source read method to use an iterator of arrow batches. Note this can only be tested against the Spark master branch build (not with spark4.0.0.dev2 release).

Example

spark.read.format("huggingface")
 .option("split", "train")
 .option("config", "plain_text")
 .option("streaming", "true")
 .load("rotten_tomatoes")

def partitions(self) -> Sequence[InputPartition]:
from datasets import load_dataset
if not self.streaming:
return [Shard(index=0)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhoestq I am not able to get num_shards for a non-streaming dataset. Do you know if this is supported?

Copy link
Member

@lhoestq lhoestq Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the dataset is available locally it is loaded as a Dataset as a memory mapped Arrow Table that is the concatenation of all the shards. So in practice you don't really care about the shards themselves since you can take whatever slice of the Table you want. The number of shards can be set to the maximum level of parallelism of the Spark setup, or we can decide to have as many shards as cached Arrow files, or as many shards as Arrow Record Batches for example.

Copy link
Member

@lhoestq lhoestq Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here load_dataset(..., streaming=False) downloads the full dataset and prepares it as Arrow files locally, so it must be called only once. I understand in the current implementation it would be called once since the number of partitions is set to 1 so it works but it doesn't leverage Spark distributed.

There is this experimental feature that was added a while ago to let load_dataset run in parallel using Spark via joblibspark : huggingface/datasets#5924

with parallel_backend('spark') as backend:
  ds = load_dataset(..., streaming=False, num_proc=<number of spark jobs to spawn>)  # returns directly if the dataset is cached

It's also possible to get the Dataset from the downloaded and prepared Arrow dataset like this

builder = load_dataset_builder(...)
with parallel_backend('spark') as backend:
  builder.download_and_prepare(..., num_proc=...)  # returns directly if the dataset is cached
ds = builder.as_dataset(split)
# EDIT: it should be possible to get an IterableDataset as well but I need to double check

if this doesn't fit the current implementation well we can keep it for later and call the internals of the builder manually in proper Spark code if needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I can try it out using Spark cluster mode instead of the local mode to see if streaming works better.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good ! we can see if we want to parallelize the non-streaming case later anyway :)

@allisonwang-db allisonwang-db merged commit 6f1e38d into main Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants