# 3.1 Unit Testing PySpark Code with Pytest in Databricks

This notebook demonstrates how to implement effective unit testing for PySpark code using Pytest, with patterns that work both locally and in Databricks environments.

## Learning Objectives
- Structure PySpark code for testability
- Create Pytest fixtures for SparkSession and test data
- Mock Databricks-specific utilities for testing
- Write comprehensive unit tests for transformation functions
- Organize tests for CI/CD integration

## Structuring Code for Testability

The key to testing PySpark code is to separate pure transformation logic from I/O operations and Databricks-specific utilities.

In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import *
from unittest.mock import MagicMock, patch
import pytest

# Example: Testable PySpark Module Structure
class SalesDataProcessor:
    """
    Example of well-structured PySpark code for testing:
    - Pure transformation functions
    - Dependency injection for external dependencies
    - Clear separation of concerns
    """
    
    @staticmethod
    def clean_sales_data(df):
        """Pure transformation function - easily testable"""
        return (df
                .filter(F.col("amount") > 0)  # Remove invalid amounts
                .withColumn("amount", F.round(F.col("amount"), 2))  # Round to 2 decimals
                .withColumn("customer_name", F.trim(F.initcap(F.col("customer_name"))))  # Standardize names
                .withColumn("sale_date", F.to_date(F.col("sale_date"), "yyyy-MM-dd"))  # Parse dates
                .dropDuplicates(["transaction_id"])  # Remove duplicates
               )
    
    @staticmethod
    def add_derived_columns(df):
        """Pure transformation function for derived columns"""
        return (df
                .withColumn("sale_year", F.year(F.col("sale_date")))
                .withColumn("sale_month", F.month(F.col("sale_date")))
                .withColumn("revenue_category",
                           F.when(F.col("amount") < 100, "Low")
                            .when(F.col("amount") < 1000, "Medium")
                            .otherwise("High"))
               )
    
    @staticmethod
    def calculate_customer_metrics(df):
        """Pure aggregation function"""
        return (df
                .groupBy("customer_name")
                .agg(
                    F.sum("amount").alias("total_spent"),
                    F.avg("amount").alias("avg_purchase"),
                    F.count("*").alias("purchase_count"),
                    F.min("sale_date").alias("first_purchase"),
                    F.max("sale_date").alias("last_purchase")
                )
               )
    
    @classmethod
    def process_sales_pipeline(cls, df):
        """Complete transformation pipeline - composed of pure functions"""
        return (df
                .transform(cls.clean_sales_data)
                .transform(cls.add_derived_columns)
               )
    
    # Example of function with external dependencies (needs mocking for tests)
    def save_to_delta(self, df, path, dbutils=None):
        """Function with external dependency - requires mocking in tests"""
        if dbutils:
            # This would need to be mocked in tests
            dbutils.fs.rm(path, True)  # Clear existing data
        
        # Write to Delta (this is a side effect)
        df.write.format("delta").mode("overwrite").save(path)
        
        return f"Data saved to {path}"

print("SalesDataProcessor class defined with testable structure")

## Creating Pytest Fixtures

Fixtures provide reusable resources for tests, such as SparkSession instances and test data.

In [None]:
# Example Pytest fixtures (typically in conftest.py)

@pytest.fixture(scope="session")
def spark_session():
    """
    Fixture providing a local SparkSession for testing.
    Session-scoped for efficiency across all tests.
    """
    spark = (SparkSession.builder
             .master("local[*]")
             .appName("PySparkUnitTests")
             .config("spark.sql.adaptive.enabled", "false")  # Disable for consistent tests
             .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
             .getOrCreate())
    
    yield spark
    
    spark.stop()

@pytest.fixture
def sample_sales_data(spark_session):
    """
    Fixture providing test sales data
    """
    data = [
        (1, "john doe", 150.75, "2023-01-15"),
        (2, "jane smith", 89.99, "2023-01-16"), 
        (3, "bob johnson", 0.0, "2023-01-17"),  # Invalid amount
        (4, "alice brown", 1200.50, "2023-01-18"),
        (1, "john doe", 150.75, "2023-01-15"),  # Duplicate
        (5, " charlie wilson ", 75.25, "2023-01-19"),  # Needs trimming
    ]
    
    schema = StructType([
        StructField("transaction_id", IntegerType(), True),
        StructField("customer_name", StringType(), True),
        StructField("amount", DoubleType(), True),
        StructField("sale_date", StringType(), True)
    ])
    
    return spark_session.createDataFrame(data, schema)

@pytest.fixture
def expected_cleaned_data(spark_session):
    """
    Fixture providing expected results after cleaning
    """
    data = [
        (1, "John Doe", 150.75, "2023-01-15"),
        (2, "Jane Smith", 89.99, "2023-01-16"),
        (4, "Alice Brown", 1200.50, "2023-01-18"),
        (5, "Charlie Wilson", 75.25, "2023-01-19"),
    ]
    
    schema = StructType([
        StructField("transaction_id", IntegerType(), True),
        StructField("customer_name", StringType(), True),
        StructField("amount", DoubleType(), True),
        StructField("sale_date", StringType(), True)
    ])
    
    return spark_session.createDataFrame(data, schema)

@pytest.fixture
def mock_dbutils():
    """
    Fixture providing mocked dbutils for testing
    """
    mock_dbutils = MagicMock()
    mock_dbutils.fs.rm.return_value = True
    mock_dbutils.widgets.get.return_value = "test_value"
    return mock_dbutils

print("Pytest fixtures defined for testing")

## Unit Tests for Pure Functions

Let's write comprehensive unit tests for our pure transformation functions:

In [None]:
# Unit tests for pure transformation functions
# (This would typically be in a separate test_sales_processor.py file)

class TestSalesDataProcessor:
    """
    Unit tests for SalesDataProcessor transformation functions
    """
    
    def test_clean_sales_data_removes_invalid_amounts(self, spark_session, sample_sales_data):
        """
        Test that clean_sales_data removes records with amount <= 0
        """
        # Act
        result = SalesDataProcessor.clean_sales_data(sample_sales_data)
        
        # Assert
        result_amounts = [row["amount"] for row in result.collect()]
        assert all(amount > 0 for amount in result_amounts), "All amounts should be positive"
        assert 0.0 not in result_amounts, "Zero amounts should be filtered out"
    
    def test_clean_sales_data_removes_duplicates(self, spark_session, sample_sales_data):
        """
        Test that clean_sales_data removes duplicate transaction_ids
        """
        # Act
        result = SalesDataProcessor.clean_sales_data(sample_sales_data)
        
        # Assert
        total_count = result.count()
        distinct_count = result.select("transaction_id").distinct().count()
        assert total_count == distinct_count, "Should have no duplicate transaction_ids"
    
    def test_clean_sales_data_standardizes_names(self, spark_session, sample_sales_data):
        """
        Test that customer names are properly formatted
        """
        # Act
        result = SalesDataProcessor.clean_sales_data(sample_sales_data)
        
        # Assert
        names = [row["customer_name"] for row in result.collect()]
        
        # Check specific transformations
        assert "John Doe" in names, "Should convert 'john doe' to 'John Doe'"
        assert "Charlie Wilson" in names, "Should trim and title case ' charlie wilson '"
        assert all(name == name.strip() for name in names), "All names should be trimmed"
    
    def test_clean_sales_data_parses_dates(self, spark_session, sample_sales_data):
        """
        Test that dates are properly parsed to DateType
        """
        # Act
        result = SalesDataProcessor.clean_sales_data(sample_sales_data)
        
        # Assert
        date_field = next(field for field in result.schema.fields if field.name == "sale_date")
        assert isinstance(date_field.dataType, DateType), "sale_date should be DateType"
    
    def test_add_derived_columns_creates_correct_columns(self, spark_session, sample_sales_data):
        """
        Test that derived columns are correctly added
        """
        # Arrange
        cleaned_data = SalesDataProcessor.clean_sales_data(sample_sales_data)
        
        # Act
        result = SalesDataProcessor.add_derived_columns(cleaned_data)
        
        # Assert
        expected_columns = {"sale_year", "sale_month", "revenue_category"}
        actual_columns = set(result.columns)
        assert expected_columns.issubset(actual_columns), f"Missing columns: {expected_columns - actual_columns}"
        
        # Test specific values
        sample_row = result.filter(F.col("transaction_id") == 1).collect()[0]
        assert sample_row["sale_year"] == 2023, "Year should be extracted correctly"
        assert sample_row["sale_month"] == 1, "Month should be extracted correctly"
        assert sample_row["revenue_category"] == "Medium", "150.75 should be categorized as Medium"
    
    def test_revenue_categorization(self, spark_session):
        """
        Test revenue categorization logic with specific test cases
        """
        # Arrange
        test_data = [
            (1, "Test Customer", 50.0, "2023-01-01"),    # Low
            (2, "Test Customer", 500.0, "2023-01-01"),   # Medium
            (3, "Test Customer", 1500.0, "2023-01-01"),  # High
        ]
        
        schema = StructType([
            StructField("transaction_id", IntegerType(), True),
            StructField("customer_name", StringType(), True),
            StructField("amount", DoubleType(), True),
            StructField("sale_date", StringType(), True)
        ])
        
        df = spark_session.createDataFrame(test_data, schema)
        cleaned_df = SalesDataProcessor.clean_sales_data(df)
        
        # Act
        result = SalesDataProcessor.add_derived_columns(cleaned_df)
        
        # Assert
        categories = {row["transaction_id"]: row["revenue_category"] for row in result.collect()}
        assert categories[1] == "Low", "$50 should be Low category"
        assert categories[2] == "Medium", "$500 should be Medium category"
        assert categories[3] == "High", "$1500 should be High category"
    
    def test_calculate_customer_metrics(self, spark_session):
        """
        Test customer metrics calculation
        """
        # Arrange
        test_data = [
            (1, "John Doe", 100.0, "2023-01-01"),
            (2, "John Doe", 200.0, "2023-01-15"),
            (3, "Jane Smith", 150.0, "2023-01-10"),
        ]
        
        schema = StructType([
            StructField("transaction_id", IntegerType(), True),
            StructField("customer_name", StringType(), True),
            StructField("amount", DoubleType(), True),
            StructField("sale_date", StringType(), True)
        ])
        
        df = spark_session.createDataFrame(test_data, schema)
        processed_df = SalesDataProcessor.process_sales_pipeline(df)
        
        # Act
        result = SalesDataProcessor.calculate_customer_metrics(processed_df)
        
        # Assert
        john_metrics = result.filter(F.col("customer_name") == "John Doe").collect()[0]
        assert john_metrics["total_spent"] == 300.0, "John's total should be $300"
        assert john_metrics["avg_purchase"] == 150.0, "John's average should be $150"
        assert john_metrics["purchase_count"] == 2, "John should have 2 purchases"
        
        jane_metrics = result.filter(F.col("customer_name") == "Jane Smith").collect()[0]
        assert jane_metrics["purchase_count"] == 1, "Jane should have 1 purchase"
    
    def test_complete_pipeline_integration(self, spark_session, sample_sales_data):
        """
        Integration test for the complete transformation pipeline
        """
        # Act
        result = SalesDataProcessor.process_sales_pipeline(sample_sales_data)
        
        # Assert pipeline correctness
        assert result.count() == 4, "Should have 4 records after cleaning (removed invalid and duplicate)"
        
        # Check all expected columns exist
        expected_columns = {
            "transaction_id", "customer_name", "amount", "sale_date",
            "sale_year", "sale_month", "revenue_category"
        }
        actual_columns = set(result.columns)
        assert expected_columns.issubset(actual_columns), "Pipeline should create all expected columns"
        
        # Verify data quality
        amounts = [row["amount"] for row in result.collect()]
        assert all(amount > 0 for amount in amounts), "All amounts should be positive after cleaning"

print("Unit test class defined for SalesDataProcessor")

## Testing Functions with External Dependencies

When functions have external dependencies (like dbutils), we need to use mocking:

In [None]:
# Testing functions with external dependencies

class TestSalesDataProcessorWithMocking:
    """
    Tests for functions that have external dependencies
    """
    
    def test_save_to_delta_with_mocked_dbutils(self, spark_session, sample_sales_data, mock_dbutils):
        """
        Test save_to_delta function with mocked dbutils
        """
        # Arrange
        processor = SalesDataProcessor()
        test_path = "/tmp/test_delta_table"
        
        # Act
        result = processor.save_to_delta(sample_sales_data, test_path, mock_dbutils)
        
        # Assert
        assert result == f"Data saved to {test_path}"
        mock_dbutils.fs.rm.assert_called_once_with(test_path, True)
    
    @patch('your_module.dbutils')  # Replace with actual import path
    def test_function_that_uses_dbutils_directly(self, mock_dbutils, spark_session):
        """
        Example of patching dbutils when it's imported directly
        """
        # Configure mock
        mock_dbutils.widgets.get.return_value = "production"
        
        # Your test logic here
        # result = function_that_uses_dbutils()
        # assert result == expected_value
        pass
    
    def test_error_handling_in_transformation(self, spark_session):
        """
        Test error handling in transformation functions
        """
        # Arrange: Create data that might cause errors
        bad_data = [
            (1, None, 100.0, "invalid-date"),  # Null customer name, invalid date
        ]
        
        schema = StructType([
            StructField("transaction_id", IntegerType(), True),
            StructField("customer_name", StringType(), True),
            StructField("amount", DoubleType(), True),
            StructField("sale_date", StringType(), True)
        ])
        
        df = spark_session.createDataFrame(bad_data, schema)
        
        # Act & Assert: Depending on your error handling strategy
        try:
            result = SalesDataProcessor.clean_sales_data(df)
            # If your function handles errors gracefully:
            collected_result = result.collect()  # This should not fail
            assert len(collected_result) >= 0, "Function should handle bad data gracefully"
        except Exception as e:
            # If your function is designed to fail fast:
            assert isinstance(e, (ValueError, TypeError)), f"Expected specific error type, got {type(e)}"

print("Mocking tests defined")

## DataFrame Assertion Utilities

Let's create utilities for asserting DataFrame equality and schema validation:

In [None]:
# DataFrame assertion utilities for comprehensive testing

class DataFrameAssertions:
    """
    Utility class for DataFrame-specific assertions
    """
    
    @staticmethod
    def assert_dataframe_equal(actual_df, expected_df, ignore_order=True, ignore_nullable=True):
        """
        Assert that two DataFrames are equal
        """
        # Check schema equality
        if not ignore_nullable:
            assert actual_df.schema == expected_df.schema, "Schemas do not match"
        else:
            # Compare schemas ignoring nullable property
            actual_fields = {(f.name, f.dataType) for f in actual_df.schema.fields}
            expected_fields = {(f.name, f.dataType) for f in expected_df.schema.fields}
            assert actual_fields == expected_fields, "Schema fields do not match"
        
        # Check data equality
        actual_data = actual_df.collect()
        expected_data = expected_df.collect()
        
        if ignore_order:
            actual_data = sorted(actual_data, key=str)
            expected_data = sorted(expected_data, key=str)
        
        assert len(actual_data) == len(expected_data), f"Row count mismatch: {len(actual_data)} vs {len(expected_data)}"
        
        for i, (actual_row, expected_row) in enumerate(zip(actual_data, expected_data)):
            assert actual_row == expected_row, f"Row {i} mismatch: {actual_row} vs {expected_row}"
    
    @staticmethod
    def assert_schema_equal(actual_df, expected_schema):
        """
        Assert DataFrame schema matches expected schema
        """
        actual_schema = actual_df.schema
        
        assert len(actual_schema.fields) == len(expected_schema.fields), "Field count mismatch"
        
        for actual_field, expected_field in zip(actual_schema.fields, expected_schema.fields):
            assert actual_field.name == expected_field.name, f"Field name mismatch: {actual_field.name} vs {expected_field.name}"
            assert actual_field.dataType == expected_field.dataType, f"Data type mismatch for {actual_field.name}"
    
    @staticmethod
    def assert_columns_exist(df, expected_columns):
        """
        Assert that DataFrame contains expected columns
        """
        actual_columns = set(df.columns)
        expected_columns = set(expected_columns)
        missing_columns = expected_columns - actual_columns
        
        assert len(missing_columns) == 0, f"Missing columns: {missing_columns}"
    
    @staticmethod
    def assert_no_nulls_in_columns(df, columns):
        """
        Assert that specified columns contain no null values
        """
        for column in columns:
            null_count = df.filter(F.col(column).isNull()).count()
            assert null_count == 0, f"Column '{column}' contains {null_count} null values"
    
    @staticmethod
    def assert_column_values_in_range(df, column, min_val=None, max_val=None):
        """
        Assert that column values are within specified range
        """
        if min_val is not None:
            min_actual = df.agg(F.min(column)).collect()[0][0]
            assert min_actual >= min_val, f"Minimum value {min_actual} is below threshold {min_val}"
        
        if max_val is not None:
            max_actual = df.agg(F.max(column)).collect()[0][0]
            assert max_actual <= max_val, f"Maximum value {max_actual} is above threshold {max_val}"

# Example usage in tests
class TestWithDataFrameAssertions:
    """
    Examples of using DataFrame assertion utilities
    """
    
    def test_data_quality_assertions(self, spark_session, sample_sales_data):
        """
        Test using DataFrame assertion utilities
        """
        # Act
        result = SalesDataProcessor.clean_sales_data(sample_sales_data)
        
        # Assert using utilities
        DataFrameAssertions.assert_columns_exist(result, 
                                               ["transaction_id", "customer_name", "amount", "sale_date"])
        
        DataFrameAssertions.assert_no_nulls_in_columns(result, 
                                                      ["transaction_id", "customer_name"])
        
        DataFrameAssertions.assert_column_values_in_range(result, "amount", min_val=0.01)
    
    def test_schema_validation(self, spark_session, sample_sales_data):
        """
        Test schema validation
        """
        # Arrange
        expected_schema = StructType([
            StructField("transaction_id", IntegerType(), True),
            StructField("customer_name", StringType(), True),
            StructField("amount", DoubleType(), True),
            StructField("sale_date", DateType(), True),  # Note: DateType after cleaning
            StructField("sale_year", IntegerType(), True),
            StructField("sale_month", IntegerType(), True),
            StructField("revenue_category", StringType(), True)
        ])
        
        # Act
        result = SalesDataProcessor.process_sales_pipeline(sample_sales_data)
        
        # Assert
        DataFrameAssertions.assert_schema_equal(result, expected_schema)

print("DataFrame assertion utilities defined")

## Running Tests and Test Organization

Here's how to organize and run your PySpark tests effectively:

In [None]:
# Example test organization structure
# This demonstrates how to organize your test files

"""
Recommended project structure for testable PySpark code:

project/
├── src/
│   ├── __init__.py
│   ├── transformations/
│   │   ├── __init__.py
│   │   ├── sales_processor.py
│   │   └── data_quality.py
│   └── utils/
│       ├── __init__.py
│       └── spark_utils.py
├── tests/
│   ├── __init__.py
│   ├── conftest.py  # Pytest fixtures
│   ├── test_sales_processor.py
│   ├── test_data_quality.py
│   └── utils/
│       ├── __init__.py
│       └── test_assertions.py
├── pytest.ini
├── requirements.txt
└── README.md
"""

# Example pytest.ini configuration
pytest_ini_content = """
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = 
    -v
    --tb=short
    --strict-markers
    --disable-warnings
    --cov=src
    --cov-report=html
    --cov-report=term-missing

markers =
    unit: Unit tests
    integration: Integration tests
    slow: Slow running tests
"""

# Example conftest.py for shared fixtures
conftest_content = """
import pytest
from pyspark.sql import SparkSession
from unittest.mock import MagicMock

@pytest.fixture(scope="session")
def spark_session():
    spark = (SparkSession.builder
             .master("local[*]")
             .appName("PySparkUnitTests")
             .config("spark.sql.adaptive.enabled", "false")
             .getOrCreate())
    yield spark
    spark.stop()

@pytest.fixture
def mock_dbutils():
    return MagicMock()

# Add other shared fixtures here
"""

print("Test organization structure and configuration examples provided")

# Example commands to run tests
print("\n=== Test Execution Commands ===")
print("Run all tests:                    pytest")
print("Run with coverage:                pytest --cov=src")
print("Run specific test file:           pytest tests/test_sales_processor.py")
print("Run specific test:                pytest tests/test_sales_processor.py::TestSalesDataProcessor::test_clean_sales_data")
print("Run tests with specific marker:   pytest -m unit")
print("Run tests in parallel:            pytest -n auto")
print("Run tests with output:            pytest -v -s")

## Integration with Databricks and CI/CD

Here's how to integrate your tests with Databricks development workflow:

In [None]:
# Integration patterns for Databricks and CI/CD

# Example GitHub Actions workflow for PySpark tests
github_actions_yml = """
name: PySpark Tests

on:
  push:
    branches: [ main, develop ]
  pull_request:
    branches: [ main ]

jobs:
  test:
    runs-on: ubuntu-latest
    
    strategy:
      matrix:
        python-version: [3.8, 3.9]
    
    steps:
    - uses: actions/checkout@v2
    
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v2
      with:
        python-version: ${{ matrix.python-version }}
    
    - name: Set up Java 11
      uses: actions/setup-java@v2
      with:
        java-version: '11'
        distribution: 'temurin'
    
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install -r requirements.txt
        pip install pytest pytest-cov pytest-xdist
    
    - name: Run tests
      run: |
        pytest --cov=src --cov-report=xml --cov-report=term-missing
    
    - name: Upload coverage to Codecov
      uses: codecov/codecov-action@v1
      with:
        file: ./coverage.xml
"""

# Example Databricks notebook for running tests
databricks_test_notebook = """
# Databricks notebook source
# MAGIC %md
# MAGIC # PySpark Unit Tests Runner
# MAGIC 
# MAGIC This notebook runs unit tests in the Databricks environment

# COMMAND ----------

# Install test dependencies
%pip install pytest

# COMMAND ----------

# Import your modules
from src.transformations.sales_processor import SalesDataProcessor
from tests.utils.test_assertions import DataFrameAssertions

# COMMAND ----------

# Run a simple smoke test
def run_smoke_tests():
    # Create test data
    test_data = [(1, "Test User", 100.0, "2023-01-01")]
    schema = StructType([
        StructField("transaction_id", IntegerType(), True),
        StructField("customer_name", StringType(), True),
        StructField("amount", DoubleType(), True),
        StructField("sale_date", StringType(), True)
    ])
    
    df = spark.createDataFrame(test_data, schema)
    
    # Test transformations
    result = SalesDataProcessor.process_sales_pipeline(df)
    
    # Basic assertions
    assert result.count() == 1, "Should have 1 record"
    assert "sale_year" in result.columns, "Should have derived columns"
    
    print("✅ Smoke tests passed!")

run_smoke_tests()

# COMMAND ----------

# MAGIC %md
# MAGIC ## Integration with Databricks Asset Bundles
# MAGIC 
# MAGIC For full CI/CD integration, use Databricks Asset Bundles (DABs)
"""

# Example requirements.txt for PySpark testing
requirements_txt = """
pyspark>=3.3.0
pytest>=6.0.0
pytest-cov>=2.10.0
pytest-xdist>=2.0.0
pytest-mock>=3.0.0
delta-spark>=2.0.0
pandas>=1.3.0
pyarrow>=5.0.0
"""

print("CI/CD integration examples provided")
print("\n=== Key Integration Points ===")
print("1. Local testing with pytest before committing")
print("2. CI/CD pipeline runs tests on every PR")
print("3. Databricks notebooks can run smoke tests")
print("4. Use Databricks Connect for local development")
print("5. Databricks Asset Bundles for deployment automation")

## Summary

**Key Takeaways:**

1. **Code Structure for Testability**:
   - Extract pure transformation functions
   - Separate I/O and external dependencies
   - Use dependency injection for testability

2. **Pytest Fixtures**:
   - Session-scoped SparkSession for efficiency
   - Reusable test data fixtures
   - Mock external dependencies (dbutils)

3. **Comprehensive Testing**:
   - Unit tests for pure functions
   - DataFrame-specific assertions
   - Error handling and edge cases
   - Integration tests for complete pipelines

4. **CI/CD Integration**:
   - GitHub Actions for automated testing
   - Databricks notebooks for smoke testing
   - Databricks Asset Bundles for deployment

**Benefits of Test-First PySpark Development**:
- Faster feedback loops
- Higher code quality and reliability
- Easier refactoring and maintenance
- Better documentation through tests
- Confidence in production deployments

**Next Steps**: In the next notebook, we'll explore DataFrame and schema validation techniques to ensure data quality throughout your pipelines.

## Exercise

Practice test-driven development:

1. Create a new transformation function for your domain
2. Write tests first (TDD approach)
3. Implement the function to make tests pass
4. Add edge case tests
5. Create integration tests for a complete pipeline
6. Set up pytest fixtures for your test data

In [None]:
# Your exercise code here

class TestYourTransformation:
    """Practice TDD with your own transformation"""
    
    def test_your_function_basic_case(self, spark_session):
        """Write your test first"""
        # Arrange
        # Create test data
        
        # Act
        # Call your transformation function
        
        # Assert
        # Verify expected behavior
        pass
    
    def test_your_function_edge_cases(self, spark_session):
        """Test edge cases"""
        pass

def your_transformation_function(df):
    """Implement your function to make tests pass"""
    # Your implementation here
    pass

# Run your tests
# pytest your_test_file.py -v