# OMOP Data Cleaning Process

## Overview

This notebook documents the data cleaning process applied to raw OMOP CDM data to create a clean, analysis-ready dataset.

### What Was Done

The cleaning process consisted of three main phases:

1. **Schema Validation**: Verified data types and column presence
2. **Data Quality Cleaning**: Removed duplicates, null keys, and orphaned records
3. **Referential Integrity**: Ensured foreign key relationships were valid

### Tables Processed

- person
- concept
- visit_occurrence
- condition_occurrence
- drug_exposure
- procedure_occurrence
- measurement
- observation_period

---

## 1. Import Libraries

In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col, count, when
import logging
from typing import Dict

# Initialize Spark
spark = SparkSession.builder.appName("OMOP_Data_Cleaning").getOrCreate()

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

## 2. Define OMOP Schema

The schema definitions specify:
- **key_cols**: Primary key columns
- **foreign_keys**: References to other tables
- **date_cols**: Date columns for validation
- **timestamp_cols**: Timestamp columns for validation

This schema information drives the cleaning process.

In [None]:
# Schema definition for OMOP tables
OMOP_SCHEMA = {
    "person": {
        "key_cols": ["PERSON_ID"],
        "foreign_keys": {}
    },
    "concept": {
        "key_cols": ["CONCEPT_ID"],
        "foreign_keys": {}
    },
    "visit_occurrence": {
        "key_cols": ["VISIT_OCCURRENCE_ID"],
        "foreign_keys": {
            "PERSON_ID": "person"
        }
    },
    "condition_occurrence": {
        "key_cols": ["CONDITION_OCCURRENCE_ID"],
        "foreign_keys": {
            "PERSON_ID": "person",
            "VISIT_OCCURRENCE_ID": "visit_occurrence"
        }
    },
    "drug_exposure": {
        "key_cols": ["DRUG_EXPOSURE_ID"],
        "foreign_keys": {
            "PERSON_ID": "person",
            "VISIT_OCCURRENCE_ID": "visit_occurrence"
        }
    },
    "procedure_occurrence": {
        "key_cols": ["PROCEDURE_OCCURRENCE_ID"],
        "foreign_keys": {
            "PERSON_ID": "person",
            "VISIT_OCCURRENCE_ID": "visit_occurrence"
        }
    },
    "measurement": {
        "key_cols": ["MEASUREMENT_ID"],
        "foreign_keys": {
            "PERSON_ID": "person",
            "VISIT_OCCURRENCE_ID": "visit_occurrence"
        }
    },
    "observation_period": {
        "key_cols": ["OBSERVATION_PERIOD_ID"],
        "foreign_keys": {
            "PERSON_ID": "person"
        }
    }
}

print(f"Schema defined for {len(OMOP_SCHEMA)} tables")

## 3. Load Raw Data

Raw OMOP tables were loaded from Parquet files.

In [None]:
# Path to raw data (MODIFY THIS)
RAW_DATA_PATH = "/path/to/your/raw/omop/data/"

# Load raw tables
raw_tables = {}
table_names = list(OMOP_SCHEMA.keys())

for table_name in table_names:
    try:
        df = spark.read.parquet(f"{RAW_DATA_PATH}{table_name}")
        raw_tables[table_name] = df
        logger.info(f"Loaded {table_name}: {df.count():,} records")
    except Exception as e:
        logger.error(f"Error loading {table_name}: {str(e)}")

print(f"\nLoaded {len(raw_tables)} tables")

## 4. Data Quality Metrics (Before Cleaning)

Initial data quality assessment calculated for each table:
- Total record count
- Null rates per column
- Unique value counts

In [None]:
def calculate_quality_metrics(df: DataFrame, table_name: str) -> Dict:
    """Calculate basic quality metrics for a table."""
    total_rows = df.count()
    
    metrics = {
        'table_name': table_name,
        'total_rows': total_rows,
        'columns': len(df.columns)
    }
    
    # Calculate null rates for key columns
    if table_name in OMOP_SCHEMA:
        key_cols = OMOP_SCHEMA[table_name]['key_cols']
        for col_name in key_cols:
            if col_name in df.columns:
                null_count = df.filter(col(col_name).isNull()).count()
                metrics[f'{col_name}_null_rate'] = (null_count / total_rows * 100) if total_rows > 0 else 0
    
    return metrics

# Calculate initial metrics
print("\nInitial Data Quality Metrics:")
print("=" * 80)
for table_name, df in raw_tables.items():
    metrics = calculate_quality_metrics(df, table_name)
    print(f"\n{table_name}:")
    print(f"  Records: {metrics['total_rows']:,}")
    print(f"  Columns: {metrics['columns']}")

## 5. Cleaning Process

### What Was Removed

The cleaning process removed:
1. **Duplicate records** based on primary keys
2. **Records with null primary keys**
3. **Orphaned records** with invalid foreign keys

### Cleaning Order

Tables were cleaned in dependency order:
1. Core tables (person, concept) - no foreign keys
2. Visit table - references person
3. Clinical event tables - reference person and visit

This order ensured that foreign key validation worked correctly.

In [None]:
def remove_duplicates(df: DataFrame, key_cols: list) -> DataFrame:
    """Remove duplicate records based on primary keys."""
    initial_count = df.count()
    df_cleaned = df.dropDuplicates(key_cols)
    final_count = df_cleaned.count()
    
    if initial_count != final_count:
        logger.info(f"Removed {initial_count - final_count:,} duplicates")
    
    return df_cleaned

def remove_null_keys(df: DataFrame, key_cols: list) -> DataFrame:
    """Remove records with null primary keys."""
    initial_count = df.count()
    
    for key_col in key_cols:
        if key_col in df.columns:
            df = df.filter(col(key_col).isNotNull())
    
    final_count = df.count()
    if initial_count != final_count:
        logger.info(f"Removed {initial_count - final_count:,} records with null keys")
    
    return df

def clean_foreign_keys(df: DataFrame, foreign_keys: Dict, reference_tables: Dict) -> DataFrame:
    """Remove records with invalid foreign keys."""
    initial_count = df.count()
    
    for fk_col, ref_table_name in foreign_keys.items():
        if fk_col in df.columns and ref_table_name in reference_tables:
            ref_df = reference_tables[ref_table_name]
            ref_key = OMOP_SCHEMA[ref_table_name]['key_cols'][0]
            
            # Keep records with valid foreign keys or null foreign keys
            valid_fk_df = df.join(
                ref_df.select(col(ref_key).alias('ref_key')),
                df[fk_col] == col('ref_key'),
                'left_semi'
            )
            
            null_fk_df = df.filter(col(fk_col).isNull())
            df = valid_fk_df.union(null_fk_df)
    
    final_count = df.count()
    if initial_count != final_count:
        logger.info(f"Removed {initial_count - final_count:,} records with invalid foreign keys")
    
    return df

def clean_table(df: DataFrame, table_name: str, reference_tables: Dict = None) -> DataFrame:
    """Apply full cleaning pipeline to a table."""
    logger.info(f"\nCleaning {table_name}...")
    initial_count = df.count()
    
    schema_info = OMOP_SCHEMA[table_name]
    
    # Step 1: Remove duplicates
    df = remove_duplicates(df, schema_info['key_cols'])
    
    # Step 2: Remove null primary keys
    df = remove_null_keys(df, schema_info['key_cols'])
    
    # Step 3: Clean foreign keys
    if reference_tables and schema_info['foreign_keys']:
        df = clean_foreign_keys(df, schema_info['foreign_keys'], reference_tables)
    
    final_count = df.count()
    logger.info(f"Completed: {initial_count:,} → {final_count:,} records ({(final_count/initial_count*100):.2f}% retained)")
    
    return df

print("Cleaning functions defined")

## 6. Execute Cleaning Pipeline

### Phase 1: Core Tables

Clean tables with no foreign keys first (person, concept).

In [None]:
cleaned_tables = {}

print("\nPhase 1: Cleaning core tables")
print("=" * 80)

for table_name in ['person', 'concept']:
    if table_name in raw_tables:
        cleaned_tables[table_name] = clean_table(
            raw_tables[table_name],
            table_name
        )

### Phase 2: Visit Table

Clean visit_occurrence table, which references person.

In [None]:
print("\nPhase 2: Cleaning visit table")
print("=" * 80)

if 'visit_occurrence' in raw_tables:
    cleaned_tables['visit_occurrence'] = clean_table(
        raw_tables['visit_occurrence'],
        'visit_occurrence',
        reference_tables={'person': cleaned_tables['person']}
    )

if 'observation_period' in raw_tables:
    cleaned_tables['observation_period'] = clean_table(
        raw_tables['observation_period'],
        'observation_period',
        reference_tables={'person': cleaned_tables['person']}
    )

### Phase 3: Clinical Event Tables

Clean clinical event tables that reference both person and visit.

In [None]:
print("\nPhase 3: Cleaning clinical event tables")
print("=" * 80)

clinical_tables = ['condition_occurrence', 'drug_exposure', 'procedure_occurrence', 'measurement']

reference_dict = {
    'person': cleaned_tables['person'],
    'visit_occurrence': cleaned_tables.get('visit_occurrence')
}

for table_name in clinical_tables:
    if table_name in raw_tables:
        cleaned_tables[table_name] = clean_table(
            raw_tables[table_name],
            table_name,
            reference_tables=reference_dict
        )

## 7. Cleaning Results Summary

Compare before and after record counts.

In [None]:
import pandas as pd

print("\nCleaning Results Summary")
print("=" * 80)

results = []
for table_name in cleaned_tables.keys():
    before = raw_tables[table_name].count()
    after = cleaned_tables[table_name].count()
    removed = before - after
    pct_removed = (removed / before * 100) if before > 0 else 0
    
    results.append({
        'Table': table_name,
        'Before': before,
        'After': after,
        'Removed': removed,
        '% Removed': f"{pct_removed:.2f}%"
    })

results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))

total_before = results_df['Before'].sum()
total_after = results_df['After'].sum()
total_removed = results_df['Removed'].sum()
overall_pct = (total_removed / total_before * 100) if total_before > 0 else 0

print("\n" + "=" * 80)
print(f"\nOverall: {total_before:,} → {total_after:,} records")
print(f"Total removed: {total_removed:,} ({overall_pct:.2f}%)")

## 8. Validate Foreign Key Integrity

After cleaning, verify that no orphaned foreign keys remain.

In [None]:
def validate_foreign_keys(tables: Dict) -> Dict:
    """Validate foreign key integrity across all tables."""
    results = {}
    
    for table_name, df in tables.items():
        if table_name in OMOP_SCHEMA:
            foreign_keys = OMOP_SCHEMA[table_name]['foreign_keys']
            
            if foreign_keys:
                violations = {}
                
                for fk_col, ref_table in foreign_keys.items():
                    if ref_table in tables:
                        ref_df = tables[ref_table]
                        ref_key = OMOP_SCHEMA[ref_table]['key_cols'][0]
                        
                        # Count orphaned records
                        orphaned = df.join(
                            ref_df.select(col(ref_key).alias('ref_key')),
                            df[fk_col] == col('ref_key'),
                            'left_anti'
                        ).filter(col(fk_col).isNotNull()).count()
                        
                        if orphaned > 0:
                            violations[fk_col] = orphaned
                
                if violations:
                    results[table_name] = violations
    
    return results

print("\nForeign Key Integrity Validation")
print("=" * 80)

violations = validate_foreign_keys(cleaned_tables)

if violations:
    print("\n⚠ Foreign key violations found:")
    for table_name, fk_violations in violations.items():
        print(f"\n{table_name}:")
        for fk_col, count in fk_violations.items():
            print(f"  {fk_col}: {count:,} orphaned records")
else:
    print("\n✓ All foreign key relationships are valid")

## 9. Export Cleaned Data

The cleaned tables were saved to Parquet format for downstream analysis.

In [None]:
# Output path (MODIFY THIS)
OUTPUT_PATH = "/path/to/your/cleaned/omop/data/"

print("\nExporting cleaned tables...")
print("=" * 80)

for table_name, df in cleaned_tables.items():
    output_file = f"{OUTPUT_PATH}{table_name}"
    print(f"Exporting {table_name} to {output_file}...")
    
    df.write.mode("overwrite").parquet(output_file)
    
    print(f"  ✓ Exported {df.count():,} records")

print(f"\nAll tables exported to: {OUTPUT_PATH}")

---

## Summary of Cleaning Process

### What Was Done

1. **Duplicate Removal**: Removed duplicate records based on primary keys
2. **Null Key Removal**: Removed records with null primary key values
3. **Foreign Key Validation**: Removed records with invalid foreign key references
4. **Dependency-Ordered Processing**: Cleaned tables in order of their dependencies

### Key Decisions

**What Was Removed:**
- Exact duplicates (same primary key)
- Records with missing primary keys
- Orphaned records (foreign keys pointing to non-existent records)

**What Was Preserved:**
- Records with null foreign keys (assuming optional relationships)
- Records with null values in non-key columns
- All other data quality issues (for downstream handling)

### Processing Order Rationale

Tables were processed in three phases:

1. **Core tables** (person, concept): No dependencies
2. **Intermediate tables** (visit_occurrence): Depend only on person
3. **Event tables** (conditions, drugs, etc.): Depend on person and visit

This ordering ensured that when validating foreign keys, the reference tables had already been cleaned.

### Output

The cleaned dataset maintains the OMOP CDM structure with:
- Valid referential integrity
- No duplicate primary keys
- No null primary keys
- Parquet format for efficient storage and access

---