# OMOP ETL Data Quality Analysis and Cleaning Pipeline

## Overview

This notebook demonstrates a comprehensive data quality analysis and cleaning pipeline for OMOP Common Data Model (CDM) tables. The pipeline processes raw OMOP data, identifies quality issues, and produces a clean dataset suitable for research and analysis.

**Key Features:**
- Automated schema validation and type checking
- Comprehensive data quality metrics calculation
- Statistical anomaly detection
- Foreign key validation
- Date consistency checks
- Duplicate detection and removal
- Detailed quality reporting

**OMOP Tables Processed:**
- Person
- Concept
- Condition Occurrence
- Drug Exposure
- Procedure Occurrence
- Visit Occurrence
- Measurement
- Observation Period

---

## 1. Import Necessary Libraries

We begin by importing all required PySpark and Python libraries for data processing and analysis.

In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import (
    avg, col, min as spark_min, max as spark_max, count, when, to_date, to_timestamp, year, month,
    concat_ws, lit, lag, collect_list, first, last, explode, sum as spark_sum,
    date_add, datediff, coalesce, floor
)
from pyspark.sql.window import Window
from typing import Dict
import logging
from datetime import datetime
import builtins  # To use built-in functions like round
from functools import reduce

## 2. Define SchemaManager Class

The `SchemaManager` class contains the schema definitions for all OMOP CDM tables, including:
- Primary keys
- Foreign key relationships
- Date/timestamp columns
- Data type specifications (integer, float, string)

This provides a single source of truth for table structure and enables automated validation.

In [None]:
class SchemaManager:
    @staticmethod 
    def get_schema_info() -> dict:
        """Returns schema definitions for OMOP CDM tables."""
        return {
            "person": {
                "key_cols": ["PERSON_ID"],
                "foreign_keys": {},
                "date_cols": ["DEATH_DATETIME"],
                "timestamp_cols": ["DT_CREATED", "DT_MODIFIED"],
                "int_cols": [
                    "PERSON_ID", "YEAR_OF_BIRTH", "MONTH_OF_BIRTH", "DAY_OF_BIRTH",
                    "RACE_CONCEPT_ID", "ETHNICITY_CONCEPT_ID", "LOCATION_ID", 
                    "PROVIDER_ID", "CARE_SITE_ID"
                ],
                "float_cols": [],
                "string_cols": [
                    "GENDER_CONCEPT_ID", "PERSON_SOURCE_VALUE", "GENDER_SOURCE_VALUE",
                    "RACE_SOURCE_VALUE", "ETHNICITY_SOURCE_VALUE", "CREATED_BY",
                    "MODIFIED_BY", "HASH_COLUMN"
                ]
            },
            "concept": {
                "key_cols": ["CONCEPT_ID"],
                "foreign_keys": {},
                "date_cols": ["VALID_START_DATE", "VALID_END_DATE"],
                "timestamp_cols": ["DT_CREATED", "DT_MODIFIED"],
                "int_cols": ["CONCEPT_CLASS_ID"],
                "float_cols": [],
                "string_cols": [
                    "CONCEPT_ID", "CONCEPT_NAME", "DOMAIN_ID", "VOCABULARY_ID",
                    "STANDARD_CONCEPT", "CONCEPT_CODE", "INVALID_REASON",
                    "CREATED_BY", "MODIFIED_BY", "HASH_COLUMN"
                ]
            },
            "condition_occurrence": {
                "key_cols": ["CONDITION_OCCURRENCE_ID"],
                "foreign_keys": {
                    "PERSON_ID": "person",
                    "VISIT_OCCURRENCE_ID": "visit_occurrence"
                },
                "date_cols": ["CONDITION_START_DATE", "CONDITION_END_DATE"],
                "timestamp_cols": ["CONDITION_START_DATETIME", "CONDITION_END_DATETIME", "DT_CREATED", "DT_MODIFIED"],
                "int_cols": [
                    "CONDITION_OCCURRENCE_ID", "PERSON_ID", "VISIT_OCCURRENCE_ID",
                    "CONDITION_TYPE_CONCEPT_ID", "PROVIDER_ID",
                    "VISIT_DETAIL_ID", "CONDITION_STATUS_CONCEPT_ID"
                ],
                "float_cols": [],
                "string_cols": [
                    "CONDITION_CONCEPT_ID", "STOP_REASON", "CONDITION_SOURCE_CONCEPT_ID",
                    "CONDITION_SOURCE_VALUE", "CONDITION_STATUS_SOURCE_VALUE",
                    "CREATED_BY", "MODIFIED_BY", "HASH_COLUMN"
                ]
            },
            "drug_exposure": {
                "key_cols": ["DRUG_EXPOSURE_ID"],
                "foreign_keys": {
                    "PERSON_ID": "person",
                    "VISIT_OCCURRENCE_ID": "visit_occurrence"
                },
                "date_cols": [
                    "DRUG_EXPOSURE_START_DATE", "DRUG_EXPOSURE_END_DATE", "VERBATIM_END_DATE"
                ],
                "timestamp_cols": [
                    "DRUG_EXPOSURE_START_DATETIME", "DRUG_EXPOSURE_END_DATETIME",
                    "DT_CREATED", "DT_MODIFIED"
                ],
                "int_cols": [
                    "DRUG_EXPOSURE_ID", "PERSON_ID", "VISIT_OCCURRENCE_ID",
                    "DRUG_TYPE_CONCEPT_ID", "PROVIDER_ID", "VISIT_DETAIL_ID",
                    "REFILLS", "DAYS_SUPPLY", "ROUTE_CONCEPT_ID"
                ],
                "float_cols": [
                    "QUANTITY"
                ],
                "string_cols": [
                    "DRUG_CONCEPT_ID", "STOP_REASON", "LOT_NUMBER",
                    "SIG", "ROUTE_SOURCE_VALUE", "DOSE_UNIT_SOURCE_VALUE",
                    "CREATED_BY", "MODIFIED_BY", "HASH_COLUMN"
                ]
            },
            "procedure_occurrence": {
                "key_cols": ["PROCEDURE_OCCURRENCE_ID"],
                "foreign_keys": {
                    "PERSON_ID": "person",
                    "VISIT_OCCURRENCE_ID": "visit_occurrence"
                },
                "date_cols": ["PROCEDURE_DATE"],
                "timestamp_cols": ["PROCEDURE_DATETIME", "DT_CREATED", "DT_MODIFIED"],
                "int_cols": [
                    "PROCEDURE_OCCURRENCE_ID", "PERSON_ID", "VISIT_OCCURRENCE_ID",
                    "PROCEDURE_TYPE_CONCEPT_ID", "MODIFIER_CONCEPT_ID",
                    "QUANTITY", "PROVIDER_ID", "VISIT_DETAIL_ID"
                ],
                "float_cols": [],
                "string_cols": [
                    "PROCEDURE_CONCEPT_ID", "PROCEDURE_SOURCE_CONCEPT_ID",
                    "PROCEDURE_SOURCE_VALUE", "MODIFIER_SOURCE_VALUE",
                    "CREATED_BY", "MODIFIED_BY", "HASH_COLUMN"
                ]
            },
            "visit_occurrence": {
                "key_cols": ["VISIT_OCCURRENCE_ID"],
                "foreign_keys": {
                    "PERSON_ID": "person"
                },
                "date_cols": ["VISIT_START_DATE", "VISIT_END_DATE"],
                "timestamp_cols": ["VISIT_START_DATETIME", "VISIT_END_DATETIME", "DT_CREATED", "DT_MODIFIED"],
                "int_cols": [
                    "VISIT_OCCURRENCE_ID", "PERSON_ID", "VISIT_TYPE_CONCEPT_ID",
                    "PROVIDER_ID", "CARE_SITE_ID", "ADMITTED_FROM_CONCEPT_ID",
                    "DISCHARGED_TO_CONCEPT_ID", "PRECEDING_VISIT_OCCURRENCE_ID"
                ],
                "float_cols": [],
                "string_cols": [
                    "VISIT_CONCEPT_ID", "VISIT_SOURCE_VALUE", "VISIT_SOURCE_CONCEPT_ID",
                    "ADMITTED_FROM_SOURCE_VALUE", "DISCHARGED_TO_SOURCE_VALUE",
                    "CREATED_BY", "MODIFIED_BY", "HASH_COLUMN"
                ]
            },
            "measurement": {
                "key_cols": ["MEASUREMENT_ID"],
                "foreign_keys": {
                    "PERSON_ID": "person",
                    "VISIT_OCCURRENCE_ID": "visit_occurrence"
                },
                "date_cols": ["MEASUREMENT_DATE"],
                "timestamp_cols": ["MEASUREMENT_DATETIME", "DT_CREATED", "DT_MODIFIED"],
                "int_cols": [
                    "MEASUREMENT_ID", "PERSON_ID", "VISIT_OCCURRENCE_ID",
                    "MEASUREMENT_TYPE_CONCEPT_ID", "OPERATOR_CONCEPT_ID",
                    "VALUE_AS_CONCEPT_ID", "UNIT_CONCEPT_ID", "PROVIDER_ID",
                    "VISIT_DETAIL_ID", "MEASUREMENT_EVENT_ID"
                ],
                "float_cols": [
                    "VALUE_AS_NUMBER", "RANGE_LOW", "RANGE_HIGH"
                ],
                "string_cols": [
                    "MEASUREMENT_CONCEPT_ID", "MEASUREMENT_SOURCE_VALUE",
                    "MEASUREMENT_SOURCE_CONCEPT_ID", "UNIT_SOURCE_VALUE",
                    "VALUE_SOURCE_VALUE", "MEAS_EVENT_FIELD_CONCEPT_ID",
                    "CREATED_BY", "MODIFIED_BY", "HASH_COLUMN"
                ]
            },
            "observation_period": {
                "key_cols": ["OBSERVATION_PERIOD_ID"],
                "foreign_keys": {
                    "PERSON_ID": "person"
                },
                "date_cols": ["OBSERVATION_PERIOD_START_DATE", "OBSERVATION_PERIOD_END_DATE"],
                "timestamp_cols": ["DT_CREATED", "DT_MODIFIED"],
                "int_cols": [
                    "OBSERVATION_PERIOD_ID", "PERSON_ID", "PERIOD_TYPE_CONCEPT_ID"
                ],
                "float_cols": [],
                "string_cols": [
                    "CREATED_BY", "MODIFIED_BY", "HASH_COLUMN"
                ]
            }
        }

## 3. Define OMOPDataCleaner Class

The `OMOPDataCleaner` class is the core of our data quality pipeline. It provides methods for:

### Key Functionality:

1. **Data Quality Metrics**: Calculate comprehensive quality statistics
   - Null rates
   - Uniqueness
   - Completeness
   - Statistical measures (mean, std dev, min, max)

2. **Anomaly Detection**: Identify outliers and data quality issues
   - Statistical outliers (±3 standard deviations)
   - Invalid date ranges
   - Missing required values

3. **Validation**: Ensure data integrity
   - Primary key uniqueness
   - Foreign key referential integrity
   - Date consistency checks

4. **Cleaning Operations**: Remove or fix problematic records
   - Duplicate removal
   - Orphaned record handling
   - Invalid date correction

In [None]:
class OMOPDataCleaner:
    def __init__(self, spark: SparkSession, schema_info: dict):
        """Initialize the data cleaner with Spark session and schema information.
        
        Args:
            spark: Active SparkSession
            schema_info: Dictionary containing schema definitions for all tables
        """
        self.spark = spark
        self.schema_info = schema_info
        self.logger = self._setup_logger()
        
    def _setup_logger(self) -> logging.Logger:
        """Set up logging configuration."""
        logger = logging.getLogger('OMOPDataCleaner')
        logger.setLevel(logging.INFO)
        return logger
    
    def calculate_quality_metrics(self, df: DataFrame, table_name: str) -> dict:
        """Calculate comprehensive data quality metrics for a given table.
        
        Metrics calculated:
        - Total record count
        - Null rates per column
        - Unique value counts
        - Statistical measures (for numeric columns)
        - Completeness scores
        
        Args:
            df: Input DataFrame
            table_name: Name of the OMOP table
            
        Returns:
            Dictionary containing quality metrics
        """
        total_rows = df.count()
        metrics = {
            'table_name': table_name,
            'total_rows': total_rows,
            'column_metrics': {}
        }
        
        for column in df.columns:
            # Calculate null rate
            null_count = df.filter(col(column).isNull()).count()
            null_rate = (null_count / total_rows * 100) if total_rows > 0 else 0
            
            # Calculate unique values
            unique_count = df.select(column).distinct().count()
            
            column_metrics = {
                'null_count': null_count,
                'null_rate': builtins.round(null_rate, 2),
                'unique_count': unique_count,
                'completeness': builtins.round((1 - null_rate/100) * 100, 2)
            }
            
            # Add statistics for numeric columns
            if column in self.schema_info[table_name].get('int_cols', []) or \
               column in self.schema_info[table_name].get('float_cols', []):
                stats = df.select(
                    avg(col(column)).alias('mean'),
                    spark_min(col(column)).alias('min'),
                    spark_max(col(column)).alias('max')
                ).collect()[0]
                
                column_metrics.update({
                    'mean': builtins.round(float(stats['mean']), 2) if stats['mean'] else None,
                    'min': stats['min'],
                    'max': stats['max']
                })
            
            metrics['column_metrics'][column] = column_metrics
        
        return metrics
    
    def detect_anomalies(self, df: DataFrame, table_name: str) -> dict:
        """Detect anomalies and data quality issues in the dataset.
        
        Checks performed:
        - Statistical outliers (values beyond 3 standard deviations)
        - Invalid date ranges (end date before start date)
        - Future dates
        - Missing required foreign keys
        
        Args:
            df: Input DataFrame
            table_name: Name of the OMOP table
            
        Returns:
            Dictionary containing anomaly counts and details
        """
        anomalies = {
            'table_name': table_name,
            'statistical_outliers': {},
            'date_inconsistencies': {},
            'future_dates': {}
        }
        
        # Check for statistical outliers in numeric columns
        numeric_cols = self.schema_info[table_name].get('int_cols', []) + \
                      self.schema_info[table_name].get('float_cols', [])
        
        for column in numeric_cols:
            if column in df.columns:
                # Calculate mean and standard deviation
                stats = df.select(
                    avg(col(column)).alias('mean'),
                    avg((col(column) - avg(col(column))) ** 2).alias('variance')
                ).collect()[0]
                
                if stats['mean'] and stats['variance']:
                    mean_val = stats['mean']
                    std_dev = stats['variance'] ** 0.5
                    
                    # Count outliers (beyond 3 standard deviations)
                    outlier_count = df.filter(
                        (col(column) < mean_val - 3 * std_dev) | 
                        (col(column) > mean_val + 3 * std_dev)
                    ).count()
                    
                    if outlier_count > 0:
                        anomalies['statistical_outliers'][column] = outlier_count
        
        # Check for date inconsistencies
        date_cols = self.schema_info[table_name].get('date_cols', [])
        
        # Check if end dates are before start dates
        if 'CONDITION_START_DATE' in df.columns and 'CONDITION_END_DATE' in df.columns:
            invalid_dates = df.filter(
                col('CONDITION_END_DATE') < col('CONDITION_START_DATE')
            ).count()
            if invalid_dates > 0:
                anomalies['date_inconsistencies']['condition_dates'] = invalid_dates
        
        # Similar checks for other date pairs
        date_pairs = [
            ('VISIT_START_DATE', 'VISIT_END_DATE'),
            ('DRUG_EXPOSURE_START_DATE', 'DRUG_EXPOSURE_END_DATE'),
            ('OBSERVATION_PERIOD_START_DATE', 'OBSERVATION_PERIOD_END_DATE')
        ]
        
        for start_col, end_col in date_pairs:
            if start_col in df.columns and end_col in df.columns:
                invalid_count = df.filter(
                    col(end_col).isNotNull() & 
                    col(start_col).isNotNull() &
                    (col(end_col) < col(start_col))
                ).count()
                if invalid_count > 0:
                    anomalies['date_inconsistencies'][f'{start_col}_{end_col}'] = invalid_count
        
        return anomalies
    
    def validate_foreign_keys(self, df: DataFrame, table_name: str, 
                            reference_tables: dict) -> dict:
        """Validate foreign key relationships.
        
        Checks that all foreign key values exist in the referenced tables.
        
        Args:
            df: Input DataFrame
            table_name: Name of the current table
            reference_tables: Dictionary of reference table DataFrames
            
        Returns:
            Dictionary containing validation results and orphaned record counts
        """
        validation_results = {
            'table_name': table_name,
            'foreign_key_violations': {}
        }
        
        foreign_keys = self.schema_info[table_name].get('foreign_keys', {})
        
        for fk_col, ref_table in foreign_keys.items():
            if fk_col in df.columns and ref_table in reference_tables:
                ref_df = reference_tables[ref_table]
                ref_key = self.schema_info[ref_table]['key_cols'][0]
                
                # Find orphaned records (foreign keys with no matching reference)
                orphaned = df.join(
                    ref_df.select(col(ref_key).alias('ref_key')),
                    df[fk_col] == col('ref_key'),
                    'left_anti'
                ).filter(col(fk_col).isNotNull()).count()
                
                if orphaned > 0:
                    validation_results['foreign_key_violations'][fk_col] = {
                        'reference_table': ref_table,
                        'orphaned_count': orphaned
                    }
        
        return validation_results
    
    def remove_duplicates(self, df: DataFrame, table_name: str) -> DataFrame:
        """Remove duplicate records based on primary keys.
        
        Args:
            df: Input DataFrame
            table_name: Name of the OMOP table
            
        Returns:
            DataFrame with duplicates removed
        """
        key_cols = self.schema_info[table_name]['key_cols']
        
        initial_count = df.count()
        
        # Remove duplicates keeping the first occurrence
        df_cleaned = df.dropDuplicates(key_cols)
        
        final_count = df_cleaned.count()
        duplicates_removed = initial_count - final_count
        
        if duplicates_removed > 0:
            self.logger.info(
                f"{table_name}: Removed {duplicates_removed} duplicate records "
                f"({builtins.round(duplicates_removed/initial_count*100, 2)}%)"
            )
        
        return df_cleaned
    
    def clean_table(self, df: DataFrame, table_name: str, 
                   reference_tables: dict = None) -> DataFrame:
        """Execute the complete cleaning pipeline for a table.
        
        Steps:
        1. Remove duplicates
        2. Validate foreign keys (if reference tables provided)
        3. Remove records with invalid foreign keys
        4. Validate date consistency
        
        Args:
            df: Input DataFrame
            table_name: Name of the OMOP table
            reference_tables: Dictionary of reference DataFrames for FK validation
            
        Returns:
            Cleaned DataFrame
        """
        self.logger.info(f"Starting cleaning process for {table_name}")
        initial_count = df.count()
        
        # Step 1: Remove duplicates
        df = self.remove_duplicates(df, table_name)
        
        # Step 2: Remove records with null primary keys
        key_cols = self.schema_info[table_name]['key_cols']
        for key_col in key_cols:
            if key_col in df.columns:
                df = df.filter(col(key_col).isNotNull())
        
        # Step 3: Validate and clean foreign keys
        if reference_tables:
            foreign_keys = self.schema_info[table_name].get('foreign_keys', {})
            for fk_col, ref_table in foreign_keys.items():
                if fk_col in df.columns and ref_table in reference_tables:
                    ref_df = reference_tables[ref_table]
                    ref_key = self.schema_info[ref_table]['key_cols'][0]
                    
                    # Keep only records with valid foreign keys or null foreign keys
                    valid_fk_df = df.join(
                        ref_df.select(col(ref_key).alias('ref_key')),
                        df[fk_col] == col('ref_key'),
                        'left_semi'
                    )
                    
                    null_fk_df = df.filter(col(fk_col).isNull())
                    df = valid_fk_df.union(null_fk_df)
        
        final_count = df.count()
        records_removed = initial_count - final_count
        
        self.logger.info(
            f"{table_name}: Cleaning complete. "
            f"Records: {initial_count} → {final_count} "
            f"(removed {records_removed}, {builtins.round(records_removed/initial_count*100, 2)}%)"
        )
        
        return df

## 4. Load Raw OMOP Data

Load the raw OMOP CDM tables from your data source. The tables should be in Parquet format stored in a designated directory.

**Note**: Replace the `base_path` with your actual data directory path.

In [None]:
# Initialize Spark session (if not already initialized)
# spark = SparkSession.builder.appName("OMOP_ETL_Cleaning").getOrCreate()

# Define base path for raw OMOP data
# REPLACE WITH YOUR ACTUAL PATH
raw_data_path = "/path/to/your/raw/omop/data/"

# List of tables to process
table_names = [
    "person",
    "concept",
    "condition_occurrence",
    "drug_exposure",
    "procedure_occurrence",
    "visit_occurrence",
    "measurement",
    "observation_period"
]

# Load all tables into a dictionary
raw_tables = {}

print("Loading raw OMOP tables...\n")
for table_name in table_names:
    try:
        df = spark.read.parquet(f"{raw_data_path}{table_name}")
        raw_tables[table_name] = df
        print(f"✓ Loaded '{table_name}' with {df.count():,} records")
    except Exception as e:
        print(f"✗ Error loading '{table_name}': {str(e)}")

print(f"\nSuccessfully loaded {len(raw_tables)} tables")

## 5. Initialize Data Cleaner

Create an instance of the `OMOPDataCleaner` with the schema information.

In [None]:
# Get schema information
schema_info = SchemaManager.get_schema_info()

# Initialize the data cleaner
cleaner = OMOPDataCleaner(spark, schema_info)

print("Data cleaner initialized successfully")

## 6. Analyze Data Quality (Before Cleaning)

Calculate quality metrics for all tables before cleaning to establish a baseline.

In [None]:
print("Calculating data quality metrics...\n")
print("=" * 80)

quality_metrics_before = {}

for table_name, df in raw_tables.items():
    print(f"\nAnalyzing {table_name}...")
    
    # Calculate quality metrics
    metrics = cleaner.calculate_quality_metrics(df, table_name)
    quality_metrics_before[table_name] = metrics
    
    # Print summary
    print(f"  Total Records: {metrics['total_rows']:,}")
    
    # Find columns with high null rates
    high_null_cols = [
        (col_name, col_metrics['null_rate']) 
        for col_name, col_metrics in metrics['column_metrics'].items()
        if col_metrics['null_rate'] > 50
    ]
    
    if high_null_cols:
        print(f"  Columns with >50% null values: {len(high_null_cols)}")
        for col_name, null_rate in sorted(high_null_cols, key=lambda x: x[1], reverse=True)[:5]:
            print(f"    - {col_name}: {null_rate}%")
    
    # Detect anomalies
    anomalies = cleaner.detect_anomalies(df, table_name)
    
    if anomalies['statistical_outliers']:
        print(f"  Statistical outliers detected in {len(anomalies['statistical_outliers'])} columns")
    
    if anomalies['date_inconsistencies']:
        print(f"  Date inconsistencies found: {sum(anomalies['date_inconsistencies'].values())} records")

print("\n" + "=" * 80)
print("Data quality analysis complete")

## 7. Clean OMOP Tables

Execute the cleaning pipeline on all tables. The cleaning process follows a specific order to maintain referential integrity:

1. **Core tables** (no foreign keys): person, concept
2. **Visit table**: visit_occurrence (references person)
3. **Clinical event tables**: condition_occurrence, drug_exposure, procedure_occurrence, measurement (reference person and visit)
4. **Observation period**: observation_period (references person)

In [None]:
print("Starting data cleaning process...\n")
print("=" * 80)

cleaned_tables = {}

# Phase 1: Clean core tables (no foreign keys)
print("\nPhase 1: Cleaning core tables...")
for table_name in ['person', 'concept']:
    if table_name in raw_tables:
        print(f"\nCleaning {table_name}...")
        cleaned_tables[table_name] = cleaner.clean_table(
            raw_tables[table_name], 
            table_name
        )

# Phase 2: Clean tables with foreign key to person
print("\n" + "-" * 80)
print("Phase 2: Cleaning tables with person reference...")
for table_name in ['visit_occurrence', 'observation_period']:
    if table_name in raw_tables:
        print(f"\nCleaning {table_name}...")
        cleaned_tables[table_name] = cleaner.clean_table(
            raw_tables[table_name],
            table_name,
            reference_tables={'person': cleaned_tables['person']}
        )

# Phase 3: Clean clinical event tables (reference person and visit)
print("\n" + "-" * 80)
print("Phase 3: Cleaning clinical event tables...")
clinical_tables = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence', 'measurement']

reference_dict = {
    'person': cleaned_tables['person'],
    'visit_occurrence': cleaned_tables.get('visit_occurrence')
}

for table_name in clinical_tables:
    if table_name in raw_tables:
        print(f"\nCleaning {table_name}...")
        cleaned_tables[table_name] = cleaner.clean_table(
            raw_tables[table_name],
            table_name,
            reference_tables=reference_dict
        )

print("\n" + "=" * 80)
print(f"\nCleaning complete! Processed {len(cleaned_tables)} tables")

## 8. Analyze Data Quality (After Cleaning)

Recalculate quality metrics to assess the impact of the cleaning process.

In [None]:
print("Calculating post-cleaning quality metrics...\n")
print("=" * 80)

quality_metrics_after = {}

print("\n{:<30} {:>15} {:>15} {:>15}".format(
    "Table", "Before", "After", "Change (%)"
))
print("-" * 80)

for table_name, df in cleaned_tables.items():
    # Calculate metrics
    metrics = cleaner.calculate_quality_metrics(df, table_name)
    quality_metrics_after[table_name] = metrics
    
    # Compare before and after
    before_count = quality_metrics_before[table_name]['total_rows']
    after_count = metrics['total_rows']
    change_pct = ((after_count - before_count) / before_count * 100) if before_count > 0 else 0
    
    print("{:<30} {:>15,} {:>15,} {:>14.2f}%".format(
        table_name,
        before_count,
        after_count,
        change_pct
    ))

print("\n" + "=" * 80)

## 9. Validate Foreign Key Integrity

After cleaning, verify that all foreign key relationships are valid.

In [None]:
print("Validating foreign key integrity...\n")
print("=" * 80)

validation_results = {}

for table_name, df in cleaned_tables.items():
    if schema_info[table_name].get('foreign_keys'):
        print(f"\nValidating {table_name}...")
        
        results = cleaner.validate_foreign_keys(
            df, 
            table_name, 
            cleaned_tables
        )
        validation_results[table_name] = results
        
        if results['foreign_key_violations']:
            print(f"  ✗ Foreign key violations found:")
            for fk, details in results['foreign_key_violations'].items():
                print(f"    - {fk} → {details['reference_table']}: "
                      f"{details['orphaned_count']} orphaned records")
        else:
            print(f"  ✓ All foreign keys valid")

print("\n" + "=" * 80)
print("Foreign key validation complete")

## 10. Export Cleaned Data

Save the cleaned tables to Parquet format for future use.

**Note**: Replace the `output_path` with your desired output directory.

In [None]:
# Define output path for cleaned data
# REPLACE WITH YOUR DESIRED OUTPUT PATH
output_path = "/path/to/your/cleaned/omop/data/"

print("Exporting cleaned tables...\n")
print("=" * 80)

def export_to_parquet(df, table_name):
    """Export DataFrame to Parquet format."""
    if df is not None:
        output_file = f"{output_path}{table_name}"
        print(f"Exporting '{table_name}' to {output_file}...")
        
        df.write \
          .mode("overwrite") \
          .parquet(output_file)
        
        print(f"  ✓ Successfully exported '{table_name}' ({df.count():,} records)\n")
    else:
        print(f"  ✗ DataFrame for '{table_name}' is None. Skipping.\n")

# Export all cleaned tables
for table_name, df in cleaned_tables.items():
    export_to_parquet(df, table_name)

print("=" * 80)
print(f"\nAll cleaned tables exported to: {output_path}")

## 11. Generate Summary Report

Create a comprehensive summary of the cleaning process, including:
- Record count changes
- Data quality improvements
- Issues identified and resolved

In [None]:
print("\n" + "=" * 80)
print(" " * 20 + "CLEANING SUMMARY REPORT")
print("=" * 80)

total_records_before = sum(m['total_rows'] for m in quality_metrics_before.values())
total_records_after = sum(m['total_rows'] for m in quality_metrics_after.values())
total_removed = total_records_before - total_records_after
removal_percentage = (total_removed / total_records_before * 100) if total_records_before > 0 else 0

print(f"\nOverall Statistics:")
print(f"  Total records before cleaning: {total_records_before:,}")
print(f"  Total records after cleaning:  {total_records_after:,}")
print(f"  Records removed:               {total_removed:,} ({removal_percentage:.2f}%)")
print(f"  Number of tables processed:    {len(cleaned_tables)}")

print("\n" + "-" * 80)
print("\nPer-Table Summary:")
print(f"\n{:<30} {:>12} {:>12} {:>12} {:>12}".format(
    "Table", "Before", "After", "Removed", "% Removed"
))
print("-" * 80)

for table_name in cleaned_tables.keys():
    before = quality_metrics_before[table_name]['total_rows']
    after = quality_metrics_after[table_name]['total_rows']
    removed = before - after
    pct_removed = (removed / before * 100) if before > 0 else 0
    
    print(f"{table_name:<30} {before:>12,} {after:>12,} {removed:>12,} {pct_removed:>11.2f}%")

print("\n" + "=" * 80)
print("\nData Cleaning Pipeline Complete!")
print("\nCleaned dataset is ready for analysis and research.")
print("=" * 80)

## 12. Optional: Load and Verify Cleaned Data

As a final verification step, you can reload the cleaned data and perform spot checks.

In [None]:
# Uncomment to load and verify cleaned data

# print("Loading cleaned tables for verification...\n")

# loaded_tables = {}

# for table_name in table_names:
#     try:
#         df = spark.read.parquet(f"{output_path}{table_name}")
#         loaded_tables[table_name] = df
#         print(f"✓ Loaded '{table_name}' with {df.count():,} records")
#     except Exception as e:
#         print(f"✗ Error loading '{table_name}': {str(e)}")

# # Example: Display sample from person table
# if 'person' in loaded_tables:
#     print("\nSample from person table:")
#     loaded_tables['person'].show(5, truncate=False)

---

## Notes and Recommendations

### Cleaning Strategy
The cleaning pipeline follows these key principles:

1. **Preserve Data Integrity**: Foreign key relationships are maintained throughout the cleaning process
2. **Order of Operations**: Tables are cleaned in dependency order to prevent orphaned records
3. **Conservative Approach**: Only clearly invalid data is removed (duplicates, null primary keys, orphaned foreign keys)

### Data Quality Considerations

**What Was Removed:**
- Duplicate records (based on primary keys)
- Records with null primary keys
- Orphaned records (foreign keys pointing to non-existent records)

**What Was Preserved:**
- Records with high null rates in optional fields
- Statistical outliers (flagged but not removed)
- Records with date inconsistencies (flagged but not removed)


```

---