diff --git a/kedro-datasets/kedro_datasets/spark/README.md b/kedro-datasets/kedro_datasets/spark/README.md new file mode 100644 index 000000000..7400c3c47 --- /dev/null +++ b/kedro-datasets/kedro_datasets/spark/README.md @@ -0,0 +1,44 @@ +# Spark Streaming + +``SparkStreamingDataSet`` loads and saves data to streaming DataFrames. +See [Spark Structured Streaming](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html) for details. + +To work with multiple streaming nodes, 2 hooks are required for: + - Integrating Pyspark, see [Build a Kedro pipeline with PySpark](https://docs.kedro.org/en/stable/integrations/pyspark_integration.html) for details + - Running streaming query without termination unless exception + +#### Supported file formats + +Supported file formats are: + +- Text +- CSV +- JSON +- ORC +- Parquet + +#### Example SparkStreamsHook: + +```python +from kedro.framework.hooks import hook_impl +from pyspark.sql import SparkSession + +class SparkStreamsHook: + @hook_impl + def after_pipeline_run(self) -> None: + """Starts a spark streaming await session + once the pipeline reaches the last node + """ + + spark = SparkSession.builder.getOrCreate() + spark.streams.awaitAnyTermination() +``` +To make the application work with Kafka format, the respective spark configuration needs to be added to``conf/base/spark.yml``. + +#### Example spark.yml: + +```yaml +spark.driver.maxResultSize: 3g +spark.scheduler.mode: FAIR + +``` diff --git a/kedro-datasets/kedro_datasets/spark/__init__.py b/kedro-datasets/kedro_datasets/spark/__init__.py index 3dede09aa..bd649f5c7 100644 --- a/kedro-datasets/kedro_datasets/spark/__init__.py +++ b/kedro-datasets/kedro_datasets/spark/__init__.py @@ -1,6 +1,12 @@ """Provides I/O modules for Apache Spark.""" -__all__ = ["SparkDataSet", "SparkHiveDataSet", "SparkJDBCDataSet", "DeltaTableDataSet"] +__all__ = [ + "SparkDataSet", + "SparkHiveDataSet", + "SparkJDBCDataSet", + "DeltaTableDataSet", + "SparkStreamingDataSet", +] from contextlib import suppress @@ -12,3 +18,5 @@ from .spark_jdbc_dataset import SparkJDBCDataSet with suppress(ImportError): from .deltatable_dataset import DeltaTableDataSet +with suppress(ImportError): + from .spark_streaming_dataset import SparkStreamingDataSet diff --git a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py new file mode 100644 index 000000000..2f7743e65 --- /dev/null +++ b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py @@ -0,0 +1,155 @@ +"""SparkStreamingDataSet to load and save a PySpark Streaming DataFrame.""" +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any, Dict + +from kedro.io.core import AbstractDataSet +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.utils import AnalysisException + +from kedro_datasets.spark.spark_dataset import ( + SparkDataSet, + _split_filepath, + _strip_dbfs_prefix, +) + + +class SparkStreamingDataSet(AbstractDataSet): + """``SparkStreamingDataSet`` loads data into Spark Streaming Dataframe objects. + Example usage for the + `YAML API `_: + .. code-block:: yaml + raw.new_inventory: + type: streaming.extras.datasets.spark_streaming_dataset.SparkStreamingDataSet + filepath: data/01_raw/stream/inventory/ + file_format: json + save_args: + output_mode: append + checkpoint: data/04_checkpoint/raw_new_inventory + header: True + load_args: + schema: + filepath: data/01_raw/schema/inventory_schema.json + """ + + DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any] + DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any] + + def __init__( + self, + filepath: str = "", + file_format: str = "", + save_args: Dict[str, Any] = None, + load_args: Dict[str, Any] = None, + ) -> None: + """Creates a new instance of SparkStreamingDataSet. + Args: + filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks + specify ``filepath``s starting with ``/dbfs/``. For message brokers such as + Kafka and all filepath is not required. + file_format: File format used during load and save + operations. These are formats supported by the running + SparkContext include parquet, csv, delta. For a list of supported + formats please refer to Apache Spark documentation at + https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html + load_args: Load args passed to Spark DataFrameReader load method. + It is dependent on the selected file format. You can find + a list of read options for each supported format + in Spark DataFrame read documentation: + https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html, + Please note that a schema is mandatory for a streaming DataFrame + if ``schemaInference`` is not True. + save_args: Save args passed to Spark DataFrame write options. + Similar to load_args this is dependent on the selected file + format. You can pass ``mode`` and ``partitionBy`` to specify + your overwrite mode and partitioning respectively. You can find + a list of options for each format in Spark DataFrame + write documentation: + https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html + """ + self._file_format = file_format + self._save_args = save_args + self._load_args = load_args + + fs_prefix, filepath = _split_filepath(filepath) + + self._fs_prefix = fs_prefix + self._filepath = PurePosixPath(filepath) + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + # Handle schema load argument + self._schema = self._load_args.pop("schema", None) + if self._schema is not None: + if isinstance(self._schema, dict): + self._schema = SparkDataSet._load_schema_from_file(self._schema) + + def _describe(self) -> Dict[str, Any]: + """Returns a dict that describes attributes of the dataset.""" + return { + "filepath": self._fs_prefix + str(self._filepath), + "file_format": self._file_format, + "load_args": self._load_args, + "save_args": self._save_args, + } + + @staticmethod + def _get_spark(): + return SparkSession.builder.getOrCreate() + + def _load(self) -> DataFrame: + """Loads data from filepath. + If the connector type is kafka then no file_path is required, schema needs to be + seperated from load_args. + Returns: + Data from filepath as pyspark dataframe. + """ + load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + data_stream_reader = ( + self._get_spark() + .readStream.schema(self._schema) + .format(self._file_format) + .options(**self._load_args) + ) + return data_stream_reader.load(load_path) + + def _save(self, data: DataFrame) -> None: + """Saves pyspark dataframe. + Args: + data: PySpark streaming dataframe for saving + """ + save_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + output_constructor = data.writeStream.format(self._file_format) + + ( + output_constructor.option( + "checkpointLocation", self._save_args.pop("checkpoint") + ) + .option("path", save_path) + .outputMode(self._save_args.pop("output_mode")) + .options(**self._save_args) + .start() + ) + + def _exists(self) -> bool: + load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + + try: + self._get_spark().readStream.schema(self._schema).load( + load_path, self._file_format + ) + except AnalysisException as exception: + if ( + exception.desc.startswith("Path does not exist:") + or "is not a Streaming data" in exception.desc + ): + return False + raise + return True diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index e69de8fa9..210eb6884 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -50,10 +50,15 @@ def _collect_requirements(requires): "plotly.PlotlyDataSet": [PANDAS, "plotly>=4.8.0, <6.0"], "plotly.JSONDataSet": ["plotly>=4.8.0, <6.0"], } -polars_require = {"polars.CSVDataSet": [POLARS],} +polars_require = { + "polars.CSVDataSet": [POLARS] +} redis_require = {"redis.PickleDataSet": ["redis~=4.1"]} snowflake_require = { - "snowflake.SnowparkTableDataSet": ["snowflake-snowpark-python~=1.0.0", "pyarrow~=8.0"] + "snowflake.SnowparkTableDataSet": [ + "snowflake-snowpark-python~=1.0.0", + "pyarrow~=8.0", + ] } spark_require = { "spark.SparkDataSet": [SPARK, HDFS, S3FS], @@ -71,9 +76,7 @@ def _collect_requirements(requires): "tensorflow-macos~=2.0; platform_system == 'Darwin' and platform_machine == 'arm64'", ] } -video_require = { - "video.VideoDataSet": ["opencv-python~=4.5.5.64"] -} +video_require = {"video.VideoDataSet": ["opencv-python~=4.5.5.64"]} yaml_require = {"yaml.YAMLDataSet": [PANDAS, "PyYAML>=4.2, <7.0"]} extras_require = { diff --git a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py new file mode 100644 index 000000000..c4fb6c005 --- /dev/null +++ b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py @@ -0,0 +1,178 @@ +import json + +import boto3 +import pytest +from kedro.io.core import DataSetError +from moto import mock_s3 +from pyspark.sql import SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from pyspark.sql.utils import AnalysisException + +from kedro_datasets.spark.spark_dataset import SparkDataSet +from kedro_datasets.spark.spark_streaming_dataset import SparkStreamingDataSet + +SCHEMA_FILE_NAME = "schema.json" +BUCKET_NAME = "test_bucket" +AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} + + +def sample_schema(schema_path): + """read the schema file from json path""" + with open(schema_path, encoding="utf-8") as f: + try: + return StructType.fromJson(json.loads(f.read())) + except Exception as exc: + raise DataSetError( + f"Contents of 'schema.filepath' ({schema_path}) are invalid. " + f"Schema is required for streaming data load, Please provide a valid schema_path." + ) from exc + + +@pytest.fixture +def sample_spark_df_schema() -> StructType: + """Spark Dataframe schema""" + return StructType( + [ + StructField("sku", StringType(), True), + StructField("new_stock", IntegerType(), True), + ] + ) + + +@pytest.fixture +def sample_spark_streaming_df(tmp_path, sample_spark_df_schema): + """Create a sample dataframe for streaming""" + data = [("0001", 2), ("0001", 7), ("0002", 4)] + schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix() + with open(schema_path, "w", encoding="utf-8") as f: + json.dump(sample_spark_df_schema.jsonValue(), f) + return SparkSession.builder.getOrCreate().createDataFrame( + data, sample_spark_df_schema + ) + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing using moto.""" + with mock_s3(): + conn = boto3.client( + "s3", + aws_access_key_id="fake_access_key", + aws_secret_access_key="fake_secret_key", + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def s3_bucket(): + with mock_s3(): + s3 = boto3.resource("s3", region_name="us-east-1") + bucket_name = "test-bucket" + s3.create_bucket(Bucket=bucket_name) + yield bucket_name + + +@pytest.fixture +def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructType): + """Creates schema file and adds it to mocked S3 bucket.""" + temporary_path = tmp_path / SCHEMA_FILE_NAME + temporary_path.write_text(sample_spark_df_schema.json(), encoding="utf-8") + + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, Key=SCHEMA_FILE_NAME, Body=temporary_path.read_bytes() + ) + return mocked_s3_bucket + + +class TestSparkStreamingDataSet: + def test_load(self, tmp_path, sample_spark_streaming_df): + filepath = (tmp_path / "test_streams").as_posix() + schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix() + + spark_json_ds = SparkDataSet( + filepath=filepath, file_format="json", save_args=[{"mode", "overwrite"}] + ) + spark_json_ds.save(sample_spark_streaming_df) + + streaming_ds = SparkStreamingDataSet( + filepath=filepath, + file_format="json", + load_args={"schema": {"filepath": schema_path}}, + ).load() + assert streaming_ds.isStreaming + schema = sample_schema(schema_path) + assert streaming_ds.schema == schema + + @pytest.mark.usefixtures("mocked_s3_schema") + def test_load_options_schema_path_with_credentials( + self, tmp_path, sample_spark_streaming_df + ): + filepath = (tmp_path / "test_streams").as_posix() + schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix() + + spark_json_ds = SparkDataSet( + filepath=filepath, file_format="json", save_args=[{"mode", "overwrite"}] + ) + spark_json_ds.save(sample_spark_streaming_df) + + streaming_ds = SparkStreamingDataSet( + filepath=filepath, + file_format="json", + load_args={ + "schema": { + "filepath": f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}", + "credentials": AWS_CREDENTIALS, + } + }, + ).load() + + assert streaming_ds.isStreaming + schema = sample_schema(schema_path) + assert streaming_ds.schema == schema + + def test_save(self, tmp_path, sample_spark_streaming_df): + filepath_json = (tmp_path / "test_streams").as_posix() + filepath_output = (tmp_path / "test_streams_output").as_posix() + schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix() + checkpoint_path = (tmp_path / "checkpoint").as_posix() + + # Save the sample json file to temp_path for creating dataframe + spark_json_ds = SparkDataSet( + filepath=filepath_json, + file_format="json", + save_args=[{"mode", "overwrite"}], + ) + spark_json_ds.save(sample_spark_streaming_df) + + # Load the json file as the streaming dataframe + loaded_with_streaming = SparkStreamingDataSet( + filepath=filepath_json, + file_format="json", + load_args={"schema": {"filepath": schema_path}}, + ).load() + + # Append json streams to filepath_output with specified schema path + streaming_ds = SparkStreamingDataSet( + filepath=filepath_output, + file_format="json", + load_args={"schema": {"filepath": schema_path}}, + save_args={"checkpoint": checkpoint_path, "output_mode": "append"}, + ) + assert not streaming_ds.exists() + + streaming_ds.save(loaded_with_streaming) + assert streaming_ds.exists() + + def test_exists_raises_error(self, mocker): + # exists should raise all errors except for + # AnalysisExceptions clearly indicating a missing file + spark_data_set = SparkStreamingDataSet(filepath="") + mocker.patch.object( + spark_data_set, + "_get_spark", + side_effect=AnalysisException("Other Exception", []), + ) + + with pytest.raises(DataSetError, match="Other Exception"): + spark_data_set.exists()