Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.9"
datasets = "^3.1"
datasets = "^3.2"

[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
Expand Down
107 changes: 73 additions & 34 deletions pyspark_huggingface/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import ast
from dataclasses import dataclass
from typing import Sequence
from typing import TYPE_CHECKING, Optional, Sequence

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.pandas.types import from_arrow_schema
from pyspark.sql.types import StructType

if TYPE_CHECKING:
from datasets import DatasetBuilder, IterableDataset

class HuggingFaceDatasets(DataSource):
"""
A DataSource for reading and writing HuggingFace Datasets in Spark.
Expand Down Expand Up @@ -51,32 +55,71 @@ class HuggingFaceDatasets(DataSource):
|Worth the enterta...| 0|
|... | ...|
+--------------------+-----+

Enable predicate pushdown for Parquet datasets.

>>> spark.read.format("huggingface") \
... .option("filters", '[("language_score", ">", 0.99)]') \
... .option("columns", '["text", "language_score"]') \
... .load("HuggingFaceFW/fineweb-edu") \
... .show()
+--------------------+------------------+
| text| language_score|
+--------------------+------------------+
|died Aug. 28, 181...|0.9901925325393677|
|Coyotes spend a g...|0.9902171492576599|
|... | ...|
+--------------------+------------------+
"""

DEFAULT_SPLIT: str = "train"

def __init__(self, options):
super().__init__(options)
from datasets import load_dataset_builder

if "path" not in options or not options["path"]:
raise Exception("You must specify a dataset name.")

kwargs = dict(self.options)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this ok to pass all the remaining options as kwargs to the builder ? or should I set an allow-list / disallow-list ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine. Just note all values in options are string.

self.dataset_name = kwargs.pop("path")
self.config_name = kwargs.pop("config", None)
self.split = kwargs.pop("split", self.DEFAULT_SPLIT)
self.streaming = kwargs.pop("streaming", "true").lower() == "true"
for arg in kwargs:
if kwargs[arg].lower() == "true":
kwargs[arg] = True
elif kwargs[arg].lower() == "false":
kwargs[arg] = False
else:
try:
kwargs[arg] = ast.literal_eval(kwargs[arg])
except ValueError:
pass

self.builder = load_dataset_builder(self.dataset_name, self.config_name, **kwargs)
streaming_dataset = self.builder.as_streaming_dataset()
if self.split not in streaming_dataset:
raise Exception(f"Split {self.split} is invalid. Valid options are {list(streaming_dataset)}")

self.streaming_dataset = streaming_dataset[self.split]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to make sure: is streaming_dataset (or the dataset builder) pickleable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes everything is pickleable :)

if not self.streaming_dataset.features:
self.streaming_dataset = self.streaming_dataset._resolve_features()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the ultimate way to be 100% sure we have the features - because for some data formats like JSON Lines we need to stream some rows to infer the features


@classmethod
def name(cls):
return "huggingface"

def schema(self):
from datasets import load_dataset_builder
dataset_name = self.options["path"]
config_name = self.options.get("config")
ds_builder = load_dataset_builder(dataset_name, config_name)
features = ds_builder.info.features
if features is None:
raise Exception(
"Unable to automatically determine the schema using the dataset features. "
"Please specify the schema manually using `.schema()`."
)
return from_arrow_schema(features.arrow_schema)
return from_arrow_schema(self.streaming_dataset.features.arrow_schema)

def reader(self, schema: StructType) -> "DataSourceReader":
return HuggingFaceDatasetsReader(schema, self.options)
return HuggingFaceDatasetsReader(
schema,
builder=self.builder,
split=self.split,
streaming_dataset=self.streaming_dataset if self.streaming else None
)


@dataclass
Expand All @@ -86,37 +129,33 @@ class Shard(InputPartition):


class HuggingFaceDatasetsReader(DataSourceReader):
DEFAULT_SPLIT: str = "train"

def __init__(self, schema: StructType, options: dict):
from datasets import get_dataset_split_names, get_dataset_default_config_name
def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, streaming_dataset: Optional["IterableDataset"]):
self.schema = schema
self.dataset_name = options["path"]
self.streaming = options.get("streaming", "true").lower() == "true"
self.config_name = options.get("config")
self.builder = builder
self.split = split
self.streaming_dataset = streaming_dataset
# Get and validate the split name
self.split = options.get("split", self.DEFAULT_SPLIT)
valid_splits = get_dataset_split_names(self.dataset_name, self.config_name)
if self.split not in valid_splits:
raise Exception(f"Split {self.split} is invalid. Valid options are {valid_splits}")

def partitions(self) -> Sequence[InputPartition]:
from datasets import load_dataset
if not self.streaming:
return [Shard(index=0)]
if self.streaming_dataset:
return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)]
else:
dataset = load_dataset(self.dataset_name, name=self.config_name, split=self.split, streaming=True)
return [Shard(index=i) for i in range(dataset.num_shards)]
return [Shard(index=0)]

def read(self, partition: Shard):
from datasets import load_dataset
columns = [field.name for field in self.schema.fields]
dataset = load_dataset(self.dataset_name, name=self.config_name, split=self.split, streaming=self.streaming)
if self.streaming:
shard = dataset.shard(num_shards=dataset.num_shards, index=partition.index)
for _, pa_table in shard._ex_iterable.iter_arrow():
yield from pa_table.select(columns).to_batches()
if self.streaming_dataset:
shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index)
if shard._ex_iterable.iter_arrow:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor edit: some streaming datasets don't have iter_arrow , like datasets WebDataset formats for which we stream python objects

for _, pa_table in shard._ex_iterable.iter_arrow():
yield from pa_table.select(columns).to_batches()
else:
for _, example in shard:
yield example
else:
self.builder.download_and_prepare()
dataset = self.builder.as_dataset(self.split)
Comment on lines +157 to +158
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor edit: reuse the builder instead of calling load_dataset

# Get the underlying arrow table of the dataset
table = dataset._data
yield from table.select(columns).to_batches()