In [0]:
import unittest
import pandas as pd
import datetime
import io
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.testing import assertDataFrameEqual, assertSchemaEqual

# --- 1. THE TRANSFORMATION FUNCTION ---
def transform_transaction_items_silver(df_bronze):
    """
    Modular transformation function for unit testing.
    """
    # Standardize column headers to snake_case
    standardized_cols = [col.lower().replace(" ", "_") for col in df_bronze.columns]
    df_transformed = df_bronze.toDF(*standardized_cols)

    # Quality Gates: transaction_id/item_id cannot be null; quantity/price must be > 0
    valid_mask = (F.col("transaction_id").isNotNull()) & \
                 (F.col("item_id").isNotNull()) & \
                 (F.col("quantity") > 0) & \
                 (F.col("price") > 0)

    df_clean = df_transformed.filter(valid_mask)

    # Deduplication: Keep latest record based on load_dt
    window_spec = Window.partitionBy("transaction_id", "item_id").orderBy(F.col("load_dt").desc())

    df_silver_final = df_clean.withColumn("row_rank", F.row_number().over(window_spec)) \
        .filter("row_rank == 1") \
        .drop("row_rank") \
        .withColumn("total_amount", F.round(F.col("quantity") * F.col("price"), 2)) \
        .withColumn("load_dt", F.to_timestamp(F.col("load_dt"))) \
        .select("transaction_id", "item_id", "quantity", "price", "total_amount", "load_dt")

    return df_silver_final

# --- 2. THE TEST SUITE ---
class TestTransactionItemsSilver(unittest.TestCase):
    def setUp(self):
        """Setup input schema for mock data"""
        self.input_schema = StructType([
            StructField("transaction_id", LongType(), True),
            StructField("item_id", LongType(), True),
            StructField("quantity", IntegerType(), True),
            StructField("price", DoubleType(), True),
            StructField("load_dt", StringType(), True)
        ])

    def test_calculation_and_deduplication(self):
        # 1. MOCK INPUT DATA
        # Row 1 & 2: Duplicates for same transaction item (Row 2 is newer)
        # Row 3: Invalid quantity (should be filtered)
        # Row 4: Null transaction_id (should be filtered)
        data = [
            (1001, 55, 2, 10.0, "2026-01-01 10:00:00"),
            (1001, 55, 3, 10.0, "2026-01-02 10:00:00"), 
            (1002, 56, -1, 5.0, "2026-01-02 10:00:00"),
            (None, 57, 1, 20.0, "2026-01-02 10:00:00")
        ]
        df_input = spark.createDataFrame(data, self.input_schema)

        # 2. DEFINE EXPECTED OUTPUT
        expected_schema = StructType([
            StructField("transaction_id", LongType(), True),
            StructField("item_id", LongType(), True),
            StructField("quantity", IntegerType(), True),
            StructField("price", DoubleType(), True),
            StructField("total_amount", DoubleType(), True),
            StructField("load_dt", TimestampType(), True)
        ])
        
        # Localize to UTC to prevent TypeError
        expected_ts = pd.Timestamp("2026-01-02 10:00:00").tz_localize('UTC')

        expected_data = [
            (1001, 55, 3, 10.0, 30.0, expected_ts)
        ]
        df_expected = spark.createDataFrame(expected_data, expected_schema)

        # 3. RUN TRANSFORMATION
        df_actual = transform_transaction_items_silver(df_input)

        # 4. ASSERTIONS
        assertSchemaEqual(df_actual.schema, expected_schema)
        assertDataFrameEqual(df_actual, df_expected)

# --- 3. EXECUTION AND REPORT GENERATION ---
stream = io.StringIO()
runner = unittest.TextTestRunner(stream=stream, verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestTransactionItemsSilver)
result = runner.run(suite)

print(f"""
=========================================
UNIT TEST REPORT - {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
=========================================
Status: {'SUCCESS' if result.wasSuccessful() else 'FAILED'}
Tests Run: {result.testsRun}
Failures: {len(result.failures)}
Errors: {len(result.errors)}
-----------------------------------------
Details:
{stream.getvalue()}
=========================================
""")