# Advanced Testing with pyspark.testing

Use the built-in `pyspark.testing` helpers to assert DataFrame equality, schemas, and structured transformations with minimal boilerplate.


## Prerequisites

- Spark environment capable of running PySpark notebooks.
- Shared dataset `notebooks/data/orders_demo.csv` available.
- PySpark 3.5 or later (includes the `pyspark.testing.utils` module).
- `setuptools` installed (provides the `distutils` module required by PySpark 3.5 on Python 3.12).
- NumPy 2.0+ users should alias `np.NaN` to `np.nan` (see helper cell below).
- Optional: pytest for running automated suites.


In [None]:
# Ensure distutils is available (Python 3.12 requires setuptools to provide it)
try:
    import distutils  # type: ignore
except ModuleNotFoundError as exc:
    raise ModuleNotFoundError("The built-in pyspark.testing helpers require distutils. Install setuptools with `pip install setuptools`.") from exc


## Session and Data Setup

We reuse the orders demo dataset and create a view for convenience.


In [None]:
from pathlib import Path
from pyspark.sql import SparkSession, functions as F, types as T

spark = SparkSession.builder.appName('AdvancedPySparkTesting').getOrCreate()

repo_root = Path.cwd()
if (repo_root / 'notebooks').exists():
    data_path = repo_root / 'notebooks' / 'data' / 'orders_demo.csv'
else:
    data_path = Path('..') / 'data' / 'orders_demo.csv'

orders_df = (
    spark.read
    .option('header', True)
    .option('inferSchema', True)
    .csv(str(data_path))
    .withColumn('order_date', F.to_date('order_date'))
)
orders_df.createOrReplaceTempView('orders')
orders_df.orderBy('order_date', 'region').show()


## Transformation Under Test

Here we implement a reusable helper that aggregates daily orders per region with both sums and averages. The function will be the subject of our assertions.


In [None]:
def summarize_orders(df):
    """Aggregate total and average orders per region."""
    return (
        df.groupBy('region')
          .agg(
              F.sum('orders').alias('total_orders'),
              F.avg('orders').alias('avg_orders'),
          )
          .orderBy('region')
    )

summary_df = summarize_orders(orders_df)
summary_df.show()


In [None]:
# Ensure NumPy compatibility: NumPy 2.0 removed np.NaN alias used by pyspark.testing
import numpy as np
if not hasattr(np, 'NaN'):
    np.NaN = np.nan


## Asserting DataFrame Equality

`pyspark.testing.utils.assertDataFrameEqual` normalizes Spark schemas so you can compare DataFrames directly without hand-written sorting or casting. Ensure your expected schema matches the nullability Spark produces, because the helper validates full schema equality.


In [None]:
from pyspark.testing.utils import assertDataFrameEqual

expected_rows = [
    ('east', 30, 10.0),
    ('north', 34, 34 / 3),
    ('south', 39, 13.0),
]
expected_df = spark.createDataFrame(expected_rows, schema=summary_df.schema).orderBy('region')

assertDataFrameEqual(summary_df, expected_df)
print('DataFrame contents match expected totals and averages.')


## Comparing Schemas Explicitly

Schema comparisons prevent accidental type drift. Use `assertSchemaEqual` from `pyspark.testing.utils` to ensure fields stay aligned.


In [None]:
from pyspark.testing.utils import assertSchemaEqual

assertSchemaEqual(summary_df.schema, expected_df.schema)
print('Schemas match the expected structure.')


## Testing with pytest + pyspark.testing

Combine these assertions with pytest fixtures for automated suites. Example layout:

```python
# tests/conftest.py
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope='session')
def spark():
    session = (
        SparkSession.builder
        .appName('pytest-pyspark')
        .master('local[2]')
        .getOrCreate()
    )
    yield session
    session.stop()
```

```python
# tests/test_summarize_orders.py
from pyspark.testing.utils import assertDataFrameEqual
from myproject.transforms import summarize_orders

def load_orders(spark):
    return (
        spark.read
        .option('header', True)
        .option('inferSchema', True)
        .csv('notebooks/data/orders_demo.csv')
    )

def test_summarize_orders_totals(spark):
    df = load_orders(spark)
    result = summarize_orders(df)

    expected = spark.createDataFrame([
        ('east', 30, 10.0),
        ('north', 34, 34 / 3),
        ('south', 39, 13.0),
    ], ['region', 'total_orders', 'avg_orders']).orderBy('region')

    assertDataFrameEqual(result.orderBy('region'), expected, check_nullable=False)
```

`pyspark.testing.utils` handles Spark-specific comparison logic, keeping your tests concise.


## Snapshot Testing Wide DataFrames

For larger comparisons, serialize sorted results to JSON files checked into version control, then use `assertDataFrameEqual` against the reloaded snapshot. Update snapshots intentionally when business logic changes.


## Clean Up

Stop the SparkSession when you finish working in the notebook.


In [None]:
spark.stop()


## Exercises

- Write a new aggregation helper (for example, daily averages) and test it with `assertDataFrameEqual`.
- Use `assertSchemaEqual` to ensure a schema change is detected when you add a nullable column.
- Combine `assertDataFrameEqual` with `assertColumnEquality` (from the same module) to verify both the structure and values of a transformation.
