diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 7e6689c..8733fee 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -1,7 +1,7 @@ import ast import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Iterator, List, Optional +from typing import TYPE_CHECKING, Iterator, List, Optional, Union from pyspark.sql.datasource import ( DataSource, @@ -11,7 +11,12 @@ from pyspark.sql.types import StructType if TYPE_CHECKING: - from huggingface_hub import CommitOperationAdd, CommitOperationDelete + from huggingface_hub import ( + CommitOperation, + CommitOperationAdd, + HfApi, + ) + from huggingface_hub.hf_api import RepoFile, RepoFolder from pyarrow import RecordBatch logger = logging.getLogger(__name__) @@ -27,8 +32,8 @@ class HuggingFaceSink(DataSource): Data Source Options: - token (str, required): HuggingFace API token for authentication. - path (str, required): HuggingFace repository ID, e.g. `{username}/{dataset}`. - - path_in_repo (str): Path within the repository to write the data. Defaults to the root. - - split (str): Split name to write the data to. Defaults to `train`. Only `train`, `test`, and `validation` are supported. + - path_in_repo (str): Path within the repository to write the data. Defaults to "data". + - split (str): Split name to write the data to. Defaults to `train`. - revision (str): Branch, tag, or commit to write to. Defaults to the main branch. - endpoint (str): Custom HuggingFace API endpoint URL. - max_bytes_per_file (int): Maximum size of each Parquet file. @@ -125,7 +130,9 @@ def __init__( import uuid self.repo_id = repo_id - self.path_in_repo = (path_in_repo or "").strip("/") + self.path_in_repo = ( + path_in_repo.strip("/") if path_in_repo is not None else "data" + ) self.split = split or "train" self.revision = revision self.schema = schema @@ -140,26 +147,7 @@ def __init__( # Use a unique filename prefix to avoid conflicts with existing files self.uuid = uuid.uuid4() - self.validate() - - def validate(self): - if self.split not in ["train", "test", "validation"]: - """ - TODO: Add support for custom splits. - - For custom split names to be recognized, the files must have path with format: - `data/{split}-{iiiii}-of-{nnnnn}.parquet` - where `iiiii` is the part number and `nnnnn` is the total number of parts, both padded to 5 digits. - Example: `data/custom-00000-of-00002.parquet` - - Therefore the current usage of UUID to avoid naming conflicts won't work for custom split names. - To fix this we can rename the files in the commit phase to satisfy the naming convention. - """ - raise NotImplementedError( - f"Only 'train', 'test', and 'validation' splits are supported. Got '{self.split}'." - ) - - def get_api(self): + def _get_api(self): from huggingface_hub import HfApi return HfApi(token=self.token, endpoint=self.endpoint) @@ -168,16 +156,11 @@ def get_api(self): def prefix(self) -> str: return f"{self.path_in_repo}/{self.split}".strip("/") - def get_delete_operations(self) -> Iterator["CommitOperationDelete"]: + def _list_split(self, api: "HfApi") -> Iterator[Union["RepoFile", "RepoFolder"]]: """ - Get the commit operations to delete all existing Parquet files. - This is used when `overwrite=True` to clear the target directory. + Get all existing files of the current split. """ - from huggingface_hub import CommitOperationDelete from huggingface_hub.errors import EntryNotFoundError - from huggingface_hub.hf_api import RepoFolder - - api = self.get_api() try: objects = api.list_repo_tree( @@ -190,11 +173,9 @@ def get_delete_operations(self) -> Iterator["CommitOperationDelete"]: ) for obj in objects: if obj.path.startswith(self.prefix): - yield CommitOperationDelete( - path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder) - ) - except EntryNotFoundError as e: - logger.info(f"Writing to a new path: {e}") + yield obj + except EntryNotFoundError: + pass def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: import io @@ -208,7 +189,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: context = TaskContext.get() partition_id = context.partitionId() if context else 0 - api = self.get_api() + api = self._get_api() schema = to_arrow_schema(self.schema) num_files = 0 @@ -265,25 +246,84 @@ def flush(writer: pq.ParquetWriter): return HuggingFaceCommitMessage(additions=additions) def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] - import math + """ + Commit the pre-uploaded Parquet files to the HuggingFace Hub, renaming them to match the expected format: + `{split}-{current:05d}-of-{total:05d}.parquet`. + Also delete or rename existing files of the split, depending on the mode. + """ + + from huggingface_hub import CommitOperationCopy, CommitOperationDelete + from huggingface_hub.hf_api import RepoFile, RepoFolder - api = self.get_api() - operations = [ - addition for message in messages for addition in message.additions - ] - if self.overwrite: # Delete existing files if overwrite is enabled - operations.extend(self.get_delete_operations()) + api = self._get_api() + additions = [addition for message in messages for addition in message.additions] + operations = {} + count_new = len(additions) + count_existing = 0 + + def format_path(i): + return f"{self.prefix}-{i:05d}-of-{count_new + count_existing:05d}.parquet" + + def rename(old_path, new_path): + if old_path != new_path: + yield CommitOperationCopy( + src_path_in_repo=old_path, path_in_repo=new_path + ) + yield CommitOperationDelete(path_in_repo=old_path) + + # In overwrite mode, delete existing files + if self.overwrite: + for obj in self._list_split(api): + # Delete old file + operations[obj.path] = CommitOperationDelete( + path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder) + ) + # In append mode, rename existing files to have the correct total number of parts + else: + rename_operations = [] + existing = list( + obj for obj in self._list_split(api) if isinstance(obj, RepoFile) + ) + count_existing = len(existing) + for i, obj in enumerate(existing): + new_path = format_path(i) + rename_operations.extend(rename(obj.path, new_path)) + # Rename files in a separate commit to prevent them from being overwritten by new files of the same name + self._create_commits( + api, + operations=rename_operations, + message="Rename existing files before uploading new files using PySpark", + ) + + # Rename additions, putting them after existing files if any + for i, addition in enumerate(additions): + addition.path_in_repo = format_path(i + count_existing) + # Overwrite the deletion operation if the file already exists + operations[addition.path_in_repo] = addition + + # Upload the new files + self._create_commits( + api, + operations=list(operations.values()), + message="Upload using PySpark", + ) + + def _create_commits( + self, api: "HfApi", operations: List["CommitOperation"], message: str + ) -> None: """ Split the commit into multiple parts if necessary. The HuggingFace API may time out if there are too many operations in a single commit. """ + import math + num_commits = math.ceil(len(operations) / self.max_operations_per_commit) for i in range(num_commits): begin = i * self.max_operations_per_commit end = (i + 1) * self.max_operations_per_commit part = operations[begin:end] - commit_message = "Upload using PySpark" + ( + commit_message = message + ( f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "" ) api.create_commit( diff --git a/tests/test_huggingface_writer.py b/tests/test_huggingface_writer.py index 734621b..e9ea6ae 100644 --- a/tests/test_huggingface_writer.py +++ b/tests/test_huggingface_writer.py @@ -6,7 +6,6 @@ from pyspark.testing import assertDataFrameEqual from pytest_mock import MockerFixture - # ============== Fixtures & Helpers ============== @pytest.fixture(scope="session") @@ -22,8 +21,10 @@ def token(): return os.environ["HF_TOKEN"] -def reader(spark): - return spark.read.format("huggingface").option("token", token()) +def load(repo, split): + from datasets import load_dataset + + return load_dataset(repo, token=token(), split=split).to_pandas() def writer(df: DataFrame): @@ -34,7 +35,7 @@ def writer(df: DataFrame): def random_df(spark: SparkSession): from pyspark.sql.functions import rand - return lambda n: spark.range(n).select((rand()).alias("value")) + return lambda n: spark.range(n, numPartitions=2).select((rand()).alias("value")) @pytest.fixture(scope="session") @@ -59,41 +60,44 @@ def repo(api, username): # ============== Tests ============== -def test_basic(spark, repo, random_df): + +def test_basic(repo, random_df): df = random_df(10) writer(df).mode("append").save(repo) - actual = reader(spark).load(repo) - assertDataFrameEqual(df, actual) + actual = load(repo, "train") + assertDataFrameEqual(actual, df.toPandas()) -def test_append(spark, repo, random_df): +@pytest.mark.parametrize("split", ["train", "custom"]) +def test_append(repo, random_df, split): df1 = random_df(10) df2 = random_df(10) - writer(df1).mode("append").save(repo) - writer(df2).mode("append").save(repo) - actual = reader(spark).load(repo) + writer(df1).options(split=split).mode("append").save(repo) + writer(df2).options(split=split).mode("append").save(repo) + actual = load(repo, split) expected = df1.union(df2) - assertDataFrameEqual(actual, expected) + assertDataFrameEqual(actual, expected.toPandas()) -def test_overwrite(spark, repo, random_df): +@pytest.mark.parametrize("split", ["train", "custom"]) +def test_overwrite(repo, random_df, split): df1 = random_df(10) df2 = random_df(10) - writer(df1).mode("append").save(repo) - writer(df2).mode("overwrite").save(repo) - actual = reader(spark).load(repo) - assertDataFrameEqual(actual, df2) + writer(df1).options(split=split).mode("append").save(repo) + writer(df2).options(split=split).mode("overwrite").save(repo) + actual = load(repo, split) + assertDataFrameEqual(actual, df2.toPandas()) -def test_split(spark, repo, random_df): +def test_split(repo, random_df): df1 = random_df(10) df2 = random_df(10) writer(df1).mode("append").save(repo) - writer(df2).mode("append").options(split="test").save(repo) - actual1 = reader(spark).options(split="train").load(repo) - actual2 = reader(spark).options(split="test").load(repo) - assertDataFrameEqual(actual1, df1) - assertDataFrameEqual(actual2, df2) + writer(df2).mode("append").options(split="custom").save(repo) + actual1 = load(repo, "train") + actual2 = load(repo, "custom") + assertDataFrameEqual(actual1, df1.toPandas()) + assertDataFrameEqual(actual2, df2.toPandas()) def test_revision(repo, random_df, api):