In [None]:
# =============================================================================
# Data Quality Profiling Notebook
# =============================================================================
# This notebook profiles silver and gold tables for data quality metrics:
# - Row counts
# - Null counts per column
# - Distinct value counts
# - Min/Max values for numeric columns
# - Data completeness percentages
#
# Results are stored in: {catalog}.data_quality.table_profiling

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType, TimestampType
import argparse
from datetime import datetime

def profile_table(spark, catalog_name, schema_name, table_name):
    """Profile a table and return quality metrics."""
    
    full_table_name = f"{catalog_name}.{schema_name}.{table_name}"
    print(f"\nProfiling: {full_table_name}")
    
    try:
        df = spark.table(full_table_name)
        
        # Get row count
        row_count = df.count()
        print(f"  Rows: {row_count:,}")
        
        # Get column count
        column_count = len(df.columns)
        print(f"  Columns: {column_count}")
        
        metrics = []
        
        for col_name in df.columns:
            col_type = dict(df.dtypes)[col_name]
            
            # Calculate null count and null percentage
            null_count = df.filter(F.col(col_name).isNull()).count()
            null_pct = (null_count / row_count * 100) if row_count > 0 else 0
            
            # Calculate distinct count
            distinct_count = df.select(col_name).distinct().count()
            
            # Get min/max for numeric columns
            min_val = None
            max_val = None
            if col_type in ['int', 'bigint', 'double', 'float', 'decimal']:
                stats = df.agg(
                    F.min(col_name).alias('min_val'),
                    F.max(col_name).alias('max_val')
                ).collect()[0]
                min_val = str(stats['min_val']) if stats['min_val'] is not None else None
                max_val = str(stats['max_val']) if stats['max_val'] is not None else None
            
            metrics.append({
                'catalog_name': catalog_name,
                'schema_name': schema_name,
                'table_name': table_name,
                'column_name': col_name,
                'data_type': col_type,
                'total_rows': row_count,
                'null_count': null_count,
                'null_percentage': round(null_pct, 2),
                'distinct_count': distinct_count,
                'min_value': min_val,
                'max_value': max_val,
                'profiling_timestamp': datetime.utcnow(),
                'profiling_run_id': datetime.utcnow().strftime('%Y%m%d_%H%M%S')
            })
        
        print(f"  ✓ Profiled {len(metrics)} columns")
        return metrics
        
    except Exception as e:
        print(f"  ✗ Error profiling table: {str(e)}")
        return []

def main():
    spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
    
    # Get parameters from Databricks widgets
    try:
        from pyspark.dbutils import DBUtils
        dbutils = DBUtils(spark)
        catalog_name = dbutils.widgets.get("catalog_name")
    except Exception:
        parser = argparse.ArgumentParser()
        parser.add_argument("--catalog_name", type=str, required=True)
        args, _ = parser.parse_known_args()
        catalog_name = args.catalog_name
    
    print("="*70)
    print("DATA QUALITY PROFILING")
    print("="*70)
    print(f"Catalog: {catalog_name}")
    print(f"Target Schema: data_quality")
    print("="*70)
    
    # Define tables to profile
    tables_to_profile = [
        # Silver tables
        ('synthea', 'claims_silver'),
        ('synthea', 'medications_silver'),
        ('synthea', 'patient_encounters_silver'),
        ('synthea', 'procedures_silver'),
        # Gold tables
        ('synthea', 'member_monthly_snapshot_gold'),
    ]
    
    # Collect all metrics
    all_metrics = []
    
    for schema_name, table_name in tables_to_profile:
        metrics = profile_table(spark, catalog_name, schema_name, table_name)
        all_metrics.extend(metrics)
    
    if not all_metrics:
        print("\n⚠️  No metrics collected")
        return
    
    # Convert to DataFrame
    metrics_df = spark.createDataFrame(all_metrics)
    
    # Create data_quality schema if it doesn't exist
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.data_quality")
    print(f"\n✓ Schema created/verified: {catalog_name}.data_quality")
    
    # Append to data quality table
    target_table = f"{catalog_name}.data_quality.table_profiling"
    
    metrics_df.write \
        .mode("append") \
        .format("delta") \
        .option("mergeSchema", "true") \
        .saveAsTable(target_table)
    
    total_metrics = len(all_metrics)
    total_tables = len(tables_to_profile)
    
    print(f"\n{'='*70}")
    print(f"✓ Data quality profiling complete!")
    print(f"  Tables profiled: {total_tables}")
    print(f"  Metrics collected: {total_metrics:,}")
    print(f"  Results stored in: {target_table}")
    print(f"{'='*70}\n")
    
    # Show summary
    print("Summary by table:")
    summary = metrics_df.groupBy('schema_name', 'table_name') \
        .agg(
            F.first('total_rows').alias('row_count'),
            F.count('column_name').alias('column_count'),
            F.avg('null_percentage').alias('avg_null_pct')
        ) \
        .orderBy('schema_name', 'table_name')
    
    summary.show(truncate=False)

if __name__ == "__main__":
    main()

