# PySpark: Zero to Hero
## Module 35: Unit Testing with Pytest

Writing unit tests is critical for production-grade Data Engineering. It ensures your transformation logic works as expected and prevents regressions when code changes.

In this module, we will use the **`pytest`** framework to test PySpark transformations.

### Agenda:
1.  **Setup:** Install `pytest`.
2.  **Project Structure:** Create a standard testing directory structure.
3.  **Fixtures (`conftest.py`):** Create a reusable `SparkSession` for tests.
4.  **Code Module (`common.py`):** Define the PySpark transformations to test.
5.  **Test Cases:** Write positive and negative test assertions.
6.  **Execution:** Run the tests and analyze the results.

In [None]:
# Install pytest if not already installed
%pip install pytest

In [None]:
import os

# Create a directory named 'tests' to hold our code and tests
os.makedirs("tests", exist_ok=True)
print("Created 'tests' directory.")

We use `conftest.py` to define **Fixtures**. Fixtures are setup functions that run before tests. 
Here, we define a `spark_session` fixture with `scope="session"`, meaning the SparkSession is created once and reused across all tests, which saves time.

In [None]:
%%writefile tests/conftest.py
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark_session():
    """
    Fixture to create a SparkSession for testing.
    The scope='session' ensures it is created once per test run.
    """
    spark = SparkSession.builder \
        .master("local[1]") \
        .appName("PySpark Unit Test") \
        .getOrCreate()
    
    yield spark
    
    # Teardown (optional, but good practice)
    spark.stop()

We create a file `common.py` that contains the actual transformation functions we want to test.
1.  `remove_extra_spaces`: Cleans up whitespace in a column.
2.  `filter_senior_citizen`: Filters rows where age >= 60.

In [None]:
%%writefile tests/common.py
from pyspark.sql.functions import col, regexp_replace

def remove_extra_spaces(df, column_name):
    """
    Removes extra spaces from the specified column.
    Replaces 2 or more spaces with a single space.
    """
    # Regex pattern "\\s+" looks for one or more whitespace characters
    # We want to normalize multiple spaces to single space, 
    # but here we strictly follow the video logic: replace multiple spaces with single space.
    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))
    return df_transformed

def filter_senior_citizen(df, column_name):
    """
    Filters dataframe for rows where column_name >= 60.
    """
    df_filtered = df.filter(col(column_name) >= 60)
    return df_filtered

Now we write the actual tests. Note that test files must start with `test_` for pytest to discover them.
We inject the `spark_session` fixture into our test functions automatically by name.

In [None]:
%%writefile tests/test_app.py
import pytest
from common import remove_extra_spaces, filter_senior_citizen
# Note: In a real package structure, imports might look different. 
# Since we are running pytest from the parent dir or same dir, this simple import works.

def test_single_space(spark_session):
    """Test Case 1: Verify extra spaces are removed."""
    
    # 1. Prepare Sample Data
    data = [("John  D.", 30), ("Alice   G.", 25), ("Bob T.", 35)]
    columns = ["name", "age"]
    original_df = spark_session.createDataFrame(data, columns)
    
    # 2. Apply Transformation
    transformed_df = remove_extra_spaces(original_df, "name")
    
    # 3. Prepare Expected Data
    expected_data = [("John D.", 30), ("Alice G.", 25), ("Bob T.", 35)]
    expected_df = spark_session.createDataFrame(expected_data, columns)
    
    # 4. Assert
    # Comparing schema and data
    assert transformed_df.schema == expected_df.schema
    # Collect data to compare lists of Row objects
    assert transformed_df.collect() == expected_df.collect()

def test_row_count(spark_session):
    """Test Case 2: Verify row count remains same after cleaning spaces."""
    data = [("John  D.", 30), ("Alice   G.", 25)]
    df = spark_session.createDataFrame(data, ["name", "age"])
    
    transformed_df = remove_extra_spaces(df, "name")
    
    assert transformed_df.count() == df.count()

def test_senior_citizen_count(spark_session):
    """Test Case 3: Verify filtering logic."""
    data = [("A", 60), ("B", 65), ("C", 55), ("D", 70)]
    df = spark_session.createDataFrame(data, ["name", "age"])
    
    filtered_df = filter_senior_citizen(df, "age")
    
    # Expected: A(60), B(65), D(70) -> 3 records
    assert filtered_df.count() == 3

We purposefully write a failing test to see how pytest reports errors.

In [None]:
%%writefile tests/test_negative.py
import pytest
from common import filter_senior_citizen

def test_senior_citizen_count_negative(spark_session):
    """
    Test Case 4: Negative Scenario.
    We assert an incorrect count to force a failure demonstration.
    """
    data = [("A", 60), ("B", 65), ("C", 55), ("D", 20)]
    df = spark_session.createDataFrame(data, ["name", "age"])
    
    filtered_df = filter_senior_citizen(df, "age")
    
    # Actual count is 2 (60, 65). 
    # We assert 3 to make it fail.
    expected_count = 3 
    
    assert filtered_df.count() == expected_count

In [None]:
# We run pytest from the command line using the '!' magic command.
# -v : Verbose mode (shows each test name and result)
# We point it to the 'tests/' folder.

!python -m pytest tests/ -v

## Summary

1.  **Pytest Integration:** Pytest works seamlessly with PySpark.
2.  **Fixtures:** Use `conftest.py` to manage the `SparkSession` lifecycle. Use `scope="session"` to avoid restarting Spark for every single test function.
3.  **Assertions:** Use standard Python `assert` statements to compare DataFrames (via `.collect()`, `.count()`, or `.schema`).
4.  **Best Practices:** 
    *   Separate test logic from business logic.
    *   Test both data values and row counts.
    *   Integrate these tests into CI/CD pipelines (Jenkins, GitHub Actions) to run automatically on code commits.