diff --git a/pyproject.toml b/pyproject.toml index c305bbc..79b722e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" -datasets = "^3.1" +datasets = "^3.2" [tool.poetry.group.dev.dependencies] pytest = "^8.0.0" diff --git a/pyspark_huggingface/huggingface.py b/pyspark_huggingface/huggingface.py index e77bc35..fe20a24 100644 --- a/pyspark_huggingface/huggingface.py +++ b/pyspark_huggingface/huggingface.py @@ -1,10 +1,14 @@ +import ast from dataclasses import dataclass -from typing import Sequence +from typing import TYPE_CHECKING, Optional, Sequence from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition from pyspark.sql.pandas.types import from_arrow_schema from pyspark.sql.types import StructType +if TYPE_CHECKING: + from datasets import DatasetBuilder, IterableDataset + class HuggingFaceDatasets(DataSource): """ A DataSource for reading and writing HuggingFace Datasets in Spark. @@ -51,32 +55,71 @@ class HuggingFaceDatasets(DataSource): |Worth the enterta...| 0| |... | ...| +--------------------+-----+ + + Enable predicate pushdown for Parquet datasets. + + >>> spark.read.format("huggingface") \ + ... .option("filters", '[("language_score", ">", 0.99)]') \ + ... .option("columns", '["text", "language_score"]') \ + ... .load("HuggingFaceFW/fineweb-edu") \ + ... .show() + +--------------------+------------------+ + | text| language_score| + +--------------------+------------------+ + |died Aug. 28, 181...|0.9901925325393677| + |Coyotes spend a g...|0.9902171492576599| + |... | ...| + +--------------------+------------------+ """ + DEFAULT_SPLIT: str = "train" + def __init__(self, options): super().__init__(options) + from datasets import load_dataset_builder + if "path" not in options or not options["path"]: raise Exception("You must specify a dataset name.") + kwargs = dict(self.options) + self.dataset_name = kwargs.pop("path") + self.config_name = kwargs.pop("config", None) + self.split = kwargs.pop("split", self.DEFAULT_SPLIT) + self.streaming = kwargs.pop("streaming", "true").lower() == "true" + for arg in kwargs: + if kwargs[arg].lower() == "true": + kwargs[arg] = True + elif kwargs[arg].lower() == "false": + kwargs[arg] = False + else: + try: + kwargs[arg] = ast.literal_eval(kwargs[arg]) + except ValueError: + pass + + self.builder = load_dataset_builder(self.dataset_name, self.config_name, **kwargs) + streaming_dataset = self.builder.as_streaming_dataset() + if self.split not in streaming_dataset: + raise Exception(f"Split {self.split} is invalid. Valid options are {list(streaming_dataset)}") + + self.streaming_dataset = streaming_dataset[self.split] + if not self.streaming_dataset.features: + self.streaming_dataset = self.streaming_dataset._resolve_features() + @classmethod def name(cls): return "huggingface" def schema(self): - from datasets import load_dataset_builder - dataset_name = self.options["path"] - config_name = self.options.get("config") - ds_builder = load_dataset_builder(dataset_name, config_name) - features = ds_builder.info.features - if features is None: - raise Exception( - "Unable to automatically determine the schema using the dataset features. " - "Please specify the schema manually using `.schema()`." - ) - return from_arrow_schema(features.arrow_schema) + return from_arrow_schema(self.streaming_dataset.features.arrow_schema) def reader(self, schema: StructType) -> "DataSourceReader": - return HuggingFaceDatasetsReader(schema, self.options) + return HuggingFaceDatasetsReader( + schema, + builder=self.builder, + split=self.split, + streaming_dataset=self.streaming_dataset if self.streaming else None + ) @dataclass @@ -86,37 +129,33 @@ class Shard(InputPartition): class HuggingFaceDatasetsReader(DataSourceReader): - DEFAULT_SPLIT: str = "train" - def __init__(self, schema: StructType, options: dict): - from datasets import get_dataset_split_names, get_dataset_default_config_name + def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, streaming_dataset: Optional["IterableDataset"]): self.schema = schema - self.dataset_name = options["path"] - self.streaming = options.get("streaming", "true").lower() == "true" - self.config_name = options.get("config") + self.builder = builder + self.split = split + self.streaming_dataset = streaming_dataset # Get and validate the split name - self.split = options.get("split", self.DEFAULT_SPLIT) - valid_splits = get_dataset_split_names(self.dataset_name, self.config_name) - if self.split not in valid_splits: - raise Exception(f"Split {self.split} is invalid. Valid options are {valid_splits}") def partitions(self) -> Sequence[InputPartition]: - from datasets import load_dataset - if not self.streaming: - return [Shard(index=0)] + if self.streaming_dataset: + return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)] else: - dataset = load_dataset(self.dataset_name, name=self.config_name, split=self.split, streaming=True) - return [Shard(index=i) for i in range(dataset.num_shards)] + return [Shard(index=0)] def read(self, partition: Shard): - from datasets import load_dataset columns = [field.name for field in self.schema.fields] - dataset = load_dataset(self.dataset_name, name=self.config_name, split=self.split, streaming=self.streaming) - if self.streaming: - shard = dataset.shard(num_shards=dataset.num_shards, index=partition.index) - for _, pa_table in shard._ex_iterable.iter_arrow(): - yield from pa_table.select(columns).to_batches() + if self.streaming_dataset: + shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index) + if shard._ex_iterable.iter_arrow: + for _, pa_table in shard._ex_iterable.iter_arrow(): + yield from pa_table.select(columns).to_batches() + else: + for _, example in shard: + yield example else: + self.builder.download_and_prepare() + dataset = self.builder.as_dataset(self.split) # Get the underlying arrow table of the dataset table = dataset._data yield from table.select(columns).to_batches()