Skip to content

Commit

Permalink
Added spark connect implementation (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
canimus committed Mar 22, 2024
1 parent eb32c98 commit 18a2e9c
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 7 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ When benchmarking against pydeequ, `cuallee` uses circa <3k java classes underne
Provider | API | Versions
------- | ----------- | ------
![snowflake](logos/snowflake.svg?raw=true "Snowpark DataFrame API")| `snowpark` | `1.11.1`, `1.4.0`
![databricks](logos/databricks.svg?raw=true "PySpark DataFrame API")| `pyspark` | `3.4.0`, `3.3.x`, `3.2.x`
![databricks](logos/databricks.svg?raw=true "PySpark DataFrame API")| `pyspark` & `spark-connect` |`3.5.x`, `3.4.0`, `3.3.x`, `3.2.x`
![bigquery](logos/bigquery.png?raw=true "BigQuery Client API")| `bigquery` | `3.4.1`
![pandas](logos/pandas.svg?raw=true "Pandas DataFrame API")| `pandas`| `2.0.2`, `1.5.x`, `1.4.x`
![duckdb](logos/duckdb.png?raw=true "DuckDB API")|`duckdb` | `0.9.2`,~~`0.8.0`~~, ~~`0.7.1`~~
Expand Down Expand Up @@ -284,6 +284,14 @@ In order to establish a connection to your SnowFlake account `cuallee` relies in
- `SF_DATABASE`
- `SF_SCHEMA`

## Spark Connect
Just add the environment variable `SPARK_REMOTE` to your remote session, then `cuallee` will connect using
```python
spark_connect = SparkSession.builder.remote(os.getenv("SPARK_REMOTE")).getOrCreate()
```
and convert all checks to `select` as opposed to `Observation` API compute instructions.


## Databricks Connection
By default `cuallee` will search for a SparkSession available in the `globals` so there is literally no need to ~~`SparkSession.builder`~~. When working in a local environment it will automatically search for an available session, or start one.

Expand All @@ -308,6 +316,7 @@ check.validate(conn)

`100%` data frame agnostic implementation of data quality checks.
Define once, `run everywhere`
- ~~[x] PySpark 3.5.0~~
- ~~[x] PySpark 3.4.0~~
- ~~[x] PySpark 3.3.0~~
- ~~[x] PySpark 3.2.x~~
Expand All @@ -317,6 +326,7 @@ Define once, `run everywhere`
- ~~[x] BigQuery Client~~
- ~~[x] Polars DataFrame~~
- ~~[*] Dagster Integration~~
- ~~[x] Spark Connect~~
- [-] PDF Report
- [ ] Metadata check
- [ ] Help us in a discussion?
Expand Down
12 changes: 11 additions & 1 deletion cuallee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from toolz.curried import map as map_curried

logger = logging.getLogger("cuallee")
__version__ = "0.9.0"
__version__ = "0.9.1"
# Verify Libraries Available
# ==========================
try:
Expand All @@ -30,6 +30,11 @@
except (ModuleNotFoundError, ImportError):
logger.debug("KO: PySpark")

try:
from pyspark.sql.connect.dataframe import DataFrame as pyspark_connect_dataframe
except (ModuleNotFoundError, ImportError):
logger.debug("KO: PySpark Connect")

try:
from snowflake.snowpark import DataFrame as snowpark_dataframe # type: ignore
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -674,6 +679,11 @@ def validate(self, dataframe: Any):
):
self.compute_engine = importlib.import_module("cuallee.pyspark_validation")

elif "pyspark_connect_dataframe" in globals() and isinstance(
dataframe, pyspark_connect_dataframe
):
self.compute_engine = importlib.import_module("cuallee.pyspark_validation")

# When dataframe is Pandas DataFrame API
elif "pandas_dataframe" in globals() and isinstance(
dataframe, pandas_dataframe
Expand Down
22 changes: 20 additions & 2 deletions cuallee/pyspark_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@

import os

try:
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
from pyspark.sql.connect.session import SparkSession as SparkConnectSession

global spark_connect
if "SPARK_REMOTE" in os.environ:
spark_connect = SparkConnectSession.builder.remote(
os.getenv("SPARK_REMOTE")
).getOrCreate()
except (ModuleNotFoundError, ImportError):
pass


class ComputeMethod(enum.Enum):
OBSERVE = "OBSERVE"
Expand Down Expand Up @@ -731,8 +743,12 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame:
"""Compute all rules in this check for specific data frame"""
from pyspark.sql.session import SparkSession

if "spark_connect" in globals():
spark = globals()["spark_connect"]
# Check SparkSession is available in environment through globals
if spark_in_session := valfilter(lambda x: isinstance(x, SparkSession), globals()):
elif spark_in_session := valfilter(
lambda x: isinstance(x, SparkSession), globals()
):
# Obtain the first spark session available in the globals
spark = first(spark_in_session.values())
else:
Expand All @@ -741,7 +757,9 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame:

# Compute the expression
computed_expressions = compute(check._rule)
if int(spark.version.replace(".", "")[:3]) < 330:
if (int(spark.version.replace(".", "")[:3]) < 330) or (
"connect" in str(type(spark))
):
computed_expressions = _replace_observe_compute(computed_expressions)

rows, observation_result = _compute_observe_method(computed_expressions, dataframe)
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "cuallee"
version = "0.9.0"
version = "0.9.1"
authors = [
{ name="Herminio Vazquez", email="canimus@gmail.com"},
{ name="Virginie Grosboillot", email="vestalisvirginis@gmail.com" }
Expand All @@ -30,6 +30,9 @@ dev = [
pyspark = [
"pyspark>=3.4.0"
]
pyspark_connect = [
"pyspark[connect]"
]
snowpark = [
"snowflake-snowpark-python==1.11.1",
"pyarrow >= 14.0.2"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
name = cuallee
version = 0.9.0
version = 0.9.1
[options]
packages = find:
4 changes: 3 additions & 1 deletion test/unit/class_control/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_has_completeness():
def test_has_information():
assert hasattr(Control, "information")


def test_has_information():
assert hasattr(Control, "intelligence")

Expand All @@ -27,6 +28,7 @@ def test_emptyness(spark):
df = spark.range(10)
assert Control.percentage_empty(df) == 0


def test_intelligence_result(spark):
df = spark.createDataFrame([("0",),("1",),("2",)], schema="id string")
df = spark.createDataFrame([("0",), ("1",), ("2",)], schema="id string")
assert Control.intelligence(df) == ["id"]

0 comments on commit 18a2e9c

Please sign in to comment.