In [0]:
import unittest
import pandas as pd
import datetime
import io
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.testing import assertDataFrameEqual, assertSchemaEqual

# --- 1. THE TRANSFORMATION FUNCTION ---
def transform_vouchers_silver(df_bronze):
    """
    Modular transformation function for unit testing.
    """
    # Standardize column headers to snake_case
    standardized_cols = [col.lower().replace(" ", "_") for col in df_bronze.columns]
    df_transformed = df_bronze.toDF(*standardized_cols)

    # Quality Gates: voucher_id must exist and discount_amount must be positive
    id_valid = F.col("voucher_id").isNotNull()
    amount_valid = F.col("discount_amount") > 0
    valid_mask = id_valid & amount_valid

    df_clean = df_transformed.filter(valid_mask)

    # Deduplication & Normalization
    window_spec = Window.partitionBy("voucher_id").orderBy(F.col("load_dt").desc())

    df_silver_final = df_clean.withColumn("row_rank", F.row_number().over(window_spec)) \
        .filter("row_rank == 1") \
        .drop("row_rank") \
        .withColumn("voucher_code", F.upper(F.trim(F.col("voucher_code")))) \
        .withColumn("load_dt", F.to_timestamp(F.col("load_dt"))) \
        .select("voucher_id", "voucher_code", "discount_amount", "expiry_date", "load_dt")

    return df_silver_final

# --- 2. THE TEST SUITE ---
class TestVouchersSilver(unittest.TestCase):
    def setUp(self):
        """Setup input schema for mock data"""
        self.input_schema = StructType([
            StructField("voucher_id", LongType(), True),
            StructField("voucher_code", StringType(), True),
            StructField("discount_amount", DoubleType(), True),
            StructField("expiry_date", StringType(), True),
            StructField("load_dt", StringType(), True)
        ])

    def test_voucher_normalization_and_deduplication(self):
        # 1. MOCK INPUT DATA
        # Row 1 & 2: Duplicates (Row 2 is newer and has messy code casing)
        # Row 3: Negative discount (should be filtered)
        # Row 4: Missing voucher_id (should be filtered)
        data = [
            (301, "save10", 10.0, "2026-12-31", "2026-01-01 10:00:00"),
            (301, " SAVE10 ", 10.0, "2026-12-31", "2026-01-02 10:00:00"),
            (302, "bad_deal", -5.0, "2026-12-31", "2026-01-02 10:00:00"),
            (None, "no_id", 15.0, "2026-12-31", "2026-01-02 10:00:00")
        ]
        df_input = spark.createDataFrame(data, self.input_schema)

        # 2. DEFINE EXPECTED OUTPUT
        expected_schema = StructType([
            StructField("voucher_id", LongType(), True),
            StructField("voucher_code", StringType(), True),
            StructField("discount_amount", DoubleType(), True),
            StructField("expiry_date", StringType(), True),
            StructField("load_dt", TimestampType(), True)
        ])
        
        # Localize to UTC to prevent TypeError in Databricks Connect
        expected_ts = pd.Timestamp("2026-01-02 10:00:00").tz_localize('UTC')

        expected_data = [
            (301, "SAVE10", 10.0, "2026-12-31", expected_ts)
        ]
        df_expected = spark.createDataFrame(expected_data, expected_schema)

        # 3. RUN TRANSFORMATION
        df_actual = transform_vouchers_silver(df_input)

        # 4. ASSERTIONS
        assertSchemaEqual(df_actual.schema, expected_schema)
        assertDataFrameEqual(df_actual, df_expected)

# --- 3. EXECUTION AND REPORT GENERATION ---
stream = io.StringIO()
runner = unittest.TextTestRunner(stream=stream, verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestVouchersSilver)
result = runner.run(suite)

print(f"""
=========================================
UNIT TEST REPORT - {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
=========================================
Status: {'SUCCESS' if result.wasSuccessful() else 'FAILED'}
Tests Run: {result.testsRun}
Failures: {len(result.failures)}
Errors: {len(result.errors)}
-----------------------------------------
Details:
{stream.getvalue()}
=========================================
""")