# Testing PySpark Workflows

Learn strategies for validating PySpark logic with assertions and pytest-style tests while working with the shared demo dataset.


## Why Test PySpark Code?

- Catch logical regressions in complex transformations early.
- Document expected data contracts for collaborators.
- Build confidence before shipping notebooks into production pipelines.


## Load the Shared Dataset and SparkSession

We reuse `notebooks/data/orders_demo.csv` to keep examples consistent with other tutorials.


In [None]:
from pathlib import Path
from pyspark.sql import SparkSession, functions as F

spark = SparkSession.builder.appName('TestingPySparkTutorial').getOrCreate()

repo_root = Path.cwd()
if (repo_root / 'notebooks').exists():
    data_path = repo_root / 'notebooks' / 'data' / 'orders_demo.csv'
else:
    data_path = Path('..') / 'data' / 'orders_demo.csv'

orders_df = (
    spark.read
    .option('header', True)
    .option('inferSchema', True)
    .csv(str(data_path))
    .withColumn('order_date', F.to_date('order_date'))
)
orders_df.show()


## Function Under Test

Create a transformation helper so we can assert on its output. Here, we tag demand levels by threshold.


In [None]:
def tag_demand_level(df, threshold=14):
    """Classify orders by demand level using the provided threshold."""
    return df.withColumn(
        'demand_level',
        F.when(F.col('orders') >= threshold, 'high').otherwise('steady')
    )

tagged_df = tag_demand_level(orders_df)
tagged_df.orderBy('order_date', 'region').show()


## Inline Assertions Inside Notebooks

For quick experiments, use Python's built-in `assert` statements to validate expectations directly in the notebook.


In [None]:
# Ensure every region has at least one row labelled steady
steady_counts = (
    tagged_df
      .filter(F.col('demand_level') == 'steady')
      .groupBy('region')
      .count()
      .collect()
)
assert len(steady_counts) == 3, 'Expected steady rows for each region'
print('All regions include steady demand rows.')


## Comparing DataFrames Deterministically

Collecting ordered tuples keeps comparisons deterministic. This pattern is handy when writing unit tests without helper libraries.


In [None]:
expected = [
    ('2024-01-01', 'north', 'steady'),
    ('2024-01-02', 'north', 'high'),
    ('2024-01-03', 'north', 'steady'),
]

actual = (
    tagged_df
      .filter(F.col('region') == 'north')
      .orderBy('order_date')
      .select('order_date', 'region', 'demand_level')
      .rdd.map(lambda row: (row.order_date.strftime('%Y-%m-%d'), row.region, row.demand_level))
      .collect()
)
assert actual == expected, f"Unexpected demand levels: {actual}"
print('North region demand levels match the expectation.')


## Structuring Tests with pytest

Move assertions into test modules to automate checks. Create a reusable Spark fixture and call transformation helpers. Example structure:

```python
# conftest.py
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope='session')
def spark():
    session = (
        SparkSession.builder
        .appName('pytest-pyspark')
        .master('local[2]')
        .getOrCreate()
    )
    yield session
    session.stop()
```

```python
# tests/test_tag_demand.py
from pathlib import Path
from pyspark.sql import functions as F
from myproject.transforms import tag_demand_level

def load_orders(spark):
    data_path = Path('notebooks/data/orders_demo.csv')
    return (
        spark.read
        .option('header', True)
        .option('inferSchema', True)
        .csv(str(data_path))
    )

def test_tag_demand_level_high_classification(spark):
    df = load_orders(spark)
    result = tag_demand_level(df, threshold=14)
    high_rows = result.filter((F.col('region') == 'south') & (F.col('demand_level') == 'high'))
    assert high_rows.count() == 1
```

Run the suite with `pytest tests/` to execute all checks.


## Snapshot Testing DataFrames

For larger transformations, compare sorted DataFrames or write helper utilities that normalize column order, null handling, and types before asserting equality.


## Clean Up

Stop the SparkSession at the end of notebook execution.


In [None]:
spark.stop()


## Exercises

- Add an assertion that verifies no region has negative order counts after running `tag_demand_level`.
- Refactor the inline assertions into pytest-style `test_` functions and run them from the command line.
- Introduce an intentional bug in `tag_demand_level` and observe how the tests fail, then fix the bug.
