-
Notifications
You must be signed in to change notification settings - Fork 5
Enable predicate pushdown #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
import ast | ||
from dataclasses import dataclass | ||
from typing import Sequence | ||
from typing import TYPE_CHECKING, Optional, Sequence | ||
|
||
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition | ||
from pyspark.sql.pandas.types import from_arrow_schema | ||
from pyspark.sql.types import StructType | ||
|
||
if TYPE_CHECKING: | ||
from datasets import DatasetBuilder, IterableDataset | ||
|
||
class HuggingFaceDatasets(DataSource): | ||
""" | ||
A DataSource for reading and writing HuggingFace Datasets in Spark. | ||
|
@@ -51,32 +55,71 @@ class HuggingFaceDatasets(DataSource): | |
|Worth the enterta...| 0| | ||
|... | ...| | ||
+--------------------+-----+ | ||
|
||
Enable predicate pushdown for Parquet datasets. | ||
|
||
>>> spark.read.format("huggingface") \ | ||
... .option("filters", '[("language_score", ">", 0.99)]') \ | ||
... .option("columns", '["text", "language_score"]') \ | ||
... .load("HuggingFaceFW/fineweb-edu") \ | ||
... .show() | ||
+--------------------+------------------+ | ||
| text| language_score| | ||
+--------------------+------------------+ | ||
|died Aug. 28, 181...|0.9901925325393677| | ||
|Coyotes spend a g...|0.9902171492576599| | ||
|... | ...| | ||
+--------------------+------------------+ | ||
""" | ||
|
||
DEFAULT_SPLIT: str = "train" | ||
|
||
def __init__(self, options): | ||
super().__init__(options) | ||
from datasets import load_dataset_builder | ||
|
||
if "path" not in options or not options["path"]: | ||
raise Exception("You must specify a dataset name.") | ||
|
||
kwargs = dict(self.options) | ||
self.dataset_name = kwargs.pop("path") | ||
self.config_name = kwargs.pop("config", None) | ||
self.split = kwargs.pop("split", self.DEFAULT_SPLIT) | ||
self.streaming = kwargs.pop("streaming", "true").lower() == "true" | ||
for arg in kwargs: | ||
if kwargs[arg].lower() == "true": | ||
kwargs[arg] = True | ||
elif kwargs[arg].lower() == "false": | ||
kwargs[arg] = False | ||
else: | ||
try: | ||
kwargs[arg] = ast.literal_eval(kwargs[arg]) | ||
except ValueError: | ||
pass | ||
|
||
self.builder = load_dataset_builder(self.dataset_name, self.config_name, **kwargs) | ||
streaming_dataset = self.builder.as_streaming_dataset() | ||
if self.split not in streaming_dataset: | ||
raise Exception(f"Split {self.split} is invalid. Valid options are {list(streaming_dataset)}") | ||
|
||
self.streaming_dataset = streaming_dataset[self.split] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just want to make sure: is streaming_dataset (or the dataset builder) pickleable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes everything is pickleable :) |
||
if not self.streaming_dataset.features: | ||
self.streaming_dataset = self.streaming_dataset._resolve_features() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the ultimate way to be 100% sure we have the |
||
|
||
@classmethod | ||
def name(cls): | ||
return "huggingface" | ||
|
||
def schema(self): | ||
from datasets import load_dataset_builder | ||
dataset_name = self.options["path"] | ||
config_name = self.options.get("config") | ||
ds_builder = load_dataset_builder(dataset_name, config_name) | ||
features = ds_builder.info.features | ||
if features is None: | ||
raise Exception( | ||
"Unable to automatically determine the schema using the dataset features. " | ||
"Please specify the schema manually using `.schema()`." | ||
) | ||
return from_arrow_schema(features.arrow_schema) | ||
return from_arrow_schema(self.streaming_dataset.features.arrow_schema) | ||
|
||
def reader(self, schema: StructType) -> "DataSourceReader": | ||
return HuggingFaceDatasetsReader(schema, self.options) | ||
return HuggingFaceDatasetsReader( | ||
schema, | ||
builder=self.builder, | ||
split=self.split, | ||
streaming_dataset=self.streaming_dataset if self.streaming else None | ||
) | ||
|
||
|
||
@dataclass | ||
|
@@ -86,37 +129,33 @@ class Shard(InputPartition): | |
|
||
|
||
class HuggingFaceDatasetsReader(DataSourceReader): | ||
DEFAULT_SPLIT: str = "train" | ||
|
||
def __init__(self, schema: StructType, options: dict): | ||
from datasets import get_dataset_split_names, get_dataset_default_config_name | ||
def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, streaming_dataset: Optional["IterableDataset"]): | ||
self.schema = schema | ||
self.dataset_name = options["path"] | ||
self.streaming = options.get("streaming", "true").lower() == "true" | ||
self.config_name = options.get("config") | ||
self.builder = builder | ||
self.split = split | ||
self.streaming_dataset = streaming_dataset | ||
# Get and validate the split name | ||
self.split = options.get("split", self.DEFAULT_SPLIT) | ||
valid_splits = get_dataset_split_names(self.dataset_name, self.config_name) | ||
if self.split not in valid_splits: | ||
raise Exception(f"Split {self.split} is invalid. Valid options are {valid_splits}") | ||
|
||
def partitions(self) -> Sequence[InputPartition]: | ||
from datasets import load_dataset | ||
if not self.streaming: | ||
return [Shard(index=0)] | ||
if self.streaming_dataset: | ||
return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)] | ||
else: | ||
dataset = load_dataset(self.dataset_name, name=self.config_name, split=self.split, streaming=True) | ||
return [Shard(index=i) for i in range(dataset.num_shards)] | ||
return [Shard(index=0)] | ||
|
||
def read(self, partition: Shard): | ||
from datasets import load_dataset | ||
columns = [field.name for field in self.schema.fields] | ||
dataset = load_dataset(self.dataset_name, name=self.config_name, split=self.split, streaming=self.streaming) | ||
if self.streaming: | ||
shard = dataset.shard(num_shards=dataset.num_shards, index=partition.index) | ||
for _, pa_table in shard._ex_iterable.iter_arrow(): | ||
yield from pa_table.select(columns).to_batches() | ||
if self.streaming_dataset: | ||
shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index) | ||
if shard._ex_iterable.iter_arrow: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor edit: some streaming datasets don't have |
||
for _, pa_table in shard._ex_iterable.iter_arrow(): | ||
yield from pa_table.select(columns).to_batches() | ||
else: | ||
for _, example in shard: | ||
yield example | ||
else: | ||
self.builder.download_and_prepare() | ||
dataset = self.builder.as_dataset(self.split) | ||
Comment on lines
+157
to
+158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor edit: reuse the builder instead of calling load_dataset |
||
# Get the underlying arrow table of the dataset | ||
table = dataset._data | ||
yield from table.select(columns).to_batches() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this ok to pass all the remaining
options
as kwargs to the builder ? or should I set an allow-list / disallow-list ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fine. Just note all values in
options
are string.