In [0]:
import unittest
import pandas as pd
import datetime
import io
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.testing import assertDataFrameEqual, assertSchemaEqual

# --- 1. THE TRANSFORMATION FUNCTION ---
def transform_stores_silver(df_bronze):
    """
    Modular transformation function for unit testing.
    """
    # Standardize column headers to snake_case
    standardized_cols = [col.lower().replace(" ", "_") for col in df_bronze.columns]
    df_standardized = df_bronze.toDF(*standardized_cols)

    # Quality Gates logic
    geo_valid = (F.col("latitude").between(-90, 90)) & (F.col("longitude").between(-180, 180))
    id_valid = F.col("store_id").isNotNull()
    valid_mask = geo_valid & id_valid

    # Filter clean data
    df_clean = df_standardized.filter(valid_mask)

    # Deduplication & Normalization
    window_spec = Window.partitionBy("store_id").orderBy(F.col("load_dt").desc())

    df_silver_final = df_clean.withColumn("row_rank", F.row_number().over(window_spec)) \
        .filter("row_rank == 1") \
        .drop("row_rank") \
        .withColumn("state", F.upper(F.col("state"))) \
        .withColumn("load_dt", F.to_timestamp(F.col("load_dt"))) \
        .withColumn("postal_code", F.col("postal_code").cast("long")) \
        .select("store_id", "store_name", "latitude", "longitude", "state", "postal_code", "load_dt")

    return df_silver_final

# --- 2. THE TEST SUITE ---
class TestStoresSilver(unittest.TestCase):
    def setUp(self):
        """Setup input schema for mock data"""
        self.input_schema = StructType([
            StructField("store_id", LongType(), True),
            StructField("store_name", StringType(), True),
            StructField("latitude", DoubleType(), True),
            StructField("longitude", DoubleType(), True),
            StructField("state", StringType(), True),
            StructField("postal_code", StringType(), True),
            StructField("load_dt", StringType(), True)
        ])

    def test_geo_validation_and_deduplication(self):
        # 1. MOCK INPUT DATA
        # Row 1 & 2: Duplicates (Row 2 is newer)
        # Row 3: Invalid Latitude (should be filtered)
        # Row 4: Missing store_id (should be filtered)
        data = [
            (501, "Store A", 45.0, -120.0, "ny", "10001", "2026-01-01 10:00:00"),
            (501, "Store A", 45.0, -120.0, "ny", "10001", "2026-01-02 10:00:00"),
            (502, "Store B", 95.0, -120.0, "ca", "90210", "2026-01-02 10:00:00"),
            (None, "Store C", 45.0, -120.0, "wa", "98101", "2026-01-02 10:00:00")
        ]
        df_input = spark.createDataFrame(data, self.input_schema)

        # 2. DEFINE EXPECTED OUTPUT
        expected_schema = StructType([
            StructField("store_id", LongType(), True),
            StructField("store_name", StringType(), True),
            StructField("latitude", DoubleType(), True),
            StructField("longitude", DoubleType(), True),
            StructField("state", StringType(), True),
            StructField("postal_code", LongType(), True), # Cast to Long
            StructField("load_dt", TimestampType(), True)
        ])
        
        # Localize to UTC to prevent TypeError
        expected_ts = pd.Timestamp("2026-01-02 10:00:00").tz_localize('UTC')

        expected_data = [
            (501, "Store A", 45.0, -120.0, "NY", 10001, expected_ts)
        ]
        df_expected = spark.createDataFrame(expected_data, expected_schema)

        # 3. RUN TRANSFORMATION
        df_actual = transform_stores_silver(df_input)

        # 4. ASSERTIONS
        assertSchemaEqual(df_actual.schema, expected_schema)
        assertDataFrameEqual(df_actual, df_expected)

# --- 3. EXECUTION AND REPORT GENERATION ---
stream = io.StringIO()
runner = unittest.TextTestRunner(stream=stream, verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestStoresSilver)
result = runner.run(suite)

print(f"""
=========================================
UNIT TEST REPORT - {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
=========================================
Status: {'SUCCESS' if result.wasSuccessful() else 'FAILED'}
Tests Run: {result.testsRun}
Failures: {len(result.failures)}
Errors: {len(result.errors)}
-----------------------------------------
Details:
{stream.getvalue()}
=========================================
""")