Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 86 additions & 46 deletions pyspark_huggingface/huggingface_sink.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional
from typing import TYPE_CHECKING, Iterator, List, Optional, Union

from pyspark.sql.datasource import (
DataSource,
Expand All @@ -11,7 +11,12 @@
from pyspark.sql.types import StructType

if TYPE_CHECKING:
from huggingface_hub import CommitOperationAdd, CommitOperationDelete
from huggingface_hub import (
CommitOperation,
CommitOperationAdd,
HfApi,
)
from huggingface_hub.hf_api import RepoFile, RepoFolder
from pyarrow import RecordBatch

logger = logging.getLogger(__name__)
Expand All @@ -27,8 +32,8 @@ class HuggingFaceSink(DataSource):
Data Source Options:
- token (str, required): HuggingFace API token for authentication.
- path (str, required): HuggingFace repository ID, e.g. `{username}/{dataset}`.
- path_in_repo (str): Path within the repository to write the data. Defaults to the root.
- split (str): Split name to write the data to. Defaults to `train`. Only `train`, `test`, and `validation` are supported.
- path_in_repo (str): Path within the repository to write the data. Defaults to "data".
- split (str): Split name to write the data to. Defaults to `train`.
- revision (str): Branch, tag, or commit to write to. Defaults to the main branch.
- endpoint (str): Custom HuggingFace API endpoint URL.
- max_bytes_per_file (int): Maximum size of each Parquet file.
Expand Down Expand Up @@ -125,7 +130,9 @@ def __init__(
import uuid

self.repo_id = repo_id
self.path_in_repo = (path_in_repo or "").strip("/")
self.path_in_repo = (
path_in_repo.strip("/") if path_in_repo is not None else "data"
)
self.split = split or "train"
self.revision = revision
self.schema = schema
Expand All @@ -140,26 +147,7 @@ def __init__(
# Use a unique filename prefix to avoid conflicts with existing files
self.uuid = uuid.uuid4()

self.validate()

def validate(self):
if self.split not in ["train", "test", "validation"]:
"""
TODO: Add support for custom splits.

For custom split names to be recognized, the files must have path with format:
`data/{split}-{iiiii}-of-{nnnnn}.parquet`
where `iiiii` is the part number and `nnnnn` is the total number of parts, both padded to 5 digits.
Example: `data/custom-00000-of-00002.parquet`

Therefore the current usage of UUID to avoid naming conflicts won't work for custom split names.
To fix this we can rename the files in the commit phase to satisfy the naming convention.
"""
raise NotImplementedError(
f"Only 'train', 'test', and 'validation' splits are supported. Got '{self.split}'."
)

def get_api(self):
def _get_api(self):
from huggingface_hub import HfApi

return HfApi(token=self.token, endpoint=self.endpoint)
Expand All @@ -168,16 +156,11 @@ def get_api(self):
def prefix(self) -> str:
return f"{self.path_in_repo}/{self.split}".strip("/")

def get_delete_operations(self) -> Iterator["CommitOperationDelete"]:
def _list_split(self, api: "HfApi") -> Iterator[Union["RepoFile", "RepoFolder"]]:
"""
Get the commit operations to delete all existing Parquet files.
This is used when `overwrite=True` to clear the target directory.
Get all existing files of the current split.
"""
from huggingface_hub import CommitOperationDelete
from huggingface_hub.errors import EntryNotFoundError
from huggingface_hub.hf_api import RepoFolder

api = self.get_api()

try:
objects = api.list_repo_tree(
Expand All @@ -190,11 +173,9 @@ def get_delete_operations(self) -> Iterator["CommitOperationDelete"]:
)
for obj in objects:
if obj.path.startswith(self.prefix):
yield CommitOperationDelete(
path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder)
)
except EntryNotFoundError as e:
logger.info(f"Writing to a new path: {e}")
yield obj
except EntryNotFoundError:
pass

def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage:
import io
Expand All @@ -208,7 +189,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage:
context = TaskContext.get()
partition_id = context.partitionId() if context else 0

api = self.get_api()
api = self._get_api()

schema = to_arrow_schema(self.schema)
num_files = 0
Expand Down Expand Up @@ -265,25 +246,84 @@ def flush(writer: pq.ParquetWriter):
return HuggingFaceCommitMessage(additions=additions)

def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override]
import math
"""
Commit the pre-uploaded Parquet files to the HuggingFace Hub, renaming them to match the expected format:
`{split}-{current:05d}-of-{total:05d}.parquet`.
Also delete or rename existing files of the split, depending on the mode.
"""

from huggingface_hub import CommitOperationCopy, CommitOperationDelete
from huggingface_hub.hf_api import RepoFile, RepoFolder

api = self.get_api()
operations = [
addition for message in messages for addition in message.additions
]
if self.overwrite: # Delete existing files if overwrite is enabled
operations.extend(self.get_delete_operations())
api = self._get_api()

additions = [addition for message in messages for addition in message.additions]
operations = {}
count_new = len(additions)
count_existing = 0

def format_path(i):
return f"{self.prefix}-{i:05d}-of-{count_new + count_existing:05d}.parquet"

def rename(old_path, new_path):
if old_path != new_path:
yield CommitOperationCopy(
src_path_in_repo=old_path, path_in_repo=new_path
)
yield CommitOperationDelete(path_in_repo=old_path)

# In overwrite mode, delete existing files
if self.overwrite:
for obj in self._list_split(api):
# Delete old file
operations[obj.path] = CommitOperationDelete(
path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder)
)
# In append mode, rename existing files to have the correct total number of parts
else:
rename_operations = []
existing = list(
obj for obj in self._list_split(api) if isinstance(obj, RepoFile)
)
count_existing = len(existing)
for i, obj in enumerate(existing):
new_path = format_path(i)
rename_operations.extend(rename(obj.path, new_path))
# Rename files in a separate commit to prevent them from being overwritten by new files of the same name
self._create_commits(
api,
operations=rename_operations,
message="Rename existing files before uploading new files using PySpark",
)

# Rename additions, putting them after existing files if any
for i, addition in enumerate(additions):
addition.path_in_repo = format_path(i + count_existing)
# Overwrite the deletion operation if the file already exists
operations[addition.path_in_repo] = addition

# Upload the new files
self._create_commits(
api,
operations=list(operations.values()),
message="Upload using PySpark",
)

def _create_commits(
self, api: "HfApi", operations: List["CommitOperation"], message: str
) -> None:
"""
Split the commit into multiple parts if necessary.
The HuggingFace API may time out if there are too many operations in a single commit.
"""
import math

num_commits = math.ceil(len(operations) / self.max_operations_per_commit)
for i in range(num_commits):
begin = i * self.max_operations_per_commit
end = (i + 1) * self.max_operations_per_commit
part = operations[begin:end]
commit_message = "Upload using PySpark" + (
commit_message = message + (
f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else ""
)
api.create_commit(
Expand Down
50 changes: 27 additions & 23 deletions tests/test_huggingface_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pyspark.testing import assertDataFrameEqual
from pytest_mock import MockerFixture


# ============== Fixtures & Helpers ==============

@pytest.fixture(scope="session")
Expand All @@ -22,8 +21,10 @@ def token():
return os.environ["HF_TOKEN"]


def reader(spark):
return spark.read.format("huggingface").option("token", token())
def load(repo, split):
from datasets import load_dataset

return load_dataset(repo, token=token(), split=split).to_pandas()


def writer(df: DataFrame):
Expand All @@ -34,7 +35,7 @@ def writer(df: DataFrame):
def random_df(spark: SparkSession):
from pyspark.sql.functions import rand

return lambda n: spark.range(n).select((rand()).alias("value"))
return lambda n: spark.range(n, numPartitions=2).select((rand()).alias("value"))


@pytest.fixture(scope="session")
Expand All @@ -59,41 +60,44 @@ def repo(api, username):

# ============== Tests ==============

def test_basic(spark, repo, random_df):

def test_basic(repo, random_df):
df = random_df(10)
writer(df).mode("append").save(repo)
actual = reader(spark).load(repo)
assertDataFrameEqual(df, actual)
actual = load(repo, "train")
assertDataFrameEqual(actual, df.toPandas())


def test_append(spark, repo, random_df):
@pytest.mark.parametrize("split", ["train", "custom"])
def test_append(repo, random_df, split):
df1 = random_df(10)
df2 = random_df(10)
writer(df1).mode("append").save(repo)
writer(df2).mode("append").save(repo)
actual = reader(spark).load(repo)
writer(df1).options(split=split).mode("append").save(repo)
writer(df2).options(split=split).mode("append").save(repo)
actual = load(repo, split)
expected = df1.union(df2)
assertDataFrameEqual(actual, expected)
assertDataFrameEqual(actual, expected.toPandas())


def test_overwrite(spark, repo, random_df):
@pytest.mark.parametrize("split", ["train", "custom"])
def test_overwrite(repo, random_df, split):
df1 = random_df(10)
df2 = random_df(10)
writer(df1).mode("append").save(repo)
writer(df2).mode("overwrite").save(repo)
actual = reader(spark).load(repo)
assertDataFrameEqual(actual, df2)
writer(df1).options(split=split).mode("append").save(repo)
writer(df2).options(split=split).mode("overwrite").save(repo)
actual = load(repo, split)
assertDataFrameEqual(actual, df2.toPandas())


def test_split(spark, repo, random_df):
def test_split(repo, random_df):
df1 = random_df(10)
df2 = random_df(10)
writer(df1).mode("append").save(repo)
writer(df2).mode("append").options(split="test").save(repo)
actual1 = reader(spark).options(split="train").load(repo)
actual2 = reader(spark).options(split="test").load(repo)
assertDataFrameEqual(actual1, df1)
assertDataFrameEqual(actual2, df2)
writer(df2).mode("append").options(split="custom").save(repo)
actual1 = load(repo, "train")
actual2 = load(repo, "custom")
assertDataFrameEqual(actual1, df1.toPandas())
assertDataFrameEqual(actual2, df2.toPandas())


def test_revision(repo, random_df, api):
Expand Down