diff --git a/.gitignore b/.gitignore index 82f9275..7b6caf3 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 0000000..48234e3 --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,297 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "19b1960e-9e0a-401f-be15-d343902eaa21", + "metadata": {}, + "source": [ + "# Spark HuggingFace Connector Demo" + ] + }, + { + "cell_type": "markdown", + "id": "c9a7bf1d-c208-4873-9e06-5db981f8eeaa", + "metadata": {}, + "source": [ + "## Create a Spark Session" + ] + }, + { + "cell_type": "code", + "id": "620d3ecb-b9cb-480c-b300-69198cce7a9c", + "metadata": {}, + "source": [ + "from pyspark.sql import SparkSession\n", + "\n", + "spark = SparkSession.builder.getOrCreate()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "6f876028-2af5-4e63-8e9d-59afc0959267", + "metadata": {}, + "source": [ + "## Load a dataset as a Spark DataFrame" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b8580bde-3f64-4c71-a087-8b3f71099aee", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-26T08:54:32.132099Z", + "start_time": "2024-11-26T08:54:28.903653Z" + } + }, + "outputs": [], + "source": [ + "df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3bbf61d1-4c2c-40e7-9790-2722637aac9d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "root\n", + " |-- text: string (nullable = true)\n", + " |-- label: long (nullable = true)\n", + "\n" + ] + } + ], + "source": [ + "df.printSchema()" + ] + }, + { + "cell_type": "code", + "id": "7f7b9a2b-8733-499a-af56-3c51196d060f", + "metadata": {}, + "source": [ + "# Cache the dataframe to avoid re-downloading data\n", + "df.cache()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "df121dba-2e1e-4206-b2bf-db156c298ee1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8530" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Trigger the cache computation\n", + "df.count()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8866bdfb-0782-4430-8b1e-09c65e699f41", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Row(text='the rock is destined to be the 21st century\\'s new \" conan \" and that he\\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .', label=1)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "id": "0d9d3112-d19b-4fa8-a6fc-ba40816d1d11", + "metadata": {}, + "source": [ + "df.show(n=5)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "225bbbef-4164-424d-a701-c6c74494ef81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4265" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Then you can operate on this dataframe\n", + "df.filter(df.label == 0).count()" + ] + }, + { + "cell_type": "markdown", + "id": "3932f1fd-a324-4f15-86e1-bbe1064d707a", + "metadata": {}, + "source": [ + "## Load a different split\n", + "You can specify the `split` data source option:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a16e9270-eb02-4568-8739-db4dc715c274", + "metadata": {}, + "outputs": [], + "source": [ + "test_df = (\n", + " spark.read.format(\"huggingface\")\n", + " .option(\"split\", \"test\")\n", + " .load(\"rotten_tomatoes\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3aec5719-c3a1-4d18-92c8-2b0c2f4bb939", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataFrame[text: string, label: bigint]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "d605289d-361d-4a6c-9b70-f7ccdff3aa9d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + }, + { + "data": { + "text/plain": [ + "1066" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.count()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "df1ad003-1476-4557-811b-31c3888c0030", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------------------+-----+\n", + "| text|label|\n", + "+--------------------+-----+\n", + "|lovingly photogra...| 1|\n", + "|consistently clev...| 1|\n", + "|it's like a \" big...| 1|\n", + "|the story gives a...| 1|\n", + "|red dragon \" neve...| 1|\n", + "+--------------------+-----+\n", + "only showing top 5 rows\n", + "\n" + ] + } + ], + "source": [ + "test_df.show(n=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7f14b91-059e-4894-83d2-4ed74e0adaf9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyspark_huggingface", + "language": "python", + "name": "pyspark_huggingface" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyspark_huggingface/__init__.py b/pyspark_huggingface/__init__.py new file mode 100644 index 0000000..ac09d17 --- /dev/null +++ b/pyspark_huggingface/__init__.py @@ -0,0 +1 @@ +from pyspark_huggingface.huggingface import HuggingFaceDatasets as DefaultSource diff --git a/pyspark_huggingface/huggingface.py b/pyspark_huggingface/huggingface.py new file mode 100644 index 0000000..42680c8 --- /dev/null +++ b/pyspark_huggingface/huggingface.py @@ -0,0 +1,95 @@ +from pyspark.sql.datasource import DataSource, DataSourceReader +from pyspark.sql.pandas.types import from_arrow_schema +from pyspark.sql.types import StructType + +class HuggingFaceDatasets(DataSource): + """ + A DataSource for reading and writing HuggingFace Datasets in Spark. + + This data source allows reading public datasets from the HuggingFace Hub directly into Spark + DataFrames. The schema is automatically inferred from the dataset features. The split can be + specified using the `split` option. The default split is `train`. + + Name: `huggingface` + + Notes: + ----- + - The HuggingFace `datasets` library is required to use this data source. Make sure it is installed. + - If the schema is automatically inferred, it will use string type for all fields. + - Currently it can only be used with public datasets. Private or gated ones are not supported. + + Examples + -------- + + Load a public dataset from the HuggingFace Hub. + + >>> df = spark.read.format("huggingface").load("imdb") + DataFrame[text: string, label: bigint] + + >>> df.show() + +--------------------+-----+ + | text|label| + +--------------------+-----+ + |I rented I AM CUR...| 0| + |"I Am Curious: Ye...| 0| + |... | ...| + +--------------------+-----+ + + Load a specific split from a public dataset from the HuggingFace Hub. + + >>> spark.read.format("huggingface").option("split", "test").load("imdb").show() + +--------------------+-----+ + | text|label| + +--------------------+-----+ + |I love sci-fi and...| 0| + |Worth the enterta...| 0| + |... | ...| + +--------------------+-----+ + """ + + def __init__(self, options): + super().__init__(options) + if "path" not in options or not options["path"]: + raise Exception("You must specify a dataset name.") + + @classmethod + def name(cls): + return "huggingface" + + def schema(self): + from datasets import load_dataset_builder + dataset_name = self.options["path"] + ds_builder = load_dataset_builder(dataset_name) + features = ds_builder.info.features + if features is None: + raise Exception( + "Unable to automatically determine the schema using the dataset features. " + "Please specify the schema manually using `.schema()`." + ) + return from_arrow_schema(features.arrow_schema) + + def reader(self, schema: StructType) -> "DataSourceReader": + return HuggingFaceDatasetsReader(schema, self.options) + + +class HuggingFaceDatasetsReader(DataSourceReader): + DEFAULT_SPLIT: str = "train" + + def __init__(self, schema: StructType, options: dict): + self.schema = schema + self.dataset_name = options["path"] + # Get and validate the split name + self.split = options.get("split", self.DEFAULT_SPLIT) + from datasets import get_dataset_split_names + valid_splits = get_dataset_split_names(self.dataset_name) + if self.split not in valid_splits: + raise Exception(f"Split {self.split} is invalid. Valid options are {valid_splits}") + + def read(self, partition): + from datasets import load_dataset + columns = [field.name for field in self.schema.fields] + # TODO: add config + iter_dataset = load_dataset(self.dataset_name, split=self.split, streaming=True) + for data in iter_dataset: + # TODO: next spark 4.0.0 dev release will include the feature to yield as an iterator of pa.RecordBatch + yield tuple([data.get(column) for column in columns]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bd26d90 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +datasets==3.1.0 +pyspark[connect]==4.0.0.dev2 +pytest diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_huggingface.py b/tests/test_huggingface.py new file mode 100644 index 0000000..8c6b9b8 --- /dev/null +++ b/tests/test_huggingface.py @@ -0,0 +1,15 @@ +import pytest +from pyspark.sql import SparkSession +from pyspark_huggingface import HuggingFaceDatasets + + +@pytest.fixture +def spark(): + spark = SparkSession.builder.getOrCreate() + yield spark + + +def test_basic_load(spark): + spark.dataSource.register(HuggingFaceDatasets) + df = spark.read.format("huggingface").load("rotten_tomatoes") + assert df.count() == 8530 # length of the training dataset