In [0]:
dbutils.widgets.dropdown("mode", "pi", ["pi", "comment"])
mode = dbutils.widgets.get("mode")

In [0]:
from pyspark.sql.functions import col, row_number
from pyspark.sql.window import Window

def filter_to_most_recent(df):
    window_spec = Window.partitionBy("table", "column_name").orderBy(col("_created_at").desc())
    df_with_row_num = df.withColumn("row_num", row_number().over(window_spec))
    return df_with_row_num.filter(col("row_num") == 1).drop("row_num")

def run_pi_tests():
    df = spark.table("dbxmetagen.metadata_results.pi_metadata_generation_log")
    df = filter_to_most_recent(df)
    display(df)
    
    distinct_tables = df.select("table").distinct().count()
    assert distinct_tables == 3, f"Expected 3 distinct tables, but found {distinct_tables}"
    
    ddl_types = df.select("ddl_type").distinct().collect()
    ddl_types = [row["ddl_type"] for row in ddl_types]
    assert "column" in ddl_types and "table" in ddl_types, "ddl_type column does not contain both 'column' and 'table' values"
    
    column_classifications = df.filter(col("ddl_type") == "column").select("classification").distinct().collect()
    column_classifications = [row["classification"] for row in column_classifications]
    assert all(c in ["None", "pi"] for c in column_classifications), f"Invalid classification values for columns: {column_classifications}"
    
    table_classifications = df.filter(col("ddl_type") == "table").select("classification").distinct().collect()
    table_classifications = [row["classification"] for row in table_classifications]
    valid_table_classifications = ["pii", "pci", "phi", "medical_information", "None"]
    assert all(c in valid_table_classifications for c in table_classifications), f"Invalid classification values for tables: {table_classifications}"

    test_manual_override("pi", df)

def run_comment_tests():
    df = spark.table("dbxmetagen.metadata_results.comment_metadata_generation_log")
    df = filter_to_most_recent(df)
    display(df)
    
    distinct_tables = df.select("table").distinct().count()
    assert distinct_tables == 3, f"Expected 3 distinct tables, but found {distinct_tables}"
    
    ddl_types = df.select("ddl_type").distinct().collect()
    ddl_types = [row["ddl_type"] for row in ddl_types]
    assert "column" in ddl_types and "table" in ddl_types, "ddl_type column does not contain both 'column' and 'table' values"

    test_manual_override("comment", df)

def test_manual_override(mode, df):
    allow_manual_override = True
    if allow_manual_override and mode == "comment":
        healthcare_test_df = df
        room_number_metadata = healthcare_test_df.filter((col("column_name") == "room_number") & (col("table") == "dbxmetagen.default.healthcare_test")).collect()[0]
        room_number_comment = room_number_metadata["column_content"]
        assert room_number_comment == "Hospital room number - TEST OVERRIDE", f"Expected comment 'Hospital room number - TEST OVERRIDE', but found '{room_number_comment}'"
    elif allow_manual_override and mode == "pi":        
        pi_metadata_df = df
        room_number_metadata = pi_metadata_df.filter((col("table") == "dbxmetagen.default.healthcare_test") & (col("column") == "room_number")).collect()[0]
        
        assert room_number_metadata["classification"] == "pi", f"Expected classification 'pi', but found '{room_number_metadata['classification']}'"
        assert room_number_metadata["type"] == "pii", f"Expected type 'pii', but found '{room_number_metadata['type']}'"
    else:
        raise Exception("Invalid mode provided")

if mode == "pi":
    run_pi_tests()
elif mode == "comment":
    run_comment_tests()

print("All tests passed!")