From 18a2e9c9e68b831045d7e5ca8fb3b65df7d48ac0 Mon Sep 17 00:00:00 2001 From: Herminio Vazquez Date: Fri, 22 Mar 2024 21:38:25 +0100 Subject: [PATCH] Added spark connect implementation (#185) --- README.md | 12 +++++++++++- cuallee/__init__.py | 12 +++++++++++- cuallee/pyspark_validation.py | 22 ++++++++++++++++++++-- pyproject.toml | 5 ++++- setup.cfg | 2 +- test/unit/class_control/test_methods.py | 4 +++- 6 files changed, 50 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 59fd21b1..164fbcd3 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ When benchmarking against pydeequ, `cuallee` uses circa <3k java classes underne Provider | API | Versions ------- | ----------- | ------ ![snowflake](logos/snowflake.svg?raw=true "Snowpark DataFrame API")| `snowpark` | `1.11.1`, `1.4.0` -![databricks](logos/databricks.svg?raw=true "PySpark DataFrame API")| `pyspark` | `3.4.0`, `3.3.x`, `3.2.x` +![databricks](logos/databricks.svg?raw=true "PySpark DataFrame API")| `pyspark` & `spark-connect` |`3.5.x`, `3.4.0`, `3.3.x`, `3.2.x` ![bigquery](logos/bigquery.png?raw=true "BigQuery Client API")| `bigquery` | `3.4.1` ![pandas](logos/pandas.svg?raw=true "Pandas DataFrame API")| `pandas`| `2.0.2`, `1.5.x`, `1.4.x` ![duckdb](logos/duckdb.png?raw=true "DuckDB API")|`duckdb` | `0.9.2`,~~`0.8.0`~~, ~~`0.7.1`~~ @@ -284,6 +284,14 @@ In order to establish a connection to your SnowFlake account `cuallee` relies in - `SF_DATABASE` - `SF_SCHEMA` +## Spark Connect +Just add the environment variable `SPARK_REMOTE` to your remote session, then `cuallee` will connect using +```python +spark_connect = SparkSession.builder.remote(os.getenv("SPARK_REMOTE")).getOrCreate() +``` +and convert all checks to `select` as opposed to `Observation` API compute instructions. + + ## Databricks Connection By default `cuallee` will search for a SparkSession available in the `globals` so there is literally no need to ~~`SparkSession.builder`~~. When working in a local environment it will automatically search for an available session, or start one. @@ -308,6 +316,7 @@ check.validate(conn) `100%` data frame agnostic implementation of data quality checks. Define once, `run everywhere` +- ~~[x] PySpark 3.5.0~~ - ~~[x] PySpark 3.4.0~~ - ~~[x] PySpark 3.3.0~~ - ~~[x] PySpark 3.2.x~~ @@ -317,6 +326,7 @@ Define once, `run everywhere` - ~~[x] BigQuery Client~~ - ~~[x] Polars DataFrame~~ - ~~[*] Dagster Integration~~ +- ~~[x] Spark Connect~~ - [-] PDF Report - [ ] Metadata check - [ ] Help us in a discussion? diff --git a/cuallee/__init__.py b/cuallee/__init__.py index 37424f7f..20bac255 100644 --- a/cuallee/__init__.py +++ b/cuallee/__init__.py @@ -12,7 +12,7 @@ from toolz.curried import map as map_curried logger = logging.getLogger("cuallee") -__version__ = "0.9.0" +__version__ = "0.9.1" # Verify Libraries Available # ========================== try: @@ -30,6 +30,11 @@ except (ModuleNotFoundError, ImportError): logger.debug("KO: PySpark") +try: + from pyspark.sql.connect.dataframe import DataFrame as pyspark_connect_dataframe +except (ModuleNotFoundError, ImportError): + logger.debug("KO: PySpark Connect") + try: from snowflake.snowpark import DataFrame as snowpark_dataframe # type: ignore except (ModuleNotFoundError, ImportError): @@ -674,6 +679,11 @@ def validate(self, dataframe: Any): ): self.compute_engine = importlib.import_module("cuallee.pyspark_validation") + elif "pyspark_connect_dataframe" in globals() and isinstance( + dataframe, pyspark_connect_dataframe + ): + self.compute_engine = importlib.import_module("cuallee.pyspark_validation") + # When dataframe is Pandas DataFrame API elif "pandas_dataframe" in globals() and isinstance( dataframe, pandas_dataframe diff --git a/cuallee/pyspark_validation.py b/cuallee/pyspark_validation.py index 763a82e2..3711a816 100644 --- a/cuallee/pyspark_validation.py +++ b/cuallee/pyspark_validation.py @@ -15,6 +15,18 @@ import os +try: + from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame + from pyspark.sql.connect.session import SparkSession as SparkConnectSession + + global spark_connect + if "SPARK_REMOTE" in os.environ: + spark_connect = SparkConnectSession.builder.remote( + os.getenv("SPARK_REMOTE") + ).getOrCreate() +except (ModuleNotFoundError, ImportError): + pass + class ComputeMethod(enum.Enum): OBSERVE = "OBSERVE" @@ -731,8 +743,12 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame: """Compute all rules in this check for specific data frame""" from pyspark.sql.session import SparkSession + if "spark_connect" in globals(): + spark = globals()["spark_connect"] # Check SparkSession is available in environment through globals - if spark_in_session := valfilter(lambda x: isinstance(x, SparkSession), globals()): + elif spark_in_session := valfilter( + lambda x: isinstance(x, SparkSession), globals() + ): # Obtain the first spark session available in the globals spark = first(spark_in_session.values()) else: @@ -741,7 +757,9 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame: # Compute the expression computed_expressions = compute(check._rule) - if int(spark.version.replace(".", "")[:3]) < 330: + if (int(spark.version.replace(".", "")[:3]) < 330) or ( + "connect" in str(type(spark)) + ): computed_expressions = _replace_observe_compute(computed_expressions) rows, observation_result = _compute_observe_method(computed_expressions, dataframe) diff --git a/pyproject.toml b/pyproject.toml index ce33d799..05c3a612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "cuallee" -version = "0.9.0" +version = "0.9.1" authors = [ { name="Herminio Vazquez", email="canimus@gmail.com"}, { name="Virginie Grosboillot", email="vestalisvirginis@gmail.com" } @@ -30,6 +30,9 @@ dev = [ pyspark = [ "pyspark>=3.4.0" ] +pyspark_connect = [ + "pyspark[connect]" +] snowpark = [ "snowflake-snowpark-python==1.11.1", "pyarrow >= 14.0.2" diff --git a/setup.cfg b/setup.cfg index 50850697..f96cfde6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] name = cuallee -version = 0.9.0 +version = 0.9.1 [options] packages = find: \ No newline at end of file diff --git a/test/unit/class_control/test_methods.py b/test/unit/class_control/test_methods.py index b1252830..c59a2010 100644 --- a/test/unit/class_control/test_methods.py +++ b/test/unit/class_control/test_methods.py @@ -9,6 +9,7 @@ def test_has_completeness(): def test_has_information(): assert hasattr(Control, "information") + def test_has_information(): assert hasattr(Control, "intelligence") @@ -27,6 +28,7 @@ def test_emptyness(spark): df = spark.range(10) assert Control.percentage_empty(df) == 0 + def test_intelligence_result(spark): - df = spark.createDataFrame([("0",),("1",),("2",)], schema="id string") + df = spark.createDataFrame([("0",), ("1",), ("2",)], schema="id string") assert Control.intelligence(df) == ["id"]